petitRADTRANS.sbi.dataset
=========================

.. py:module:: petitRADTRANS.sbi.dataset

.. autoapi-nested-parse::

   Dataset storage interfaces for SBI simulation corpora.



Classes
-------

.. autoapisummary::

   petitRADTRANS.sbi.dataset.DatasetSplit
   petitRADTRANS.sbi.dataset.SplitSamplingPolicy
   petitRADTRANS.sbi.dataset.SimulationDatasetManifest
   petitRADTRANS.sbi.dataset.SimulationDatasetWriter
   petitRADTRANS.sbi.dataset.SimulationDatasetStore
   petitRADTRANS.sbi.dataset.StoredSimulationDataset
   petitRADTRANS.sbi.dataset.GeneratedSimulationDataset
   petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader
   petitRADTRANS.sbi.dataset.ZarrSimulationDatasetWriter
   petitRADTRANS.sbi.dataset.ZarrSimulationDatasetStore
   petitRADTRANS.sbi.dataset.HDF5SimulationDatasetWriter
   petitRADTRANS.sbi.dataset.HDF5StoredSimulationDataset
   petitRADTRANS.sbi.dataset.HDF5SimulationDatasetStore


Functions
---------

.. autoapisummary::

   petitRADTRANS.sbi.dataset.generate_simulation_dataset


Module Contents
---------------

.. py:class:: DatasetSplit

   Bases: :py:obj:`str`, :py:obj:`enum.Enum`


   Named dataset partitions used during training and evaluation.


   .. py:attribute:: TRAIN
      :value: 'train'



   .. py:attribute:: VALIDATION
      :value: 'validation'



   .. py:attribute:: TEST
      :value: 'test'



   .. py:attribute:: BENCHMARK
      :value: 'benchmark'



.. py:class:: SplitSamplingPolicy

   Bases: :py:obj:`str`, :py:obj:`enum.Enum`


   Deterministic policies for assigning generated samples to dataset splits.


   .. py:attribute:: SEQUENTIAL
      :value: 'sequential'



   .. py:attribute:: ROUND_ROBIN
      :value: 'round_robin'



   .. py:attribute:: SHUFFLED
      :value: 'shuffled'



.. py:class:: SimulationDatasetManifest

   Describe a stored SBI simulation corpus.


   .. py:attribute:: task_name
      :type:  str


   .. py:attribute:: parameter_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: n_simulations
      :type:  int


   .. py:attribute:: splits
      :type:  Mapping[str, int]


   .. py:attribute:: storage_uri
      :type:  str | None
      :value: None



   .. py:attribute:: task_fingerprint
      :type:  str | None
      :value: None



   .. py:attribute:: observation_schema
      :type:  Mapping[str, Any]


   .. py:attribute:: available_fields
      :type:  tuple[str, Ellipsis]
      :value: ()



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


   .. py:method:: from_task(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int = 0, splits: Mapping[str, int] | None = None, metadata: Mapping[str, Any] | None = None) -> SimulationDatasetManifest
      :classmethod:



   .. py:method:: to_payload() -> dict[str, Any]


   .. py:method:: from_payload(payload: Mapping[str, Any]) -> SimulationDatasetManifest
      :classmethod:



.. py:class:: SimulationDatasetWriter

   Bases: :py:obj:`abc.ABC`


   Append-only writer for chunked simulation datasets.


   .. py:method:: append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) -> None
      :abstractmethod:


      Persist one simulation batch to the selected split.



   .. py:method:: finalize() -> SimulationDatasetManifest
      :abstractmethod:


      Seal the dataset and return its manifest.



.. py:class:: SimulationDatasetStore

   Bases: :py:obj:`abc.ABC`


   Backend-independent interface for reading and writing simulation data.


   .. py:method:: create_writer(manifest: SimulationDatasetManifest) -> SimulationDatasetWriter
      :abstractmethod:


      Create a writer for a new simulation dataset.



   .. py:method:: open(manifest_or_uri: SimulationDatasetManifest | str) -> Any
      :abstractmethod:


      Open a stored dataset for training or evaluation.



.. py:class:: StoredSimulationDataset

   Read-only handle for a stored chunked simulation dataset.


   .. py:attribute:: manifest
      :type:  SimulationDatasetManifest


   .. py:attribute:: storage_uri
      :type:  str


   .. py:attribute:: root
      :type:  Any


   .. py:method:: iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) -> Iterator[dict[str, Any]]

      Yield row-slices of one split without loading it all into RAM.

      Parameters
      ----------
      split:
          Dataset split to iterate.
      chunk_size:
          Number of rows per yielded chunk.
      indices:
          Optional integer array of row indices to read.  When supplied only
          those rows are fetched in ascending order.  When ``None`` every row
          in the split is visited sequentially.
      observation_fields:
          Optional set of observation field names to load.  When ``None``
          every dataset field is loaded.  Pass
          ``_OBSERVATION_TRAINING_FIELDS`` to skip large unused fields such
          as ``covariance`` during training.

      Yields
      ------
      dict[str, Any]
          Same structure as :meth:`read_split` but covering only
          ``chunk_size`` rows per iteration.



   .. py:method:: read_split(split: DatasetSplit = DatasetSplit.TRAIN) -> dict[str, Any]


.. py:class:: GeneratedSimulationDataset

   Return value for one-call simulation corpus generation.


   .. py:attribute:: manifest
      :type:  SimulationDatasetManifest


   .. py:attribute:: artifact_metadata
      :type:  petitRADTRANS.sbi.artifacts.ArtifactMetadata


   .. py:attribute:: preprocessing_metadata
      :type:  petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | None
      :value: None



.. py:class:: NormalizedObservationDatasetReader

   Lightweight reader yielding normalized ObservationBlock batches for training.


   .. py:attribute:: dataset
      :type:  StoredSimulationDataset | HDF5StoredSimulationDataset


   .. py:attribute:: preprocessing_metadata
      :type:  petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata


   .. py:attribute:: _split_cache
      :type:  dict


   .. py:method:: iter_batches(split: DatasetSplit = DatasetSplit.TRAIN, batch_size: int = 32, shuffle: bool = False, seed: int | None = None, parameter_space: str = 'physical', encoder: Any = None) -> Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]

      Yield mini-batches of normalized observations and matched parameters.

      Parameters
      ----------
      split:
          Dataset split to iterate over.
      batch_size:
          Number of samples yielded in each batch.
      shuffle:
          Whether to shuffle sample order within the requested split.
      seed:
          Optional random seed used when ``shuffle=True``.
      parameter_space:
          Parameter representation returned in each batch. Supported values
          are ``'physical'``, ``'cube'``, and ``'unconstrained'``.
      encoder:
          Optional encoder used to convert block lists into dense embedding
          arrays before batches are yielded.

      Returns
      -------
      Iterator[PosteriorBatch]
          Iterator over batches containing parameters, observations, and small
          metadata dictionaries describing the batch provenance.

      Notes
      -----
      Splits with fewer than ``_STREAMING_THRESHOLD`` rows are cached in RAM
      after the first read so that repeated epoch calls (e.g. validation) pay
      only slicing cost.  Larger splits are streamed directly from disk one
      batch at a time to avoid loading the full dataset into memory.



   .. py:method:: _iter_batches_cached(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) -> Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]

      Iterate using the full-split RAM cache (small splits).



   .. py:method:: _iter_batches_streaming(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) -> Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]

      Iterate by streaming rows directly from disk (large splits).

      When ``shuffle=True`` a global index permutation is computed in RAM
      (one integer per simulation row — negligible memory) and used to fetch
      HDF5 rows in sorted sub-windows of ``batch_size``, satisfying h5py's
      monotonic-index requirement while still presenting shuffled order to the
      training loop.



   .. py:method:: _select_parameters(split_data: dict, parameter_space: str) -> Any
      :staticmethod:


      Return the parameter array for the requested coordinate space.



.. py:class:: ZarrSimulationDatasetWriter(manifest: SimulationDatasetManifest, root: Any, chunk_size: int = 256)

   Bases: :py:obj:`SimulationDatasetWriter`


   Append-only Zarr-backed writer for simulation corpora.


   .. py:attribute:: _manifest


   .. py:attribute:: _root


   .. py:attribute:: _chunk_size
      :value: 256



   .. py:attribute:: _split_counts


   .. py:attribute:: _available_fields


   .. py:method:: _get_or_create_group(parent: Any, name: str) -> Any
      :staticmethod:



   .. py:method:: _write_manifest() -> None


   .. py:method:: _append_array(group: Any, name: str, values: Any, leading_chunk: int | None = None) -> None


   .. py:method:: _append_observations(split_group: Any, observations: Mapping[str, Any]) -> None


   .. py:method:: append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) -> None

      Persist one simulation batch to the selected split.



   .. py:method:: finalize() -> SimulationDatasetManifest

      Seal the dataset and return its manifest.



.. py:class:: ZarrSimulationDatasetStore(chunk_size: int = 256)

   Bases: :py:obj:`SimulationDatasetStore`


   Concrete Zarr-backed store for chunked simulation corpora.


   .. py:attribute:: chunk_size
      :value: 256



   .. py:method:: _require_zarr() -> None
      :staticmethod:



   .. py:method:: create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') -> SimulationDatasetWriter

      Create a writer for a new simulation dataset.



   .. py:method:: open(manifest_or_uri: SimulationDatasetManifest | str) -> StoredSimulationDataset

      Open a stored dataset for training or evaluation.



.. py:class:: HDF5SimulationDatasetWriter(manifest: SimulationDatasetManifest, file: Any, chunk_size: int = 256)

   Bases: :py:obj:`SimulationDatasetWriter`


   Append-only HDF5-backed writer for simulation corpora.

   All simulation data is stored in a single ``.h5`` file, which avoids the
   large number of chunk files produced by the Zarr directory-store backend.
   Datasets are created with an unlimited first dimension (``maxshape=(None,
   ...)``) so that batches can be appended incrementally without pre-allocation.


   .. py:attribute:: _manifest


   .. py:attribute:: _file


   .. py:attribute:: _chunk_size
      :value: 256



   .. py:attribute:: _split_counts


   .. py:attribute:: _available_fields


   .. py:method:: _write_manifest() -> None


   .. py:method:: _get_or_create_group(parent: Any, name: str) -> Any
      :staticmethod:



   .. py:method:: _append_array(group: Any, name: str, values: Any) -> None


   .. py:method:: _append_observations(split_group: Any, observations: Mapping[str, Any]) -> None


   .. py:method:: append(batch: petitRADTRANS.sbi.simulator.SimulationBatch, split: DatasetSplit = DatasetSplit.TRAIN) -> None

      Persist one simulation batch to the selected split.



   .. py:method:: finalize() -> SimulationDatasetManifest

      Seal the dataset and return its manifest.



.. py:class:: HDF5StoredSimulationDataset

   Read-only handle for an HDF5-backed simulation dataset.


   .. py:attribute:: manifest
      :type:  SimulationDatasetManifest


   .. py:attribute:: storage_uri
      :type:  str


   .. py:attribute:: _file_path
      :type:  str


   .. py:method:: iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) -> Iterator[dict[str, Any]]

      Yield row-slices of one split without loading it all into RAM.

      Parameters
      ----------
      split:
          Dataset split to iterate.
      chunk_size:
          Number of rows per yielded chunk.
      indices:
          Optional integer array of row indices to read.  When supplied only
          those rows are fetched in ascending order.  When ``None`` every row
          in the split is visited sequentially.
      observation_fields:
          Optional set of observation field names to load.  When ``None``
          every dataset field is loaded.  Pass
          ``_OBSERVATION_TRAINING_FIELDS`` to skip large unused fields such
          as ``covariance`` during training.

      Yields
      ------
      dict[str, Any]
          Same structure as :meth:`read_split` but covering only
          ``chunk_size`` rows at a time.  The HDF5 file is opened once and
          kept open for the duration of the iteration.



   .. py:method:: read_split(split: DatasetSplit = DatasetSplit.TRAIN) -> dict[str, Any]


.. py:class:: HDF5SimulationDatasetStore(chunk_size: int = 256)

   Bases: :py:obj:`SimulationDatasetStore`


   HDF5-backed store for simulation corpora.

   Stores all simulation data for a corpus in a single ``.h5`` file,
   keeping file counts at 1 regardless of the number of simulations or splits.
   Requires ``h5py`` (already present in the ``jaxprt`` environment).


   .. py:attribute:: chunk_size
      :value: 256



   .. py:method:: _require_h5py() -> None
      :staticmethod:



   .. py:method:: create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') -> HDF5SimulationDatasetWriter

      Create a writer for a new simulation dataset.



   .. py:method:: open(manifest_or_uri: SimulationDatasetManifest | str) -> HDF5StoredSimulationDataset

      Open a stored dataset for training or evaluation.



.. py:function:: generate_simulation_dataset(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int, chunk_size: int = 256, split_counts: Mapping[DatasetSplit | str, int] | None = None, split_fractions: Mapping[DatasetSplit | str, float] | None = None, split_policy: SplitSamplingPolicy | str = SplitSamplingPolicy.SEQUENTIAL, split_seed: int | None = None, simulator: Any = None, include_preprocessing_metadata: bool = True, dataset_version: str = '0.1.0', resume: bool = False, backend: str = 'hdf5', data_parallel: bool | None = None, store_covariance: bool = False) -> GeneratedSimulationDataset

   Generate and persist a simulation corpus in one call.

   Parameters
   ----------
   backend:
       Storage backend to use. ``'hdf5'`` (default) writes all data into a
       single ``.h5`` file, keeping the file count at 1 regardless of the
       number of simulations. ``'zarr'`` uses the Zarr directory store
       which produces one file per compressed chunk.
   data_parallel:
       When ``True`` (or ``None`` with multiple JAX devices), the vmapped RT
       kernel is distributed across devices using ``jax.pmap``.  The
       effective per-iteration batch size is automatically scaled by the
       device count so each device processes the configured
       ``simulation_config.batch_size`` samples.
   store_covariance:
       When ``True`` the full covariance matrix for each simulated spectrum is
       written to disk.  When ``False`` (default) only the covariance diagonal
       is stored under the ``covariance`` field to reduce storage pressure.


