petitRADTRANS.retrieval.runtime
===============================

.. py:module:: petitRADTRANS.retrieval.runtime

.. autoapi-nested-parse::

   Runtime data model and evaluation helpers for retrievals.

   This module defines the runtime-facing representation used by the retrieval
   package when executing sampler evaluations, especially for JAX-compatible and
   sampler-agnostic workflows.

   The core responsibilities of the module are:

   - represent parameter values in a pytree-friendly form
   - normalize observations and likelihood inputs into immutable runtime state
   - group shared model evaluations across observations
   - evaluate retrieval likelihoods through scalar, JAX, and batched JAX paths
   - translate between sampler coordinates and the legacy ``Parameter`` objects

   The classes in this file are intentionally small, immutable containers with a
   few conversion helpers. Together they provide the bridge between the legacy
   retrieval configuration objects and the newer runtime-native execution model.



Classes
-------

.. autoapisummary::

   petitRADTRANS.retrieval.runtime.PhysicalParams
   petitRADTRANS.retrieval.runtime.ObservationState
   petitRADTRANS.retrieval.runtime.LikelihoodState
   petitRADTRANS.retrieval.runtime.ModelContext
   petitRADTRANS.retrieval.runtime.ModelInputs
   petitRADTRANS.retrieval.runtime.ModelResult
   petitRADTRANS.retrieval.runtime.ModelGroupState
   petitRADTRANS.retrieval.runtime.RetrievalRuntime
   petitRADTRANS.retrieval.runtime.ParameterLayout


Module Contents
---------------

.. py:class:: PhysicalParams

   Immutable parameter values used by the runtime evaluator.

   The runtime stores free and fixed parameter values separately so that the
   dynamic portion can participate cleanly in JAX pytree flattening while the
   original parameter ordering remains preserved.


   .. py:attribute:: ordered_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: dynamic_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: dynamic_values
      :type:  tuple[Any, Ellipsis]


   .. py:attribute:: static_items
      :type:  tuple[tuple[str, Any], Ellipsis]
      :value: ()



   .. py:method:: from_mapping(mapping: Mapping[str, Any]) -> PhysicalParams
      :classmethod:


      Build runtime parameters from an ordered name-to-value mapping.



   .. py:method:: tree_flatten()

      Return the pytree representation used by JAX transforms.



   .. py:method:: tree_unflatten(aux_data, children)
      :classmethod:


      Reconstruct a ``PhysicalParams`` instance from pytree parts.



   .. py:method:: as_dict() -> dict[str, Any]

      Return all parameters as a plain dictionary in original order.



   .. py:method:: get(name: str, default: Any = None) -> Any

      Return a parameter value by name with an optional default.



   .. py:method:: __getitem__(name: str) -> Any

      Return a parameter value by name, raising ``KeyError`` if missing.



   .. py:method:: __contains__(name: object) -> bool

      Return ``True`` if *name* is a known parameter.



   .. py:method:: keys()

      Return the parameter names in original order.



   .. py:method:: ordered_items() -> tuple[tuple[str, Any], Ellipsis]

      Return ordered ``(name, value)`` pairs for all parameters.



   .. py:method:: as_legacy_dict() -> dict[str, Any]

      Return parameters wrapped as ``SimpleNamespace(value=...)`` objects.

      This provides backward compatibility with shared helper functions
      (chemistry, clouds, pressure initialization) that expect the legacy
      ``parameters['name'].value`` access pattern.



.. py:class:: ObservationState

   Immutable observation description used by the runtime.

   This mirrors the pieces of a retrieval ``Data`` object that are needed when
   evaluating models and likelihoods, while stripping away mutable behavior.


   .. py:attribute:: name
      :type:  str


   .. py:attribute:: wavelengths
      :type:  Any


   .. py:attribute:: spectrum
      :type:  Any


   .. py:attribute:: mask
      :type:  Any


   .. py:attribute:: uncertainties
      :type:  Any
      :value: None



   .. py:attribute:: covariance
      :type:  Any
      :value: None



   .. py:attribute:: radtrans_object
      :type:  Any
      :value: None



   .. py:attribute:: model_generating_function
      :type:  Any
      :value: None



   .. py:attribute:: model_resolution
      :type:  Any
      :value: None



   .. py:attribute:: data_resolution
      :type:  Any
      :value: None



   .. py:attribute:: data_resolution_array_model
      :type:  Any
      :value: None



   .. py:attribute:: wavelength_bin_widths
      :type:  Any
      :value: None



   .. py:attribute:: external_radtrans_reference
      :type:  str | None
      :value: None



   .. py:attribute:: line_opacity_mode
      :type:  str | None
      :value: None



   .. py:attribute:: model_contract
      :type:  str
      :value: 'legacy'



   .. py:attribute:: scale_flux
      :type:  bool
      :value: False



   .. py:attribute:: scale_uncertainties
      :type:  bool
      :value: False



   .. py:attribute:: fit_flux_offset
      :type:  bool
      :value: False



   .. py:attribute:: fit_instrumental_resolution
      :type:  bool
      :value: False



   .. py:attribute:: subtract_continuum
      :type:  bool
      :value: False



   .. py:attribute:: photometric_transformation_function
      :type:  Any
      :value: None



   .. py:attribute:: fit_covariance
      :type:  bool
      :value: False



   .. py:attribute:: covariance_mode
      :type:  str
      :value: 'none'



   .. py:attribute:: global_covariance_kernel
      :type:  str
      :value: 'squared_exponential'



   .. py:attribute:: local_covariance_kernel
      :type:  str
      :value: 'squared_exponential'



   .. py:attribute:: n_local_covariance_kernels
      :type:  int
      :value: 0



   .. py:attribute:: covariance_jitter
      :type:  float
      :value: 0.0



   .. py:attribute:: variability_atmospheric_column_model_flux_return_mode
      :type:  bool
      :value: False



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: LikelihoodState

   Precomputed likelihood inputs associated with one observation.

   The runtime stores the uncertainty representation separately from the raw
   observation so samplers can reuse the same normalized bookkeeping across
   repeated evaluations.


   .. py:attribute:: data_uncertainties
      :type:  Any
      :value: None



   .. py:attribute:: data_covariance
      :type:  Any
      :value: None



   .. py:attribute:: mask
      :type:  Any
      :value: None



   .. py:attribute:: inv_covariance
      :type:  Any
      :value: None



   .. py:attribute:: covariance_cholesky_factor
      :type:  Any
      :value: None



   .. py:attribute:: log_covariance_determinant
      :type:  Any
      :value: None



   .. py:attribute:: log_likelihood_normalization
      :type:  Any
      :value: None



   .. py:attribute:: n_free_parameters
      :type:  int
      :value: 0



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: ModelContext

   Execution context passed to runtime-native model functions.

   A runtime-native model receives this context instead of the legacy mutable
   ``Parameter`` dictionary plus a collection of loosely coupled flags.


   .. py:attribute:: name
      :type:  str


   .. py:attribute:: mode
      :type:  str


   .. py:attribute:: radtrans
      :type:  petitRADTRANS.radtrans.Radtrans


   .. py:attribute:: adaptive_mesh_refinement
      :type:  bool


   .. py:attribute:: variability_atmospheric_column_model_flux_return_mode
      :type:  bool
      :value: False



   .. py:attribute:: return_contribution
      :type:  bool
      :value: False



   .. py:attribute:: return_opacities
      :type:  bool
      :value: False



   .. py:attribute:: return_photosphere_radius
      :type:  bool
      :value: False



   .. py:attribute:: return_rosseland_optical_depths
      :type:  bool
      :value: False



   .. py:attribute:: return_radius_hydrostatic_equilibrium
      :type:  bool
      :value: False



   .. py:attribute:: return_cloud_contribution
      :type:  bool
      :value: False



   .. py:attribute:: return_abundances
      :type:  bool
      :value: False



   .. py:attribute:: model_metadata
      :type:  Mapping[str, Any]


.. py:class:: ModelInputs

   Structured physical inputs consumed by low-level spectrum calculators.

   Runtime-native model builders populate this container once they have
   derived thermodynamic, chemical, and cloud state from the current
   parameters.


   .. py:attribute:: pressures
      :type:  Any


   .. py:attribute:: temperatures
      :type:  Any


   .. py:attribute:: abundances
      :type:  Mapping[str, Any]


   .. py:attribute:: mean_molar_masses
      :type:  Any


   .. py:attribute:: gravity
      :type:  Any


   .. py:attribute:: planet_radius
      :type:  Any


   .. py:attribute:: reference_pressure
      :type:  Any
      :value: 0.1



   .. py:attribute:: opaque_cloud_top_pressure
      :type:  Any
      :value: None



   .. py:attribute:: sigma_lnorm
      :type:  Any
      :value: None



   .. py:attribute:: cloud_particle_mean_radii
      :type:  Any
      :value: None



   .. py:attribute:: cloud_f_sed
      :type:  Any
      :value: None



   .. py:attribute:: eddy_diffusion_coefficients
      :type:  Any
      :value: None



   .. py:attribute:: haze_factor
      :type:  Any
      :value: 1.0



   .. py:attribute:: power_law_opacity_coefficient
      :type:  Any
      :value: None



   .. py:attribute:: power_law_opacity_350nm
      :type:  Any
      :value: None



   .. py:attribute:: cloud_hansen_b
      :type:  Any
      :value: None



   .. py:attribute:: cloud_fraction
      :type:  Any
      :value: 1.0



   .. py:attribute:: patchy_clouds
      :type:  tuple[str, Ellipsis] | None
      :value: None



   .. py:attribute:: cloud_particle_radius_distribution
      :type:  str | None
      :value: None



   .. py:attribute:: distance_to_system
      :type:  Any
      :value: None



   .. py:attribute:: stellar_radius
      :type:  Any
      :value: None



   .. py:attribute:: disk_blackbody_temperature
      :type:  Any
      :value: None



   .. py:attribute:: disk_radius
      :type:  Any
      :value: None



   .. py:attribute:: v_band_extinction
      :type:  Any
      :value: None



   .. py:attribute:: v_band_reddening
      :type:  Any
      :value: None



   .. py:attribute:: additional_inputs
      :type:  Mapping[str, Any]


   .. py:method:: __post_init__()


   .. py:method:: tree_flatten()

      Return the pytree representation used by JAX transforms.



   .. py:method:: tree_unflatten(aux_data, children)
      :classmethod:


      Reconstruct a ``ModelInputs`` instance from pytree parts.



.. py:class:: ModelResult

   Normalized output emitted by runtime model evaluation.

   The result can represent a valid spectrum, a pressure-temperature profile,
   or an invalid evaluation. Named fields replace the legacy tuple protocol so
   downstream code can handle auxiliary outputs consistently.

   For JIT compatibility, invalid results carry NaN-filled spectrum arrays of
   the expected shape rather than ``None`` so that downstream tracing always
   sees consistent types.

   The class is registered as a JAX pytree so that ``jax.vmap`` and ``jax.jit``
   can propagate batched spectra through model evaluation pipelines.  The
   ``kind`` field and ``auxiliary_outputs`` keys are treated as static
   (aux_data); all array-valued fields — including the values inside
   ``auxiliary_outputs`` — are dynamic leaves.


   .. py:attribute:: kind
      :type:  str


   .. py:attribute:: wavelengths
      :type:  Any
      :value: None



   .. py:attribute:: spectrum
      :type:  Any
      :value: None



   .. py:attribute:: pressures
      :type:  Any
      :value: None



   .. py:attribute:: temperatures
      :type:  Any
      :value: None



   .. py:attribute:: additional_log_likelihood
      :type:  Any
      :value: 0.0



   .. py:attribute:: auxiliary_outputs
      :type:  Mapping[str, Any]


   .. py:attribute:: atmospheric_column_fluxes
      :type:  Any
      :value: None



   .. py:method:: tree_flatten()

      Flatten for JAX pytree operations (jit, vmap, grad, …).

      ``kind`` and the keys of ``auxiliary_outputs`` go into ``aux_data``
      (static — must be identical across a vmap batch).  All array-valued
      fields, including the *values* of ``auxiliary_outputs``, are leaves.



   .. py:method:: tree_unflatten(aux_data, children)
      :classmethod:


      Reconstruct a ``ModelResult`` from pytree parts.



   .. py:method:: make_nan_result(wavelengths: Any) -> ModelResult
      :staticmethod:


      Create an invalid result with NaN-filled spectrum of matching shape.

      ``wavelengths`` should be a JAX array (typically from the Radtrans
      wavelength grid).  The returned result has ``kind='invalid'`` for
      backward compatibility, but the all-NaN spectrum is the primary
      JIT-safe sentinel.



   .. py:method:: spectrum_is_valid() -> Any

      JIT-safe check: True when spectrum is a real, non-NaN array.



.. py:class:: ModelGroupState

   Description of one shared model evaluation group.

   Multiple observations can reuse the same underlying model call when they
   point at the same radiative-transfer source. The runtime tracks that shared
   evaluation through this grouping object.


   .. py:attribute:: source_name
      :type:  str


   .. py:attribute:: observation_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: model_contract
      :type:  str


   .. py:attribute:: model_generating_function
      :type:  Any
      :value: None



   .. py:attribute:: radtrans_object
      :type:  Any
      :value: None



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


   .. py:attribute:: model_context
      :type:  Any
      :value: None



.. py:class:: RetrievalRuntime

   Immutable runtime view of a retrieval configuration.

   ``RetrievalRuntime`` centralizes the normalized observation graph, model
   grouping, and likelihood evaluation code used by scalar and JAX-backed
   samplers.


   .. py:attribute:: parameter_layout
      :type:  ParameterLayout


   .. py:attribute:: observations
      :type:  Mapping[str, ObservationState]


   .. py:attribute:: likelihoods
      :type:  Mapping[str, LikelihoodState]


   .. py:attribute:: model_groups
      :type:  tuple[ModelGroupState, Ellipsis]
      :value: ()



   .. py:attribute:: adaptive_mesh_refinement
      :type:  bool
      :value: False



   .. py:method:: from_data(parameter_layout: ParameterLayout, data: Mapping[str, Any], n_free_parameters: int | None = None, adaptive_mesh_refinement: bool = False) -> RetrievalRuntime
      :classmethod:


      Build a runtime from retrieval data objects.

      This normalizes observation state, resolves model contracts, and groups
      observations that share a single model evaluation.



   .. py:method:: from_configuration(configuration: Any, parameter_layout: ParameterLayout | None = None, n_free_parameters: int | None = None) -> RetrievalRuntime
      :classmethod:


      Build a runtime directly from a retrieval configuration object.



   .. py:method:: _evaluate_model_group(group: ModelGroupState, physical_params: PhysicalParams) -> ModelResult


   .. py:method:: _score_dataset(data_name: str, parameters: Mapping[str, Any], wavelengths_model: Any, spectrum_model: Any, additional_log_likelihood: Any, model_auxiliary_outputs: Mapping[str, Any] | None = None) -> float

      Score one observation against a model spectrum using ``LogLikelihood``.

      This uses the pre-built ``ObservationState`` stored on the runtime
      rather than reaching back into mutable ``Data`` objects.

      The method decomposes the work into three composable stages:

      1. **Prepare observation data** — apply flux scaling, offset, and
         uncertainty inflation (``_apply_flux_scaling``, ``_offset_flux``,
         ``_scale_and_inflate_uncertainties``).
      2. **Project model** — convolve, rebin, and RV-shift the model onto the
         observation wavelength grid (``project_model_to_observation_space``).
      3. **Compute log-likelihood** — delegate to
         ``compute_observation_log_likelihood`` which handles NaN-sentinel
         detection internally.

      Multi-dimensional observations (2-D orders, 3-D detectors×orders, and
      object-dtype ragged arrays) iterate over the static outer dimensions
      before invoking the same composable primitives.



   .. py:method:: evaluate_scalar(physical_params: PhysicalParams, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) -> float

      Evaluate the posterior contribution for one parameter point.

      This path is primarily used by non-JAX samplers and debugging code.
      It returns the log likelihood plus any runtime-provided prior weight.



   .. py:method:: evaluate_jax(physical_params: PhysicalParams, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) -> float

      Evaluate the posterior through a JAX-compatible execution path.

      All control flow uses ``lax.cond`` / ``jnp.where`` so the function
      is safe to pass through ``jax.jit`` and ``jax.grad``.  Invalid model
      evaluations are detected via the NaN-spectrum sentinel produced by
      ``ModelResult.make_nan_result`` rather than Python-level string
      comparisons.



   .. py:method:: evaluate_vectorized_jax(free_parameter_values: Any, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) -> Any

      Evaluate one or many parameter vectors with JAX vectorization.

      A one-dimensional input is evaluated as a single sample. Higher-rank
      inputs are mapped over the leading dimension with ``vmap``.



   .. py:method:: evaluate_jax_batched(free_parameter_values: Any, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) -> Any

      Evaluate a batch of parameter vectors by vmapping the scalar JAX path.

      For model groups using ``MODEL_CONTRACT_DIFFERENTIABLE`` this method
      applies ``jax.vmap`` to the single-sample :meth:`evaluate_jax` path.
      This keeps model generation and scoring adjacent inside each mapped
      sample, which avoids retaining a fully materialized batched
      ``ModelResult`` tree before likelihood accumulation.

      For non-native groups the method falls back to
      :meth:`evaluate_vectorized_jax`, which applies a single vmap over the
      full ``evaluate_jax`` pipeline.

      A one-dimensional input is handled as a single sample (no batching).
      Higher-rank inputs are mapped over the leading dimension.

      Args:
          free_parameter_values: Shape ``(n_samples, n_free_params)`` or
              ``(n_free_params,)`` for a single sample.

      Returns:
          Log-posterior contributions, shape ``(n_samples,)`` or scalar.



   .. py:method:: _project_dataset_observation(data_name: str, parameters: PhysicalParams, wavelengths_model: Any, spectrum_model: Any) -> Any

      Project one model spectrum into the observation space of a dataset.



.. py:class:: ParameterLayout

   Translation layer between sampler coordinates and runtime parameters.

   ``ParameterLayout`` owns the ordered parameter metadata needed to transform
   from unit-cube or unconstrained sampler coordinates into physical values,
   and to reconstruct legacy ``Parameter`` dictionaries when required.


   .. py:attribute:: parameter_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: free_parameter_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: free_parameter_indices
      :type:  tuple[int, Ellipsis]


   .. py:attribute:: parameter_templates
      :type:  tuple[petitRADTRANS.retrieval.parameter.Parameter, Ellipsis]


   .. py:attribute:: cube_epsilon
      :type:  float


   .. py:method:: from_configuration(configuration: Any) -> ParameterLayout
      :classmethod:


      Build a parameter layout from a retrieval configuration.



   .. py:method:: from_parameters(parameters: Mapping[str, petitRADTRANS.retrieval.parameter.Parameter]) -> ParameterLayout
      :classmethod:


      Build a parameter layout from an ordered parameter mapping.



   .. py:property:: n_free_parameters
      :type: int


      Return the number of free parameters in the layout.



   .. py:method:: _transform_parameter(parameter: petitRADTRANS.retrieval.parameter.Parameter, cube_value: Any) -> Any


   .. py:method:: _validate_free_parameter_count(values: numpy.typing.ArrayLike) -> numpy.ndarray


   .. py:method:: prior_transform_numpy(cube: numpy.typing.ArrayLike) -> numpy.ndarray

      Transform unit-cube samples into physical parameter values.

      Both single samples and batched samples are supported. The return value
      is always a NumPy array to keep non-JAX sampler interfaces simple.



   .. py:method:: prior_transform_inplace(cube: numpy.typing.ArrayLike) -> numpy.typing.ArrayLike

      Apply :meth:`prior_transform_numpy` and write the result back in place.



   .. py:method:: transform_cube_to_physical(cube_position: Any) -> jax.numpy.ndarray

      Transform one unit-cube position into JAX physical coordinates.



   .. py:method:: transform_cube_to_physical_batch(cube_positions: Any) -> jax.numpy.ndarray

      Transform one or many unit-cube positions into physical coordinates.



   .. py:method:: cube_from_unconstrained(unconstrained_position: Any) -> jax.numpy.ndarray

      Map unconstrained coordinates onto the open unit cube.



   .. py:method:: cube_from_unconstrained_batch(unconstrained_positions: Any) -> jax.numpy.ndarray

      Map one or many unconstrained positions onto the open unit cube.



   .. py:method:: transform_unconstrained(unconstrained_position: Any) -> jax.numpy.ndarray

      Transform unconstrained coordinates directly into physical values.



   .. py:method:: transform_unconstrained_batch(unconstrained_positions: Any) -> jax.numpy.ndarray

      Transform one or many unconstrained positions into physical values.



   .. py:method:: log_jacobian_single(unconstrained_position: Any) -> jax.numpy.ndarray

      Return the log-Jacobian term for one unconstrained position.



   .. py:method:: log_jacobian(unconstrained_positions: Any) -> jax.numpy.ndarray

      Return log-Jacobian terms for one or many unconstrained positions.



   .. py:method:: physical_params_from_vector(values: numpy.typing.ArrayLike) -> PhysicalParams

      Build ``PhysicalParams`` from a flat vector of free values.



   .. py:method:: physical_params_from_matrix(values: numpy.typing.ArrayLike) -> tuple[PhysicalParams, Ellipsis]

      Build ``PhysicalParams`` objects from a batch of free-parameter vectors.



   .. py:method:: physical_params_from_free_values(values: Any) -> PhysicalParams

      Build ``PhysicalParams`` from ordered free-parameter values.



   .. py:method:: physical_params_from_cube(cube_position: numpy.typing.ArrayLike) -> PhysicalParams

      Build ``PhysicalParams`` from a unit-cube position.



   .. py:method:: physical_params_from_unconstrained(unconstrained_position: Any) -> PhysicalParams

      Build ``PhysicalParams`` from unconstrained sampler coordinates.



   .. py:method:: free_parameter_vector(physical_params: PhysicalParams) -> jax.numpy.ndarray

      Extract the ordered free-parameter vector from ``PhysicalParams``.



   .. py:method:: build_legacy_parameter_dict(physical_params: PhysicalParams) -> dict[str, petitRADTRANS.retrieval.parameter.Parameter]

      Reconstruct legacy ``Parameter`` objects for compatibility paths.



   .. py:method:: serialize_metadata() -> dict[str, Any]

      Return a deterministic metadata payload describing the layout.



