petitRADTRANS.sbi#
Simulation-based inference interfaces for petitRADTRANS.
The modules in petitRADTRANS.sbi provide a production-oriented
architecture for amortized inference workflows built on top of the existing
retrieval runtime. The package is intentionally lightweight at this stage: it
defines the core task, simulation, dataset, posterior, and benchmarking
interfaces without committing to one training backend or storage engine.
Submodules#
- petitRADTRANS.sbi.artifacts
- petitRADTRANS.sbi.benchmark
- petitRADTRANS.sbi.calibration
- petitRADTRANS.sbi.compatibility
- petitRADTRANS.sbi.dataset
- petitRADTRANS.sbi.encoders
- petitRADTRANS.sbi.estimator_registry
- petitRADTRANS.sbi.flow_matching_posterior
- petitRADTRANS.sbi.flow_posterior
- petitRADTRANS.sbi.flows
- petitRADTRANS.sbi.inference
- petitRADTRANS.sbi.observation
- petitRADTRANS.sbi.plotting
- petitRADTRANS.sbi.posterior
- petitRADTRANS.sbi.posterior_base
- petitRADTRANS.sbi.preprocessing
- petitRADTRANS.sbi.simulator
- petitRADTRANS.sbi.task
- petitRADTRANS.sbi.training
Exceptions#
Raised when an observation or artifact is incompatible with a task. |
Classes#
Compare an amortized result to one or more exact retrieval baselines. |
|
Metrics summarizing agreement and predictive performance. |
|
One benchmark problem used to compare inference backends. |
|
Run standardized benchmark comparisons for SBI tasks. |
|
Local linear-identifiability summary around one representative point. |
|
Aggregate local information-content diagnostics for one observation. |
|
Aggregate posterior-predictive summaries for held-out observations. |
|
Rank-based SBC summary over a held-out set of observations. |
|
HDF5-backed store for simulation corpora. |
|
Backend-independent interface for reading and writing simulation data. |
|
Read-only handle for a stored chunked simulation dataset. |
|
Concrete Zarr-backed store for chunked simulation corpora. |
|
Named dataset partitions used during training and evaluation. |
|
Lightweight reader yielding normalized ObservationBlock batches for training. |
|
Learned hierarchical encoder dispatching over modalities. |
|
Learned photometry encoder using per-point MLP features and pooling. |
|
Learned spectral encoder using 1D convolutions to retain positional information. |
|
Learned spectral encoder based on patch summaries and MLP pooling. |
|
Serve trained SBI models through a retrieval-like interface. |
|
Return type for amortized inference queries. |
|
Describe whether an observation is inside the training support. |
|
Represent one modality-specific observation block. |
|
Transform structured observation blocks into model-ready embeddings. |
|
Supported observation block types for SBI conditioning. |
|
Posterior estimator specialized to the autoregressive flow backend. |
|
Concrete amortized posterior using a conditional flow backend. |
|
Posterior estimator specialized to the neural autoregressive backend. |
|
Posterior estimator specialized to the spline flow backend. |
|
Conditional flow-matching posterior skeleton. |
|
Shared persistence helper for estimator backends with on-disk artifacts. |
|
Training batch passed to amortized posterior estimators. |
|
Backend-agnostic interface for amortized posterior models. |
|
Posterior samples and optional per-sample diagnostics. |
|
Serializable preprocessing metadata for an SBI task family. |
|
Interface for simulation proposals beyond the prior distribution. |
|
Concrete simulator backed by the retrieval runtime. |
|
Container for one simulated batch. |
|
Base simulator for SBI dataset generation and validation. |
|
Describe how observational noise is injected during simulation. |
|
Capture the supported observation family for an amortized task. |
|
Admissible range for simulated observation values. |
|
Bundle the immutable ingredients required for an SBI problem. |
|
Control how a task generates prior-predictive simulations. |
|
Early stopping policy for SBI training. |
|
Reusable optimization loop for amortized SBI posteriors. |
|
Configuration for posterior optimization. |
Functions#
Diagnose local physical identifiability around representative posterior points. |
|
Generate posterior-predictive summaries for a held-out dataset split. |
|
|
Run SBC over a dataset reader using normalized observation batches. |
|
Convert a local sensitivity report into a JSON-serializable payload. |
|
Generate and persist a simulation corpus in one call. |
Load a saved posterior estimator without naming its concrete class. |
|
|
Build one observation block with modality normalization. |
Build observation blocks for each sample in a batched payload. |
|
Build modality-aware observation blocks for one simulated sample. |
|
Plot Fisher-correlation heatmaps for each representative point. |
|
|
Plot whitened Jacobian heatmaps for each representative point. |
|
Plot singular spectra of the whitened Jacobian for each point. |
|
Plot a lower-triangular corner view of posterior structure. |
|
Plot one histogram per posterior dimension. |
|
Plot observed values against posterior-predictive means and intervals. |
|
Plot one SBC rank histogram per inferred parameter. |
|
Fit preprocessing statistics from training observation blocks. |
Normalize one observation block with fitted preprocessing statistics. |
|
Normalize a list of observation blocks. |
Package Contents#
- class petitRADTRANS.sbi.BenchmarkComparison#
Compare an amortized result to one or more exact retrieval baselines.
- case_name: str#
- amortized_result: petitRADTRANS.sbi.inference.AmortizedRetrievalResult#
- exact_results: Mapping[str, Any]#
- metrics: BenchmarkMetrics#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.BenchmarkMetrics#
Metrics summarizing agreement and predictive performance.
- calibration: Mapping[str, float]#
- posterior_distance: Mapping[str, float]#
- predictive_checks: Mapping[str, float]#
- runtime: Mapping[str, float]#
- class petitRADTRANS.sbi.RetrievalBenchmarkCase#
One benchmark problem used to compare inference backends.
- name: str#
- observation: Any#
- reference_posterior: Any = None#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.RetrievalBenchmarkSuite(cases: list[RetrievalBenchmarkCase])#
Run standardized benchmark comparisons for SBI tasks.
- cases#
- abstractmethod run_case(case: RetrievalBenchmarkCase) BenchmarkComparison#
Run one benchmark case and compute comparison metrics.
Parameters#
- case:
Benchmark case describing the task, observation, and optional exact reference posterior to compare against.
Returns#
- BenchmarkComparison
Comparison payload combining amortized and exact results together with any derived metrics.
Notes#
The base class is intentionally abstract. Concrete suites are expected to bind exact and amortized inference backends and define the metric computations appropriate for the comparison.
- run_all() list[BenchmarkComparison]#
Run all configured benchmark cases.
- class petitRADTRANS.sbi.LocalSensitivityPointReport#
Local linear-identifiability summary around one representative point.
Attributes#
- label:
Short human-readable label for the representative point.
- parameters:
Physical parameter vector at which the Jacobian was evaluated.
- finite_difference_steps:
Per-parameter finite-difference step sizes used during Jacobian construction.
- finite_difference_schemes:
Per-parameter scheme labels such as
'central'or'forward'.- whitened_jacobian:
Observation Jacobian divided by observational uncertainty, with shape
(n_observation_values, n_parameters).- singular_values:
Singular values of the whitened Jacobian.
- effective_rank:
Number of singular values larger than the configured relative cutoff.
- condition_number:
Condition number inferred from the singular spectrum.
- fisher_matrix:
Approximate Fisher information matrix
J^T J.- fisher_covariance:
Damped pseudo-inverse of the Fisher matrix.
- fisher_correlation:
Correlation matrix derived from the approximate Fisher covariance.
- parameter_sensitivity_norm:
Per-parameter square-root Fisher diagonal.
- local_sigma:
Approximate local standard deviation from the Fisher covariance.
- metadata:
Auxiliary diagnostics such as the ridge term and failed columns.
- label: str#
- parameters: numpy.ndarray#
- finite_difference_steps: numpy.ndarray#
- finite_difference_schemes: tuple[str, Ellipsis]#
- whitened_jacobian: numpy.ndarray#
- singular_values: numpy.ndarray#
- effective_rank: int#
- condition_number: float#
- fisher_matrix: numpy.ndarray#
- fisher_covariance: numpy.ndarray#
- fisher_correlation: numpy.ndarray#
- parameter_sensitivity_norm: numpy.ndarray#
- local_sigma: numpy.ndarray#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.LocalSensitivityReport#
Aggregate local information-content diagnostics for one observation.
Attributes#
- parameter_names:
Parameter ordering used throughout the report.
- posterior_mean:
Posterior mean in physical parameter space.
- posterior_std:
Posterior standard deviation in physical parameter space.
- posterior_median:
Posterior median in physical parameter space.
- posterior_iqr:
Posterior interquartile range in physical parameter space.
- representative_points:
Local sensitivity summaries evaluated at representative posterior points such as the posterior mean and highest-density sample.
- aggregate_local_sigma:
Median local Fisher sigma across representative points.
- aggregate_parameter_sensitivity_norm:
Median parameter sensitivity norm across representative points.
- posterior_to_local_sigma_ratio:
Ratio between posterior standard deviation and local Fisher sigma.
- parameter_diagnostics:
Per-parameter heuristic summary separating weak data constraints from broader-than-local posterior structure.
- metadata:
Auxiliary metadata such as quantile levels and observation slices.
- parameter_names: tuple[str, Ellipsis]#
- posterior_mean: numpy.ndarray#
- posterior_std: numpy.ndarray#
- posterior_median: numpy.ndarray#
- posterior_iqr: numpy.ndarray#
- representative_points: tuple[LocalSensitivityPointReport, Ellipsis]#
- aggregate_local_sigma: numpy.ndarray#
- aggregate_parameter_sensitivity_norm: numpy.ndarray#
- posterior_to_local_sigma_ratio: numpy.ndarray#
- parameter_diagnostics: Mapping[str, Mapping[str, Any]]#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.PosteriorPredictiveReport#
Aggregate posterior-predictive summaries for held-out observations.
Attributes#
- observed_values:
Observed values kept on the original observation scale.
- predictive_mean:
Posterior-predictive mean curves or vectors.
- predictive_std:
Posterior-predictive standard deviations.
- interval_lower, interval_upper:
Central predictive interval bounds for the requested level.
- interval_coverage:
Per-dataset fraction of observed points covered by the predictive interval.
- mean_absolute_error:
Mean absolute deviation between predictive mean and observed values.
- metadata:
Auxiliary metadata such as split name, number of cases, and parameter space used during predictive generation.
- observed_values: Mapping[str, numpy.ndarray]#
- predictive_mean: Mapping[str, numpy.ndarray]#
- predictive_std: Mapping[str, numpy.ndarray]#
- interval_lower: Mapping[str, numpy.ndarray]#
- interval_upper: Mapping[str, numpy.ndarray]#
- interval_coverage: Mapping[str, float]#
- mean_absolute_error: Mapping[str, float]#
- mean_absolute_error_sigma: Mapping[str, float]#
- median_interval_width_over_uncertainty: Mapping[str, float]#
- crps: Mapping[str, float]#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.SimulationBasedCalibrationReport#
Rank-based SBC summary over a held-out set of observations.
Attributes#
- ranks:
Integer rank of the ground-truth parameter within posterior samples for each held-out case and parameter dimension.
- rank_histogram_counts:
Per-parameter SBC histogram counts.
- posterior_means:
Posterior mean for each held-out case.
- truths:
Ground-truth parameters paired with each posterior sample set.
- coverages:
Coverage summaries for the requested nominal interval levels.
- mean_rank:
Mean empirical rank for each parameter dimension.
- normalized_mean_rank_error:
Absolute mean-rank error normalized by the expected average rank.
- metadata:
Auxiliary run metadata such as split name and number of posterior draws.
- ranks: numpy.ndarray#
- rank_histogram_counts: numpy.ndarray#
- posterior_means: numpy.ndarray#
- truths: numpy.ndarray#
- coverages: tuple[CoverageLevelReport, Ellipsis]#
- mean_rank: numpy.ndarray#
- normalized_mean_rank_error: numpy.ndarray#
- metadata: Mapping[str, Any]#
- petitRADTRANS.sbi.generate_local_sensitivity_report(task: petitRADTRANS.sbi.task.SBITask, posterior_samples: Any, observation_blocks: Sequence[Any], posterior_log_probabilities: Any = None, parameter_space: str = 'physical', simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, quantile_levels: Sequence[float] = (0.1, 0.5, 0.9), finite_difference_relative_step: float = 0.001, finite_difference_std_fraction: float = 0.1, finite_difference_absolute_floor: float = 1e-05, max_step_reduction_attempts: int = 6, svd_relative_tolerance: float = 0.001, posterior_underexploited_ratio_threshold: float = 1.5, weak_sensitivity_fraction_threshold: float = 0.15, seed: int | None = None) LocalSensitivityReport#
Diagnose local physical identifiability around representative posterior points.
The report evaluates a deterministic simulator Jacobian at representative posterior points, whitens it by observational uncertainty, and derives a Fisher-style local covariance approximation for each point.
- petitRADTRANS.sbi.generate_posterior_predictive_report(task: petitRADTRANS.sbi.task.SBITask, posterior: Any, dataset_reader: petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader, split: petitRADTRANS.sbi.dataset.DatasetSplit = DatasetSplit.TEST, n_posterior_samples: int = 256, interval_level: float = 0.9, max_cases: int | None = None, seed: int | None = None, simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, n_predictive_forward_model_samples: int | None = None, checkpoint_directory: str | pathlib.Path | None = None, data_parallel: bool | None = None) PosteriorPredictiveReport#
Generate posterior-predictive summaries for a held-out dataset split.
Parameters#
- task:
SBI task defining parameter transforms and the simulator configuration.
- posterior:
Trained posterior estimator used to sample held-out predictive draws.
- dataset_reader:
Reader providing normalized observations and preprocessing metadata.
- split:
Held-out split used for the predictive report.
- n_posterior_samples:
Number of posterior draws generated per held-out observation.
- interval_level:
Central predictive interval level reported for each dataset.
- max_cases:
Optional cap on the number of held-out observations evaluated.
- seed:
Optional base seed used to make predictive sampling reproducible.
- simulator:
Optional runtime simulator override.
- n_predictive_forward_model_samples:
Number of posterior draws passed through the forward model per held-out case. When
Nonealln_posterior_samplesdraws are forwarded. Subsampling here is the primary lever for keeping the total number of petitRADTRANS calls to a manageable level when evaluating many held-out cases.- checkpoint_directory:
Optional directory for per-case checkpoints. When provided, each completed case is written to disk as a compressed
.npzfile and skipped on resume. This makes the expensive forward-model loop restartable after interruption.
Returns#
- PosteriorPredictiveReport
Aggregate posterior-predictive summary over the requested split.
Notes#
Observations are normalized internally for posterior encoding but compared on the original observation scale in the returned report.
- petitRADTRANS.sbi.generate_sbc_report(posterior: Any, dataset_reader: petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader, split: petitRADTRANS.sbi.dataset.DatasetSplit = DatasetSplit.TEST, n_posterior_samples: int = 256, batch_size: int = 32, parameter_space: str | None = None, levels: Sequence[float] = (0.5, 0.8, 0.95), max_cases: int | None = None, seed: int | None = None, data_parallel: bool | None = None) SimulationBasedCalibrationReport#
Run SBC over a dataset reader using normalized observation batches.
Parameters#
- posterior:
Trained posterior estimator exposing
encode_observationandsample_posterior.- dataset_reader:
Reader yielding normalized held-out observations and matched parameter values.
- split:
Dataset split used for the SBC evaluation.
- n_posterior_samples:
Number of posterior draws generated per held-out case.
- batch_size:
Reader batch size used during report generation.
- parameter_space:
Optional parameter space override. Defaults to the posterior’s own configured parameter space.
- levels:
Coverage levels summarized in the returned report.
- max_cases:
Optional cap on the number of held-out observations evaluated.
- seed:
Optional base seed used to generate reproducible posterior draws.
Returns#
- SimulationBasedCalibrationReport
SBC summary computed from the requested dataset split.
- petitRADTRANS.sbi.local_sensitivity_report_to_payload(report: LocalSensitivityReport) dict[str, Any]#
Convert a local sensitivity report into a JSON-serializable payload.
- exception petitRADTRANS.sbi.TaskCompatibilityError#
Bases:
ValueErrorRaised when an observation or artifact is incompatible with a task.
- class petitRADTRANS.sbi.HDF5SimulationDatasetStore(chunk_size: int = 256)#
Bases:
SimulationDatasetStoreHDF5-backed store for simulation corpora.
Stores all simulation data for a corpus in a single
.h5file, keeping file counts at 1 regardless of the number of simulations or splits. Requiresh5py(already present in thejaxprtenvironment).- chunk_size = 256#
- static _require_h5py() None#
- create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') HDF5SimulationDatasetWriter#
Create a writer for a new simulation dataset.
- open(manifest_or_uri: SimulationDatasetManifest | str) HDF5StoredSimulationDataset#
Open a stored dataset for training or evaluation.
- class petitRADTRANS.sbi.SimulationDatasetStore#
Bases:
abc.ABCBackend-independent interface for reading and writing simulation data.
- abstractmethod create_writer(manifest: SimulationDatasetManifest) SimulationDatasetWriter#
Create a writer for a new simulation dataset.
- abstractmethod open(manifest_or_uri: SimulationDatasetManifest | str) Any#
Open a stored dataset for training or evaluation.
- class petitRADTRANS.sbi.StoredSimulationDataset#
Read-only handle for a stored chunked simulation dataset.
- manifest: SimulationDatasetManifest#
- storage_uri: str#
- root: Any#
- iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) Iterator[dict[str, Any]]#
Yield row-slices of one split without loading it all into RAM.
Parameters#
- split:
Dataset split to iterate.
- chunk_size:
Number of rows per yielded chunk.
- indices:
Optional integer array of row indices to read. When supplied only those rows are fetched in ascending order. When
Noneevery row in the split is visited sequentially.- observation_fields:
Optional set of observation field names to load. When
Noneevery dataset field is loaded. Pass_OBSERVATION_TRAINING_FIELDSto skip large unused fields such ascovarianceduring training.
Yields#
- dict[str, Any]
Same structure as
read_split()but covering onlychunk_sizerows per iteration.
- read_split(split: DatasetSplit = DatasetSplit.TRAIN) dict[str, Any]#
- class petitRADTRANS.sbi.ZarrSimulationDatasetStore(chunk_size: int = 256)#
Bases:
SimulationDatasetStoreConcrete Zarr-backed store for chunked simulation corpora.
- chunk_size = 256#
- static _require_zarr() None#
- create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') SimulationDatasetWriter#
Create a writer for a new simulation dataset.
- open(manifest_or_uri: SimulationDatasetManifest | str) StoredSimulationDataset#
Open a stored dataset for training or evaluation.
- petitRADTRANS.sbi.generate_simulation_dataset(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int, chunk_size: int = 256, split_counts: Mapping[DatasetSplit | str, int] | None = None, split_fractions: Mapping[DatasetSplit | str, float] | None = None, split_policy: SplitSamplingPolicy | str = SplitSamplingPolicy.SEQUENTIAL, split_seed: int | None = None, simulator: Any = None, include_preprocessing_metadata: bool = True, dataset_version: str = '0.1.0', resume: bool = False, backend: str = 'hdf5', data_parallel: bool | None = None, store_covariance: bool = False) GeneratedSimulationDataset#
Generate and persist a simulation corpus in one call.
Parameters#
- backend:
Storage backend to use.
'hdf5'(default) writes all data into a single.h5file, keeping the file count at 1 regardless of the number of simulations.'zarr'uses the Zarr directory store which produces one file per compressed chunk.- data_parallel:
When
True(orNonewith multiple JAX devices), the vmapped RT kernel is distributed across devices usingjax.pmap. The effective per-iteration batch size is automatically scaled by the device count so each device processes the configuredsimulation_config.batch_sizesamples.- store_covariance:
When
Truethe full covariance matrix for each simulated spectrum is written to disk. WhenFalse(default) only the covariance diagonal is stored under thecovariancefield to reduce storage pressure.
- class petitRADTRANS.sbi.DatasetSplit#
Bases:
str,enum.EnumNamed dataset partitions used during training and evaluation.
- TRAIN = 'train'#
- VALIDATION = 'validation'#
- TEST = 'test'#
- BENCHMARK = 'benchmark'#
- class petitRADTRANS.sbi.NormalizedObservationDatasetReader#
Lightweight reader yielding normalized ObservationBlock batches for training.
- preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata#
- _split_cache: dict#
- iter_batches(split: DatasetSplit = DatasetSplit.TRAIN, batch_size: int = 32, shuffle: bool = False, seed: int | None = None, parameter_space: str = 'physical', encoder: Any = None) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#
Yield mini-batches of normalized observations and matched parameters.
Parameters#
- split:
Dataset split to iterate over.
- batch_size:
Number of samples yielded in each batch.
- shuffle:
Whether to shuffle sample order within the requested split.
- seed:
Optional random seed used when
shuffle=True.- parameter_space:
Parameter representation returned in each batch. Supported values are
'physical','cube', and'unconstrained'.- encoder:
Optional encoder used to convert block lists into dense embedding arrays before batches are yielded.
Returns#
- Iterator[PosteriorBatch]
Iterator over batches containing parameters, observations, and small metadata dictionaries describing the batch provenance.
Notes#
Splits with fewer than
_STREAMING_THRESHOLDrows are cached in RAM after the first read so that repeated epoch calls (e.g. validation) pay only slicing cost. Larger splits are streamed directly from disk one batch at a time to avoid loading the full dataset into memory.
- _iter_batches_cached(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#
Iterate using the full-split RAM cache (small splits).
- _iter_batches_streaming(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#
Iterate by streaming rows directly from disk (large splits).
When
shuffle=Truea global index permutation is computed in RAM (one integer per simulation row — negligible memory) and used to fetch HDF5 rows in sorted sub-windows ofbatch_size, satisfying h5py’s monotonic-index requirement while still presenting shuffled order to the training loop.
- static _select_parameters(split_data: dict, parameter_space: str) Any#
Return the parameter array for the requested coordinate space.
- class petitRADTRANS.sbi.HierarchicalObservationEncoder(embedding_dim: int = 128, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 128, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, key: jax.Array | None = None)#
Bases:
equinox.Module,petitRADTRANS.sbi.observation.ObservationEncoderLearned hierarchical encoder dispatching over modalities.
Parameters#
- embedding_dim:
Size of the final joint observation embedding.
- spectrum_embedding_dim:
Intermediate embedding size used by the spectral sub-encoder.
- photometry_embedding_dim:
Intermediate embedding size used by the photometry sub-encoder.
- patch_size:
Patch size used by the spectral encoder.
- hidden_dim:
Hidden width shared across the component encoders and aggregator.
- key:
Optional JAX random key used to initialize all submodules.
Notes#
Spectrum and photometry blocks are handled by dedicated sub-encoders and then merged with a permutation-invariant aggregator. Unsupported modalities currently fall back to resized raw-value vectors.
- spectrum_encoder: SpectralPatchEncoder | SpectralConv1DEncoder#
- photometry_encoder: PhotometryPointEncoder#
- aggregator: DatasetSetAggregator#
- embedding_dim: int#
- _encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one observation block with modality-aware dispatch.
Parameters#
- block:
Observation block to encode.
Returns#
- jnp.ndarray
Fixed-width embedding for the supplied block.
- encode(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#
Encode one structured observation made of multiple blocks.
Parameters#
- blocks:
Observation blocks associated with one target system.
Returns#
- EncodedObservation
Aggregated embedding and light metadata describing the block set.
- encode_stacked_batch(blocks_batch: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) jax.numpy.ndarray#
Encode a batch of identically-structured observations using vmap.
All observations must share the same block structure (same number of blocks, same array shapes per block). This is always the case for SBI tasks where every observation is produced by the same forward model.
Parameters#
- blocks_batch:
Outer list indexes samples; inner list holds the per-block observation data for one sample.
Returns#
- jnp.ndarray
Float32 array of shape
(n_samples, embedding_dim).
- encode_from_prestacked(obs: Any) jax.numpy.ndarray#
Encode a batch of observations from pre-stacked arrays.
Accepts a
PreStackedObservationsinstance whose array fields have already been extracted fromObservationBlockobjects outside the JAX JIT boundary. Only the vmapped XLA computation runs here, enabling the training step to be compiled once and reused across all batches.Parameters#
- obs:
Pre-stacked observation container with
stacked_blocks(one(values, uncertainties, coordinates, mask, log_scale, absolute_values)tuple per block, each of shape(batch_size, n_wl)except the scalarlog_scalearrays) andmodalities(static tuple of modality value strings).
Returns#
- jnp.ndarray
Float32 array of shape
(batch_size, embedding_dim).
- class petitRADTRANS.sbi.PhotometryPointEncoder(embedding_dim: int = 64, hidden_dim: int = 96, key: jax.Array | None = None)#
Bases:
equinox.ModuleLearned photometry encoder using per-point MLP features and pooling.
Parameters#
- embedding_dim:
Size of the returned photometric embedding.
- hidden_dim:
Hidden width of the per-point MLP.
- key:
Optional JAX random key used for initialization.
Notes#
Each photometric point is represented by value, uncertainty, coordinate, and an inferred width feature before permutation-invariant pooling.
- point_mlp: equinox.nn.MLP#
- output_projection: equinox.nn.Linear#
- embedding_dim: int#
- encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one photometric observation block.
Parameters#
- block:
Photometry-like observation block to encode.
Returns#
- jnp.ndarray
Dense embedding for the supplied photometric block.
- _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray) jax.numpy.ndarray#
Encode pre-processed block arrays (vmappable — no ObservationBlock input).
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays already produced by
_as_vector.uncertaintiesandcoordinatesmay have length 0.- mask:
Boolean 1-D array already produced by
_safe_mask.
Returns#
- jnp.ndarray
Dense photometric embedding of shape
(embedding_dim,).
- class petitRADTRANS.sbi.SpectralConv1DEncoder(embedding_dim: int = 64, n_wavelengths: int = 233, hidden_channels: tuple[int, int] = (32, 64), key: jax.Array | None = None)#
Bases:
equinox.ModuleLearned spectral encoder using 1D convolutions to retain positional information.
Parameters#
- embedding_dim:
Size of the spectral embedding returned for one observation block.
- n_wavelengths:
Fixed number of spectral points per observation block. Must match the observation schema of the SBI task.
- hidden_channels:
Tuple of intermediate channel widths for conv layers.
- key:
Optional JAX random key used for weight initialization.
- conv1: equinox.nn.Conv1d#
- conv2: equinox.nn.Conv1d#
- conv3: equinox.nn.Conv1d#
- amplitude_projection: equinox.nn.MLP#
- pool_projection: equinox.nn.Linear#
- branch_projection: equinox.nn.Linear#
- output_projection: equinox.nn.Linear#
- scale_projection: equinox.nn.Linear#
- embedding_dim: int#
- n_wavelengths: int#
- _forward_conv(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#
Run 1D conv stack on pre-processed arrays.
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays of length
n_wavelengths.- mask:
Boolean 1-D mask (True = invalid).
Returns#
- jnp.ndarray
Embedding of shape
(embedding_dim,).
- _pad_to_fixed(arr: jax.numpy.ndarray) jax.numpy.ndarray#
Pad or truncate a 1-D array to
n_wavelengths.
- encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one spectral observation block.
Parameters#
- block:
Spectrum-like observation block.
Returns#
- jnp.ndarray
Dense embedding of shape
(embedding_dim,).
- _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#
Encode pre-processed block arrays (vmappable — no ObservationBlock input).
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays already produced by
_as_vector.- mask:
Boolean 1-D array already produced by
_safe_mask.- log_median_flux:
Scalar absolute-flux feature derived from the raw spectrum median.
Returns#
- jnp.ndarray
Dense spectral embedding of shape
(embedding_dim,).
- class petitRADTRANS.sbi.SpectralPatchEncoder(embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 96, key: jax.Array | None = None)#
Bases:
equinox.ModuleLearned spectral encoder based on patch summaries and MLP pooling.
Parameters#
- embedding_dim:
Size of the spectral embedding returned for one observation block.
- patch_size:
Number of spectral points summarized together before MLP encoding.
- hidden_dim:
Hidden width of the internal summary MLP.
- key:
Optional JAX random key used for weight initialization.
Notes#
The encoder summarizes values, uncertainties, and coordinates patch by patch instead of operating on a fully convolutional representation. This keeps the implementation lightweight and shape-agnostic.
- patch_mlp: equinox.nn.MLP#
- amplitude_projection: equinox.nn.MLP#
- branch_projection: equinox.nn.Linear#
- output_projection: equinox.nn.Linear#
- scale_projection: equinox.nn.Linear#
- embedding_dim: int#
- patch_size: int#
- encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one spectral observation block.
Parameters#
- block:
Spectrum-like observation block containing values and optional uncertainties, coordinates, and masks.
Returns#
- jnp.ndarray
Dense fixed-width embedding for the supplied spectral block.
- _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#
Encode pre-processed block arrays (vmappable — no ObservationBlock input).
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays already produced by
_as_vector.uncertaintiesandcoordinatesmay have length 0 to indicate absence.- mask:
Boolean 1-D array already produced by
_safe_mask.
Returns#
- jnp.ndarray
Dense spectral embedding of shape
(embedding_dim,).
- petitRADTRANS.sbi.load_posterior_estimator(input_directory: str) petitRADTRANS.sbi.posterior.PosteriorEstimator#
Load a saved posterior estimator without naming its concrete class.
- class petitRADTRANS.sbi.AmortizedRetrieval(task: petitRADTRANS.sbi.task.SBITask, posterior_estimator: petitRADTRANS.sbi.posterior.PosteriorEstimator, simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | Mapping[str, Any] | None = None)#
Serve trained SBI models through a retrieval-like interface.
Parameters#
- task:
SBI task compatible with the trained posterior.
- posterior_estimator:
Trained posterior estimator used for encoding and sampling.
- simulator:
Optional simulator override used for posterior-predictive generation.
- preprocessing_metadata:
Optional preprocessing metadata or payload. When omitted the inference service attempts to recover the preprocessing payload saved inside the posterior artifact.
Notes#
Raw user-facing observation blocks are normalized internally before encoding when preprocessing metadata is available, but posterior-predictive comparisons remain on the original observation scale.
- task#
- posterior_estimator#
- simulator = None#
- preprocessing_metadata#
- _resolve_preprocessing_metadata(preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | Mapping[str, Any] | None) petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | None#
- prepare_observation_blocks(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) list[petitRADTRANS.sbi.observation.ObservationBlock]#
Normalize user-facing observation blocks when preprocessing metadata is available.
Parameters#
- observation_blocks:
Raw observation blocks provided by the caller.
Returns#
- list[ObservationBlock]
Either the original blocks or normalized copies, depending on whether preprocessing metadata is available.
- infer(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], n_posterior_samples: int = 1000, include_posterior_predictive: bool = False, posterior_predictive_interval_level: float = 0.9, n_predictive_forward_model_samples: int | None = None, seed: int | None = None) AmortizedRetrievalResult#
Infer a posterior for one structured observation.
Parameters#
- observation_blocks:
Raw observation blocks for one target system.
- n_posterior_samples:
Number of posterior draws to generate.
- include_posterior_predictive:
Whether to also generate a posterior-predictive summary.
- posterior_predictive_interval_level:
Predictive interval level used when generating the optional posterior-predictive report.
- n_predictive_forward_model_samples:
Number of posterior draws passed through the forward model when generating the posterior-predictive summary. When
Nonealln_posterior_samplesdraws are forwarded. Set this to a small value (e.g. 50–200) to avoid a multi-hour hang whenn_posterior_samplesis large.- seed:
Optional seed for posterior and posterior-predictive sampling.
Returns#
- AmortizedRetrievalResult
Posterior samples and any optional predictive metadata.
- posterior_predictive(posterior: petitRADTRANS.sbi.posterior.PosteriorSamples, observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], interval_level: float = 0.9, n_predictive_forward_model_samples: int | None = None, seed: int | None = None) petitRADTRANS.sbi.calibration.PosteriorPredictiveReport#
Generate posterior-predictive draws for a fitted observation.
Parameters#
- posterior:
Posterior samples associated with one observed system.
- observation_blocks:
Original observation blocks used for user-facing comparison.
- interval_level:
Central predictive interval level to report.
- n_predictive_forward_model_samples:
Number of posterior draws passed through the forward model. When
Noneall samples inposteriorare used. Set this to a small value (e.g. 50–200) to keep the number of expensive petitRADTRANS calls manageable.- seed:
Optional seed for predictive simulation.
Returns#
- PosteriorPredictiveReport
Predictive summary for the supplied posterior samples.
- diagnose_domain(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) OODDiagnostic | None#
Estimate whether the observation is in-distribution for the model.
Parameters#
- observation_blocks:
Observation blocks to diagnose.
Returns#
- OODDiagnostic | None
Robust support-distance diagnostic derived from preprocessing statistics, or
Nonewhen preprocessing metadata is unavailable.
- class petitRADTRANS.sbi.AmortizedRetrievalResult#
Return type for amortized inference queries.
Attributes#
- posterior:
Posterior samples and any attached diagnostics.
- posterior_predictive:
Optional posterior-predictive report for the same observation.
- ood_diagnostic:
Optional in/out-of-distribution assessment.
- metadata:
Additional metadata such as preprocessing usage during inference.
- posterior_predictive: Any = None#
- ood_diagnostic: OODDiagnostic | None = None#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.OODDiagnostic#
Describe whether an observation is inside the training support.
Attributes#
- score:
Scalar out-of-distribution score.
- threshold:
Optional threshold used to convert the score into a pass/fail decision.
- passed:
Optional boolean indicating whether the observation passed the OOD test.
- metadata:
Auxiliary diagnostic metadata.
- score: float#
- threshold: float | None = None#
- passed: bool | None = None#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.ObservationBlock#
Represent one modality-specific observation block.
- Attributes:
- name:
Stable identifier of the block within a task.
- modality:
Semantic block type used to dispatch encoding logic.
- values:
Observed values after any task-level preprocessing.
- uncertainties:
Optional per-element uncertainty representation.
- coordinates:
Optional coordinate arrays such as wavelengths or timestamps.
- mask:
Optional mask applied to the values.
- metadata:
Additional instrument and preprocessing metadata.
- name: str#
- modality: ObservationModality#
- values: Any#
- uncertainties: Any = None#
- coordinates: Any = None#
- mask: Any = None#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.ObservationEncoder#
Bases:
abc.ABCTransform structured observation blocks into model-ready embeddings.
- abstractmethod encode(blocks: list[ObservationBlock]) EncodedObservation#
Encode a list of observation blocks into a shared representation.
- batch_encode(observations: list[list[ObservationBlock]]) list[EncodedObservation]#
Encode multiple observations using repeated single-item encoding.
- class petitRADTRANS.sbi.ObservationModality#
Bases:
str,enum.EnumSupported observation block types for SBI conditioning.
- SPECTRUM = 'spectrum'#
- PHOTOMETRY = 'photometry'#
- TIME_SERIES = 'time_series'#
- AUXILIARY = 'auxiliary'#
- petitRADTRANS.sbi.build_observation_block(name: str, modality: ObservationModality | str, values: Any, uncertainties: Any = None, coordinates: Any = None, mask: Any = None, metadata: Mapping[str, Any] | None = None) ObservationBlock#
Build one observation block with modality normalization.
- petitRADTRANS.sbi.build_observation_block_batch(observation_payloads: Mapping[str, Mapping[str, Any]], modalities: Mapping[str, str]) list[list[ObservationBlock]]#
Build observation blocks for each sample in a batched payload.
- petitRADTRANS.sbi.build_observation_blocks_from_sample(observation_payloads: Mapping[str, Mapping[str, Any]], modalities: Mapping[str, str], sample_index: int) list[ObservationBlock]#
Build modality-aware observation blocks for one simulated sample.
- class petitRADTRANS.sbi.ConditionalAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)#
Bases:
ConditionalFlowPosteriorPosterior estimator specialized to the autoregressive flow backend.
- class petitRADTRANS.sbi.ConditionalFlowPosterior(parameter_dim: int, embedding_dim: int = 128, num_coupling_layers: int = 4, hidden_dim: int = 128, conditioner_depth: int = 2, autoregressive_transform_units: int = 16, neural_autoregressive_min_slope: float = 0.001, neural_autoregressive_min_residual: float = 0.05, neural_autoregressive_inverse_bisection_steps: int = 48, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', flow_family: str = 'spline', num_spline_bins: int = 8, spline_bound: float = 10.0, base_distribution: str = 'gaussian', use_base_affine: bool = False, training_objective: str = 'npe', early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, gradient_clip_norm: float | None = 1.0, embedding_noise_std: float = 0.0, embedding_noise_min_scale: float = 0.0, aux_scale_loss_weight: float = 0.0, aux_parameter_loss_weight: float = 0.0, parameter_noise_floor: float = 0.0, spline_small_bin_regularization_weight: float | None = None, spline_min_bin_ratio_target: float = 0.2, spline_entropy_regularization_weight: float = 0.0, spline_derivative_regularization_weight: float = 0.0, spline_entropy_floor: float = 0.6, spline_min_derivative_target: float = 0.25, spline_max_derivative_target: float = 5.0, weight_decay: float = 0.0, use_cosine_schedule: bool = False, warmup_fraction: float = 0.02, warmup_epochs: float | None = None, min_learning_rate: float = 1e-06, lr_schedule_total_epochs: float | None = None, stable_inverse_forward_max_abs_error_threshold: float | None = 0.0001, stable_inverse_forward_logdet_closure_max_abs_error_threshold: float | None = 0.0001, stable_cube_edge_hit_rate_threshold: float | None = 0.0, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, encoder_hidden_dim: int | None = None, encoder_patch_size: int = 32, seed: int = 0, task_metadata: Mapping[str, Any] | None = None, verbose_diagnostics: bool = False, diagnostics_output_directory: str | None = None, diagnostics_plot_interval: int = 1)#
Bases:
petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimatorConcrete amortized posterior using a conditional flow backend.
Parameters#
- parameter_dim:
Number of inferred free parameters represented by the posterior.
- embedding_dim:
Size of the learned observation embedding consumed by the flow.
- num_coupling_layers:
Number of conditional coupling or spline-transform layers in the flow.
- hidden_dim:
Hidden width used by encoder-side and flow-side MLP conditioners.
- learning_rate:
Optimizer learning rate used by
SBITrainer.- batch_size:
Number of simulations processed per optimization step.
- num_epochs:
Maximum number of passes through the training split.
- parameter_space:
Parameter coordinates learned by the posterior. Supported values are
'physical','cube', and'unconstrained'.- flow_family:
Conditional density-transform family. Supported values are
'spline','affine','autoregressive', and'neural_autoregressive'.- num_spline_bins:
Number of rational-quadratic spline bins when
flow_family='spline'.- spline_bound:
Finite support bound of each spline transform in latent space.
- early_stopping_patience:
Optional number of non-improving epochs tolerated before stopping.
- early_stopping_min_delta:
Minimum improvement required for early-stopping comparisons.
- checkpoint_directory:
Optional directory used to persist resumable trainer checkpoints.
- checkpoint_backend:
Checkpoint persistence backend name.
'auto'selects Orbax when available and otherwise falls back to Equinox serialization.- resume_from_checkpoint:
Whether
fitshould attempt to resume from the latest checkpoint.- seed:
Base random seed used for flow initialization and posterior sampling.
- task_metadata:
Optional user-supplied metadata persisted alongside the trained model.
Notes#
The posterior stores task fingerprinting, observation schema, and preprocessing payload information when those are available from the training dataset reader. That metadata is later reused by inference and artifact registration paths.
- estimator_family = 'conditional_flow'#
- embedding_dim = 128#
- num_coupling_layers = 4#
- conditioner_depth = 2#
- autoregressive_transform_units = 16#
- neural_autoregressive_min_slope#
- neural_autoregressive_min_residual#
- neural_autoregressive_inverse_bisection_steps = 48#
- learning_rate#
- batch_size = 32#
- num_epochs = 5#
- flow_family = ''#
- effective_flow_family = ''#
- base_distribution = ''#
- use_base_affine = False#
- training_objective = ''#
- num_spline_bins = 8#
- early_stopping_patience = None#
- early_stopping_min_delta#
- checkpoint_directory = None#
- checkpoint_backend = 'auto'#
- resume_from_checkpoint = False#
- gradient_clip_norm = 1.0#
- embedding_noise_std#
- embedding_noise_min_scale#
- aux_scale_loss_weight#
- aux_parameter_loss_weight#
- parameter_noise_floor#
- spline_entropy_regularization_weight#
- spline_small_bin_regularization_weight#
- spline_derivative_regularization_weight#
- spline_min_bin_ratio_target#
- spline_entropy_floor#
- spline_min_derivative_target#
- spline_max_derivative_target#
- weight_decay#
- use_cosine_schedule = False#
- warmup_fraction#
- warmup_epochs = None#
- min_learning_rate#
- lr_schedule_total_epochs = None#
- stable_inverse_forward_max_abs_error_threshold = None#
- stable_inverse_forward_logdet_closure_max_abs_error_threshold = None#
- stable_cube_edge_hit_rate_threshold = None#
- spectrum_encoder_type = ''#
- n_wavelengths = 233#
- spectrum_embedding_dim = 64#
- photometry_embedding_dim = 64#
- encoder_patch_size = 32#
- verbose_diagnostics = False#
- diagnostics_output_directory = None#
- diagnostics_plot_interval = 1#
- model#
- _build_flow(key: jax.Array) Any#
- static _batch_embeddings(model: _PosteriorModel, observations: Any) jax.numpy.ndarray#
- static _loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
- static _training_loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
- static _loss_from_observations(model: _PosteriorModel, parameters: jax.numpy.ndarray, observations: Any, parameter_space: str, *, include_spline_regularization: bool) jax.numpy.ndarray#
- _validation_diagnostics(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) dict[str, float]#
- static _make_parameter_noise_loss(base_loss_fn: Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray], noise_floor: float, parameter_space: str) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#
Wrap a training loss so the parameter targets are jittered.
Minimizing the conditional NLL in a very-low-noise regime with a highly parameter-predictive embedding drives the learned conditional density toward a delta (NLL -> -inf), which an over-sharp flow realizes as exploded transforms (and, for spline flows, cube-edge collapse). A small Gaussian jitter on the parameter targets gives the conditional a hard minimum width, capping sharpness so the NLL has a finite minimum and the flow’s inverse stays well-conditioned.
The jitter is applied in unconstrained (logit) space – the space the flow actually models – not in cube space. Additive cube-space noise is asymmetric near the [0, 1] bounds: clipping piles jittered targets onto the edges, and the cube->logit Jacobian rewards edge mass, so it worsens the very edge-collapse it was meant to prevent. Jittering in logit space is symmetric and boundary-free. Applied only during training; evaluation and checkpoint selection use the clean targets.
- static _make_elbo_loss(log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray], parameter_space: str, *, num_samples: int = 1) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#
Return an amortized-variational (ELBO) training loss.
Maximizes, for observations
xdrawn from the (prior-predictive) dataset,ELBO(x) = E_{q(theta|x)}[ log p(x|theta) + log p(theta) - log q(theta|x) ]
with a single- (or few-) sample reparameterized estimator.
q(theta|x)is the conditional flow: latentszare drawn from the flow base and pushed throughflow.inverseto reparameterized samples, so the gradient flows through both the flow and the encoder.The
-log q(entropy) term diverges to-infasqcollapses to a point mass, so a delta posterior is structurally impossible – the width is set by the likelihood/entropy balance rather than by the (near-deterministic) embedding.log p(x|theta)is supplied by the injected differentiable forward-model likelihood evaluated at the physical parameters reconstructed from the sampled cube coordinates. Withparameter_space='cube'the prior is uniform on the unit hypercube, solog p(theta_cube) = 0and is dropped.Parameters#
- log_likelihood_fn:
(theta_cube, observations) -> (batch,)returning the Gaussian observational log-likelihoodlog p(x | theta)for each batch element. Receives unit-cube coordinates (shape(batch, dim)) and the batch’s prestacked observations; it owns the cube->physical transform, the differentiable forward model, and the noise model. Must be JAX-differentiable w.r.t.theta_cube.- parameter_space:
Must be
"cube".- num_samples:
Number of reparameterized posterior draws per observation used to estimate the ELBO expectation.
1is standard for amortized VI; larger values reduce gradient variance at a proportional forward-model cost.
- static _make_noisy_loss(noise_std: float, preprocessing_metadata: Any = None, *, noise_min_scale: float = 0.0, include_spline_regularization: bool = True) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#
Return a loss function that injects on-the-fly Gaussian noise into the observation spectra before encoding.
For spectra preprocessed into a per-sample transformed value space, the uncertainty channel is assumed to already be expressed in the same transformed units as the value channel, so noise can be injected directly from that channel. When preprocessing metadata is provided, non-spectral blocks still reconstruct raw measurement uncertainties so noise is injected at the correct physical scale in normalized-value space.
noise_stdacts as a multiplier on the observational noise level (1.0 = exact observational noise, >1 = regularising over-noise), andnoise_min_scaleis treated as a minimum fraction of the typical transformed uncertainty for the block rather than as an absolute floor in normalized-value units.When preprocessing metadata is unavailable, a flat Gaussian noise with scale
noise_stdis applied as a fallback.
- fit(dataset: Any, *, elbo_log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray] | None = None, elbo_num_samples: int = 1) petitRADTRANS.sbi.posterior_base.TrainingArtifacts#
Train the posterior on one normalized simulation dataset reader.
Parameters#
- dataset:
Reader-like object that yields
PosteriorBatchinstances viaiter_batchesand exposes dataset manifest metadata when available.- elbo_log_likelihood_fn:
Required when
training_objective='elbo'. A differentiable(theta_cube, observations) -> (batch,)callable returning the Gaussian observational log-likelihood through the forward model. See_make_elbo_loss(). Ignored for the default NPE objective.- elbo_num_samples:
Number of reparameterized posterior draws per observation for the ELBO estimator (default 1).
Returns#
- TrainingArtifacts
Training history, validation metrics, and trainer metadata for the completed optimization run.
Notes#
If the reader exposes preprocessing metadata or a manifest fingerprint, those are cached on the posterior so they can be saved and reused by the inference and artifact layers.
- _build_estimator_config() dict[str, Any]#
Return backend-specific configuration for metadata persistence.
- _build_serialized_metadata(artifact_metadata) dict[str, Any]#
- static _resolve_estimator_config(metadata: Mapping[str, Any]) dict[str, Any]#
- hydrate_loaded_metadata(metadata: Mapping[str, Any]) None#
- classmethod from_serialized_metadata(metadata: Mapping[str, Any]) ConditionalFlowPosterior#
Rebuild an estimator instance from persisted metadata only.
- save_backend_state(output_path: pathlib.Path) None#
Persist backend-specific model state into the output directory.
- load_backend_state(input_path: pathlib.Path) None#
Restore backend-specific model state from the input directory.
- encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#
Encode one structured observation into the posterior context space.
Parameters#
- blocks:
Observation blocks describing one spectral/photometric observation.
Returns#
- EncodedObservation
Aggregated embedding and lightweight metadata describing the input observation family.
- batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) list[petitRADTRANS.sbi.observation.EncodedObservation]#
Encode a batch of observations using a single vmapped forward pass.
Parameters#
- blocks_list:
List of per-sample observation block lists.
Returns#
- list[EncodedObservation]
One encoded observation per input sample.
- sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) petitRADTRANS.sbi.posterior_base.PosteriorSamples#
Draw posterior samples conditioned on one encoded observation.
Parameters#
- observation:
Encoded observation produced by
encode_observation().- n_samples:
Number of posterior draws to generate.
- seed:
Optional random seed overriding the model-level default seed.
Returns#
- PosteriorSamples
Samples in the posterior’s configured parameter space. Non-finite outputs are clamped to large finite values for downstream stability.
- batch_sample_posterior(embeddings: Any, n_samples: int, base_seed: int = 0) numpy.ndarray#
Draw posterior samples for a batch of encoded observations.
Runs a single JIT-compiled vmapped call over all contexts rather than looping
sample_posterioronce per observation, eliminating the Python overhead of per-observation dispatch.Parameters#
- embeddings:
Float32 array of shape
(batch_size, embedding_dim)produced by stackingEncodedObservation.embeddingvectors.- n_samples:
Number of posterior draws per observation.
- base_seed:
Base random seed. Each observation in the batch receives a unique sub-key derived from this seed.
Returns#
- np.ndarray
Array of shape
(batch_size, n_samples, parameter_dim)with non-finite values clamped to large finite values.
- log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#
Evaluate posterior log-density for one or many parameter vectors.
Parameters#
- observation:
Encoded observation that defines the posterior context.
- parameters:
One parameter vector or a batch of parameter vectors in the posterior’s configured parameter space.
Returns#
- Any
Scalar log-density or a vector of log-densities matching the input batch structure.
- class petitRADTRANS.sbi.ConditionalNeuralAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)#
Bases:
ConditionalFlowPosteriorPosterior estimator specialized to the neural autoregressive backend.
- class petitRADTRANS.sbi.ConditionalSplineFlowPosterior(*args: Any, **kwargs: Any)#
Bases:
ConditionalFlowPosteriorPosterior estimator specialized to the spline flow backend.
Notes#
This convenience subclass forces
flow_family='spline'while keeping the rest of theConditionalFlowPosteriorAPI unchanged.
- class petitRADTRANS.sbi.FlowMatchingPosterior(parameter_dim: int, embedding_dim: int = 128, hidden_dim: int = 128, num_velocity_layers: int = 3, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', integration_steps: int = 32, early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, seed: int = 0, task_metadata: Mapping[str, Any] | None = None)#
Bases:
petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimatorConditional flow-matching posterior skeleton.
Warning
This estimator is experimental. It does not expose
log_proband the ODE integration scheme is a simple midpoint rule. Expect the API and numerical behaviour to change in future releases.This estimator family trains a conditional vector field on straight-line interpolation paths between Gaussian noise and target parameters, then generates posterior samples by integrating the learned field from noise to the terminal parameter state.
- estimator_family = 'flow_matching'#
- embedding_dim = 128#
- num_velocity_layers = 3#
- learning_rate#
- batch_size = 32#
- num_epochs = 5#
- integration_steps = 32#
- early_stopping_patience = None#
- early_stopping_min_delta#
- checkpoint_directory = None#
- checkpoint_backend = 'auto'#
- resume_from_checkpoint = False#
- model#
- static _batch_embeddings(model: _FlowMatchingModel, observations: Any) jax.numpy.ndarray#
- static _loss(model: _FlowMatchingModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
- fit(dataset: Any) petitRADTRANS.sbi.posterior_base.TrainingArtifacts#
Train the posterior estimator on a simulation dataset.
- _build_estimator_config() dict[str, Any]#
Return backend-specific configuration for metadata persistence.
- static _resolve_estimator_config(metadata: Mapping[str, Any]) dict[str, Any]#
- _build_serialized_metadata(artifact_metadata) dict[str, Any]#
- classmethod from_serialized_metadata(metadata: Mapping[str, Any]) FlowMatchingPosterior#
Rebuild an estimator instance from persisted metadata only.
- save_backend_state(output_path: pathlib.Path) None#
Persist backend-specific model state into the output directory.
- load_backend_state(input_path: pathlib.Path) None#
Restore backend-specific model state from the input directory.
- encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#
Encode a structured observation into the estimator input space.
- batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) list[petitRADTRANS.sbi.observation.EncodedObservation]#
Encode a batch of observations using a single vmapped forward pass.
- sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) petitRADTRANS.sbi.posterior_base.PosteriorSamples#
Sample the amortized posterior for one encoded observation.
- log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#
Evaluate posterior log-density when supported by the backend.
- class petitRADTRANS.sbi.PersistentPosteriorEstimator(parameter_dim: int, parameter_space: str = 'unconstrained', seed: int = 0, task_metadata: Mapping[str, Any] | None = None)#
Bases:
PosteriorEstimatorShared persistence helper for estimator backends with on-disk artifacts.
- estimator_family = 'persistent_estimator'#
- metadata_schema_version = '0.2.0'#
- parameter_dim#
- parameter_space = 'unconstrained'#
- seed = 0#
- task_metadata#
- training_artifacts: TrainingArtifacts | None = None#
- task_name: str | None#
- task_version: str | None = None#
- task_fingerprint: str | None = None#
- observation_schema: Mapping[str, Any]#
- preprocessing_metadata_payload: Mapping[str, Any]#
- artifact_metadata: petitRADTRANS.sbi.artifacts.ArtifactMetadata | None = None#
- abstractmethod _build_estimator_config() dict[str, Any]#
Return backend-specific configuration for metadata persistence.
- classmethod from_serialized_metadata(metadata: Mapping[str, Any]) PersistentPosteriorEstimator#
- Abstractmethod:
Rebuild an estimator instance from persisted metadata only.
- abstractmethod save_backend_state(output_path: pathlib.Path) None#
Persist backend-specific model state into the output directory.
- abstractmethod load_backend_state(input_path: pathlib.Path) None#
Restore backend-specific model state from the input directory.
- static _load_training_artifacts(metadata: Mapping[str, Any]) TrainingArtifacts | None#
- hydrate_loaded_metadata(metadata: Mapping[str, Any]) None#
- _build_artifact_metadata_payload() dict[str, Any]#
- build_artifact_metadata(version: str) petitRADTRANS.sbi.artifacts.ArtifactMetadata#
Assemble registry metadata for the currently trained estimator.
- _build_serialized_metadata(artifact_metadata: petitRADTRANS.sbi.artifacts.ArtifactMetadata) dict[str, Any]#
- save(output_directory: str, artifact_registry: petitRADTRANS.sbi.artifacts.ArtifactRegistry | None = None, artifact_version: str = '0.1.0') None#
Persist model weights, metadata, and optional artifact registration.
- classmethod load(input_directory: str) PersistentPosteriorEstimator#
Restore a saved persistent estimator from disk.
- class petitRADTRANS.sbi.PosteriorBatch#
Training batch passed to amortized posterior estimators.
- parameters: Any#
- observations: Any#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.PosteriorEstimator#
Bases:
abc.ABCBackend-agnostic interface for amortized posterior models.
- abstractmethod fit(dataset: Any) TrainingArtifacts#
Train the posterior estimator on a simulation dataset.
- abstractmethod encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#
Encode a structured observation into the estimator input space.
- abstractmethod sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) PosteriorSamples#
Sample the amortized posterior for one encoded observation.
- abstractmethod log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#
Evaluate posterior log-density when supported by the backend.
- abstractmethod save(output_directory: str) None#
Persist trained model weights and metadata.
- classmethod load(input_directory: str) PosteriorEstimator#
- Abstractmethod:
Restore a saved estimator from disk.
- class petitRADTRANS.sbi.PosteriorSamples#
Posterior samples and optional per-sample diagnostics.
- samples: Any#
- log_probabilities: Any = None#
- weights: Any = None#
- metadata: Mapping[str, Any]#
- petitRADTRANS.sbi.plot_local_sensitivity_fisher_correlations(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#
Plot Fisher-correlation heatmaps for each representative point.
- petitRADTRANS.sbi.plot_local_sensitivity_jacobians(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#
Plot whitened Jacobian heatmaps for each representative point.
- petitRADTRANS.sbi.plot_local_sensitivity_singular_values(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#
Plot singular spectra of the whitened Jacobian for each point.
- petitRADTRANS.sbi.plot_posterior_corner(samples: Any, parameter_names: Sequence[str] | None = None, bins: int = 40, figsize: tuple[float, float] | None = None, max_points: int = 8192) tuple[Any, numpy.ndarray]#
Plot a lower-triangular corner view of posterior structure.
Parameters#
- samples:
Posterior samples with shape
(n_samples, n_dim)or a scalar-vector equivalent.- parameter_names:
Optional display names for each posterior dimension.
- bins:
Number of bins used for one- and two-dimensional histograms.
- figsize:
Optional Matplotlib figure size.
- max_points:
Maximum number of posterior draws plotted. Larger sample sets are evenly subsampled to keep rendering costs bounded.
Returns#
- tuple[Any, np.ndarray]
Matplotlib figure and axes array for further customization or saving.
- petitRADTRANS.sbi.plot_posterior_marginals(samples: Any, parameter_names: Sequence[str] | None = None, bins: int = 40, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#
Plot one histogram per posterior dimension.
Parameters#
- samples:
Posterior samples with shape
(n_samples, n_dim)or a scalar-vector equivalent.- parameter_names:
Optional display names for each posterior dimension.
- bins:
Number of histogram bins per dimension.
- figsize:
Optional Matplotlib figure size.
Returns#
- tuple[Any, np.ndarray]
Matplotlib figure and axes array for further customization or saving.
- petitRADTRANS.sbi.plot_posterior_predictive_report(report: petitRADTRANS.sbi.calibration.PosteriorPredictiveReport, dataset_names: Sequence[str] | None = None, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#
Plot observed values against posterior-predictive means and intervals.
Parameters#
- report:
Posterior-predictive report to visualize.
- dataset_names:
Optional subset of dataset names to plot. Defaults to all datasets in the report.
- figsize:
Optional Matplotlib figure size.
Returns#
- tuple[Any, np.ndarray]
Matplotlib figure and axes array.
Notes#
For batched reports the helper currently visualizes only the first case for each dataset, which keeps the plot compact for quick inspection.
- petitRADTRANS.sbi.plot_sbc_rank_histograms(report: petitRADTRANS.sbi.calibration.SimulationBasedCalibrationReport, parameter_names: Sequence[str] | None = None, figsize: tuple[float, float] | None = None) tuple[Any, numpy.ndarray]#
Plot one SBC rank histogram per inferred parameter.
Parameters#
- report:
SBC report containing per-parameter rank histogram counts.
- parameter_names:
Optional display names for each inferred parameter.
- figsize:
Optional Matplotlib figure size.
Returns#
- tuple[Any, np.ndarray]
Matplotlib figure and axes array.
- class petitRADTRANS.sbi.TaskPreprocessingMetadata#
Serializable preprocessing metadata for an SBI task family.
- version: str#
- blocks: Mapping[str, BlockNormalizationStats]#
- metadata: Mapping[str, Any]#
- to_payload() dict[str, Any]#
- classmethod from_payload(payload: Mapping[str, Any]) TaskPreprocessingMetadata#
- petitRADTRANS.sbi.fit_task_preprocessing(training_observations: list[list[petitRADTRANS.sbi.observation.ObservationBlock]], version: str = '0.5.0', metadata: Mapping[str, Any] | None = None) TaskPreprocessingMetadata#
Fit preprocessing statistics from training observation blocks.
- petitRADTRANS.sbi.normalize_observation_block(block: petitRADTRANS.sbi.observation.ObservationBlock, preprocessing_metadata: TaskPreprocessingMetadata) petitRADTRANS.sbi.observation.ObservationBlock#
Normalize one observation block with fitted preprocessing statistics.
Uses robust (median/IQR) normalization for values and uncertainties when available, falling back to mean/std for backward compatibility with older preprocessing metadata. Coordinates always use mean/std normalization.
- petitRADTRANS.sbi.normalize_observation_blocks(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], preprocessing_metadata: TaskPreprocessingMetadata) list[petitRADTRANS.sbi.observation.ObservationBlock]#
Normalize a list of observation blocks.
- class petitRADTRANS.sbi.ProposalSampler#
Bases:
abc.ABCInterface for simulation proposals beyond the prior distribution.
- abstractmethod sample(n_samples: int, task: petitRADTRANS.sbi.task.SBITask) Any#
Draw free-parameter vectors for the given task.
- class petitRADTRANS.sbi.RuntimeSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None, seed: int | None = None, data_parallel: bool | None = None)#
Bases:
BatchedSimulatorConcrete simulator backed by the retrieval runtime.
The simulator samples from the retrieval prior using JAX random utilities and projects deterministic forward-model outputs into the observation space using
RetrievalRuntime.- _rng_key#
- _validate_task_support() None#
- _next_key() jax.Array#
- static _row_invalid_value_mask(values: Any, mask: Any, constraint: Any) numpy.ndarray#
- _invalid_sample_mask(observations: Mapping[str, Mapping[str, Any]], n_samples: int) tuple[numpy.ndarray, dict[str, int]]#
- static _slice_batch_rows(batch: SimulationBatch, row_indices: numpy.ndarray) SimulationBatch#
- static _concatenate_simulation_batches(batches: list[SimulationBatch], diagnostics: Mapping[str, Any] | None = None) SimulationBatch#
- _sample_prior_parameter_batch(n_samples: int) dict[str, jax.numpy.ndarray]#
- sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#
Sample free-parameter vectors from the task prior or a proposal.
- _advance_rng_for_batch(n_samples: int) None#
Advance the PRNG state as if
simulate_batch(n_samples)were called.This performs only cheap key-splitting – no forward-model evaluation – making it suitable for fast dataset-resume skipping.
The number of
_next_keycalls matches the pattern insidesimulate_batch→_sample_prior_parameter_batch(1 key) and_simulate_parameter_matrix(1 key per observation that adds noise).
- _apply_noise(values: Any, uncertainties: Any, covariance: Any) tuple[Any, Any, Any]#
- _apply_noise_batched(values_batch: Any, uncertainties_batch: Any, covariance_batch: Any) tuple[Any, Any, Any]#
Apply noise to a batch of simulation outputs.
- Args:
values_batch: Shape
(n_samples, n_wavelengths). uncertainties_batch: Shape(n_samples, n_wavelengths)orNone. covariance_batch: Shape(n_samples, n_wl, n_wl)orNone.- Returns:
Noisy values, uncertainties, and covariance with the same leading batch dimension as
values_batch.
- static _stack_observation_payloads(payloads: list[dict[str, Any]]) dict[str, Any]#
- _simulate_parameter_matrix(parameter_matrix: Any, cube_parameters: Any = None, unconstrained_parameters: Any = None, data_parallel: bool = False) SimulationBatch#
- simulate_from_parameters(parameters: Any) SimulationBatch#
Run the forward model and noise pipeline for pre-specified parameters.
- simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#
Sample parameters and preserve prior-space coordinates when available.
- class petitRADTRANS.sbi.SimulationBatch#
Container for one simulated batch.
- Attributes:
- parameters:
Array-like free-parameter matrix with leading dimension equal to the number of simulated samples.
- observations:
Task-conditioned simulated observations.
- log_likelihood:
Optional scalar likelihood values associated with each sample.
- diagnostics:
Additional runtime diagnostics such as clipping flags or NaN counts.
- parameters: Any#
- observations: Any#
- cube_parameters: Any = None#
- unconstrained_parameters: Any = None#
- log_likelihood: Any = None#
- diagnostics: Mapping[str, Any]#
- property n_samples: int#
Return the number of simulated samples represented by the batch.
- class petitRADTRANS.sbi.Simulator(task: petitRADTRANS.sbi.task.SBITask)#
Bases:
abc.ABCBase simulator for SBI dataset generation and validation.
- task#
- abstractmethod sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) Any#
Sample free-parameter vectors from the task prior or a proposal.
- abstractmethod simulate_from_parameters(parameters: Any) SimulationBatch#
Run the forward model and noise pipeline for pre-specified parameters.
- simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) SimulationBatch#
Sample parameters and simulate one batch in a single call.
- class petitRADTRANS.sbi.NoiseModelConfig#
Describe how observational noise is injected during simulation.
- Attributes:
- mode:
Short identifier of the noise model implementation.
- parameters:
Backend-specific parameters used to instantiate the noise model.
- seed:
Optional deterministic seed for repeatable simulation pipelines.
- mode: str = 'observational'#
- parameters: Mapping[str, Any]#
- seed: int | None = None#
- class petitRADTRANS.sbi.ObservationSchema#
Capture the supported observation family for an amortized task.
- Attributes:
- dataset_names:
Dataset identifiers included in the task.
- modalities:
Mapping from dataset name to a short modality label such as
'spectrum','photometry', or'time_series'.- metadata:
Additional task-level information required by encoders or benchmarks, such as instrument names or wavelength coverage.
- dataset_names: tuple[str, Ellipsis]#
- modalities: Mapping[str, str]#
- metadata: Mapping[str, Any]#
- class petitRADTRANS.sbi.ObservationValueConstraint#
Admissible range for simulated observation values.
- Attributes:
- min_value:
Optional lower bound for valid simulated values.
- max_value:
Optional upper bound for valid simulated values.
- min_inclusive:
Whether
min_valueis included in the valid interval.- max_inclusive:
Whether
max_valueis included in the valid interval.
- min_value: float | None = None#
- max_value: float | None = None#
- min_inclusive: bool = True#
- max_inclusive: bool = True#
- to_payload() dict[str, Any]#
- class petitRADTRANS.sbi.SBITask#
Bundle the immutable ingredients required for an SBI problem.
The task owns the retrieval configuration, canonical parameter layout, observation schema, and simulation policy needed to generate training data and evaluate amortized posteriors.
- name: str#
- retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig#
- parameter_layout: petitRADTRANS.retrieval.runtime.ParameterLayout#
- observation_schema: ObservationSchema#
- simulation_config: SimulationConfig#
- task_version: str = '0.1.0'#
- preprocessing_version: str = '0.5.0'#
- forward_model_family: str = 'unknown'#
- petitradtrans_version: str#
- metadata: Mapping[str, Any]#
- classmethod from_retrieval_config(retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig, simulation_config: SimulationConfig | None = None, metadata: Mapping[str, Any] | None = None) SBITask#
Create an SBI task from an existing retrieval configuration.
Parameters#
- retrieval_config:
Retrieval configuration describing the free parameters, observation datasets, and runtime-native forward model family.
- simulation_config:
Optional simulation policy overriding the default prior-predictive batch configuration.
- metadata:
Optional task-level metadata. Reserved keys such as
task_version,preprocessing_version, andforward_model_familyoverride the default inferred values.
Returns#
- SBITask
Immutable SBI task description with parameter layout, observation schema, and deterministic task fingerprint payload.
Notes#
The task inspects the configured retrieval datasets to infer modalities and a forward-model-family string that later participates in artifact compatibility checks.
- property parameter_names: tuple[str, Ellipsis]#
Return the free parameter names in canonical task order.
- property line_opacity_modes: Mapping[str, str]#
Return the line-opacity mode for each configured dataset.
- property fingerprint_payload: Mapping[str, Any]#
Return the deterministic payload used to fingerprint the task family.
- property task_fingerprint: petitRADTRANS.sbi.compatibility.TaskFingerprint#
Return the deterministic fingerprint of the task family.
- validate_observation_schema(observation_schema: ObservationSchema) petitRADTRANS.sbi.compatibility.CompatibilityReport#
Compare an external observation schema with the task expectation.
Parameters#
- observation_schema:
Candidate schema to compare against the task’s required datasets and modalities.
Returns#
- CompatibilityReport
Structured compatibility report containing mismatches, if any.
- validate_artifact_metadata(artifact_metadata: Any) petitRADTRANS.sbi.compatibility.CompatibilityReport#
Compare a task with persisted artifact metadata.
Parameters#
- artifact_metadata:
Artifact-like object exposing task and preprocessing provenance fields.
Anyis accepted to avoid a hard dependency cycle.
Returns#
- CompatibilityReport
Structured report describing whether the artifact is compatible with the current task definition.
The method accepts
Anyto avoid a hard import dependency back from the task module to the artifact module.
- build_observation_states(model_contract: str = 'differentiable') dict[str, petitRADTRANS.retrieval.runtime.ObservationState]#
Materialize immutable runtime observations for the task datasets.
Parameters#
- model_contract:
Runtime contract requested from each retrieval dataset.
Returns#
- dict[str, ObservationState]
Mapping from dataset name to immutable runtime observation state.
- build_runtime() petitRADTRANS.retrieval.runtime.RetrievalRuntime#
Construct a retrieval runtime matching the task definition.
Returns#
- RetrievalRuntime
Runtime object configured with the task’s parameter layout and retrieval datasets.
The default implementation mirrors the retrieval-runtime grouping logic used by the exact retrieval package and is sufficient for initial SBI simulation backends.
- static _infer_modality(data: petitRADTRANS.retrieval.data.Data) str#
- static _infer_forward_model_family(retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig) str#
- class petitRADTRANS.sbi.SimulationConfig#
Control how a task generates prior-predictive simulations.
- Attributes:
- batch_size:
Number of parameter points produced per simulator call.
- n_simulations:
Optional target number of simulations for dataset generation jobs.
- use_vectorized_runtime:
Whether to prefer the JAX-vectorized runtime path when available.
- store_per_datapoint_log_likelihood:
Persist log-likelihood components alongside simulated observations.
- noise_model:
Noise model configuration applied after forward-model evaluation.
- observation_value_constraints:
Optional per-dataset admissible value ranges enforced on deterministic projected observations before simulations are accepted.
- batch_size: int = 256#
- n_simulations: int | None = None#
- use_vectorized_runtime: bool = True#
- store_per_datapoint_log_likelihood: bool = False#
- noise_model: NoiseModelConfig#
- observation_value_constraints: Mapping[str, ObservationValueConstraint]#
- class petitRADTRANS.sbi.EarlyStoppingConfig#
Early stopping policy for SBI training.
- patience: int#
- min_delta: float = 0.0#
- class petitRADTRANS.sbi.SBITrainer(config: TrainingConfig)#
Reusable optimization loop for amortized SBI posteriors.
- config#
- checkpoint_directory = None#
- checkpoint_backend#
- _checkpoint_backend_fallback_reason: str | None = None#
- property _latest_checkpoint_directory: pathlib.Path | None#
- property _best_checkpoint_directory: pathlib.Path | None#
- _checkpoint_directory_for_kind(checkpoint_kind: str) pathlib.Path | None#
- _save_checkpoint(checkpoint_kind: str, state: dict[str, Any], metadata: dict[str, Any]) None#
- _restore_checkpoint(template_state: dict[str, Any], checkpoint_kind: str) tuple[dict[str, Any], dict[str, Any]] | None#
- fit(model: Any, dataset: Any, loss_fn: Callable[[Any, Any], Any], eval_loss_fn: Callable[[Any, Any], Any] | None = None, eval_diagnostic_fn: Callable[[Any, Any], Mapping[str, float]] | None = None, selection_metric_fn: Callable[[float, Mapping[str, float] | None], float] | None = None, selection_metric_name: str | None = None, stability_metric_fn: Callable[[float, Mapping[str, float] | None], Mapping[str, float]] | None = None, stability_flag_key: str = 'checkpoint_is_stable') tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]#
Optimize one posterior model against a dataset reader.
Parameters#
- model:
Trainable Equinox-style model to optimize.
- dataset:
Reader object exposing
iter_batchesfor train and optional validation splits.- loss_fn:
Callable receiving
(model, batch)and returning the scalar loss to minimize.- eval_loss_fn:
Optional loss used for validation. Defaults to
loss_fnwhen not provided.
Returns#
- tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]
Best model, train-loss history, validation-loss history, and a metadata dictionary describing checkpointing and stopping behavior.
Notes#
The trainer monitors validation loss when a validation split exists and otherwise falls back to training loss for best-model selection. Checkpoints may be written after each epoch when enabled.
- property _min_delta: float#
- _should_stop(patience_counter: int) bool#
- class petitRADTRANS.sbi.TrainingConfig#
Configuration for posterior optimization.
- learning_rate: float = 0.001#
- batch_size: int = 32#
- num_epochs: int = 10#
- parameter_space: str = 'unconstrained'#
- seed: int = 0#
- shuffle_train: bool = True#
- early_stopping: EarlyStoppingConfig | None = None#
- checkpoint_directory: str | None = None#
- checkpoint_backend: str = 'auto'#
- resume_from_checkpoint: bool = False#
- data_parallel: bool = False#
- gradient_clip_norm: float | None = 1.0#
- weight_decay: float = 0.0#
- use_cosine_schedule: bool = False#
- warmup_fraction: float = 0.02#
- warmup_epochs: float | None = None#
- min_learning_rate: float = 1e-06#
- lr_schedule_total_epochs: float | None = None#
- verbose_diagnostics: bool = False#
- diagnostics_output_directory: str | None = None#
- diagnostics_plot_interval: int = 1#
- n_validation_diagnostic_batches: int = 4#