petitRADTRANS.retrieval.utils
=============================

.. py:module:: petitRADTRANS.retrieval.utils

.. autoapi-nested-parse::

   This module contains a set of useful functions that don't really fit anywhere
   else. This includes flux conversions, prior functions, mean molecular weight
   calculations, transforms from mass to number fractions, and fits file output.
   Helper functions for JAX-compatible log likelihood computation.

   These functions are designed to be JIT-compilable and to improve the clarity
   and maintainability of the log_likelihood_jax method.



Attributes
----------

.. autoapisummary::

   petitRADTRANS.retrieval.utils.SQRT2


Functions
---------

.. autoapisummary::

   petitRADTRANS.retrieval.utils._get_parameter_value
   petitRADTRANS.retrieval.utils.log_prior
   petitRADTRANS.retrieval.utils.uniform_prior
   petitRADTRANS.retrieval.utils.gaussian_prior
   petitRADTRANS.retrieval.utils.log_gaussian_prior
   petitRADTRANS.retrieval.utils.delta_prior
   petitRADTRANS.retrieval.utils.inverse_gamma_prior
   petitRADTRANS.retrieval.utils.b_range
   petitRADTRANS.retrieval.utils.a_b_range
   petitRADTRANS.retrieval.utils.get_pymultinest_sample_dict
   petitRADTRANS.retrieval.utils.get_calculate_flux_return_values
   petitRADTRANS.retrieval.utils.unpack_model_output
   petitRADTRANS.retrieval.utils.extract_spectrum_parameters
   petitRADTRANS.retrieval.utils.process_spectrum_pipeline
   petitRADTRANS.retrieval.utils.broadcast_spectrum_to_mask
   petitRADTRANS.retrieval.utils.compute_log_likelihood_for_spectrum
   petitRADTRANS.retrieval.utils.handle_dimenstionality_dispatch
   petitRADTRANS.retrieval.utils.validate_model_output
   petitRADTRANS.retrieval.utils.ensure_jax_array


Module Contents
---------------

.. py:data:: SQRT2

.. py:function:: _get_parameter_value(parameters: object, name: str, default: object = None) -> object

   Extract a scalar parameter value from a mapping, unwrapping legacy ``Parameter`` objects.

   This is the canonical helper for reading parameter values from either a legacy
   parameter dictionary (where values are ``Parameter`` objects with a ``.value``
   attribute) or a ``PhysicalParams`` mapping (where values are already unwrapped
   scalars / JAX arrays).  It accepts any mapping that supports ``.get()``.


.. py:function:: log_prior(cube, lx1, lx2)

.. py:function:: uniform_prior(cube, x1, x2)

.. py:function:: gaussian_prior(cube, mu, sigma)

.. py:function:: log_gaussian_prior(cube, mu, sigma)

.. py:function:: delta_prior(cube, x1, x2)

.. py:function:: inverse_gamma_prior(cube, a, b)

.. py:function:: b_range(x, b)

.. py:function:: a_b_range(x, a, b)

.. py:function:: get_pymultinest_sample_dict(output_dir, name=None, add_log_likelihood=False, add_stats=False)

.. py:function:: get_calculate_flux_return_values(parameters)

.. py:function:: unpack_model_output(model_returned_values: Union[tuple, list], retrieve_uncertainties: bool, variability_atmospheric_column_model_flux_return_mode: bool) -> Tuple[jax.numpy.ndarray, jax.numpy.ndarray, float, float, Optional[jax.numpy.ndarray]]

   Unified handler for parsing model_returned_values from the model generating function.

   The model function can return different numbers and types of values depending on
   configuration. This function normalizes the output to a standard tuple.

   Args:
       model_returned_values: The raw output from model_generating_function
       retrieve_uncertainties: Whether uncertainties are being retrieved
       variability_atmospheric_column_model_flux_return_mode: Whether variability column data is returned

   Returns:
       Tuple of (wavelengths_model, spectrum_model, beta, additional_log_l, atmospheric_model_column_fluxes)
       where atmospheric_model_column_fluxes is None if not needed


.. py:function:: extract_spectrum_parameters(parameters: dict, data_name: str) -> dict

   Extract and pre-compute spectrum-related parameters to avoid repeated dictionary lookups
   in loops.

   Args:
       parameters: Full parameters dictionary
       data_name: Name of the data object

   Returns:
       Dictionary with extracted parameters


.. py:function:: process_spectrum_pipeline(wavelengths_model: jax.numpy.ndarray, spectrum_model: jax.numpy.ndarray, data, parameters: dict, data_name: str, spectrum_params: Optional[dict] = None, *, convolve_function: Callable | None = None, variable_resolution_binning_function: Callable | None = None, rebin_spectrum_bin_function: Callable | None = None, continuum_subtract_function: Callable | None = None) -> jax.numpy.ndarray

   Execute the spectrum processing pipeline: convolution, rebin, RV shift, continuum subtraction.

   This consolidates all the repetitive spectrum processing logic that appears multiple times
   in log_likelihood_jax.

   Args:
       wavelengths_model: Model wavelength array
       spectrum_model: Model spectrum array
       data: Data object containing wavelengths and processing configuration
       parameters: Full parameters dictionary
       data_name: Name of the data object
       spectrum_params: Pre-extracted spectrum parameters (from extract_spectrum_parameters)

   Returns:
       Processed spectrum (rebinned, convolved, continuum-subtracted as needed)


.. py:function:: broadcast_spectrum_to_mask(spectrum_model: jax.numpy.ndarray, mask: jax.numpy.ndarray) -> jax.numpy.ndarray

   Apply mask to spectrum, handling different dimensionality combinations.

   JAX-compatible masking that handles 1D, 2D, and 3D spectra without conditional branching.

   Args:
       spectrum_model: Model spectrum of any dimensionality
       mask: Mask array of same dimensionality

   Returns:
       Masked spectrum


.. py:function:: compute_log_likelihood_for_spectrum(spectrum_model: jax.numpy.ndarray, data_spectrum: jax.numpy.ndarray, data_uncertainties: jax.numpy.ndarray, data_covariance: Optional[jax.numpy.ndarray], data_mask: Optional[jax.numpy.ndarray], n_free_parameters: int, use_log_likelihood_object: bool = False, beta: float = 1.0, beta_mode: str = 'multiply') -> float

   Compute log likelihood for a single spectrum, handling different approaches.

   This function handles both direct log_likelihood method calls (for object data)
   and LogLikelihood class instantiation (for simple data).

   Args:
       spectrum_model: Model spectrum
       data_spectrum: Observed spectrum
       data_uncertainties: Uncertainties on observations
       data_covariance: Covariance matrix (optional)
       data_mask: Mask for valid data points (optional)
       n_free_parameters: Number of free parameters in fit
       use_log_likelihood_object: If True, use LogLikelihood class; else use data.log_likelihood
       beta: Uncertainty scaling parameter
       beta_mode: Mode for applying beta ('multiply' or 'add')

   Returns:
       Log likelihood value


.. py:function:: handle_dimenstionality_dispatch(data_spectrum: jax.numpy.ndarray, spectrum_model: jax.numpy.ndarray, data_mask: jax.numpy.ndarray, data_uncertainties: jax.numpy.ndarray, use_data_log_likelihood_method: bool = False, beta: float = 1.0, beta_mode: str = 'multiply') -> float

   Dispatch log likelihood computation based on spectrum dimensionality.

   Replaces the long if/elif chains for handling 1D, 2D, 3D spectra with
   a more scalable approach. Currently computes straightforward contributions
   but can be extended for more complex logic.

   Args:
       data_spectrum: Observed spectrum (can be 1D, 2D, or 3D)
       spectrum_model: Model spectrum (can be 1D, 2D, or 3D)
       data_mask: Mask for valid data points (matching dimensionality)
       data_uncertainties: Uncertainties (matching dimensionality)
       use_data_log_likelihood_method: Whether to use data.log_likelihood method
       beta: Uncertainty scaling parameter
       beta_mode: Mode for applying beta

   Returns:
       Total log likelihood across all dimensions


.. py:function:: validate_model_output(spectrum_model: Optional[jax.numpy.ndarray], wavelengths_model: Optional[jax.numpy.ndarray]) -> bool

   Validate model output for NaNs and None values.

   Returns:
       Tuple of (is_valid, error_message)


.. py:function:: ensure_jax_array(array: Union[numpy.ndarray, jax.numpy.ndarray]) -> jax.numpy.ndarray

   Ensure array is in JAX format for JIT compilation.

   Args:
       array: Input array (numpy or JAX)

   Returns:
       JAX array


