petitRADTRANS.sbi.dataset#
Dataset storage interfaces for SBI simulation corpora.
Classes#
Named dataset partitions used during training and evaluation. |
|
Deterministic policies for assigning generated samples to dataset splits. |
|
Describe a stored SBI simulation corpus. |
|
Append-only writer for chunked simulation datasets. |
|
Backend-independent interface for reading and writing simulation data. |
|
Read-only handle for a stored chunked simulation dataset. |
|
Return value for one-call simulation corpus generation. |
|
Lightweight reader yielding normalized ObservationBlock batches for training. |
|
Append-only Zarr-backed writer for simulation corpora. |
|
Concrete Zarr-backed store for chunked simulation corpora. |
|
Append-only HDF5-backed writer for simulation corpora. |
|
Read-only handle for an HDF5-backed simulation dataset. |
|
HDF5-backed store for simulation corpora. |
Functions#
|
Generate and persist a simulation corpus in one call. |
Module Contents#
- class petitRADTRANS.sbi.dataset.DatasetSplit#
Bases:
str,enum.EnumNamed dataset partitions used during training and evaluation.
- TRAIN = 'train'#
- VALIDATION = 'validation'#
- TEST = 'test'#
- BENCHMARK = 'benchmark'#
- class petitRADTRANS.sbi.dataset.SplitSamplingPolicy#
Bases:
str,enum.EnumDeterministic policies for assigning generated samples to dataset splits.
- SEQUENTIAL = 'sequential'#
- ROUND_ROBIN = 'round_robin'#
- SHUFFLED = 'shuffled'#
- class petitRADTRANS.sbi.dataset.SimulationDatasetManifest#
Describe a stored SBI simulation corpus.
- task_name: str#
- parameter_names: tuple[str, Ellipsis]#
- n_simulations: int#
- splits: Mapping[str, int]#
- storage_uri: str | None = None#
- task_fingerprint: str | None = None#
- observation_schema: Mapping[str, Any]#
- available_fields: tuple[str, Ellipsis] = ()#
- metadata: Mapping[str, Any]#
- classmethod from_task(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int = 0, splits: Mapping[str, int] | None = None, metadata: Mapping[str, Any] | None = None) SimulationDatasetManifest#
- to_payload() dict[str, Any]#
- classmethod from_payload(payload: Mapping[str, Any]) SimulationDatasetManifest#
- class petitRADTRANS.sbi.dataset.SimulationDatasetWriter#
Bases:
abc.ABCAppend-only writer for chunked simulation datasets.
- abstractmethod append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) None#
Persist one simulation batch to the selected split.
- abstractmethod finalize() SimulationDatasetManifest#
Seal the dataset and return its manifest.
- class petitRADTRANS.sbi.dataset.SimulationDatasetStore#
Bases:
abc.ABCBackend-independent interface for reading and writing simulation data.
- abstractmethod create_writer(manifest: SimulationDatasetManifest) SimulationDatasetWriter#
Create a writer for a new simulation dataset.
- abstractmethod open(manifest_or_uri: SimulationDatasetManifest | str) Any#
Open a stored dataset for training or evaluation.
- class petitRADTRANS.sbi.dataset.StoredSimulationDataset#
Read-only handle for a stored chunked simulation dataset.
- manifest: SimulationDatasetManifest#
- storage_uri: str#
- root: Any#
- iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) Iterator[dict[str, Any]]#
Yield row-slices of one split without loading it all into RAM.
Parameters#
- split:
Dataset split to iterate.
- chunk_size:
Number of rows per yielded chunk.
- indices:
Optional integer array of row indices to read. When supplied only those rows are fetched in ascending order. When
Noneevery row in the split is visited sequentially.- observation_fields:
Optional set of observation field names to load. When
Noneevery dataset field is loaded. Pass_OBSERVATION_TRAINING_FIELDSto skip large unused fields such ascovarianceduring training.
Yields#
- dict[str, Any]
Same structure as
read_split()but covering onlychunk_sizerows per iteration.
- read_split(split: DatasetSplit = DatasetSplit.TRAIN) dict[str, Any]#
- class petitRADTRANS.sbi.dataset.GeneratedSimulationDataset#
Return value for one-call simulation corpus generation.
- manifest: SimulationDatasetManifest#
- artifact_metadata: petitRADTRANS.sbi.artifacts.ArtifactMetadata#
- preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | None = None#
- class petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader#
Lightweight reader yielding normalized ObservationBlock batches for training.
- preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata#
- _split_cache: dict#
- iter_batches(split: DatasetSplit = DatasetSplit.TRAIN, batch_size: int = 32, shuffle: bool = False, seed: int | None = None, parameter_space: str = 'physical', encoder: Any = None) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#
Yield mini-batches of normalized observations and matched parameters.
Parameters#
- split:
Dataset split to iterate over.
- batch_size:
Number of samples yielded in each batch.
- shuffle:
Whether to shuffle sample order within the requested split.
- seed:
Optional random seed used when
shuffle=True.- parameter_space:
Parameter representation returned in each batch. Supported values are
'physical','cube', and'unconstrained'.- encoder:
Optional encoder used to convert block lists into dense embedding arrays before batches are yielded.
Returns#
- Iterator[PosteriorBatch]
Iterator over batches containing parameters, observations, and small metadata dictionaries describing the batch provenance.
Notes#
Splits with fewer than
_STREAMING_THRESHOLDrows are cached in RAM after the first read so that repeated epoch calls (e.g. validation) pay only slicing cost. Larger splits are streamed directly from disk one batch at a time to avoid loading the full dataset into memory.
- _iter_batches_cached(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#
Iterate using the full-split RAM cache (small splits).
- _iter_batches_streaming(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]#
Iterate by streaming rows directly from disk (large splits).
When
shuffle=Truea global index permutation is computed in RAM (one integer per simulation row — negligible memory) and used to fetch HDF5 rows in sorted sub-windows ofbatch_size, satisfying h5py’s monotonic-index requirement while still presenting shuffled order to the training loop.
- static _select_parameters(split_data: dict, parameter_space: str) Any#
Return the parameter array for the requested coordinate space.
- class petitRADTRANS.sbi.dataset.ZarrSimulationDatasetWriter(manifest: SimulationDatasetManifest, root: Any, chunk_size: int = 256)#
Bases:
SimulationDatasetWriterAppend-only Zarr-backed writer for simulation corpora.
- _manifest#
- _root#
- _chunk_size = 256#
- _split_counts#
- _available_fields#
- static _get_or_create_group(parent: Any, name: str) Any#
- _write_manifest() None#
- _append_array(group: Any, name: str, values: Any, leading_chunk: int | None = None) None#
- _append_observations(split_group: Any, observations: Mapping[str, Any]) None#
- append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) None#
Persist one simulation batch to the selected split.
- finalize() SimulationDatasetManifest#
Seal the dataset and return its manifest.
- class petitRADTRANS.sbi.dataset.ZarrSimulationDatasetStore(chunk_size: int = 256)#
Bases:
SimulationDatasetStoreConcrete Zarr-backed store for chunked simulation corpora.
- chunk_size = 256#
- static _require_zarr() None#
- create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') SimulationDatasetWriter#
Create a writer for a new simulation dataset.
- open(manifest_or_uri: SimulationDatasetManifest | str) StoredSimulationDataset#
Open a stored dataset for training or evaluation.
- class petitRADTRANS.sbi.dataset.HDF5SimulationDatasetWriter(manifest: SimulationDatasetManifest, file: Any, chunk_size: int = 256)#
Bases:
SimulationDatasetWriterAppend-only HDF5-backed writer for simulation corpora.
All simulation data is stored in a single
.h5file, which avoids the large number of chunk files produced by the Zarr directory-store backend. Datasets are created with an unlimited first dimension (maxshape=(None, ...)) so that batches can be appended incrementally without pre-allocation.- _manifest#
- _file#
- _chunk_size = 256#
- _split_counts#
- _available_fields#
- _write_manifest() None#
- static _get_or_create_group(parent: Any, name: str) Any#
- _append_array(group: Any, name: str, values: Any) None#
- _append_observations(split_group: Any, observations: Mapping[str, Any]) None#
- append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) None#
Persist one simulation batch to the selected split.
- finalize() SimulationDatasetManifest#
Seal the dataset and return its manifest.
- class petitRADTRANS.sbi.dataset.HDF5StoredSimulationDataset#
Read-only handle for an HDF5-backed simulation dataset.
- manifest: SimulationDatasetManifest#
- storage_uri: str#
- _file_path: str#
- iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) Iterator[dict[str, Any]]#
Yield row-slices of one split without loading it all into RAM.
Parameters#
- split:
Dataset split to iterate.
- chunk_size:
Number of rows per yielded chunk.
- indices:
Optional integer array of row indices to read. When supplied only those rows are fetched in ascending order. When
Noneevery row in the split is visited sequentially.- observation_fields:
Optional set of observation field names to load. When
Noneevery dataset field is loaded. Pass_OBSERVATION_TRAINING_FIELDSto skip large unused fields such ascovarianceduring training.
Yields#
- dict[str, Any]
Same structure as
read_split()but covering onlychunk_sizerows at a time. The HDF5 file is opened once and kept open for the duration of the iteration.
- read_split(split: DatasetSplit = DatasetSplit.TRAIN) dict[str, Any]#
- class petitRADTRANS.sbi.dataset.HDF5SimulationDatasetStore(chunk_size: int = 256)#
Bases:
SimulationDatasetStoreHDF5-backed store for simulation corpora.
Stores all simulation data for a corpus in a single
.h5file, keeping file counts at 1 regardless of the number of simulations or splits. Requiresh5py(already present in thejaxprtenvironment).- chunk_size = 256#
- static _require_h5py() None#
- create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') HDF5SimulationDatasetWriter#
Create a writer for a new simulation dataset.
- open(manifest_or_uri: SimulationDatasetManifest | str) HDF5StoredSimulationDataset#
Open a stored dataset for training or evaluation.
- petitRADTRANS.sbi.dataset.generate_simulation_dataset(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int, chunk_size: int = 256, split_counts: Mapping[DatasetSplit | str, int] | None = None, split_fractions: Mapping[DatasetSplit | str, float] | None = None, split_policy: SplitSamplingPolicy | str = SplitSamplingPolicy.SEQUENTIAL, split_seed: int | None = None, simulator: Any = None, include_preprocessing_metadata: bool = True, dataset_version: str = '0.1.0', resume: bool = False, backend: str = 'hdf5', data_parallel: bool | None = None, store_covariance: bool = False) GeneratedSimulationDataset#
Generate and persist a simulation corpus in one call.
Parameters#
- backend:
Storage backend to use.
'hdf5'(default) writes all data into a single.h5file, keeping the file count at 1 regardless of the number of simulations.'zarr'uses the Zarr directory store which produces one file per compressed chunk.- data_parallel:
When
True(orNonewith multiple JAX devices), the vmapped RT kernel is distributed across devices usingjax.pmap. The effective per-iteration batch size is automatically scaled by the device count so each device processes the configuredsimulation_config.batch_sizesamples.- store_covariance:
When
Truethe full covariance matrix for each simulated spectrum is written to disk. WhenFalse(default) only the covariance diagonal is stored under thecovariancefield to reduce storage pressure.