petitRADTRANS.sbi.simulator#
Simulation interfaces for prior-predictive and proposal-driven generation.
Attributes#
Classes#
Container for one simulated batch. |
|
Extension of |
|
Interface for simulation proposals beyond the prior distribution. |
|
Base simulator for SBI dataset generation and validation. |
|
Runtime-backed simulator that prefers vectorized task execution. |
|
Concrete simulator backed by the retrieval runtime. |
Functions#
Stack per-sample observation payloads from simulate_observations_scalar. |
Module Contents#
- petitRADTRANS.sbi.simulator.logger#
- class petitRADTRANS.sbi.simulator.SimulationBatch#
Container for one simulated batch.
- Attributes:
- parameters:
Array-like free-parameter matrix with leading dimension equal to the number of simulated samples.
- observations:
Task-conditioned simulated observations.
- log_likelihood:
Optional scalar likelihood values associated with each sample.
- diagnostics:
Additional runtime diagnostics such as clipping flags or NaN counts.
- parameters: Any#
- observations: Any#
- cube_parameters: Any = None#
- unconstrained_parameters: Any = None#
- log_likelihood: Any = None#
- diagnostics: Mapping[str, Any]#
- property n_samples: int#
Return the number of simulated samples represented by the batch.
- petitRADTRANS.sbi.simulator._stack_simulation_output_payloads(payloads: list[dict]) dict#
Stack per-sample observation payloads from simulate_observations_scalar.
Array fields are stacked into a leading (n_samples, …) dimension. The
"metadata"field is kept as alist(one dict per sample). Fields that areNonefor every sample remainNone.
- class petitRADTRANS.sbi.simulator.SimulationRuntime#
Bases:
petitRADTRANS.retrieval.runtime.RetrievalRuntimeExtension of
RetrievalRuntimethat adds deterministic forward-model simulation methods.RetrievalRuntimecovers likelihood evaluation. This subclass layers the simulation surface on top — projecting parameters into observation space without noise so thatRuntimeSimulatorcan inject noise afterwards and assembleSimulationBatchobjects.- simulate_observations_scalar(physical_params: petitRADTRANS.retrieval.runtime.PhysicalParams) dict[str, dict[str, Any]]#
Return deterministic observation-space model outputs for one parameter point.
- _run_batched_rt_kernel(parameter_matrix: Any) tuple[Any, list[tuple[Any, Any]]]#
Run the expensive vmapped RT kernel and return raw results.
Returns#
- tuple[batched_physical_params, list[(group, batched_model_result)]]
The batched physical parameters and per-group raw model outputs before observation projection.
- _run_batched_rt_kernel_multi_device(parameter_matrix: Any) tuple[Any, list[tuple[Any, Any]]]#
Multi-device version of
_run_batched_rt_kernel().Distributes the batch across all available JAX devices by placing each chunk on a separate device with
jax.device_putand then calling the single-device_run_batched_rt_kernelper chunk. JAX’s asynchronous GPU dispatch allows the per-device computations to overlap. Results are gathered and concatenated before return.The parameter matrix must have its leading dimension equal to a multiple of the device count (caller is responsible for padding).
- _project_batched_rt_results(batched_physical_params: Any, group_batched_results: list[tuple[Any, Any]], n_samples: int) dict[str, dict[str, Any]]#
Project raw RT kernel outputs into observation space.
This is the cheap projection step (interpolation, scaling, binning). When possible the projection and flux-transform functions are vectorized with
jax.vmapso that the Python-level loop over samples is replaced by a single traced call.
- simulate_observations_batched(parameter_matrix: Any, multi_device: bool = False) dict[str, dict[str, Any]]#
Return deterministic batched observation outputs for a parameter matrix.
For model groups using
MODEL_CONTRACT_DIFFERENTIABLEthis method vmaps the expensive RT kernel over the whole batch in one XLA call, then performs the cheap projection step (interpolation, scaling, etc.) in a Python loop. Non-native groups fall back to a scalar-loop path equivalent to callingsimulate_observations_scalar()per row.- Args:
parameter_matrix: Shape
(n_samples, n_free_params). multi_device: WhenTrueand more than one JAX device isavailable, distribute the vmapped RT kernel across devices using
jax.pmap.- Returns:
Mapping from observation name to a payload dict where
"values"has shape(n_samples, n_wavelengths)and"mask","coordinates", and"metadata"are static / per-sample lists.
- class petitRADTRANS.sbi.simulator.ProposalSampler#
Bases:
abc.ABCInterface for simulation proposals beyond the prior distribution.
- abstractmethod sample(n_samples: int, task: petitRADTRANS.sbi.task.SBITask) Any#
Draw free-parameter vectors for the given task.
- class petitRADTRANS.sbi.simulator.Simulator(task: petitRADTRANS.sbi.task.SBITask)#
Bases:
abc.ABCBase simulator for SBI dataset generation and validation.
- task#
- abstractmethod sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#
Sample free-parameter vectors from the task prior or a proposal.
- abstractmethod simulate_from_parameters(parameters: Any) SimulationBatch#
Run the forward model and noise pipeline for pre-specified parameters.
- simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#
Sample parameters and simulate one batch in a single call.
- class petitRADTRANS.sbi.simulator.BatchedSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None)#
Bases:
SimulatorRuntime-backed simulator that prefers vectorized task execution.
The class intentionally exposes only interfaces at this stage. Concrete implementations can back this with
RetrievalRuntime.evaluate_vectorized_jaxor a fallback scalar loop when JAX batching is unavailable.- runtime = None#
- sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#
Sample free-parameter vectors from the task prior or a proposal.
- abstractmethod simulate_from_parameters(parameters: Any) SimulationBatch#
Run the forward model and noise pipeline for pre-specified parameters.
- class petitRADTRANS.sbi.simulator.RuntimeSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None, seed: int | None = None, data_parallel: bool | None = None)#
Bases:
BatchedSimulatorConcrete simulator backed by the retrieval runtime.
The simulator samples from the retrieval prior using JAX random utilities and projects deterministic forward-model outputs into the observation space using
RetrievalRuntime.- _rng_key#
- _validate_task_support() None#
- _next_key() jax.Array#
- static _row_invalid_value_mask(values: Any, mask: Any, constraint: Any) numpy.ndarray#
- _invalid_sample_mask(observations: Mapping[str, Mapping[str, Any]], n_samples: int) tuple[numpy.ndarray, dict[str, int]]#
- static _slice_batch_rows(batch: SimulationBatch, row_indices: numpy.ndarray) SimulationBatch#
- static _concatenate_simulation_batches(batches: list[SimulationBatch], diagnostics: Mapping[str, Any] | None = None) SimulationBatch#
- _sample_prior_parameter_batch(n_samples: int) dict[str, jax.numpy.ndarray]#
- sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#
Sample free-parameter vectors from the task prior or a proposal.
- _advance_rng_for_batch(n_samples: int) None#
Advance the PRNG state as if
simulate_batch(n_samples)were called.This performs only cheap key-splitting – no forward-model evaluation – making it suitable for fast dataset-resume skipping.
The number of
_next_keycalls matches the pattern insidesimulate_batch→_sample_prior_parameter_batch(1 key) and_simulate_parameter_matrix(1 key per observation that adds noise).
- _apply_noise(values: Any, uncertainties: Any, covariance: Any) tuple[Any, Any, Any]#
- _apply_noise_batched(values_batch: Any, uncertainties_batch: Any, covariance_batch: Any) tuple[Any, Any, Any]#
Apply noise to a batch of simulation outputs.
- Args:
values_batch: Shape
(n_samples, n_wavelengths). uncertainties_batch: Shape(n_samples, n_wavelengths)orNone. covariance_batch: Shape(n_samples, n_wl, n_wl)orNone.- Returns:
Noisy values, uncertainties, and covariance with the same leading batch dimension as
values_batch.
- static _stack_observation_payloads(payloads: list[dict[str, Any]]) dict[str, Any]#
- _simulate_parameter_matrix(parameter_matrix: Any, cube_parameters: Any = None, unconstrained_parameters: Any = None, data_parallel: bool = False) SimulationBatch#
- simulate_from_parameters(parameters: Any) SimulationBatch#
Run the forward model and noise pipeline for pre-specified parameters.
- simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#
Sample parameters and preserve prior-space coordinates when available.