petitRADTRANS.sbi

Contents

petitRADTRANS.sbi#

Simulation-based inference interfaces for petitRADTRANS.

The modules in petitRADTRANS.sbi provide a production-oriented architecture for amortized inference workflows built on top of the existing retrieval runtime. The package is intentionally lightweight at this stage: it defines the core task, simulation, dataset, posterior, and benchmarking interfaces without committing to one training backend or storage engine.

Submodules#

Exceptions#

TaskCompatibilityError

Raised when an observation or artifact is incompatible with a task.

Classes#

BenchmarkComparison

Compare an amortized result to one or more exact retrieval baselines.

BenchmarkMetrics

Metrics summarizing agreement and predictive performance.

RetrievalBenchmarkCase

One benchmark problem used to compare inference backends.

RetrievalBenchmarkSuite

Run standardized benchmark comparisons for SBI tasks.

LocalSensitivityPointReport

Local linear-identifiability summary around one representative point.

LocalSensitivityReport

Aggregate local information-content diagnostics for one observation.

PosteriorPredictiveReport

Aggregate posterior-predictive summaries for held-out observations.

SimulationBasedCalibrationReport

Rank-based SBC summary over a held-out set of observations.

HDF5SimulationDatasetStore

HDF5-backed store for simulation corpora.

SimulationDatasetStore

Backend-independent interface for reading and writing simulation data.

StoredSimulationDataset

Read-only handle for a stored chunked simulation dataset.

ZarrSimulationDatasetStore

Concrete Zarr-backed store for chunked simulation corpora.

DatasetSplit

Named dataset partitions used during training and evaluation.

NormalizedObservationDatasetReader

Lightweight reader yielding normalized ObservationBlock batches for training.

HierarchicalObservationEncoder

Learned hierarchical encoder dispatching over modalities.

PhotometryPointEncoder

Learned photometry encoder using per-point MLP features and pooling.

SpectralConv1DEncoder

Learned spectral encoder using 1D convolutions to retain positional information.

SpectralPatchEncoder

Learned spectral encoder based on patch summaries and MLP pooling.

AmortizedRetrieval

Serve trained SBI models through a retrieval-like interface.

AmortizedRetrievalResult

Return type for amortized inference queries.

OODDiagnostic

Describe whether an observation is inside the training support.

ObservationBlock

Represent one modality-specific observation block.

ObservationEncoder

Transform structured observation blocks into model-ready embeddings.

ObservationModality

Supported observation block types for SBI conditioning.

ConditionalAutoregressiveFlowPosterior

Posterior estimator specialized to the autoregressive flow backend.

ConditionalFlowPosterior

Concrete amortized posterior using a conditional flow backend.

ConditionalNeuralAutoregressiveFlowPosterior

Posterior estimator specialized to the neural autoregressive backend.

ConditionalSplineFlowPosterior

Posterior estimator specialized to the spline flow backend.

FlowMatchingPosterior

Conditional flow-matching posterior skeleton.

PersistentPosteriorEstimator

Shared persistence helper for estimator backends with on-disk artifacts.

PosteriorBatch

Training batch passed to amortized posterior estimators.

PosteriorEstimator

Backend-agnostic interface for amortized posterior models.

PosteriorSamples

Posterior samples and optional per-sample diagnostics.

TaskPreprocessingMetadata

Serializable preprocessing metadata for an SBI task family.

ProposalSampler

Interface for simulation proposals beyond the prior distribution.

RuntimeSimulator

Concrete simulator backed by the retrieval runtime.

SimulationBatch

Container for one simulated batch.

Simulator

Base simulator for SBI dataset generation and validation.

NoiseModelConfig

Describe how observational noise is injected during simulation.

ObservationSchema

Capture the supported observation family for an amortized task.

ObservationValueConstraint

Admissible range for simulated observation values.

SBITask

Bundle the immutable ingredients required for an SBI problem.

SimulationConfig

Control how a task generates prior-predictive simulations.

EarlyStoppingConfig

Early stopping policy for SBI training.

SBITrainer

Reusable optimization loop for amortized SBI posteriors.

TrainingConfig

Configuration for posterior optimization.

Functions#

generate_local_sensitivity_report(, ...)

Diagnose local physical identifiability around representative posterior points.

generate_posterior_predictive_report(...)

Generate posterior-predictive summaries for a held-out dataset split.

generate_sbc_report(, max_cases, seed, data_parallel)

Run SBC over a dataset reader using normalized observation batches.

local_sensitivity_report_to_payload(→ dict[str, Any])

Convert a local sensitivity report into a JSON-serializable payload.

generate_simulation_dataset(→ GeneratedSimulationDataset)

Generate and persist a simulation corpus in one call.

load_posterior_estimator(...)

Load a saved posterior estimator without naming its concrete class.

build_observation_block(→ ObservationBlock)

Build one observation block with modality normalization.

build_observation_block_batch(...)

Build observation blocks for each sample in a batched payload.

build_observation_blocks_from_sample(...)

Build modality-aware observation blocks for one simulated sample.

plot_local_sensitivity_fisher_correlations(...)

Plot Fisher-correlation heatmaps for each representative point.

plot_local_sensitivity_jacobians(→ tuple[Any, ...)

Plot whitened Jacobian heatmaps for each representative point.

plot_local_sensitivity_singular_values(→ tuple[Any, ...)

Plot singular spectra of the whitened Jacobian for each point.

plot_posterior_corner(→ tuple[Any, numpy.ndarray])

Plot a lower-triangular corner view of posterior structure.

plot_posterior_marginals(→ tuple[Any, numpy.ndarray])

Plot one histogram per posterior dimension.

plot_posterior_predictive_report(→ tuple[Any, ...)

Plot observed values against posterior-predictive means and intervals.

plot_sbc_rank_histograms(→ tuple[Any, numpy.ndarray])

Plot one SBC rank histogram per inferred parameter.

fit_task_preprocessing(→ TaskPreprocessingMetadata)

Fit preprocessing statistics from training observation blocks.

normalize_observation_block(...)

Normalize one observation block with fitted preprocessing statistics.

normalize_observation_blocks(...)

Normalize a list of observation blocks.

Package Contents#

class petitRADTRANS.sbi.BenchmarkComparison#

Compare an amortized result to one or more exact retrieval baselines.

case_name: str#
amortized_result: petitRADTRANS.sbi.inference.AmortizedRetrievalResult#
exact_results: Mapping[str, Any]#
metrics: BenchmarkMetrics#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.BenchmarkMetrics#

Metrics summarizing agreement and predictive performance.

calibration: Mapping[str, float]#
posterior_distance: Mapping[str, float]#
predictive_checks: Mapping[str, float]#
runtime: Mapping[str, float]#
class petitRADTRANS.sbi.RetrievalBenchmarkCase#

One benchmark problem used to compare inference backends.

name: str#
task: petitRADTRANS.sbi.task.SBITask#
observation: Any#
reference_posterior: Any = None#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.RetrievalBenchmarkSuite(cases: list[RetrievalBenchmarkCase])#

Run standardized benchmark comparisons for SBI tasks.

cases#
abstractmethod run_case(case: RetrievalBenchmarkCase) BenchmarkComparison#

Run one benchmark case and compute comparison metrics.

Parameters#

case:

Benchmark case describing the task, observation, and optional exact reference posterior to compare against.

Returns#

BenchmarkComparison

Comparison payload combining amortized and exact results together with any derived metrics.

Notes#

The base class is intentionally abstract. Concrete suites are expected to bind exact and amortized inference backends and define the metric computations appropriate for the comparison.

run_all() list[BenchmarkComparison]#

Run all configured benchmark cases.

class petitRADTRANS.sbi.LocalSensitivityPointReport#

Local linear-identifiability summary around one representative point.

Attributes#

label:

Short human-readable label for the representative point.

parameters:

Physical parameter vector at which the Jacobian was evaluated.

finite_difference_steps:

Per-parameter finite-difference step sizes used during Jacobian construction.

finite_difference_schemes:

Per-parameter scheme labels such as 'central' or 'forward'.

whitened_jacobian:

Observation Jacobian divided by observational uncertainty, with shape (n_observation_values, n_parameters).

singular_values:

Singular values of the whitened Jacobian.

effective_rank:

Number of singular values larger than the configured relative cutoff.

condition_number:

Condition number inferred from the singular spectrum.

fisher_matrix:

Approximate Fisher information matrix J^T J.

fisher_covariance:

Damped pseudo-inverse of the Fisher matrix.

fisher_correlation:

Correlation matrix derived from the approximate Fisher covariance.

parameter_sensitivity_norm:

Per-parameter square-root Fisher diagonal.

local_sigma:

Approximate local standard deviation from the Fisher covariance.

metadata:

Auxiliary diagnostics such as the ridge term and failed columns.

label: str#
parameters: numpy.ndarray#
finite_difference_steps: numpy.ndarray#
finite_difference_schemes: tuple[str, Ellipsis]#
whitened_jacobian: numpy.ndarray#
singular_values: numpy.ndarray#
effective_rank: int#
condition_number: float#
fisher_matrix: numpy.ndarray#
fisher_covariance: numpy.ndarray#
fisher_correlation: numpy.ndarray#
parameter_sensitivity_norm: numpy.ndarray#
local_sigma: numpy.ndarray#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.LocalSensitivityReport#

Aggregate local information-content diagnostics for one observation.

Attributes#

parameter_names:

Parameter ordering used throughout the report.

posterior_mean:

Posterior mean in physical parameter space.

posterior_std:

Posterior standard deviation in physical parameter space.

posterior_median:

Posterior median in physical parameter space.

posterior_iqr:

Posterior interquartile range in physical parameter space.

representative_points:

Local sensitivity summaries evaluated at representative posterior points such as the posterior mean and highest-density sample.

aggregate_local_sigma:

Median local Fisher sigma across representative points.

aggregate_parameter_sensitivity_norm:

Median parameter sensitivity norm across representative points.

posterior_to_local_sigma_ratio:

Ratio between posterior standard deviation and local Fisher sigma.

parameter_diagnostics:

Per-parameter heuristic summary separating weak data constraints from broader-than-local posterior structure.

metadata:

Auxiliary metadata such as quantile levels and observation slices.

parameter_names: tuple[str, Ellipsis]#
posterior_mean: numpy.ndarray#
posterior_std: numpy.ndarray#
posterior_median: numpy.ndarray#
posterior_iqr: numpy.ndarray#
representative_points: tuple[LocalSensitivityPointReport, Ellipsis]#
aggregate_local_sigma: numpy.ndarray#
aggregate_parameter_sensitivity_norm: numpy.ndarray#
posterior_to_local_sigma_ratio: numpy.ndarray#
parameter_diagnostics: Mapping[str, Mapping[str, Any]]#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.PosteriorPredictiveReport#

Aggregate posterior-predictive summaries for held-out observations.

Attributes#

observed_values:

Observed values kept on the original observation scale.

predictive_mean:

Posterior-predictive mean curves or vectors.

predictive_std:

Posterior-predictive standard deviations.

interval_lower, interval_upper:

Central predictive interval bounds for the requested level.

interval_coverage:

Per-dataset fraction of observed points covered by the predictive interval.

mean_absolute_error:

Mean absolute deviation between predictive mean and observed values.

metadata:

Auxiliary metadata such as split name, number of cases, and parameter space used during predictive generation.

observed_values: Mapping[str, numpy.ndarray]#
predictive_mean: Mapping[str, numpy.ndarray]#
predictive_std: Mapping[str, numpy.ndarray]#
interval_lower: Mapping[str, numpy.ndarray]#
interval_upper: Mapping[str, numpy.ndarray]#
interval_coverage: Mapping[str, float]#
mean_absolute_error: Mapping[str, float]#
mean_absolute_error_sigma: Mapping[str, float]#
median_interval_width_over_uncertainty: Mapping[str, float]#
crps: Mapping[str, float]#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.SimulationBasedCalibrationReport#

Rank-based SBC summary over a held-out set of observations.

Attributes#

ranks:

Integer rank of the ground-truth parameter within posterior samples for each held-out case and parameter dimension.

rank_histogram_counts:

Per-parameter SBC histogram counts.

posterior_means:

Posterior mean for each held-out case.

truths:

Ground-truth parameters paired with each posterior sample set.

coverages:

Coverage summaries for the requested nominal interval levels.

mean_rank:

Mean empirical rank for each parameter dimension.

normalized_mean_rank_error:

Absolute mean-rank error normalized by the expected average rank.

metadata:

Auxiliary run metadata such as split name and number of posterior draws.

ranks: numpy.ndarray#
rank_histogram_counts: numpy.ndarray#
posterior_means: numpy.ndarray#
truths: numpy.ndarray#
coverages: tuple[CoverageLevelReport, Ellipsis]#
mean_rank: numpy.ndarray#
normalized_mean_rank_error: numpy.ndarray#
metadata: Mapping[str, Any]#
petitRADTRANS.sbi.generate_local_sensitivity_report(task: petitRADTRANS.sbi.task.SBITask, posterior_samples: Any, observation_blocks: Sequence[Any], posterior_log_probabilities: Any = None, parameter_space: str = 'physical', simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, quantile_levels: Sequence[float] = (0.1, 0.5, 0.9), finite_difference_relative_step: float = 0.001, finite_difference_std_fraction: float = 0.1, finite_difference_absolute_floor: float = 1e-05, max_step_reduction_attempts: int = 6, svd_relative_tolerance: float = 0.001, posterior_underexploited_ratio_threshold: float = 1.5, weak_sensitivity_fraction_threshold: float = 0.15, seed: int | None = None) LocalSensitivityReport#

Diagnose local physical identifiability around representative posterior points.

The report evaluates a deterministic simulator Jacobian at representative posterior points, whitens it by observational uncertainty, and derives a Fisher-style local covariance approximation for each point.

petitRADTRANS.sbi.generate_posterior_predictive_report(task: petitRADTRANS.sbi.task.SBITask, posterior: Any, dataset_reader: petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader, split: petitRADTRANS.sbi.dataset.DatasetSplit = DatasetSplit.TEST, n_posterior_samples: int = 256, interval_level: float = 0.9, max_cases: int | None = None, seed: int | None = None, simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, n_predictive_forward_model_samples: int | None = None, checkpoint_directory: str | pathlib.Path | None = None, data_parallel: bool | None = None) PosteriorPredictiveReport#

Generate posterior-predictive summaries for a held-out dataset split.

Parameters#

task:

SBI task defining parameter transforms and the simulator configuration.

posterior:

Trained posterior estimator used to sample held-out predictive draws.

dataset_reader:

Reader providing normalized observations and preprocessing metadata.

split:

Held-out split used for the predictive report.

n_posterior_samples:

Number of posterior draws generated per held-out observation.

interval_level:

Central predictive interval level reported for each dataset.

max_cases:

Optional cap on the number of held-out observations evaluated.

seed:

Optional base seed used to make predictive sampling reproducible.

simulator:

Optional runtime simulator override.

n_predictive_forward_model_samples:

Number of posterior draws passed through the forward model per held-out case. When None all n_posterior_samples draws are forwarded. Subsampling here is the primary lever for keeping the total number of petitRADTRANS calls to a manageable level when evaluating many held-out cases.

checkpoint_directory:

Optional directory for per-case checkpoints. When provided, each completed case is written to disk as a compressed .npz file and skipped on resume. This makes the expensive forward-model loop restartable after interruption.

Returns#

PosteriorPredictiveReport

Aggregate posterior-predictive summary over the requested split.

Notes#

Observations are normalized internally for posterior encoding but compared on the original observation scale in the returned report.

petitRADTRANS.sbi.generate_sbc_report(posterior: Any, dataset_reader: petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader, split: petitRADTRANS.sbi.dataset.DatasetSplit = DatasetSplit.TEST, n_posterior_samples: int = 256, batch_size: int = 32, parameter_space: str | None = None, levels: Sequence[float] = (0.5, 0.8, 0.95), max_cases: int | None = None, seed: int | None = None, data_parallel: bool | None = None) SimulationBasedCalibrationReport#

Run SBC over a dataset reader using normalized observation batches.

Parameters#

posterior:

Trained posterior estimator exposing encode_observation and sample_posterior.

dataset_reader:

Reader yielding normalized held-out observations and matched parameter values.

split:

Dataset split used for the SBC evaluation.

n_posterior_samples:

Number of posterior draws generated per held-out case.

batch_size:

Reader batch size used during report generation.

parameter_space:

Optional parameter space override. Defaults to the posterior’s own configured parameter space.

levels:

Coverage levels summarized in the returned report.

max_cases:

Optional cap on the number of held-out observations evaluated.

seed:

Optional base seed used to generate reproducible posterior draws.

Returns#

SimulationBasedCalibrationReport

SBC summary computed from the requested dataset split.

petitRADTRANS.sbi.local_sensitivity_report_to_payload(report: LocalSensitivityReport) dict[str, Any]#

Convert a local sensitivity report into a JSON-serializable payload.

exception petitRADTRANS.sbi.TaskCompatibilityError#

Bases: ValueError

Raised when an observation or artifact is incompatible with a task.

class petitRADTRANS.sbi.HDF5SimulationDatasetStore(chunk_size: int = 256)#

Bases: SimulationDatasetStore

HDF5-backed store for simulation corpora.

Stores all simulation data for a corpus in a single .h5 file, keeping file counts at 1 regardless of the number of simulations or splits. Requires h5py (already present in the jaxprt environment).

chunk_size = 256#
static _require_h5py() None#
create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') HDF5SimulationDatasetWriter#

Create a writer for a new simulation dataset.

open(manifest_or_uri: SimulationDatasetManifest | str) HDF5StoredSimulationDataset#

Open a stored dataset for training or evaluation.

class petitRADTRANS.sbi.SimulationDatasetStore#

Bases: abc.ABC

Backend-independent interface for reading and writing simulation data.

abstractmethod create_writer(manifest: SimulationDatasetManifest) SimulationDatasetWriter#

Create a writer for a new simulation dataset.

abstractmethod open(manifest_or_uri: SimulationDatasetManifest | str) Any#

Open a stored dataset for training or evaluation.

class petitRADTRANS.sbi.StoredSimulationDataset#

Read-only handle for a stored chunked simulation dataset.

manifest: SimulationDatasetManifest#
storage_uri: str#
root: Any#
iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) Iterator[dict[str, Any]]#

Yield row-slices of one split without loading it all into RAM.

Parameters#

split:

Dataset split to iterate.

chunk_size:

Number of rows per yielded chunk.

indices:

Optional integer array of row indices to read. When supplied only those rows are fetched in ascending order. When None every row in the split is visited sequentially.

observation_fields:

Optional set of observation field names to load. When None every dataset field is loaded. Pass _OBSERVATION_TRAINING_FIELDS to skip large unused fields such as covariance during training.

Yields#

dict[str, Any]

Same structure as read_split() but covering only chunk_size rows per iteration.

read_split(split: DatasetSplit = DatasetSplit.TRAIN) dict[str, Any]#
class petitRADTRANS.sbi.ZarrSimulationDatasetStore(chunk_size: int = 256)#

Bases: SimulationDatasetStore

Concrete Zarr-backed store for chunked simulation corpora.

chunk_size = 256#
static _require_zarr() None#
create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') SimulationDatasetWriter#

Create a writer for a new simulation dataset.

open(manifest_or_uri: SimulationDatasetManifest | str) StoredSimulationDataset#

Open a stored dataset for training or evaluation.

petitRADTRANS.sbi.generate_simulation_dataset(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int, chunk_size: int = 256, split_counts: Mapping[DatasetSplit | str, int] | None = None, split_fractions: Mapping[DatasetSplit | str, float] | None = None, split_policy: SplitSamplingPolicy | str = SplitSamplingPolicy.SEQUENTIAL, split_seed: int | None = None, simulator: Any = None, include_preprocessing_metadata: bool = True, dataset_version: str = '0.1.0', resume: bool = False, backend: str = 'hdf5', data_parallel: bool | None = None, store_covariance: bool = False) GeneratedSimulationDataset#

Generate and persist a simulation corpus in one call.

Parameters#

backend:

Storage backend to use. 'hdf5' (default) writes all data into a single .h5 file, keeping the file count at 1 regardless of the number of simulations. 'zarr' uses the Zarr directory store which produces one file per compressed chunk.

data_parallel:

When True (or None with multiple JAX devices), the vmapped RT kernel is distributed across devices using jax.pmap. The effective per-iteration batch size is automatically scaled by the device count so each device processes the configured simulation_config.batch_size samples.

store_covariance:

When True the full covariance matrix for each simulated spectrum is written to disk. When False (default) only the covariance diagonal is stored under the covariance field to reduce storage pressure.

class petitRADTRANS.sbi.DatasetSplit#

Bases: str, enum.Enum

Named dataset partitions used during training and evaluation.

TRAIN = 'train'#
VALIDATION = 'validation'#
TEST = 'test'#
BENCHMARK = 'benchmark'#
class petitRADTRANS.sbi.NormalizedObservationDatasetReader#

Lightweight reader yielding normalized ObservationBlock batches for training.

dataset: StoredSimulationDataset | HDF5StoredSimulationDataset#
preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata#
_split_cache: dict#
iter_batches(split: DatasetSplit = DatasetSplit.TRAIN, batch_size: int = 32, shuffle: bool = False, seed: int | None = None, parameter_space: str = 'physical', encoder: Any = None) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#

Yield mini-batches of normalized observations and matched parameters.

Parameters#

split:

Dataset split to iterate over.

batch_size:

Number of samples yielded in each batch.

shuffle:

Whether to shuffle sample order within the requested split.

seed:

Optional random seed used when shuffle=True.

parameter_space:

Parameter representation returned in each batch. Supported values are 'physical', 'cube', and 'unconstrained'.

encoder:

Optional encoder used to convert block lists into dense embedding arrays before batches are yielded.

Returns#

Iterator[PosteriorBatch]

Iterator over batches containing parameters, observations, and small metadata dictionaries describing the batch provenance.

Notes#

Splits with fewer than _STREAMING_THRESHOLD rows are cached in RAM after the first read so that repeated epoch calls (e.g. validation) pay only slicing cost. Larger splits are streamed directly from disk one batch at a time to avoid loading the full dataset into memory.

_iter_batches_cached(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#

Iterate using the full-split RAM cache (small splits).

_iter_batches_streaming(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#

Iterate by streaming rows directly from disk (large splits).

When shuffle=True a global index permutation is computed in RAM (one integer per simulation row — negligible memory) and used to fetch HDF5 rows in sorted sub-windows of batch_size, satisfying h5py’s monotonic-index requirement while still presenting shuffled order to the training loop.

static _select_parameters(split_data: dict, parameter_space: str) Any#

Return the parameter array for the requested coordinate space.

class petitRADTRANS.sbi.HierarchicalObservationEncoder(embedding_dim: int = 128, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 128, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, key: jax.Array | None = None)#

Bases: equinox.Module, petitRADTRANS.sbi.observation.ObservationEncoder

Learned hierarchical encoder dispatching over modalities.

Parameters#

embedding_dim:

Size of the final joint observation embedding.

spectrum_embedding_dim:

Intermediate embedding size used by the spectral sub-encoder.

photometry_embedding_dim:

Intermediate embedding size used by the photometry sub-encoder.

patch_size:

Patch size used by the spectral encoder.

hidden_dim:

Hidden width shared across the component encoders and aggregator.

key:

Optional JAX random key used to initialize all submodules.

Notes#

Spectrum and photometry blocks are handled by dedicated sub-encoders and then merged with a permutation-invariant aggregator. Unsupported modalities currently fall back to resized raw-value vectors.

spectrum_encoder: SpectralPatchEncoder | SpectralConv1DEncoder#
photometry_encoder: PhotometryPointEncoder#
aggregator: DatasetSetAggregator#
embedding_dim: int#
_encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one observation block with modality-aware dispatch.

Parameters#

block:

Observation block to encode.

Returns#

jnp.ndarray

Fixed-width embedding for the supplied block.

encode(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#

Encode one structured observation made of multiple blocks.

Parameters#

blocks:

Observation blocks associated with one target system.

Returns#

EncodedObservation

Aggregated embedding and light metadata describing the block set.

encode_stacked_batch(blocks_batch: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) jax.numpy.ndarray#

Encode a batch of identically-structured observations using vmap.

All observations must share the same block structure (same number of blocks, same array shapes per block). This is always the case for SBI tasks where every observation is produced by the same forward model.

Parameters#

blocks_batch:

Outer list indexes samples; inner list holds the per-block observation data for one sample.

Returns#

jnp.ndarray

Float32 array of shape (n_samples, embedding_dim).

encode_from_prestacked(obs: Any) jax.numpy.ndarray#

Encode a batch of observations from pre-stacked arrays.

Accepts a PreStackedObservations instance whose array fields have already been extracted from ObservationBlock objects outside the JAX JIT boundary. Only the vmapped XLA computation runs here, enabling the training step to be compiled once and reused across all batches.

Parameters#

obs:

Pre-stacked observation container with stacked_blocks (one (values, uncertainties, coordinates, mask, log_scale, absolute_values) tuple per block, each of shape (batch_size, n_wl) except the scalar log_scale arrays) and modalities (static tuple of modality value strings).

Returns#

jnp.ndarray

Float32 array of shape (batch_size, embedding_dim).

class petitRADTRANS.sbi.PhotometryPointEncoder(embedding_dim: int = 64, hidden_dim: int = 96, key: jax.Array | None = None)#

Bases: equinox.Module

Learned photometry encoder using per-point MLP features and pooling.

Parameters#

embedding_dim:

Size of the returned photometric embedding.

hidden_dim:

Hidden width of the per-point MLP.

key:

Optional JAX random key used for initialization.

Notes#

Each photometric point is represented by value, uncertainty, coordinate, and an inferred width feature before permutation-invariant pooling.

point_mlp: equinox.nn.MLP#
output_projection: equinox.nn.Linear#
embedding_dim: int#
encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one photometric observation block.

Parameters#

block:

Photometry-like observation block to encode.

Returns#

jnp.ndarray

Dense embedding for the supplied photometric block.

_encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray) jax.numpy.ndarray#

Encode pre-processed block arrays (vmappable — no ObservationBlock input).

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays already produced by _as_vector. uncertainties and coordinates may have length 0.

mask:

Boolean 1-D array already produced by _safe_mask.

Returns#

jnp.ndarray

Dense photometric embedding of shape (embedding_dim,).

class petitRADTRANS.sbi.SpectralConv1DEncoder(embedding_dim: int = 64, n_wavelengths: int = 233, hidden_channels: tuple[int, int] = (32, 64), key: jax.Array | None = None)#

Bases: equinox.Module

Learned spectral encoder using 1D convolutions to retain positional information.

Parameters#

embedding_dim:

Size of the spectral embedding returned for one observation block.

n_wavelengths:

Fixed number of spectral points per observation block. Must match the observation schema of the SBI task.

hidden_channels:

Tuple of intermediate channel widths for conv layers.

key:

Optional JAX random key used for weight initialization.

conv1: equinox.nn.Conv1d#
conv2: equinox.nn.Conv1d#
conv3: equinox.nn.Conv1d#
amplitude_projection: equinox.nn.MLP#
pool_projection: equinox.nn.Linear#
branch_projection: equinox.nn.Linear#
output_projection: equinox.nn.Linear#
scale_projection: equinox.nn.Linear#
embedding_dim: int#
n_wavelengths: int#
_forward_conv(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#

Run 1D conv stack on pre-processed arrays.

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays of length n_wavelengths.

mask:

Boolean 1-D mask (True = invalid).

Returns#

jnp.ndarray

Embedding of shape (embedding_dim,).

_pad_to_fixed(arr: jax.numpy.ndarray) jax.numpy.ndarray#

Pad or truncate a 1-D array to n_wavelengths.

encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one spectral observation block.

Parameters#

block:

Spectrum-like observation block.

Returns#

jnp.ndarray

Dense embedding of shape (embedding_dim,).

_encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#

Encode pre-processed block arrays (vmappable — no ObservationBlock input).

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays already produced by _as_vector.

mask:

Boolean 1-D array already produced by _safe_mask.

log_median_flux:

Scalar absolute-flux feature derived from the raw spectrum median.

Returns#

jnp.ndarray

Dense spectral embedding of shape (embedding_dim,).

class petitRADTRANS.sbi.SpectralPatchEncoder(embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 96, key: jax.Array | None = None)#

Bases: equinox.Module

Learned spectral encoder based on patch summaries and MLP pooling.

Parameters#

embedding_dim:

Size of the spectral embedding returned for one observation block.

patch_size:

Number of spectral points summarized together before MLP encoding.

hidden_dim:

Hidden width of the internal summary MLP.

key:

Optional JAX random key used for weight initialization.

Notes#

The encoder summarizes values, uncertainties, and coordinates patch by patch instead of operating on a fully convolutional representation. This keeps the implementation lightweight and shape-agnostic.

patch_mlp: equinox.nn.MLP#
amplitude_projection: equinox.nn.MLP#
branch_projection: equinox.nn.Linear#
output_projection: equinox.nn.Linear#
scale_projection: equinox.nn.Linear#
embedding_dim: int#
patch_size: int#
encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one spectral observation block.

Parameters#

block:

Spectrum-like observation block containing values and optional uncertainties, coordinates, and masks.

Returns#

jnp.ndarray

Dense fixed-width embedding for the supplied spectral block.

_encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#

Encode pre-processed block arrays (vmappable — no ObservationBlock input).

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays already produced by _as_vector. uncertainties and coordinates may have length 0 to indicate absence.

mask:

Boolean 1-D array already produced by _safe_mask.

Returns#

jnp.ndarray

Dense spectral embedding of shape (embedding_dim,).

petitRADTRANS.sbi.load_posterior_estimator(input_directory: str) petitRADTRANS.sbi.posterior.PosteriorEstimator#

Load a saved posterior estimator without naming its concrete class.

class petitRADTRANS.sbi.AmortizedRetrieval(task: petitRADTRANS.sbi.task.SBITask, posterior_estimator: petitRADTRANS.sbi.posterior.PosteriorEstimator, simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | Mapping[str, Any] | None = None)#

Serve trained SBI models through a retrieval-like interface.

Parameters#

task:

SBI task compatible with the trained posterior.

posterior_estimator:

Trained posterior estimator used for encoding and sampling.

simulator:

Optional simulator override used for posterior-predictive generation.

preprocessing_metadata:

Optional preprocessing metadata or payload. When omitted the inference service attempts to recover the preprocessing payload saved inside the posterior artifact.

Notes#

Raw user-facing observation blocks are normalized internally before encoding when preprocessing metadata is available, but posterior-predictive comparisons remain on the original observation scale.

task#
posterior_estimator#
simulator = None#
preprocessing_metadata#
_resolve_preprocessing_metadata(preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | Mapping[str, Any] | None) petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | None#
prepare_observation_blocks(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) list[petitRADTRANS.sbi.observation.ObservationBlock]#

Normalize user-facing observation blocks when preprocessing metadata is available.

Parameters#

observation_blocks:

Raw observation blocks provided by the caller.

Returns#

list[ObservationBlock]

Either the original blocks or normalized copies, depending on whether preprocessing metadata is available.

infer(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], n_posterior_samples: int = 1000, include_posterior_predictive: bool = False, posterior_predictive_interval_level: float = 0.9, n_predictive_forward_model_samples: int | None = None, seed: int | None = None) AmortizedRetrievalResult#

Infer a posterior for one structured observation.

Parameters#

observation_blocks:

Raw observation blocks for one target system.

n_posterior_samples:

Number of posterior draws to generate.

include_posterior_predictive:

Whether to also generate a posterior-predictive summary.

posterior_predictive_interval_level:

Predictive interval level used when generating the optional posterior-predictive report.

n_predictive_forward_model_samples:

Number of posterior draws passed through the forward model when generating the posterior-predictive summary. When None all n_posterior_samples draws are forwarded. Set this to a small value (e.g. 50–200) to avoid a multi-hour hang when n_posterior_samples is large.

seed:

Optional seed for posterior and posterior-predictive sampling.

Returns#

AmortizedRetrievalResult

Posterior samples and any optional predictive metadata.

posterior_predictive(posterior: petitRADTRANS.sbi.posterior.PosteriorSamples, observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], interval_level: float = 0.9, n_predictive_forward_model_samples: int | None = None, seed: int | None = None) petitRADTRANS.sbi.calibration.PosteriorPredictiveReport#

Generate posterior-predictive draws for a fitted observation.

Parameters#

posterior:

Posterior samples associated with one observed system.

observation_blocks:

Original observation blocks used for user-facing comparison.

interval_level:

Central predictive interval level to report.

n_predictive_forward_model_samples:

Number of posterior draws passed through the forward model. When None all samples in posterior are used. Set this to a small value (e.g. 50–200) to keep the number of expensive petitRADTRANS calls manageable.

seed:

Optional seed for predictive simulation.

Returns#

PosteriorPredictiveReport

Predictive summary for the supplied posterior samples.

diagnose_domain(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) OODDiagnostic | None#

Estimate whether the observation is in-distribution for the model.

Parameters#

observation_blocks:

Observation blocks to diagnose.

Returns#

OODDiagnostic | None

Robust support-distance diagnostic derived from preprocessing statistics, or None when preprocessing metadata is unavailable.

class petitRADTRANS.sbi.AmortizedRetrievalResult#

Return type for amortized inference queries.

Attributes#

posterior:

Posterior samples and any attached diagnostics.

posterior_predictive:

Optional posterior-predictive report for the same observation.

ood_diagnostic:

Optional in/out-of-distribution assessment.

metadata:

Additional metadata such as preprocessing usage during inference.

posterior: petitRADTRANS.sbi.posterior.PosteriorSamples#
posterior_predictive: Any = None#
ood_diagnostic: OODDiagnostic | None = None#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.OODDiagnostic#

Describe whether an observation is inside the training support.

Attributes#

score:

Scalar out-of-distribution score.

threshold:

Optional threshold used to convert the score into a pass/fail decision.

passed:

Optional boolean indicating whether the observation passed the OOD test.

metadata:

Auxiliary diagnostic metadata.

score: float#
threshold: float | None = None#
passed: bool | None = None#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.ObservationBlock#

Represent one modality-specific observation block.

Attributes:
name:

Stable identifier of the block within a task.

modality:

Semantic block type used to dispatch encoding logic.

values:

Observed values after any task-level preprocessing.

uncertainties:

Optional per-element uncertainty representation.

coordinates:

Optional coordinate arrays such as wavelengths or timestamps.

mask:

Optional mask applied to the values.

metadata:

Additional instrument and preprocessing metadata.

name: str#
modality: ObservationModality#
values: Any#
uncertainties: Any = None#
coordinates: Any = None#
mask: Any = None#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.ObservationEncoder#

Bases: abc.ABC

Transform structured observation blocks into model-ready embeddings.

abstractmethod encode(blocks: list[ObservationBlock]) EncodedObservation#

Encode a list of observation blocks into a shared representation.

batch_encode(observations: list[list[ObservationBlock]]) list[EncodedObservation]#

Encode multiple observations using repeated single-item encoding.

class petitRADTRANS.sbi.ObservationModality#

Bases: str, enum.Enum

Supported observation block types for SBI conditioning.

SPECTRUM = 'spectrum'#
PHOTOMETRY = 'photometry'#
TIME_SERIES = 'time_series'#
AUXILIARY = 'auxiliary'#
petitRADTRANS.sbi.build_observation_block(name: str, modality: ObservationModality | str, values: Any, uncertainties: Any = None, coordinates: Any = None, mask: Any = None, metadata: Mapping[str, Any] | None = None) ObservationBlock#

Build one observation block with modality normalization.

petitRADTRANS.sbi.build_observation_block_batch(observation_payloads: Mapping[str, Mapping[str, Any]], modalities: Mapping[str, str]) list[list[ObservationBlock]]#

Build observation blocks for each sample in a batched payload.

petitRADTRANS.sbi.build_observation_blocks_from_sample(observation_payloads: Mapping[str, Mapping[str, Any]], modalities: Mapping[str, str], sample_index: int) list[ObservationBlock]#

Build modality-aware observation blocks for one simulated sample.

class petitRADTRANS.sbi.ConditionalAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)#

Bases: ConditionalFlowPosterior

Posterior estimator specialized to the autoregressive flow backend.

class petitRADTRANS.sbi.ConditionalFlowPosterior(parameter_dim: int, embedding_dim: int = 128, num_coupling_layers: int = 4, hidden_dim: int = 128, conditioner_depth: int = 2, autoregressive_transform_units: int = 16, neural_autoregressive_min_slope: float = 0.001, neural_autoregressive_min_residual: float = 0.05, neural_autoregressive_inverse_bisection_steps: int = 48, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', flow_family: str = 'spline', num_spline_bins: int = 8, spline_bound: float = 10.0, base_distribution: str = 'gaussian', use_base_affine: bool = False, training_objective: str = 'npe', early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, gradient_clip_norm: float | None = 1.0, embedding_noise_std: float = 0.0, embedding_noise_min_scale: float = 0.0, aux_scale_loss_weight: float = 0.0, aux_parameter_loss_weight: float = 0.0, parameter_noise_floor: float = 0.0, spline_small_bin_regularization_weight: float | None = None, spline_min_bin_ratio_target: float = 0.2, spline_entropy_regularization_weight: float = 0.0, spline_derivative_regularization_weight: float = 0.0, spline_entropy_floor: float = 0.6, spline_min_derivative_target: float = 0.25, spline_max_derivative_target: float = 5.0, weight_decay: float = 0.0, use_cosine_schedule: bool = False, warmup_fraction: float = 0.02, warmup_epochs: float | None = None, min_learning_rate: float = 1e-06, lr_schedule_total_epochs: float | None = None, stable_inverse_forward_max_abs_error_threshold: float | None = 0.0001, stable_inverse_forward_logdet_closure_max_abs_error_threshold: float | None = 0.0001, stable_cube_edge_hit_rate_threshold: float | None = 0.0, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, encoder_hidden_dim: int | None = None, encoder_patch_size: int = 32, seed: int = 0, task_metadata: Mapping[str, Any] | None = None, verbose_diagnostics: bool = False, diagnostics_output_directory: str | None = None, diagnostics_plot_interval: int = 1)#

Bases: petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimator

Concrete amortized posterior using a conditional flow backend.

Parameters#

parameter_dim:

Number of inferred free parameters represented by the posterior.

embedding_dim:

Size of the learned observation embedding consumed by the flow.

num_coupling_layers:

Number of conditional coupling or spline-transform layers in the flow.

hidden_dim:

Hidden width used by encoder-side and flow-side MLP conditioners.

learning_rate:

Optimizer learning rate used by SBITrainer.

batch_size:

Number of simulations processed per optimization step.

num_epochs:

Maximum number of passes through the training split.

parameter_space:

Parameter coordinates learned by the posterior. Supported values are 'physical', 'cube', and 'unconstrained'.

flow_family:

Conditional density-transform family. Supported values are 'spline', 'affine', 'autoregressive', and 'neural_autoregressive'.

num_spline_bins:

Number of rational-quadratic spline bins when flow_family='spline'.

spline_bound:

Finite support bound of each spline transform in latent space.

early_stopping_patience:

Optional number of non-improving epochs tolerated before stopping.

early_stopping_min_delta:

Minimum improvement required for early-stopping comparisons.

checkpoint_directory:

Optional directory used to persist resumable trainer checkpoints.

checkpoint_backend:

Checkpoint persistence backend name. 'auto' selects Orbax when available and otherwise falls back to Equinox serialization.

resume_from_checkpoint:

Whether fit should attempt to resume from the latest checkpoint.

seed:

Base random seed used for flow initialization and posterior sampling.

task_metadata:

Optional user-supplied metadata persisted alongside the trained model.

Notes#

The posterior stores task fingerprinting, observation schema, and preprocessing payload information when those are available from the training dataset reader. That metadata is later reused by inference and artifact registration paths.

estimator_family = 'conditional_flow'#
embedding_dim = 128#
num_coupling_layers = 4#
hidden_dim = 128#
conditioner_depth = 2#
autoregressive_transform_units = 16#
neural_autoregressive_min_slope#
neural_autoregressive_min_residual#
neural_autoregressive_inverse_bisection_steps = 48#
learning_rate#
batch_size = 32#
num_epochs = 5#
flow_family = ''#
effective_flow_family = ''#
base_distribution = ''#
use_base_affine = False#
training_objective = ''#
num_spline_bins = 8#
early_stopping_patience = None#
early_stopping_min_delta#
checkpoint_directory = None#
checkpoint_backend = 'auto'#
resume_from_checkpoint = False#
gradient_clip_norm = 1.0#
embedding_noise_std#
embedding_noise_min_scale#
aux_scale_loss_weight#
aux_parameter_loss_weight#
parameter_noise_floor#
spline_entropy_regularization_weight#
spline_small_bin_regularization_weight#
spline_derivative_regularization_weight#
spline_min_bin_ratio_target#
spline_entropy_floor#
spline_min_derivative_target#
spline_max_derivative_target#
weight_decay#
use_cosine_schedule = False#
warmup_fraction#
warmup_epochs = None#
min_learning_rate#
lr_schedule_total_epochs = None#
stable_inverse_forward_max_abs_error_threshold = None#
stable_inverse_forward_logdet_closure_max_abs_error_threshold = None#
stable_cube_edge_hit_rate_threshold = None#
spectrum_encoder_type = ''#
n_wavelengths = 233#
spectrum_embedding_dim = 64#
photometry_embedding_dim = 64#
encoder_hidden_dim#
encoder_patch_size = 32#
verbose_diagnostics = False#
diagnostics_output_directory = None#
diagnostics_plot_interval = 1#
model#
_build_flow(key: jax.Array) Any#
static _batch_embeddings(model: _PosteriorModel, observations: Any) jax.numpy.ndarray#
static _loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
static _training_loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
static _loss_from_observations(model: _PosteriorModel, parameters: jax.numpy.ndarray, observations: Any, parameter_space: str, *, include_spline_regularization: bool) jax.numpy.ndarray#
_validation_diagnostics(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) dict[str, float]#
static _make_parameter_noise_loss(base_loss_fn: Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray], noise_floor: float, parameter_space: str) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#

Wrap a training loss so the parameter targets are jittered.

Minimizing the conditional NLL in a very-low-noise regime with a highly parameter-predictive embedding drives the learned conditional density toward a delta (NLL -> -inf), which an over-sharp flow realizes as exploded transforms (and, for spline flows, cube-edge collapse). A small Gaussian jitter on the parameter targets gives the conditional a hard minimum width, capping sharpness so the NLL has a finite minimum and the flow’s inverse stays well-conditioned.

The jitter is applied in unconstrained (logit) space – the space the flow actually models – not in cube space. Additive cube-space noise is asymmetric near the [0, 1] bounds: clipping piles jittered targets onto the edges, and the cube->logit Jacobian rewards edge mass, so it worsens the very edge-collapse it was meant to prevent. Jittering in logit space is symmetric and boundary-free. Applied only during training; evaluation and checkpoint selection use the clean targets.

static _make_elbo_loss(log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray], parameter_space: str, *, num_samples: int = 1) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#

Return an amortized-variational (ELBO) training loss.

Maximizes, for observations x drawn from the (prior-predictive) dataset,

ELBO(x) = E_{q(theta|x)}[ log p(x|theta) + log p(theta) - log q(theta|x) ]

with a single- (or few-) sample reparameterized estimator. q(theta|x) is the conditional flow: latents z are drawn from the flow base and pushed through flow.inverse to reparameterized samples, so the gradient flows through both the flow and the encoder.

The -log q (entropy) term diverges to -inf as q collapses to a point mass, so a delta posterior is structurally impossible – the width is set by the likelihood/entropy balance rather than by the (near-deterministic) embedding. log p(x|theta) is supplied by the injected differentiable forward-model likelihood evaluated at the physical parameters reconstructed from the sampled cube coordinates. With parameter_space='cube' the prior is uniform on the unit hypercube, so log p(theta_cube) = 0 and is dropped.

Parameters#

log_likelihood_fn:

(theta_cube, observations) -> (batch,) returning the Gaussian observational log-likelihood log p(x | theta) for each batch element. Receives unit-cube coordinates (shape (batch, dim)) and the batch’s prestacked observations; it owns the cube->physical transform, the differentiable forward model, and the noise model. Must be JAX-differentiable w.r.t. theta_cube.

parameter_space:

Must be "cube".

num_samples:

Number of reparameterized posterior draws per observation used to estimate the ELBO expectation. 1 is standard for amortized VI; larger values reduce gradient variance at a proportional forward-model cost.

static _make_noisy_loss(noise_std: float, preprocessing_metadata: Any = None, *, noise_min_scale: float = 0.0, include_spline_regularization: bool = True) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#

Return a loss function that injects on-the-fly Gaussian noise into the observation spectra before encoding.

For spectra preprocessed into a per-sample transformed value space, the uncertainty channel is assumed to already be expressed in the same transformed units as the value channel, so noise can be injected directly from that channel. When preprocessing metadata is provided, non-spectral blocks still reconstruct raw measurement uncertainties so noise is injected at the correct physical scale in normalized-value space. noise_std acts as a multiplier on the observational noise level (1.0 = exact observational noise, >1 = regularising over-noise), and noise_min_scale is treated as a minimum fraction of the typical transformed uncertainty for the block rather than as an absolute floor in normalized-value units.

When preprocessing metadata is unavailable, a flat Gaussian noise with scale noise_std is applied as a fallback.

fit(dataset: Any, *, elbo_log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray] | None = None, elbo_num_samples: int = 1) petitRADTRANS.sbi.posterior_base.TrainingArtifacts#

Train the posterior on one normalized simulation dataset reader.

Parameters#

dataset:

Reader-like object that yields PosteriorBatch instances via iter_batches and exposes dataset manifest metadata when available.

elbo_log_likelihood_fn:

Required when training_objective='elbo'. A differentiable (theta_cube, observations) -> (batch,) callable returning the Gaussian observational log-likelihood through the forward model. See _make_elbo_loss(). Ignored for the default NPE objective.

elbo_num_samples:

Number of reparameterized posterior draws per observation for the ELBO estimator (default 1).

Returns#

TrainingArtifacts

Training history, validation metrics, and trainer metadata for the completed optimization run.

Notes#

If the reader exposes preprocessing metadata or a manifest fingerprint, those are cached on the posterior so they can be saved and reused by the inference and artifact layers.

_build_estimator_config() dict[str, Any]#

Return backend-specific configuration for metadata persistence.

_build_serialized_metadata(artifact_metadata) dict[str, Any]#
static _resolve_estimator_config(metadata: Mapping[str, Any]) dict[str, Any]#
hydrate_loaded_metadata(metadata: Mapping[str, Any]) None#
classmethod from_serialized_metadata(metadata: Mapping[str, Any]) ConditionalFlowPosterior#

Rebuild an estimator instance from persisted metadata only.

save_backend_state(output_path: pathlib.Path) None#

Persist backend-specific model state into the output directory.

load_backend_state(input_path: pathlib.Path) None#

Restore backend-specific model state from the input directory.

encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#

Encode one structured observation into the posterior context space.

Parameters#

blocks:

Observation blocks describing one spectral/photometric observation.

Returns#

EncodedObservation

Aggregated embedding and lightweight metadata describing the input observation family.

batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) list[petitRADTRANS.sbi.observation.EncodedObservation]#

Encode a batch of observations using a single vmapped forward pass.

Parameters#

blocks_list:

List of per-sample observation block lists.

Returns#

list[EncodedObservation]

One encoded observation per input sample.

sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) petitRADTRANS.sbi.posterior_base.PosteriorSamples#

Draw posterior samples conditioned on one encoded observation.

Parameters#

observation:

Encoded observation produced by encode_observation().

n_samples:

Number of posterior draws to generate.

seed:

Optional random seed overriding the model-level default seed.

Returns#

PosteriorSamples

Samples in the posterior’s configured parameter space. Non-finite outputs are clamped to large finite values for downstream stability.

batch_sample_posterior(embeddings: Any, n_samples: int, base_seed: int = 0) numpy.ndarray#

Draw posterior samples for a batch of encoded observations.

Runs a single JIT-compiled vmapped call over all contexts rather than looping sample_posterior once per observation, eliminating the Python overhead of per-observation dispatch.

Parameters#

embeddings:

Float32 array of shape (batch_size, embedding_dim) produced by stacking EncodedObservation.embedding vectors.

n_samples:

Number of posterior draws per observation.

base_seed:

Base random seed. Each observation in the batch receives a unique sub-key derived from this seed.

Returns#

np.ndarray

Array of shape (batch_size, n_samples, parameter_dim) with non-finite values clamped to large finite values.

log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#

Evaluate posterior log-density for one or many parameter vectors.

Parameters#

observation:

Encoded observation that defines the posterior context.

parameters:

One parameter vector or a batch of parameter vectors in the posterior’s configured parameter space.

Returns#

Any

Scalar log-density or a vector of log-densities matching the input batch structure.

class petitRADTRANS.sbi.ConditionalNeuralAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)#

Bases: ConditionalFlowPosterior

Posterior estimator specialized to the neural autoregressive backend.

class petitRADTRANS.sbi.ConditionalSplineFlowPosterior(*args: Any, **kwargs: Any)#

Bases: ConditionalFlowPosterior

Posterior estimator specialized to the spline flow backend.

Notes#

This convenience subclass forces flow_family='spline' while keeping the rest of the ConditionalFlowPosterior API unchanged.

class petitRADTRANS.sbi.FlowMatchingPosterior(parameter_dim: int, embedding_dim: int = 128, hidden_dim: int = 128, num_velocity_layers: int = 3, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', integration_steps: int = 32, early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, seed: int = 0, task_metadata: Mapping[str, Any] | None = None)#

Bases: petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimator

Conditional flow-matching posterior skeleton.

Warning

This estimator is experimental. It does not expose log_prob and the ODE integration scheme is a simple midpoint rule. Expect the API and numerical behaviour to change in future releases.

This estimator family trains a conditional vector field on straight-line interpolation paths between Gaussian noise and target parameters, then generates posterior samples by integrating the learned field from noise to the terminal parameter state.

estimator_family = 'flow_matching'#
embedding_dim = 128#
hidden_dim = 128#
num_velocity_layers = 3#
learning_rate#
batch_size = 32#
num_epochs = 5#
integration_steps = 32#
early_stopping_patience = None#
early_stopping_min_delta#
checkpoint_directory = None#
checkpoint_backend = 'auto'#
resume_from_checkpoint = False#
model#
static _batch_embeddings(model: _FlowMatchingModel, observations: Any) jax.numpy.ndarray#
static _loss(model: _FlowMatchingModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
fit(dataset: Any) petitRADTRANS.sbi.posterior_base.TrainingArtifacts#

Train the posterior estimator on a simulation dataset.

_build_estimator_config() dict[str, Any]#

Return backend-specific configuration for metadata persistence.

static _resolve_estimator_config(metadata: Mapping[str, Any]) dict[str, Any]#
_build_serialized_metadata(artifact_metadata) dict[str, Any]#
classmethod from_serialized_metadata(metadata: Mapping[str, Any]) FlowMatchingPosterior#

Rebuild an estimator instance from persisted metadata only.

save_backend_state(output_path: pathlib.Path) None#

Persist backend-specific model state into the output directory.

load_backend_state(input_path: pathlib.Path) None#

Restore backend-specific model state from the input directory.

encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#

Encode a structured observation into the estimator input space.

batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) list[petitRADTRANS.sbi.observation.EncodedObservation]#

Encode a batch of observations using a single vmapped forward pass.

sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) petitRADTRANS.sbi.posterior_base.PosteriorSamples#

Sample the amortized posterior for one encoded observation.

log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#

Evaluate posterior log-density when supported by the backend.

class petitRADTRANS.sbi.PersistentPosteriorEstimator(parameter_dim: int, parameter_space: str = 'unconstrained', seed: int = 0, task_metadata: Mapping[str, Any] | None = None)#

Bases: PosteriorEstimator

Shared persistence helper for estimator backends with on-disk artifacts.

estimator_family = 'persistent_estimator'#
metadata_schema_version = '0.2.0'#
parameter_dim#
parameter_space = 'unconstrained'#
seed = 0#
task_metadata#
training_artifacts: TrainingArtifacts | None = None#
task_name: str | None#
task_version: str | None = None#
task_fingerprint: str | None = None#
observation_schema: Mapping[str, Any]#
preprocessing_metadata_payload: Mapping[str, Any]#
artifact_metadata: petitRADTRANS.sbi.artifacts.ArtifactMetadata | None = None#
abstractmethod _build_estimator_config() dict[str, Any]#

Return backend-specific configuration for metadata persistence.

classmethod from_serialized_metadata(metadata: Mapping[str, Any]) PersistentPosteriorEstimator#
Abstractmethod:

Rebuild an estimator instance from persisted metadata only.

abstractmethod save_backend_state(output_path: pathlib.Path) None#

Persist backend-specific model state into the output directory.

abstractmethod load_backend_state(input_path: pathlib.Path) None#

Restore backend-specific model state from the input directory.

static _load_training_artifacts(metadata: Mapping[str, Any]) TrainingArtifacts | None#
hydrate_loaded_metadata(metadata: Mapping[str, Any]) None#
_build_artifact_metadata_payload() dict[str, Any]#
build_artifact_metadata(version: str) petitRADTRANS.sbi.artifacts.ArtifactMetadata#

Assemble registry metadata for the currently trained estimator.

_build_serialized_metadata(artifact_metadata: petitRADTRANS.sbi.artifacts.ArtifactMetadata) dict[str, Any]#
save(output_directory: str, artifact_registry: petitRADTRANS.sbi.artifacts.ArtifactRegistry | None = None, artifact_version: str = '0.1.0') None#

Persist model weights, metadata, and optional artifact registration.

classmethod load(input_directory: str) PersistentPosteriorEstimator#

Restore a saved persistent estimator from disk.

class petitRADTRANS.sbi.PosteriorBatch#

Training batch passed to amortized posterior estimators.

parameters: Any#
observations: Any#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.PosteriorEstimator#

Bases: abc.ABC

Backend-agnostic interface for amortized posterior models.

abstractmethod fit(dataset: Any) TrainingArtifacts#

Train the posterior estimator on a simulation dataset.

abstractmethod encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#

Encode a structured observation into the estimator input space.

abstractmethod sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) PosteriorSamples#

Sample the amortized posterior for one encoded observation.

abstractmethod log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#

Evaluate posterior log-density when supported by the backend.

abstractmethod save(output_directory: str) None#

Persist trained model weights and metadata.

classmethod load(input_directory: str) PosteriorEstimator#
Abstractmethod:

Restore a saved estimator from disk.

class petitRADTRANS.sbi.PosteriorSamples#

Posterior samples and optional per-sample diagnostics.

samples: Any#
log_probabilities: Any = None#
weights: Any = None#
metadata: Mapping[str, Any]#
petitRADTRANS.sbi.plot_local_sensitivity_fisher_correlations(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#

Plot Fisher-correlation heatmaps for each representative point.

petitRADTRANS.sbi.plot_local_sensitivity_jacobians(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#

Plot whitened Jacobian heatmaps for each representative point.

petitRADTRANS.sbi.plot_local_sensitivity_singular_values(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#

Plot singular spectra of the whitened Jacobian for each point.

petitRADTRANS.sbi.plot_posterior_corner(samples: Any, parameter_names: Sequence[str] | None = None, bins: int = 40, figsize: tuple[float, float] | None = None, max_points: int = 8192) tuple[Any, numpy.ndarray]#

Plot a lower-triangular corner view of posterior structure.

Parameters#

samples:

Posterior samples with shape (n_samples, n_dim) or a scalar-vector equivalent.

parameter_names:

Optional display names for each posterior dimension.

bins:

Number of bins used for one- and two-dimensional histograms.

figsize:

Optional Matplotlib figure size.

max_points:

Maximum number of posterior draws plotted. Larger sample sets are evenly subsampled to keep rendering costs bounded.

Returns#

tuple[Any, np.ndarray]

Matplotlib figure and axes array for further customization or saving.

petitRADTRANS.sbi.plot_posterior_marginals(samples: Any, parameter_names: Sequence[str] | None = None, bins: int = 40, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#

Plot one histogram per posterior dimension.

Parameters#

samples:

Posterior samples with shape (n_samples, n_dim) or a scalar-vector equivalent.

parameter_names:

Optional display names for each posterior dimension.

bins:

Number of histogram bins per dimension.

figsize:

Optional Matplotlib figure size.

Returns#

tuple[Any, np.ndarray]

Matplotlib figure and axes array for further customization or saving.

petitRADTRANS.sbi.plot_posterior_predictive_report(report: petitRADTRANS.sbi.calibration.PosteriorPredictiveReport, dataset_names: Sequence[str] | None = None, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#

Plot observed values against posterior-predictive means and intervals.

Parameters#

report:

Posterior-predictive report to visualize.

dataset_names:

Optional subset of dataset names to plot. Defaults to all datasets in the report.

figsize:

Optional Matplotlib figure size.

Returns#

tuple[Any, np.ndarray]

Matplotlib figure and axes array.

Notes#

For batched reports the helper currently visualizes only the first case for each dataset, which keeps the plot compact for quick inspection.

petitRADTRANS.sbi.plot_sbc_rank_histograms(report: petitRADTRANS.sbi.calibration.SimulationBasedCalibrationReport, parameter_names: Sequence[str] | None = None, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#

Plot one SBC rank histogram per inferred parameter.

Parameters#

report:

SBC report containing per-parameter rank histogram counts.

parameter_names:

Optional display names for each inferred parameter.

figsize:

Optional Matplotlib figure size.

Returns#

tuple[Any, np.ndarray]

Matplotlib figure and axes array.

class petitRADTRANS.sbi.TaskPreprocessingMetadata#

Serializable preprocessing metadata for an SBI task family.

version: str#
blocks: Mapping[str, BlockNormalizationStats]#
metadata: Mapping[str, Any]#
to_payload() dict[str, Any]#
classmethod from_payload(payload: Mapping[str, Any]) TaskPreprocessingMetadata#
petitRADTRANS.sbi.fit_task_preprocessing(training_observations: list[list[petitRADTRANS.sbi.observation.ObservationBlock]], version: str = '0.5.0', metadata: Mapping[str, Any] | None = None) TaskPreprocessingMetadata#

Fit preprocessing statistics from training observation blocks.

petitRADTRANS.sbi.normalize_observation_block(block: petitRADTRANS.sbi.observation.ObservationBlock, preprocessing_metadata: TaskPreprocessingMetadata) petitRADTRANS.sbi.observation.ObservationBlock#

Normalize one observation block with fitted preprocessing statistics.

Uses robust (median/IQR) normalization for values and uncertainties when available, falling back to mean/std for backward compatibility with older preprocessing metadata. Coordinates always use mean/std normalization.

petitRADTRANS.sbi.normalize_observation_blocks(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], preprocessing_metadata: TaskPreprocessingMetadata) list[petitRADTRANS.sbi.observation.ObservationBlock]#

Normalize a list of observation blocks.

class petitRADTRANS.sbi.ProposalSampler#

Bases: abc.ABC

Interface for simulation proposals beyond the prior distribution.

abstractmethod sample(n_samples: int, task: petitRADTRANS.sbi.task.SBITask) Any#

Draw free-parameter vectors for the given task.

class petitRADTRANS.sbi.RuntimeSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None, seed: int | None = None, data_parallel: bool | None = None)#

Bases: BatchedSimulator

Concrete simulator backed by the retrieval runtime.

The simulator samples from the retrieval prior using JAX random utilities and projects deterministic forward-model outputs into the observation space using RetrievalRuntime.

_rng_key#
_validate_task_support() None#
_next_key() jax.Array#
static _row_invalid_value_mask(values: Any, mask: Any, constraint: Any) numpy.ndarray#
_invalid_sample_mask(observations: Mapping[str, Mapping[str, Any]], n_samples: int) tuple[numpy.ndarray, dict[str, int]]#
static _slice_batch_rows(batch: SimulationBatch, row_indices: numpy.ndarray) SimulationBatch#
static _concatenate_simulation_batches(batches: list[SimulationBatch], diagnostics: Mapping[str, Any] | None = None) SimulationBatch#
_sample_prior_parameter_batch(n_samples: int) dict[str, jax.numpy.ndarray]#
sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#

Sample free-parameter vectors from the task prior or a proposal.

_advance_rng_for_batch(n_samples: int) None#

Advance the PRNG state as if simulate_batch(n_samples) were called.

This performs only cheap key-splitting – no forward-model evaluation – making it suitable for fast dataset-resume skipping.

The number of _next_key calls matches the pattern inside simulate_batch → _sample_prior_parameter_batch (1 key) and _simulate_parameter_matrix (1 key per observation that adds noise).

_apply_noise(values: Any, uncertainties: Any, covariance: Any) tuple[Any, Any, Any]#
_apply_noise_batched(values_batch: Any, uncertainties_batch: Any, covariance_batch: Any) tuple[Any, Any, Any]#

Apply noise to a batch of simulation outputs.

Args:

values_batch: Shape (n_samples, n_wavelengths). uncertainties_batch: Shape (n_samples, n_wavelengths) or None. covariance_batch: Shape (n_samples, n_wl, n_wl) or None.

Returns:

Noisy values, uncertainties, and covariance with the same leading batch dimension as values_batch.

static _stack_observation_payloads(payloads: list[dict[str, Any]]) dict[str, Any]#
_simulate_parameter_matrix(parameter_matrix: Any, cube_parameters: Any = None, unconstrained_parameters: Any = None, data_parallel: bool = False) SimulationBatch#
simulate_from_parameters(parameters: Any) SimulationBatch#

Run the forward model and noise pipeline for pre-specified parameters.

simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#

Sample parameters and preserve prior-space coordinates when available.

class petitRADTRANS.sbi.SimulationBatch#

Container for one simulated batch.

Attributes:
parameters:

Array-like free-parameter matrix with leading dimension equal to the number of simulated samples.

observations:

Task-conditioned simulated observations.

log_likelihood:

Optional scalar likelihood values associated with each sample.

diagnostics:

Additional runtime diagnostics such as clipping flags or NaN counts.

parameters: Any#
observations: Any#
cube_parameters: Any = None#
unconstrained_parameters: Any = None#
log_likelihood: Any = None#
diagnostics: Mapping[str, Any]#
property n_samples: int#

Return the number of simulated samples represented by the batch.

class petitRADTRANS.sbi.Simulator(task: petitRADTRANS.sbi.task.SBITask)#

Bases: abc.ABC

Base simulator for SBI dataset generation and validation.

task#
abstractmethod sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#

Sample free-parameter vectors from the task prior or a proposal.

abstractmethod simulate_from_parameters(parameters: Any) SimulationBatch#

Run the forward model and noise pipeline for pre-specified parameters.

simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#

Sample parameters and simulate one batch in a single call.

class petitRADTRANS.sbi.NoiseModelConfig#

Describe how observational noise is injected during simulation.

Attributes:
mode:

Short identifier of the noise model implementation.

parameters:

Backend-specific parameters used to instantiate the noise model.

seed:

Optional deterministic seed for repeatable simulation pipelines.

mode: str = 'observational'#
parameters: Mapping[str, Any]#
seed: int | None = None#
class petitRADTRANS.sbi.ObservationSchema#

Capture the supported observation family for an amortized task.

Attributes:
dataset_names:

Dataset identifiers included in the task.

modalities:

Mapping from dataset name to a short modality label such as 'spectrum', 'photometry', or 'time_series'.

metadata:

Additional task-level information required by encoders or benchmarks, such as instrument names or wavelength coverage.

dataset_names: tuple[str, Ellipsis]#
modalities: Mapping[str, str]#
metadata: Mapping[str, Any]#
class petitRADTRANS.sbi.ObservationValueConstraint#

Admissible range for simulated observation values.

Attributes:
min_value:

Optional lower bound for valid simulated values.

max_value:

Optional upper bound for valid simulated values.

min_inclusive:

Whether min_value is included in the valid interval.

max_inclusive:

Whether max_value is included in the valid interval.

min_value: float | None = None#
max_value: float | None = None#
min_inclusive: bool = True#
max_inclusive: bool = True#
to_payload() dict[str, Any]#
class petitRADTRANS.sbi.SBITask#

Bundle the immutable ingredients required for an SBI problem.

The task owns the retrieval configuration, canonical parameter layout, observation schema, and simulation policy needed to generate training data and evaluate amortized posteriors.

name: str#
retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig#
parameter_layout: petitRADTRANS.retrieval.runtime.ParameterLayout#
observation_schema: ObservationSchema#
simulation_config: SimulationConfig#
task_version: str = '0.1.0'#
preprocessing_version: str = '0.5.0'#
forward_model_family: str = 'unknown'#
petitradtrans_version: str#
metadata: Mapping[str, Any]#
classmethod from_retrieval_config(retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig, simulation_config: SimulationConfig | None = None, metadata: Mapping[str, Any] | None = None) SBITask#

Create an SBI task from an existing retrieval configuration.

Parameters#

retrieval_config:

Retrieval configuration describing the free parameters, observation datasets, and runtime-native forward model family.

simulation_config:

Optional simulation policy overriding the default prior-predictive batch configuration.

metadata:

Optional task-level metadata. Reserved keys such as task_version, preprocessing_version, and forward_model_family override the default inferred values.

Returns#

SBITask

Immutable SBI task description with parameter layout, observation schema, and deterministic task fingerprint payload.

Notes#

The task inspects the configured retrieval datasets to infer modalities and a forward-model-family string that later participates in artifact compatibility checks.

property parameter_names: tuple[str, Ellipsis]#

Return the free parameter names in canonical task order.

property line_opacity_modes: Mapping[str, str]#

Return the line-opacity mode for each configured dataset.

property fingerprint_payload: Mapping[str, Any]#

Return the deterministic payload used to fingerprint the task family.

property task_fingerprint: petitRADTRANS.sbi.compatibility.TaskFingerprint#

Return the deterministic fingerprint of the task family.

validate_observation_schema(observation_schema: ObservationSchema) petitRADTRANS.sbi.compatibility.CompatibilityReport#

Compare an external observation schema with the task expectation.

Parameters#

observation_schema:

Candidate schema to compare against the task’s required datasets and modalities.

Returns#

CompatibilityReport

Structured compatibility report containing mismatches, if any.

validate_artifact_metadata(artifact_metadata: Any) petitRADTRANS.sbi.compatibility.CompatibilityReport#

Compare a task with persisted artifact metadata.

Parameters#

artifact_metadata:

Artifact-like object exposing task and preprocessing provenance fields. Any is accepted to avoid a hard dependency cycle.

Returns#

CompatibilityReport

Structured report describing whether the artifact is compatible with the current task definition.

The method accepts Any to avoid a hard import dependency back from the task module to the artifact module.

build_observation_states(model_contract: str = 'differentiable') dict[str, petitRADTRANS.retrieval.runtime.ObservationState]#

Materialize immutable runtime observations for the task datasets.

Parameters#

model_contract:

Runtime contract requested from each retrieval dataset.

Returns#

dict[str, ObservationState]

Mapping from dataset name to immutable runtime observation state.

build_runtime() petitRADTRANS.retrieval.runtime.RetrievalRuntime#

Construct a retrieval runtime matching the task definition.

Returns#

RetrievalRuntime

Runtime object configured with the task’s parameter layout and retrieval datasets.

The default implementation mirrors the retrieval-runtime grouping logic used by the exact retrieval package and is sufficient for initial SBI simulation backends.

static _infer_modality(data: petitRADTRANS.retrieval.data.Data) str#
static _infer_forward_model_family(retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig) str#
class petitRADTRANS.sbi.SimulationConfig#

Control how a task generates prior-predictive simulations.

Attributes:
batch_size:

Number of parameter points produced per simulator call.

n_simulations:

Optional target number of simulations for dataset generation jobs.

use_vectorized_runtime:

Whether to prefer the JAX-vectorized runtime path when available.

store_per_datapoint_log_likelihood:

Persist log-likelihood components alongside simulated observations.

noise_model:

Noise model configuration applied after forward-model evaluation.

observation_value_constraints:

Optional per-dataset admissible value ranges enforced on deterministic projected observations before simulations are accepted.

batch_size: int = 256#
n_simulations: int | None = None#
use_vectorized_runtime: bool = True#
store_per_datapoint_log_likelihood: bool = False#
noise_model: NoiseModelConfig#
observation_value_constraints: Mapping[str, ObservationValueConstraint]#
class petitRADTRANS.sbi.EarlyStoppingConfig#

Early stopping policy for SBI training.

patience: int#
min_delta: float = 0.0#
class petitRADTRANS.sbi.SBITrainer(config: TrainingConfig)#

Reusable optimization loop for amortized SBI posteriors.

config#
checkpoint_directory = None#
checkpoint_backend#
_checkpoint_backend_fallback_reason: str | None = None#
property _latest_checkpoint_directory: pathlib.Path | None#
property _best_checkpoint_directory: pathlib.Path | None#
_checkpoint_directory_for_kind(checkpoint_kind: str) pathlib.Path | None#
_save_checkpoint(checkpoint_kind: str, state: dict[str, Any], metadata: dict[str, Any]) None#
_restore_checkpoint(template_state: dict[str, Any], checkpoint_kind: str) tuple[dict[str, Any], dict[str, Any]] | None#
fit(model: Any, dataset: Any, loss_fn: Callable[[Any, Any], Any], eval_loss_fn: Callable[[Any, Any], Any] | None = None, eval_diagnostic_fn: Callable[[Any, Any], Mapping[str, float]] | None = None, selection_metric_fn: Callable[[float, Mapping[str, float] | None], float] | None = None, selection_metric_name: str | None = None, stability_metric_fn: Callable[[float, Mapping[str, float] | None], Mapping[str, float]] | None = None, stability_flag_key: str = 'checkpoint_is_stable') tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]#

Optimize one posterior model against a dataset reader.

Parameters#

model:

Trainable Equinox-style model to optimize.

dataset:

Reader object exposing iter_batches for train and optional validation splits.

loss_fn:

Callable receiving (model, batch) and returning the scalar loss to minimize.

eval_loss_fn:

Optional loss used for validation. Defaults to loss_fn when not provided.

Returns#

tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]

Best model, train-loss history, validation-loss history, and a metadata dictionary describing checkpointing and stopping behavior.

Notes#

The trainer monitors validation loss when a validation split exists and otherwise falls back to training loss for best-model selection. Checkpoints may be written after each epoch when enabled.

property _min_delta: float#
_should_stop(patience_counter: int) bool#
class petitRADTRANS.sbi.TrainingConfig#

Configuration for posterior optimization.

learning_rate: float = 0.001#
batch_size: int = 32#
num_epochs: int = 10#
parameter_space: str = 'unconstrained'#
seed: int = 0#
shuffle_train: bool = True#
early_stopping: EarlyStoppingConfig | None = None#
checkpoint_directory: str | None = None#
checkpoint_backend: str = 'auto'#
resume_from_checkpoint: bool = False#
data_parallel: bool = False#
gradient_clip_norm: float | None = 1.0#
weight_decay: float = 0.0#
use_cosine_schedule: bool = False#
warmup_fraction: float = 0.02#
warmup_epochs: float | None = None#
min_learning_rate: float = 1e-06#
lr_schedule_total_epochs: float | None = None#
verbose_diagnostics: bool = False#
diagnostics_output_directory: str | None = None#
diagnostics_plot_interval: int = 1#
n_validation_diagnostic_batches: int = 4#