petitRADTRANS.sbi.dataset

Contents

petitRADTRANS.sbi.dataset#

Dataset storage interfaces for SBI simulation corpora.

Classes#

DatasetSplit

Named dataset partitions used during training and evaluation.

SplitSamplingPolicy

Deterministic policies for assigning generated samples to dataset splits.

SimulationDatasetManifest

Describe a stored SBI simulation corpus.

SimulationDatasetWriter

Append-only writer for chunked simulation datasets.

SimulationDatasetStore

Backend-independent interface for reading and writing simulation data.

StoredSimulationDataset

Read-only handle for a stored chunked simulation dataset.

GeneratedSimulationDataset

Return value for one-call simulation corpus generation.

NormalizedObservationDatasetReader

Lightweight reader yielding normalized ObservationBlock batches for training.

ZarrSimulationDatasetWriter

Append-only Zarr-backed writer for simulation corpora.

ZarrSimulationDatasetStore

Concrete Zarr-backed store for chunked simulation corpora.

HDF5SimulationDatasetWriter

Append-only HDF5-backed writer for simulation corpora.

HDF5StoredSimulationDataset

Read-only handle for an HDF5-backed simulation dataset.

HDF5SimulationDatasetStore

HDF5-backed store for simulation corpora.

Functions#

generate_simulation_dataset(→ GeneratedSimulationDataset)

Generate and persist a simulation corpus in one call.

Module Contents#

class petitRADTRANS.sbi.dataset.DatasetSplit#

Bases: str, enum.Enum

Named dataset partitions used during training and evaluation.

TRAIN = 'train'#
VALIDATION = 'validation'#
TEST = 'test'#
BENCHMARK = 'benchmark'#
class petitRADTRANS.sbi.dataset.SplitSamplingPolicy#

Bases: str, enum.Enum

Deterministic policies for assigning generated samples to dataset splits.

SEQUENTIAL = 'sequential'#
ROUND_ROBIN = 'round_robin'#
SHUFFLED = 'shuffled'#
class petitRADTRANS.sbi.dataset.SimulationDatasetManifest#

Describe a stored SBI simulation corpus.

task_name: str#
parameter_names: tuple[str, Ellipsis]#
n_simulations: int#
splits: Mapping[str, int]#
storage_uri: str | None = None#
task_fingerprint: str | None = None#
observation_schema: Mapping[str, Any]#
available_fields: tuple[str, Ellipsis] = ()#
metadata: Mapping[str, Any]#
classmethod from_task(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int = 0, splits: Mapping[str, int] | None = None, metadata: Mapping[str, Any] | None = None) SimulationDatasetManifest#
to_payload() dict[str, Any]#
classmethod from_payload(payload: Mapping[str, Any]) SimulationDatasetManifest#
class petitRADTRANS.sbi.dataset.SimulationDatasetWriter#

Bases: abc.ABC

Append-only writer for chunked simulation datasets.

abstractmethod append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) None#

Persist one simulation batch to the selected split.

abstractmethod finalize() SimulationDatasetManifest#

Seal the dataset and return its manifest.

class petitRADTRANS.sbi.dataset.SimulationDatasetStore#

Bases: abc.ABC

Backend-independent interface for reading and writing simulation data.

abstractmethod create_writer(manifest: SimulationDatasetManifest) SimulationDatasetWriter#

Create a writer for a new simulation dataset.

abstractmethod open(manifest_or_uri: SimulationDatasetManifest | str) Any#

Open a stored dataset for training or evaluation.

class petitRADTRANS.sbi.dataset.StoredSimulationDataset#

Read-only handle for a stored chunked simulation dataset.

manifest: SimulationDatasetManifest#
storage_uri: str#
root: Any#
iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) Iterator[dict[str, Any]]#

Yield row-slices of one split without loading it all into RAM.

Parameters#

split:

Dataset split to iterate.

chunk_size:

Number of rows per yielded chunk.

indices:

Optional integer array of row indices to read. When supplied only those rows are fetched in ascending order. When None every row in the split is visited sequentially.

observation_fields:

Optional set of observation field names to load. When None every dataset field is loaded. Pass _OBSERVATION_TRAINING_FIELDS to skip large unused fields such as covariance during training.

Yields#

dict[str, Any]

Same structure as read_split() but covering only chunk_size rows per iteration.

read_split(split: DatasetSplit = DatasetSplit.TRAIN) dict[str, Any]#
class petitRADTRANS.sbi.dataset.GeneratedSimulationDataset#

Return value for one-call simulation corpus generation.

manifest: SimulationDatasetManifest#
artifact_metadata: petitRADTRANS.sbi.artifacts.ArtifactMetadata#
preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | None = None#
class petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader#

Lightweight reader yielding normalized ObservationBlock batches for training.

dataset: StoredSimulationDataset | HDF5StoredSimulationDataset#
preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata#
_split_cache: dict#
iter_batches(split: DatasetSplit = DatasetSplit.TRAIN, batch_size: int = 32, shuffle: bool = False, seed: int | None = None, parameter_space: str = 'physical', encoder: Any = None) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#

Yield mini-batches of normalized observations and matched parameters.

Parameters#

split:

Dataset split to iterate over.

batch_size:

Number of samples yielded in each batch.

shuffle:

Whether to shuffle sample order within the requested split.

seed:

Optional random seed used when shuffle=True.

parameter_space:

Parameter representation returned in each batch. Supported values are 'physical', 'cube', and 'unconstrained'.

encoder:

Optional encoder used to convert block lists into dense embedding arrays before batches are yielded.

Returns#

Iterator[PosteriorBatch]

Iterator over batches containing parameters, observations, and small metadata dictionaries describing the batch provenance.

Notes#

Splits with fewer than _STREAMING_THRESHOLD rows are cached in RAM after the first read so that repeated epoch calls (e.g. validation) pay only slicing cost. Larger splits are streamed directly from disk one batch at a time to avoid loading the full dataset into memory.

_iter_batches_cached(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#

Iterate using the full-split RAM cache (small splits).

_iter_batches_streaming(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#

Iterate by streaming rows directly from disk (large splits).

When shuffle=True a global index permutation is computed in RAM (one integer per simulation row — negligible memory) and used to fetch HDF5 rows in sorted sub-windows of batch_size, satisfying h5py’s monotonic-index requirement while still presenting shuffled order to the training loop.

static _select_parameters(split_data: dict, parameter_space: str) Any#

Return the parameter array for the requested coordinate space.

class petitRADTRANS.sbi.dataset.ZarrSimulationDatasetWriter(manifest: SimulationDatasetManifest, root: Any, chunk_size: int = 256)#

Bases: SimulationDatasetWriter

Append-only Zarr-backed writer for simulation corpora.

_manifest#
_root#
_chunk_size = 256#
_split_counts#
_available_fields#
static _get_or_create_group(parent: Any, name: str) Any#
_write_manifest() None#
_append_array(group: Any, name: str, values: Any, leading_chunk: int | None = None) None#
_append_observations(split_group: Any, observations: Mapping[str, Any]) None#
append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) None#

Persist one simulation batch to the selected split.

finalize() SimulationDatasetManifest#

Seal the dataset and return its manifest.

class petitRADTRANS.sbi.dataset.ZarrSimulationDatasetStore(chunk_size: int = 256)#

Bases: SimulationDatasetStore

Concrete Zarr-backed store for chunked simulation corpora.

chunk_size = 256#
static _require_zarr() None#
create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') SimulationDatasetWriter#

Create a writer for a new simulation dataset.

open(manifest_or_uri: SimulationDatasetManifest | str) StoredSimulationDataset#

Open a stored dataset for training or evaluation.

class petitRADTRANS.sbi.dataset.HDF5SimulationDatasetWriter(manifest: SimulationDatasetManifest, file: Any, chunk_size: int = 256)#

Bases: SimulationDatasetWriter

Append-only HDF5-backed writer for simulation corpora.

All simulation data is stored in a single .h5 file, which avoids the large number of chunk files produced by the Zarr directory-store backend. Datasets are created with an unlimited first dimension (maxshape=(None, ...)) so that batches can be appended incrementally without pre-allocation.

_manifest#
_file#
_chunk_size = 256#
_split_counts#
_available_fields#
_write_manifest() None#
static _get_or_create_group(parent: Any, name: str) Any#
_append_array(group: Any, name: str, values: Any) None#
_append_observations(split_group: Any, observations: Mapping[str, Any]) None#
append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) None#

Persist one simulation batch to the selected split.

finalize() SimulationDatasetManifest#

Seal the dataset and return its manifest.

class petitRADTRANS.sbi.dataset.HDF5StoredSimulationDataset#

Read-only handle for an HDF5-backed simulation dataset.

manifest: SimulationDatasetManifest#
storage_uri: str#
_file_path: str#
iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) Iterator[dict[str, Any]]#

Yield row-slices of one split without loading it all into RAM.

Parameters#

split:

Dataset split to iterate.

chunk_size:

Number of rows per yielded chunk.

indices:

Optional integer array of row indices to read. When supplied only those rows are fetched in ascending order. When None every row in the split is visited sequentially.

observation_fields:

Optional set of observation field names to load. When None every dataset field is loaded. Pass _OBSERVATION_TRAINING_FIELDS to skip large unused fields such as covariance during training.

Yields#

dict[str, Any]

Same structure as read_split() but covering only chunk_size rows at a time. The HDF5 file is opened once and kept open for the duration of the iteration.

read_split(split: DatasetSplit = DatasetSplit.TRAIN) dict[str, Any]#
class petitRADTRANS.sbi.dataset.HDF5SimulationDatasetStore(chunk_size: int = 256)#

Bases: SimulationDatasetStore

HDF5-backed store for simulation corpora.

Stores all simulation data for a corpus in a single .h5 file, keeping file counts at 1 regardless of the number of simulations or splits. Requires h5py (already present in the jaxprt environment).

chunk_size = 256#
static _require_h5py() None#
create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') HDF5SimulationDatasetWriter#

Create a writer for a new simulation dataset.

open(manifest_or_uri: SimulationDatasetManifest | str) HDF5StoredSimulationDataset#

Open a stored dataset for training or evaluation.

petitRADTRANS.sbi.dataset.generate_simulation_dataset(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int, chunk_size: int = 256, split_counts: Mapping[DatasetSplit | str, int] | None = None, split_fractions: Mapping[DatasetSplit | str, float] | None = None, split_policy: SplitSamplingPolicy | str = SplitSamplingPolicy.SEQUENTIAL, split_seed: int | None = None, simulator: Any = None, include_preprocessing_metadata: bool = True, dataset_version: str = '0.1.0', resume: bool = False, backend: str = 'hdf5', data_parallel: bool | None = None, store_covariance: bool = False) GeneratedSimulationDataset#

Generate and persist a simulation corpus in one call.

Parameters#

backend:

Storage backend to use. 'hdf5' (default) writes all data into a single .h5 file, keeping the file count at 1 regardless of the number of simulations. 'zarr' uses the Zarr directory store which produces one file per compressed chunk.

data_parallel:

When True (or None with multiple JAX devices), the vmapped RT kernel is distributed across devices using jax.pmap. The effective per-iteration batch size is automatically scaled by the device count so each device processes the configured simulation_config.batch_size samples.

store_covariance:

When True the full covariance matrix for each simulated spectrum is written to disk. When False (default) only the covariance diagonal is stored under the covariance field to reduce storage pressure.