petitRADTRANS.sbi.simulator

Contents

petitRADTRANS.sbi.simulator#

Simulation interfaces for prior-predictive and proposal-driven generation.

Attributes#

Classes#

SimulationBatch

Container for one simulated batch.

SimulationRuntime

Extension of RetrievalRuntime

ProposalSampler

Interface for simulation proposals beyond the prior distribution.

Simulator

Base simulator for SBI dataset generation and validation.

BatchedSimulator

Runtime-backed simulator that prefers vectorized task execution.

RuntimeSimulator

Concrete simulator backed by the retrieval runtime.

Functions#

_stack_simulation_output_payloads(→ dict)

Stack per-sample observation payloads from simulate_observations_scalar.

Module Contents#

petitRADTRANS.sbi.simulator.logger#
class petitRADTRANS.sbi.simulator.SimulationBatch#

Container for one simulated batch.

Attributes:
parameters:

Array-like free-parameter matrix with leading dimension equal to the number of simulated samples.

observations:

Task-conditioned simulated observations.

log_likelihood:

Optional scalar likelihood values associated with each sample.

diagnostics:

Additional runtime diagnostics such as clipping flags or NaN counts.

parameters: Any#
observations: Any#
cube_parameters: Any = None#
unconstrained_parameters: Any = None#
log_likelihood: Any = None#
diagnostics: Mapping[str, Any]#
property n_samples: int#

Return the number of simulated samples represented by the batch.

petitRADTRANS.sbi.simulator._stack_simulation_output_payloads(payloads: list[dict]) dict#

Stack per-sample observation payloads from simulate_observations_scalar.

Array fields are stacked into a leading (n_samples, …) dimension. The "metadata" field is kept as a list (one dict per sample). Fields that are None for every sample remain None.

class petitRADTRANS.sbi.simulator.SimulationRuntime#

Bases: petitRADTRANS.retrieval.runtime.RetrievalRuntime

Extension of RetrievalRuntime that adds deterministic forward-model simulation methods.

RetrievalRuntime covers likelihood evaluation. This subclass layers the simulation surface on top — projecting parameters into observation space without noise so that RuntimeSimulator can inject noise afterwards and assemble SimulationBatch objects.

simulate_observations_scalar(physical_params: petitRADTRANS.retrieval.runtime.PhysicalParams) dict[str, dict[str, Any]]#

Return deterministic observation-space model outputs for one parameter point.

_run_batched_rt_kernel(parameter_matrix: Any) tuple[Any, list[tuple[Any, Any]]]#

Run the expensive vmapped RT kernel and return raw results.

Returns#

tuple[batched_physical_params, list[(group, batched_model_result)]]

The batched physical parameters and per-group raw model outputs before observation projection.

_run_batched_rt_kernel_multi_device(parameter_matrix: Any) tuple[Any, list[tuple[Any, Any]]]#

Multi-device version of _run_batched_rt_kernel().

Distributes the batch across all available JAX devices by placing each chunk on a separate device with jax.device_put and then calling the single-device _run_batched_rt_kernel per chunk. JAX’s asynchronous GPU dispatch allows the per-device computations to overlap. Results are gathered and concatenated before return.

The parameter matrix must have its leading dimension equal to a multiple of the device count (caller is responsible for padding).

_project_batched_rt_results(batched_physical_params: Any, group_batched_results: list[tuple[Any, Any]], n_samples: int) dict[str, dict[str, Any]]#

Project raw RT kernel outputs into observation space.

This is the cheap projection step (interpolation, scaling, binning). When possible the projection and flux-transform functions are vectorized with jax.vmap so that the Python-level loop over samples is replaced by a single traced call.

simulate_observations_batched(parameter_matrix: Any, multi_device: bool = False) dict[str, dict[str, Any]]#

Return deterministic batched observation outputs for a parameter matrix.

For model groups using MODEL_CONTRACT_DIFFERENTIABLE this method vmaps the expensive RT kernel over the whole batch in one XLA call, then performs the cheap projection step (interpolation, scaling, etc.) in a Python loop. Non-native groups fall back to a scalar-loop path equivalent to calling simulate_observations_scalar() per row.

Args:

parameter_matrix: Shape (n_samples, n_free_params). multi_device: When True and more than one JAX device is

available, distribute the vmapped RT kernel across devices using jax.pmap.

Returns:

Mapping from observation name to a payload dict where "values" has shape (n_samples, n_wavelengths) and "mask", "coordinates", and "metadata" are static / per-sample lists.

class petitRADTRANS.sbi.simulator.ProposalSampler#

Bases: abc.ABC

Interface for simulation proposals beyond the prior distribution.

abstractmethod sample(n_samples: int, task: petitRADTRANS.sbi.task.SBITask) Any#

Draw free-parameter vectors for the given task.

class petitRADTRANS.sbi.simulator.Simulator(task: petitRADTRANS.sbi.task.SBITask)#

Bases: abc.ABC

Base simulator for SBI dataset generation and validation.

task#
abstractmethod sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#

Sample free-parameter vectors from the task prior or a proposal.

abstractmethod simulate_from_parameters(parameters: Any) SimulationBatch#

Run the forward model and noise pipeline for pre-specified parameters.

simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#

Sample parameters and simulate one batch in a single call.

class petitRADTRANS.sbi.simulator.BatchedSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None)#

Bases: Simulator

Runtime-backed simulator that prefers vectorized task execution.

The class intentionally exposes only interfaces at this stage. Concrete implementations can back this with RetrievalRuntime.evaluate_vectorized_jax or a fallback scalar loop when JAX batching is unavailable.

runtime = None#
sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#

Sample free-parameter vectors from the task prior or a proposal.

abstractmethod simulate_from_parameters(parameters: Any) SimulationBatch#

Run the forward model and noise pipeline for pre-specified parameters.

class petitRADTRANS.sbi.simulator.RuntimeSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None, seed: int | None = None, data_parallel: bool | None = None)#

Bases: BatchedSimulator

Concrete simulator backed by the retrieval runtime.

The simulator samples from the retrieval prior using JAX random utilities and projects deterministic forward-model outputs into the observation space using RetrievalRuntime.

_rng_key#
_validate_task_support() None#
_next_key() jax.Array#
static _row_invalid_value_mask(values: Any, mask: Any, constraint: Any) numpy.ndarray#
_invalid_sample_mask(observations: Mapping[str, Mapping[str, Any]], n_samples: int) tuple[numpy.ndarray, dict[str, int]]#
static _slice_batch_rows(batch: SimulationBatch, row_indices: numpy.ndarray) SimulationBatch#
static _concatenate_simulation_batches(batches: list[SimulationBatch], diagnostics: Mapping[str, Any] | None = None) SimulationBatch#
_sample_prior_parameter_batch(n_samples: int) dict[str, jax.numpy.ndarray]#
sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#

Sample free-parameter vectors from the task prior or a proposal.

_advance_rng_for_batch(n_samples: int) None#

Advance the PRNG state as if simulate_batch(n_samples) were called.

This performs only cheap key-splitting – no forward-model evaluation – making it suitable for fast dataset-resume skipping.

The number of _next_key calls matches the pattern inside simulate_batch_sample_prior_parameter_batch (1 key) and _simulate_parameter_matrix (1 key per observation that adds noise).

_apply_noise(values: Any, uncertainties: Any, covariance: Any) tuple[Any, Any, Any]#
_apply_noise_batched(values_batch: Any, uncertainties_batch: Any, covariance_batch: Any) tuple[Any, Any, Any]#

Apply noise to a batch of simulation outputs.

Args:

values_batch: Shape (n_samples, n_wavelengths). uncertainties_batch: Shape (n_samples, n_wavelengths) or None. covariance_batch: Shape (n_samples, n_wl, n_wl) or None.

Returns:

Noisy values, uncertainties, and covariance with the same leading batch dimension as values_batch.

static _stack_observation_payloads(payloads: list[dict[str, Any]]) dict[str, Any]#
_simulate_parameter_matrix(parameter_matrix: Any, cube_parameters: Any = None, unconstrained_parameters: Any = None, data_parallel: bool = False) SimulationBatch#
simulate_from_parameters(parameters: Any) SimulationBatch#

Run the forward model and noise pipeline for pre-specified parameters.

simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#

Sample parameters and preserve prior-space coordinates when available.