petitRADTRANS.retrieval.runtime#
Runtime data model and evaluation helpers for retrievals.
This module defines the runtime-facing representation used by the retrieval package when executing sampler evaluations, especially for JAX-compatible and sampler-agnostic workflows.
The core responsibilities of the module are:
represent parameter values in a pytree-friendly form
normalize observations and likelihood inputs into immutable runtime state
group shared model evaluations across observations
evaluate retrieval likelihoods through scalar, JAX, and batched JAX paths
translate between sampler coordinates and the legacy
Parameterobjects
The classes in this file are intentionally small, immutable containers with a few conversion helpers. Together they provide the bridge between the legacy retrieval configuration objects and the newer runtime-native execution model.
Classes#
Immutable parameter values used by the runtime evaluator. |
|
Immutable observation description used by the runtime. |
|
Precomputed likelihood inputs associated with one observation. |
|
Execution context passed to runtime-native model functions. |
|
Structured physical inputs consumed by low-level spectrum calculators. |
|
Normalized output emitted by runtime model evaluation. |
|
Description of one shared model evaluation group. |
|
Immutable runtime view of a retrieval configuration. |
|
Translation layer between sampler coordinates and runtime parameters. |
Module Contents#
- class petitRADTRANS.retrieval.runtime.PhysicalParams#
Immutable parameter values used by the runtime evaluator.
The runtime stores free and fixed parameter values separately so that the dynamic portion can participate cleanly in JAX pytree flattening while the original parameter ordering remains preserved.
- ordered_names: tuple[str, Ellipsis]#
- dynamic_names: tuple[str, Ellipsis]#
- dynamic_values: tuple[Any, Ellipsis]#
- static_items: tuple[tuple[str, Any], Ellipsis] = ()#
- classmethod from_mapping(mapping: Mapping[str, Any]) PhysicalParams#
Build runtime parameters from an ordered name-to-value mapping.
- tree_flatten()#
Return the pytree representation used by JAX transforms.
- classmethod tree_unflatten(aux_data, children)#
Reconstruct a
PhysicalParamsinstance from pytree parts.
- as_dict() dict[str, Any]#
Return all parameters as a plain dictionary in original order.
- get(name: str, default: Any = None) Any#
Return a parameter value by name with an optional default.
- __getitem__(name: str) Any#
Return a parameter value by name, raising
KeyErrorif missing.
- __contains__(name: object) bool#
Return
Trueif name is a known parameter.
- keys()#
Return the parameter names in original order.
- ordered_items() tuple[tuple[str, Any], Ellipsis]#
Return ordered
(name, value)pairs for all parameters.
- as_legacy_dict() dict[str, Any]#
Return parameters wrapped as
SimpleNamespace(value=...)objects.This provides backward compatibility with shared helper functions (chemistry, clouds, pressure initialization) that expect the legacy
parameters['name'].valueaccess pattern.
- class petitRADTRANS.retrieval.runtime.ObservationState#
Immutable observation description used by the runtime.
This mirrors the pieces of a retrieval
Dataobject that are needed when evaluating models and likelihoods, while stripping away mutable behavior.- name: str#
- wavelengths: Any#
- spectrum: Any#
- mask: Any#
- uncertainties: Any = None#
- covariance: Any = None#
- radtrans_object: Any = None#
- model_generating_function: Any = None#
- model_resolution: Any = None#
- data_resolution: Any = None#
- data_resolution_array_model: Any = None#
- wavelength_bin_widths: Any = None#
- external_radtrans_reference: str | None = None#
- line_opacity_mode: str | None = None#
- model_contract: str = 'legacy'#
- scale_flux: bool = False#
- scale_uncertainties: bool = False#
- fit_flux_offset: bool = False#
- fit_instrumental_resolution: bool = False#
- subtract_continuum: bool = False#
- photometric_transformation_function: Any = None#
- fit_covariance: bool = False#
- covariance_mode: str = 'none'#
- global_covariance_kernel: str = 'squared_exponential'#
- local_covariance_kernel: str = 'squared_exponential'#
- n_local_covariance_kernels: int = 0#
- covariance_jitter: float = 0.0#
- variability_atmospheric_column_model_flux_return_mode: bool = False#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.retrieval.runtime.LikelihoodState#
Precomputed likelihood inputs associated with one observation.
The runtime stores the uncertainty representation separately from the raw observation so samplers can reuse the same normalized bookkeeping across repeated evaluations.
- data_uncertainties: Any = None#
- data_covariance: Any = None#
- mask: Any = None#
- inv_covariance: Any = None#
- covariance_cholesky_factor: Any = None#
- log_covariance_determinant: Any = None#
- log_likelihood_normalization: Any = None#
- n_free_parameters: int = 0#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.retrieval.runtime.ModelContext#
Execution context passed to runtime-native model functions.
A runtime-native model receives this context instead of the legacy mutable
Parameterdictionary plus a collection of loosely coupled flags.- name: str#
- mode: str#
- radtrans: petitRADTRANS.radtrans.Radtrans#
- adaptive_mesh_refinement: bool#
- variability_atmospheric_column_model_flux_return_mode: bool = False#
- return_contribution: bool = False#
- return_opacities: bool = False#
- return_photosphere_radius: bool = False#
- return_rosseland_optical_depths: bool = False#
- return_radius_hydrostatic_equilibrium: bool = False#
- return_cloud_contribution: bool = False#
- return_abundances: bool = False#
- model_metadata: Mapping[str, Any]#
- class petitRADTRANS.retrieval.runtime.ModelInputs#
Structured physical inputs consumed by low-level spectrum calculators.
Runtime-native model builders populate this container once they have derived thermodynamic, chemical, and cloud state from the current parameters.
- pressures: Any#
- temperatures: Any#
- abundances: Mapping[str, Any]#
- mean_molar_masses: Any#
- gravity: Any#
- planet_radius: Any#
- reference_pressure: Any = 0.1#
- opaque_cloud_top_pressure: Any = None#
- sigma_lnorm: Any = None#
- cloud_particle_mean_radii: Any = None#
- cloud_f_sed: Any = None#
- eddy_diffusion_coefficients: Any = None#
- haze_factor: Any = 1.0#
- power_law_opacity_coefficient: Any = None#
- power_law_opacity_350nm: Any = None#
- cloud_hansen_b: Any = None#
- cloud_fraction: Any = 1.0#
- patchy_clouds: tuple[str, Ellipsis] | None = None#
- cloud_particle_radius_distribution: str | None = None#
- distance_to_system: Any = None#
- stellar_radius: Any = None#
- disk_blackbody_temperature: Any = None#
- disk_radius: Any = None#
- v_band_extinction: Any = None#
- v_band_reddening: Any = None#
- additional_inputs: Mapping[str, Any]#
- __post_init__()#
- tree_flatten()#
Return the pytree representation used by JAX transforms.
- classmethod tree_unflatten(aux_data, children)#
Reconstruct a
ModelInputsinstance from pytree parts.
- class petitRADTRANS.retrieval.runtime.ModelResult#
Normalized output emitted by runtime model evaluation.
The result can represent a valid spectrum, a pressure-temperature profile, or an invalid evaluation. Named fields replace the legacy tuple protocol so downstream code can handle auxiliary outputs consistently.
For JIT compatibility, invalid results carry NaN-filled spectrum arrays of the expected shape rather than
Noneso that downstream tracing always sees consistent types.The class is registered as a JAX pytree so that
jax.vmapandjax.jitcan propagate batched spectra through model evaluation pipelines. Thekindfield andauxiliary_outputskeys are treated as static (aux_data); all array-valued fields — including the values insideauxiliary_outputs— are dynamic leaves.- kind: str#
- wavelengths: Any = None#
- spectrum: Any = None#
- pressures: Any = None#
- temperatures: Any = None#
- additional_log_likelihood: Any = 0.0#
- auxiliary_outputs: Mapping[str, Any]#
- atmospheric_column_fluxes: Any = None#
- tree_flatten()#
Flatten for JAX pytree operations (jit, vmap, grad, …).
kindand the keys ofauxiliary_outputsgo intoaux_data(static — must be identical across a vmap batch). All array-valued fields, including the values ofauxiliary_outputs, are leaves.
- classmethod tree_unflatten(aux_data, children)#
Reconstruct a
ModelResultfrom pytree parts.
- static make_nan_result(wavelengths: Any) ModelResult#
Create an invalid result with NaN-filled spectrum of matching shape.
wavelengthsshould be a JAX array (typically from the Radtrans wavelength grid). The returned result haskind='invalid'for backward compatibility, but the all-NaN spectrum is the primary JIT-safe sentinel.
- spectrum_is_valid() Any#
JIT-safe check: True when spectrum is a real, non-NaN array.
- class petitRADTRANS.retrieval.runtime.ModelGroupState#
Description of one shared model evaluation group.
Multiple observations can reuse the same underlying model call when they point at the same radiative-transfer source. The runtime tracks that shared evaluation through this grouping object.
- source_name: str#
- observation_names: tuple[str, Ellipsis]#
- model_contract: str#
- model_generating_function: Any = None#
- radtrans_object: Any = None#
- metadata: Mapping[str, Any]#
- model_context: Any = None#
- class petitRADTRANS.retrieval.runtime.RetrievalRuntime#
Immutable runtime view of a retrieval configuration.
RetrievalRuntimecentralizes the normalized observation graph, model grouping, and likelihood evaluation code used by scalar and JAX-backed samplers.- parameter_layout: ParameterLayout#
- observations: Mapping[str, ObservationState]#
- likelihoods: Mapping[str, LikelihoodState]#
- model_groups: tuple[ModelGroupState, Ellipsis] = ()#
- adaptive_mesh_refinement: bool = False#
- classmethod from_data(parameter_layout: ParameterLayout, data: Mapping[str, Any], n_free_parameters: int | None = None, adaptive_mesh_refinement: bool = False) RetrievalRuntime#
Build a runtime from retrieval data objects.
This normalizes observation state, resolves model contracts, and groups observations that share a single model evaluation.
- classmethod from_configuration(configuration: Any, parameter_layout: ParameterLayout | None = None, n_free_parameters: int | None = None) RetrievalRuntime#
Build a runtime directly from a retrieval configuration object.
- _evaluate_model_group(group: ModelGroupState, physical_params: PhysicalParams) ModelResult#
- _score_dataset(data_name: str, parameters: Mapping[str, Any], wavelengths_model: Any, spectrum_model: Any, additional_log_likelihood: Any, model_auxiliary_outputs: Mapping[str, Any] | None = None) float#
Score one observation against a model spectrum using
LogLikelihood.This uses the pre-built
ObservationStatestored on the runtime rather than reaching back into mutableDataobjects.The method decomposes the work into three composable stages:
Prepare observation data — apply flux scaling, offset, and uncertainty inflation (
_apply_flux_scaling,_offset_flux,_scale_and_inflate_uncertainties).Project model — convolve, rebin, and RV-shift the model onto the observation wavelength grid (
project_model_to_observation_space).Compute log-likelihood — delegate to
compute_observation_log_likelihoodwhich handles NaN-sentinel detection internally.
Multi-dimensional observations (2-D orders, 3-D detectorsĂ—orders, and object-dtype ragged arrays) iterate over the static outer dimensions before invoking the same composable primitives.
- evaluate_scalar(physical_params: PhysicalParams, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) float#
Evaluate the posterior contribution for one parameter point.
This path is primarily used by non-JAX samplers and debugging code. It returns the log likelihood plus any runtime-provided prior weight.
- evaluate_jax(physical_params: PhysicalParams, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) float#
Evaluate the posterior through a JAX-compatible execution path.
All control flow uses
lax.cond/jnp.whereso the function is safe to pass throughjax.jitandjax.grad. Invalid model evaluations are detected via the NaN-spectrum sentinel produced byModelResult.make_nan_resultrather than Python-level string comparisons.
- evaluate_vectorized_jax(free_parameter_values: Any, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) Any#
Evaluate one or many parameter vectors with JAX vectorization.
A one-dimensional input is evaluated as a single sample. Higher-rank inputs are mapped over the leading dimension with
vmap.
- evaluate_jax_batched(free_parameter_values: Any, *, uncertainties_mode: str = 'default', print_log_likelihood_for_debugging: bool = False) Any#
Evaluate a batch of parameter vectors by vmapping the scalar JAX path.
For model groups using
MODEL_CONTRACT_DIFFERENTIABLEthis method appliesjax.vmapto the single-sampleevaluate_jax()path. This keeps model generation and scoring adjacent inside each mapped sample, which avoids retaining a fully materialized batchedModelResulttree before likelihood accumulation.For non-native groups the method falls back to
evaluate_vectorized_jax(), which applies a single vmap over the fullevaluate_jaxpipeline.A one-dimensional input is handled as a single sample (no batching). Higher-rank inputs are mapped over the leading dimension.
- Args:
- free_parameter_values: Shape
(n_samples, n_free_params)or (n_free_params,)for a single sample.
- free_parameter_values: Shape
- Returns:
Log-posterior contributions, shape
(n_samples,)or scalar.
- _project_dataset_observation(data_name: str, parameters: PhysicalParams, wavelengths_model: Any, spectrum_model: Any) Any#
Project one model spectrum into the observation space of a dataset.
- class petitRADTRANS.retrieval.runtime.ParameterLayout#
Translation layer between sampler coordinates and runtime parameters.
ParameterLayoutowns the ordered parameter metadata needed to transform from unit-cube or unconstrained sampler coordinates into physical values, and to reconstruct legacyParameterdictionaries when required.- parameter_names: tuple[str, Ellipsis]#
- free_parameter_names: tuple[str, Ellipsis]#
- free_parameter_indices: tuple[int, Ellipsis]#
- parameter_templates: tuple[petitRADTRANS.retrieval.parameter.Parameter, Ellipsis]#
- cube_epsilon: float#
- classmethod from_configuration(configuration: Any) ParameterLayout#
Build a parameter layout from a retrieval configuration.
- classmethod from_parameters(parameters: Mapping[str, petitRADTRANS.retrieval.parameter.Parameter]) ParameterLayout#
Build a parameter layout from an ordered parameter mapping.
- property n_free_parameters: int#
Return the number of free parameters in the layout.
- _transform_parameter(parameter: petitRADTRANS.retrieval.parameter.Parameter, cube_value: Any) Any#
- _validate_free_parameter_count(values: numpy.typing.ArrayLike) numpy.ndarray#
- prior_transform_numpy(cube: numpy.typing.ArrayLike) numpy.ndarray#
Transform unit-cube samples into physical parameter values.
Both single samples and batched samples are supported. The return value is always a NumPy array to keep non-JAX sampler interfaces simple.
- prior_transform_inplace(cube: numpy.typing.ArrayLike) numpy.typing.ArrayLike#
Apply
prior_transform_numpy()and write the result back in place.
- transform_cube_to_physical(cube_position: Any) jax.numpy.ndarray#
Transform one unit-cube position into JAX physical coordinates.
- transform_cube_to_physical_batch(cube_positions: Any) jax.numpy.ndarray#
Transform one or many unit-cube positions into physical coordinates.
- cube_from_unconstrained(unconstrained_position: Any) jax.numpy.ndarray#
Map unconstrained coordinates onto the open unit cube.
- cube_from_unconstrained_batch(unconstrained_positions: Any) jax.numpy.ndarray#
Map one or many unconstrained positions onto the open unit cube.
- transform_unconstrained(unconstrained_position: Any) jax.numpy.ndarray#
Transform unconstrained coordinates directly into physical values.
- transform_unconstrained_batch(unconstrained_positions: Any) jax.numpy.ndarray#
Transform one or many unconstrained positions into physical values.
- log_jacobian_single(unconstrained_position: Any) jax.numpy.ndarray#
Return the log-Jacobian term for one unconstrained position.
- log_jacobian(unconstrained_positions: Any) jax.numpy.ndarray#
Return log-Jacobian terms for one or many unconstrained positions.
- physical_params_from_vector(values: numpy.typing.ArrayLike) PhysicalParams#
Build
PhysicalParamsfrom a flat vector of free values.
- physical_params_from_matrix(values: numpy.typing.ArrayLike) tuple[PhysicalParams, Ellipsis]#
Build
PhysicalParamsobjects from a batch of free-parameter vectors.
- physical_params_from_free_values(values: Any) PhysicalParams#
Build
PhysicalParamsfrom ordered free-parameter values.
- physical_params_from_cube(cube_position: numpy.typing.ArrayLike) PhysicalParams#
Build
PhysicalParamsfrom a unit-cube position.
- physical_params_from_unconstrained(unconstrained_position: Any) PhysicalParams#
Build
PhysicalParamsfrom unconstrained sampler coordinates.
- free_parameter_vector(physical_params: PhysicalParams) jax.numpy.ndarray#
Extract the ordered free-parameter vector from
PhysicalParams.
- build_legacy_parameter_dict(physical_params: PhysicalParams) dict[str, petitRADTRANS.retrieval.parameter.Parameter]#
Reconstruct legacy
Parameterobjects for compatibility paths.
- serialize_metadata() dict[str, Any]#
Return a deterministic metadata payload describing the layout.