petitRADTRANS.retrieval.utils#
This module contains a set of useful functions that don’t really fit anywhere else. This includes flux conversions, prior functions, mean molecular weight calculations, transforms from mass to number fractions, and fits file output. Helper functions for JAX-compatible log likelihood computation.
These functions are designed to be JIT-compilable and to improve the clarity and maintainability of the log_likelihood_jax method.
Attributes#
Functions#
|
Extract a scalar parameter value from a mapping, unwrapping legacy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Unified handler for parsing model_returned_values from the model generating function. |
|
Extract and pre-compute spectrum-related parameters to avoid repeated dictionary lookups |
|
Execute the spectrum processing pipeline: convolution, rebin, RV shift, continuum subtraction. |
|
Apply mask to spectrum, handling different dimensionality combinations. |
|
Compute log likelihood for a single spectrum, handling different approaches. |
|
Dispatch log likelihood computation based on spectrum dimensionality. |
|
Validate model output for NaNs and None values. |
|
Ensure array is in JAX format for JIT compilation. |
Module Contents#
- petitRADTRANS.retrieval.utils.SQRT2#
- petitRADTRANS.retrieval.utils._get_parameter_value(parameters: object, name: str, default: object = None) object#
Extract a scalar parameter value from a mapping, unwrapping legacy
Parameterobjects.This is the canonical helper for reading parameter values from either a legacy parameter dictionary (where values are
Parameterobjects with a.valueattribute) or aPhysicalParamsmapping (where values are already unwrapped scalars / JAX arrays). It accepts any mapping that supports.get().
- petitRADTRANS.retrieval.utils.log_prior(cube, lx1, lx2)#
- petitRADTRANS.retrieval.utils.uniform_prior(cube, x1, x2)#
- petitRADTRANS.retrieval.utils.gaussian_prior(cube, mu, sigma)#
- petitRADTRANS.retrieval.utils.log_gaussian_prior(cube, mu, sigma)#
- petitRADTRANS.retrieval.utils.delta_prior(cube, x1, x2)#
- petitRADTRANS.retrieval.utils.inverse_gamma_prior(cube, a, b)#
- petitRADTRANS.retrieval.utils.b_range(x, b)#
- petitRADTRANS.retrieval.utils.a_b_range(x, a, b)#
- petitRADTRANS.retrieval.utils.get_pymultinest_sample_dict(output_dir, name=None, add_log_likelihood=False, add_stats=False)#
- petitRADTRANS.retrieval.utils.get_calculate_flux_return_values(parameters)#
- petitRADTRANS.retrieval.utils.unpack_model_output(model_returned_values: tuple | list, retrieve_uncertainties: bool, variability_atmospheric_column_model_flux_return_mode: bool) Tuple[jax.numpy.ndarray, jax.numpy.ndarray, float, float, jax.numpy.ndarray | None]#
Unified handler for parsing model_returned_values from the model generating function.
The model function can return different numbers and types of values depending on configuration. This function normalizes the output to a standard tuple.
- Args:
model_returned_values: The raw output from model_generating_function retrieve_uncertainties: Whether uncertainties are being retrieved variability_atmospheric_column_model_flux_return_mode: Whether variability column data is returned
- Returns:
Tuple of (wavelengths_model, spectrum_model, beta, additional_log_l, atmospheric_model_column_fluxes) where atmospheric_model_column_fluxes is None if not needed
- petitRADTRANS.retrieval.utils.extract_spectrum_parameters(parameters: dict, data_name: str) dict#
Extract and pre-compute spectrum-related parameters to avoid repeated dictionary lookups in loops.
- Args:
parameters: Full parameters dictionary data_name: Name of the data object
- Returns:
Dictionary with extracted parameters
- petitRADTRANS.retrieval.utils.process_spectrum_pipeline(wavelengths_model: jax.numpy.ndarray, spectrum_model: jax.numpy.ndarray, data, parameters: dict, data_name: str, spectrum_params: dict | None = None, *, convolve_function: Callable | None = None, variable_resolution_binning_function: Callable | None = None, rebin_spectrum_bin_function: Callable | None = None, continuum_subtract_function: Callable | None = None) jax.numpy.ndarray#
Execute the spectrum processing pipeline: convolution, rebin, RV shift, continuum subtraction.
This consolidates all the repetitive spectrum processing logic that appears multiple times in log_likelihood_jax.
- Args:
wavelengths_model: Model wavelength array spectrum_model: Model spectrum array data: Data object containing wavelengths and processing configuration parameters: Full parameters dictionary data_name: Name of the data object spectrum_params: Pre-extracted spectrum parameters (from extract_spectrum_parameters)
- Returns:
Processed spectrum (rebinned, convolved, continuum-subtracted as needed)
- petitRADTRANS.retrieval.utils.broadcast_spectrum_to_mask(spectrum_model: jax.numpy.ndarray, mask: jax.numpy.ndarray) jax.numpy.ndarray#
Apply mask to spectrum, handling different dimensionality combinations.
JAX-compatible masking that handles 1D, 2D, and 3D spectra without conditional branching.
- Args:
spectrum_model: Model spectrum of any dimensionality mask: Mask array of same dimensionality
- Returns:
Masked spectrum
- petitRADTRANS.retrieval.utils.compute_log_likelihood_for_spectrum(spectrum_model: jax.numpy.ndarray, data_spectrum: jax.numpy.ndarray, data_uncertainties: jax.numpy.ndarray, data_covariance: jax.numpy.ndarray | None, data_mask: jax.numpy.ndarray | None, n_free_parameters: int, use_log_likelihood_object: bool = False, beta: float = 1.0, beta_mode: str = 'multiply') float#
Compute log likelihood for a single spectrum, handling different approaches.
This function handles both direct log_likelihood method calls (for object data) and LogLikelihood class instantiation (for simple data).
- Args:
spectrum_model: Model spectrum data_spectrum: Observed spectrum data_uncertainties: Uncertainties on observations data_covariance: Covariance matrix (optional) data_mask: Mask for valid data points (optional) n_free_parameters: Number of free parameters in fit use_log_likelihood_object: If True, use LogLikelihood class; else use data.log_likelihood beta: Uncertainty scaling parameter beta_mode: Mode for applying beta (‘multiply’ or ‘add’)
- Returns:
Log likelihood value
- petitRADTRANS.retrieval.utils.handle_dimenstionality_dispatch(data_spectrum: jax.numpy.ndarray, spectrum_model: jax.numpy.ndarray, data_mask: jax.numpy.ndarray, data_uncertainties: jax.numpy.ndarray, use_data_log_likelihood_method: bool = False, beta: float = 1.0, beta_mode: str = 'multiply') float#
Dispatch log likelihood computation based on spectrum dimensionality.
Replaces the long if/elif chains for handling 1D, 2D, 3D spectra with a more scalable approach. Currently computes straightforward contributions but can be extended for more complex logic.
- Args:
data_spectrum: Observed spectrum (can be 1D, 2D, or 3D) spectrum_model: Model spectrum (can be 1D, 2D, or 3D) data_mask: Mask for valid data points (matching dimensionality) data_uncertainties: Uncertainties (matching dimensionality) use_data_log_likelihood_method: Whether to use data.log_likelihood method beta: Uncertainty scaling parameter beta_mode: Mode for applying beta
- Returns:
Total log likelihood across all dimensions
- petitRADTRANS.retrieval.utils.validate_model_output(spectrum_model: jax.numpy.ndarray | None, wavelengths_model: jax.numpy.ndarray | None) bool#
Validate model output for NaNs and None values.
- Returns:
Tuple of (is_valid, error_message)
- petitRADTRANS.retrieval.utils.ensure_jax_array(array: numpy.ndarray | jax.numpy.ndarray) jax.numpy.ndarray#
Ensure array is in JAX format for JIT compilation.
- Args:
array: Input array (numpy or JAX)
- Returns:
JAX array