petitRADTRANS.sbi
=================

.. py:module:: petitRADTRANS.sbi

.. autoapi-nested-parse::

   Simulation-based inference interfaces for petitRADTRANS.

   The modules in :mod:`petitRADTRANS.sbi` provide a production-oriented
   architecture for amortized inference workflows built on top of the existing
   retrieval runtime. The package is intentionally lightweight at this stage: it
   defines the core task, simulation, dataset, posterior, and benchmarking
   interfaces without committing to one training backend or storage engine.



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/petitRADTRANS/sbi/artifacts/index
   /autoapi/petitRADTRANS/sbi/benchmark/index
   /autoapi/petitRADTRANS/sbi/calibration/index
   /autoapi/petitRADTRANS/sbi/compatibility/index
   /autoapi/petitRADTRANS/sbi/dataset/index
   /autoapi/petitRADTRANS/sbi/encoders/index
   /autoapi/petitRADTRANS/sbi/estimator_registry/index
   /autoapi/petitRADTRANS/sbi/flow_matching_posterior/index
   /autoapi/petitRADTRANS/sbi/flow_posterior/index
   /autoapi/petitRADTRANS/sbi/flows/index
   /autoapi/petitRADTRANS/sbi/inference/index
   /autoapi/petitRADTRANS/sbi/observation/index
   /autoapi/petitRADTRANS/sbi/plotting/index
   /autoapi/petitRADTRANS/sbi/posterior/index
   /autoapi/petitRADTRANS/sbi/posterior_base/index
   /autoapi/petitRADTRANS/sbi/preprocessing/index
   /autoapi/petitRADTRANS/sbi/simulator/index
   /autoapi/petitRADTRANS/sbi/task/index
   /autoapi/petitRADTRANS/sbi/training/index


Exceptions
----------

.. autoapisummary::

   petitRADTRANS.sbi.TaskCompatibilityError


Classes
-------

.. autoapisummary::

   petitRADTRANS.sbi.BenchmarkComparison
   petitRADTRANS.sbi.BenchmarkMetrics
   petitRADTRANS.sbi.RetrievalBenchmarkCase
   petitRADTRANS.sbi.RetrievalBenchmarkSuite
   petitRADTRANS.sbi.LocalSensitivityPointReport
   petitRADTRANS.sbi.LocalSensitivityReport
   petitRADTRANS.sbi.PosteriorPredictiveReport
   petitRADTRANS.sbi.SimulationBasedCalibrationReport
   petitRADTRANS.sbi.HDF5SimulationDatasetStore
   petitRADTRANS.sbi.SimulationDatasetStore
   petitRADTRANS.sbi.StoredSimulationDataset
   petitRADTRANS.sbi.ZarrSimulationDatasetStore
   petitRADTRANS.sbi.DatasetSplit
   petitRADTRANS.sbi.NormalizedObservationDatasetReader
   petitRADTRANS.sbi.HierarchicalObservationEncoder
   petitRADTRANS.sbi.PhotometryPointEncoder
   petitRADTRANS.sbi.SpectralConv1DEncoder
   petitRADTRANS.sbi.SpectralPatchEncoder
   petitRADTRANS.sbi.AmortizedRetrieval
   petitRADTRANS.sbi.AmortizedRetrievalResult
   petitRADTRANS.sbi.OODDiagnostic
   petitRADTRANS.sbi.ObservationBlock
   petitRADTRANS.sbi.ObservationEncoder
   petitRADTRANS.sbi.ObservationModality
   petitRADTRANS.sbi.ConditionalAutoregressiveFlowPosterior
   petitRADTRANS.sbi.ConditionalFlowPosterior
   petitRADTRANS.sbi.ConditionalNeuralAutoregressiveFlowPosterior
   petitRADTRANS.sbi.ConditionalSplineFlowPosterior
   petitRADTRANS.sbi.FlowMatchingPosterior
   petitRADTRANS.sbi.PersistentPosteriorEstimator
   petitRADTRANS.sbi.PosteriorBatch
   petitRADTRANS.sbi.PosteriorEstimator
   petitRADTRANS.sbi.PosteriorSamples
   petitRADTRANS.sbi.TaskPreprocessingMetadata
   petitRADTRANS.sbi.ProposalSampler
   petitRADTRANS.sbi.RuntimeSimulator
   petitRADTRANS.sbi.SimulationBatch
   petitRADTRANS.sbi.Simulator
   petitRADTRANS.sbi.NoiseModelConfig
   petitRADTRANS.sbi.ObservationSchema
   petitRADTRANS.sbi.ObservationValueConstraint
   petitRADTRANS.sbi.SBITask
   petitRADTRANS.sbi.SimulationConfig
   petitRADTRANS.sbi.EarlyStoppingConfig
   petitRADTRANS.sbi.SBITrainer
   petitRADTRANS.sbi.TrainingConfig


Functions
---------

.. autoapisummary::

   petitRADTRANS.sbi.generate_local_sensitivity_report
   petitRADTRANS.sbi.generate_posterior_predictive_report
   petitRADTRANS.sbi.generate_sbc_report
   petitRADTRANS.sbi.local_sensitivity_report_to_payload
   petitRADTRANS.sbi.generate_simulation_dataset
   petitRADTRANS.sbi.load_posterior_estimator
   petitRADTRANS.sbi.build_observation_block
   petitRADTRANS.sbi.build_observation_block_batch
   petitRADTRANS.sbi.build_observation_blocks_from_sample
   petitRADTRANS.sbi.plot_local_sensitivity_fisher_correlations
   petitRADTRANS.sbi.plot_local_sensitivity_jacobians
   petitRADTRANS.sbi.plot_local_sensitivity_singular_values
   petitRADTRANS.sbi.plot_posterior_corner
   petitRADTRANS.sbi.plot_posterior_marginals
   petitRADTRANS.sbi.plot_posterior_predictive_report
   petitRADTRANS.sbi.plot_sbc_rank_histograms
   petitRADTRANS.sbi.fit_task_preprocessing
   petitRADTRANS.sbi.normalize_observation_block
   petitRADTRANS.sbi.normalize_observation_blocks


Package Contents
----------------

.. py:class:: BenchmarkComparison

   Compare an amortized result to one or more exact retrieval baselines.


   .. py:attribute:: case_name
      :type:  str


   .. py:attribute:: amortized_result
      :type:  petitRADTRANS.sbi.inference.AmortizedRetrievalResult


   .. py:attribute:: exact_results
      :type:  Mapping[str, Any]


   .. py:attribute:: metrics
      :type:  BenchmarkMetrics


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: BenchmarkMetrics

   Metrics summarizing agreement and predictive performance.


   .. py:attribute:: calibration
      :type:  Mapping[str, float]


   .. py:attribute:: posterior_distance
      :type:  Mapping[str, float]


   .. py:attribute:: predictive_checks
      :type:  Mapping[str, float]


   .. py:attribute:: runtime
      :type:  Mapping[str, float]


.. py:class:: RetrievalBenchmarkCase

   One benchmark problem used to compare inference backends.


   .. py:attribute:: name
      :type:  str


   .. py:attribute:: task
      :type:  petitRADTRANS.sbi.task.SBITask


   .. py:attribute:: observation
      :type:  Any


   .. py:attribute:: reference_posterior
      :type:  Any
      :value: None



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: RetrievalBenchmarkSuite(cases: list[RetrievalBenchmarkCase])

   Run standardized benchmark comparisons for SBI tasks.


   .. py:attribute:: cases


   .. py:method:: run_case(case: RetrievalBenchmarkCase) -> BenchmarkComparison
      :abstractmethod:


      Run one benchmark case and compute comparison metrics.

      Parameters
      ----------
      case:
          Benchmark case describing the task, observation, and optional exact
          reference posterior to compare against.

      Returns
      -------
      BenchmarkComparison
          Comparison payload combining amortized and exact results together
          with any derived metrics.

      Notes
      -----
      The base class is intentionally abstract. Concrete suites are expected
      to bind exact and amortized inference backends and define the metric
      computations appropriate for the comparison.



   .. py:method:: run_all() -> list[BenchmarkComparison]

      Run all configured benchmark cases.



.. py:class:: LocalSensitivityPointReport

   Local linear-identifiability summary around one representative point.

   Attributes
   ----------
   label:
       Short human-readable label for the representative point.
   parameters:
       Physical parameter vector at which the Jacobian was evaluated.
   finite_difference_steps:
       Per-parameter finite-difference step sizes used during Jacobian
       construction.
   finite_difference_schemes:
       Per-parameter scheme labels such as ``'central'`` or ``'forward'``.
   whitened_jacobian:
       Observation Jacobian divided by observational uncertainty, with shape
       ``(n_observation_values, n_parameters)``.
   singular_values:
       Singular values of the whitened Jacobian.
   effective_rank:
       Number of singular values larger than the configured relative cutoff.
   condition_number:
       Condition number inferred from the singular spectrum.
   fisher_matrix:
       Approximate Fisher information matrix ``J^T J``.
   fisher_covariance:
       Damped pseudo-inverse of the Fisher matrix.
   fisher_correlation:
       Correlation matrix derived from the approximate Fisher covariance.
   parameter_sensitivity_norm:
       Per-parameter square-root Fisher diagonal.
   local_sigma:
       Approximate local standard deviation from the Fisher covariance.
   metadata:
       Auxiliary diagnostics such as the ridge term and failed columns.


   .. py:attribute:: label
      :type:  str


   .. py:attribute:: parameters
      :type:  numpy.ndarray


   .. py:attribute:: finite_difference_steps
      :type:  numpy.ndarray


   .. py:attribute:: finite_difference_schemes
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: whitened_jacobian
      :type:  numpy.ndarray


   .. py:attribute:: singular_values
      :type:  numpy.ndarray


   .. py:attribute:: effective_rank
      :type:  int


   .. py:attribute:: condition_number
      :type:  float


   .. py:attribute:: fisher_matrix
      :type:  numpy.ndarray


   .. py:attribute:: fisher_covariance
      :type:  numpy.ndarray


   .. py:attribute:: fisher_correlation
      :type:  numpy.ndarray


   .. py:attribute:: parameter_sensitivity_norm
      :type:  numpy.ndarray


   .. py:attribute:: local_sigma
      :type:  numpy.ndarray


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: LocalSensitivityReport

   Aggregate local information-content diagnostics for one observation.

   Attributes
   ----------
   parameter_names:
       Parameter ordering used throughout the report.
   posterior_mean:
       Posterior mean in physical parameter space.
   posterior_std:
       Posterior standard deviation in physical parameter space.
   posterior_median:
       Posterior median in physical parameter space.
   posterior_iqr:
       Posterior interquartile range in physical parameter space.
   representative_points:
       Local sensitivity summaries evaluated at representative posterior
       points such as the posterior mean and highest-density sample.
   aggregate_local_sigma:
       Median local Fisher sigma across representative points.
   aggregate_parameter_sensitivity_norm:
       Median parameter sensitivity norm across representative points.
   posterior_to_local_sigma_ratio:
       Ratio between posterior standard deviation and local Fisher sigma.
   parameter_diagnostics:
       Per-parameter heuristic summary separating weak data constraints from
       broader-than-local posterior structure.
   metadata:
       Auxiliary metadata such as quantile levels and observation slices.


   .. py:attribute:: parameter_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: posterior_mean
      :type:  numpy.ndarray


   .. py:attribute:: posterior_std
      :type:  numpy.ndarray


   .. py:attribute:: posterior_median
      :type:  numpy.ndarray


   .. py:attribute:: posterior_iqr
      :type:  numpy.ndarray


   .. py:attribute:: representative_points
      :type:  tuple[LocalSensitivityPointReport, Ellipsis]


   .. py:attribute:: aggregate_local_sigma
      :type:  numpy.ndarray


   .. py:attribute:: aggregate_parameter_sensitivity_norm
      :type:  numpy.ndarray


   .. py:attribute:: posterior_to_local_sigma_ratio
      :type:  numpy.ndarray


   .. py:attribute:: parameter_diagnostics
      :type:  Mapping[str, Mapping[str, Any]]


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: PosteriorPredictiveReport

   Aggregate posterior-predictive summaries for held-out observations.

   Attributes
   ----------
   observed_values:
       Observed values kept on the original observation scale.
   predictive_mean:
       Posterior-predictive mean curves or vectors.
   predictive_std:
       Posterior-predictive standard deviations.
   interval_lower, interval_upper:
       Central predictive interval bounds for the requested level.
   interval_coverage:
       Per-dataset fraction of observed points covered by the predictive
       interval.
   mean_absolute_error:
       Mean absolute deviation between predictive mean and observed values.
   metadata:
       Auxiliary metadata such as split name, number of cases, and parameter
       space used during predictive generation.


   .. py:attribute:: observed_values
      :type:  Mapping[str, numpy.ndarray]


   .. py:attribute:: predictive_mean
      :type:  Mapping[str, numpy.ndarray]


   .. py:attribute:: predictive_std
      :type:  Mapping[str, numpy.ndarray]


   .. py:attribute:: interval_lower
      :type:  Mapping[str, numpy.ndarray]


   .. py:attribute:: interval_upper
      :type:  Mapping[str, numpy.ndarray]


   .. py:attribute:: interval_coverage
      :type:  Mapping[str, float]


   .. py:attribute:: mean_absolute_error
      :type:  Mapping[str, float]


   .. py:attribute:: mean_absolute_error_sigma
      :type:  Mapping[str, float]


   .. py:attribute:: median_interval_width_over_uncertainty
      :type:  Mapping[str, float]


   .. py:attribute:: crps
      :type:  Mapping[str, float]


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: SimulationBasedCalibrationReport

   Rank-based SBC summary over a held-out set of observations.

   Attributes
   ----------
   ranks:
       Integer rank of the ground-truth parameter within posterior samples for
       each held-out case and parameter dimension.
   rank_histogram_counts:
       Per-parameter SBC histogram counts.
   posterior_means:
       Posterior mean for each held-out case.
   truths:
       Ground-truth parameters paired with each posterior sample set.
   coverages:
       Coverage summaries for the requested nominal interval levels.
   mean_rank:
       Mean empirical rank for each parameter dimension.
   normalized_mean_rank_error:
       Absolute mean-rank error normalized by the expected average rank.
   metadata:
       Auxiliary run metadata such as split name and number of posterior draws.


   .. py:attribute:: ranks
      :type:  numpy.ndarray


   .. py:attribute:: rank_histogram_counts
      :type:  numpy.ndarray


   .. py:attribute:: posterior_means
      :type:  numpy.ndarray


   .. py:attribute:: truths
      :type:  numpy.ndarray


   .. py:attribute:: coverages
      :type:  tuple[CoverageLevelReport, Ellipsis]


   .. py:attribute:: mean_rank
      :type:  numpy.ndarray


   .. py:attribute:: normalized_mean_rank_error
      :type:  numpy.ndarray


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:function:: generate_local_sensitivity_report(task: petitRADTRANS.sbi.task.SBITask, posterior_samples: Any, observation_blocks: Sequence[Any], posterior_log_probabilities: Any = None, parameter_space: str = 'physical', simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, quantile_levels: Sequence[float] = (0.1, 0.5, 0.9), finite_difference_relative_step: float = 0.001, finite_difference_std_fraction: float = 0.1, finite_difference_absolute_floor: float = 1e-05, max_step_reduction_attempts: int = 6, svd_relative_tolerance: float = 0.001, posterior_underexploited_ratio_threshold: float = 1.5, weak_sensitivity_fraction_threshold: float = 0.15, seed: int | None = None) -> LocalSensitivityReport

   Diagnose local physical identifiability around representative posterior points.

   The report evaluates a deterministic simulator Jacobian at representative
   posterior points, whitens it by observational uncertainty, and derives a
   Fisher-style local covariance approximation for each point.


.. py:function:: generate_posterior_predictive_report(task: petitRADTRANS.sbi.task.SBITask, posterior: Any, dataset_reader: petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader, split: petitRADTRANS.sbi.dataset.DatasetSplit = DatasetSplit.TEST, n_posterior_samples: int = 256, interval_level: float = 0.9, max_cases: int | None = None, seed: int | None = None, simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, n_predictive_forward_model_samples: int | None = None, checkpoint_directory: str | pathlib.Path | None = None, data_parallel: bool | None = None) -> PosteriorPredictiveReport

   Generate posterior-predictive summaries for a held-out dataset split.

   Parameters
   ----------
   task:
       SBI task defining parameter transforms and the simulator configuration.
   posterior:
       Trained posterior estimator used to sample held-out predictive draws.
   dataset_reader:
       Reader providing normalized observations and preprocessing metadata.
   split:
       Held-out split used for the predictive report.
   n_posterior_samples:
       Number of posterior draws generated per held-out observation.
   interval_level:
       Central predictive interval level reported for each dataset.
   max_cases:
       Optional cap on the number of held-out observations evaluated.
   seed:
       Optional base seed used to make predictive sampling reproducible.
   simulator:
       Optional runtime simulator override.
   n_predictive_forward_model_samples:
       Number of posterior draws passed through the forward model per
       held-out case.  When ``None`` all ``n_posterior_samples`` draws are
       forwarded.  Subsampling here is the primary lever for keeping the
       total number of petitRADTRANS calls to a manageable level when
       evaluating many held-out cases.
   checkpoint_directory:
       Optional directory for per-case checkpoints.  When provided, each
       completed case is written to disk as a compressed ``.npz`` file and
       skipped on resume.  This makes the expensive forward-model loop
       restartable after interruption.

   Returns
   -------
   PosteriorPredictiveReport
       Aggregate posterior-predictive summary over the requested split.

   Notes
   -----
   Observations are normalized internally for posterior encoding but compared
   on the original observation scale in the returned report.


.. py:function:: generate_sbc_report(posterior: Any, dataset_reader: petitRADTRANS.sbi.dataset.NormalizedObservationDatasetReader, split: petitRADTRANS.sbi.dataset.DatasetSplit = DatasetSplit.TEST, n_posterior_samples: int = 256, batch_size: int = 32, parameter_space: str | None = None, levels: Sequence[float] = (0.5, 0.8, 0.95), max_cases: int | None = None, seed: int | None = None, data_parallel: bool | None = None) -> SimulationBasedCalibrationReport

   Run SBC over a dataset reader using normalized observation batches.

   Parameters
   ----------
   posterior:
       Trained posterior estimator exposing ``encode_observation`` and
       ``sample_posterior``.
   dataset_reader:
       Reader yielding normalized held-out observations and matched parameter
       values.
   split:
       Dataset split used for the SBC evaluation.
   n_posterior_samples:
       Number of posterior draws generated per held-out case.
   batch_size:
       Reader batch size used during report generation.
   parameter_space:
       Optional parameter space override. Defaults to the posterior's own
       configured parameter space.
   levels:
       Coverage levels summarized in the returned report.
   max_cases:
       Optional cap on the number of held-out observations evaluated.
   seed:
       Optional base seed used to generate reproducible posterior draws.

   Returns
   -------
   SimulationBasedCalibrationReport
       SBC summary computed from the requested dataset split.


.. py:function:: local_sensitivity_report_to_payload(report: LocalSensitivityReport) -> dict[str, Any]

   Convert a local sensitivity report into a JSON-serializable payload.


.. py:exception:: TaskCompatibilityError

   Bases: :py:obj:`ValueError`


   Raised when an observation or artifact is incompatible with a task.


.. py:class:: HDF5SimulationDatasetStore(chunk_size: int = 256)

   Bases: :py:obj:`SimulationDatasetStore`


   HDF5-backed store for simulation corpora.

   Stores all simulation data for a corpus in a single ``.h5`` file,
   keeping file counts at 1 regardless of the number of simulations or splits.
   Requires ``h5py`` (already present in the ``jaxprt`` environment).


   .. py:attribute:: chunk_size
      :value: 256



   .. py:method:: _require_h5py() -> None
      :staticmethod:



   .. py:method:: create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') -> HDF5SimulationDatasetWriter

      Create a writer for a new simulation dataset.



   .. py:method:: open(manifest_or_uri: SimulationDatasetManifest | str) -> HDF5StoredSimulationDataset

      Open a stored dataset for training or evaluation.



.. py:class:: SimulationDatasetStore

   Bases: :py:obj:`abc.ABC`


   Backend-independent interface for reading and writing simulation data.


   .. py:method:: create_writer(manifest: SimulationDatasetManifest) -> SimulationDatasetWriter
      :abstractmethod:


      Create a writer for a new simulation dataset.



   .. py:method:: open(manifest_or_uri: SimulationDatasetManifest | str) -> Any
      :abstractmethod:


      Open a stored dataset for training or evaluation.



.. py:class:: StoredSimulationDataset

   Read-only handle for a stored chunked simulation dataset.


   .. py:attribute:: manifest
      :type:  SimulationDatasetManifest


   .. py:attribute:: storage_uri
      :type:  str


   .. py:attribute:: root
      :type:  Any


   .. py:method:: iter_split_chunks(split: DatasetSplit = DatasetSplit.TRAIN, chunk_size: int = 256, indices: Any = None, observation_fields: frozenset[str] | None = None) -> Iterator[dict[str, Any]]

      Yield row-slices of one split without loading it all into RAM.

      Parameters
      ----------
      split:
          Dataset split to iterate.
      chunk_size:
          Number of rows per yielded chunk.
      indices:
          Optional integer array of row indices to read.  When supplied only
          those rows are fetched in ascending order.  When ``None`` every row
          in the split is visited sequentially.
      observation_fields:
          Optional set of observation field names to load.  When ``None``
          every dataset field is loaded.  Pass
          ``_OBSERVATION_TRAINING_FIELDS`` to skip large unused fields such
          as ``covariance`` during training.

      Yields
      ------
      dict[str, Any]
          Same structure as :meth:`read_split` but covering only
          ``chunk_size`` rows per iteration.



   .. py:method:: read_split(split: DatasetSplit = DatasetSplit.TRAIN) -> dict[str, Any]


.. py:class:: ZarrSimulationDatasetStore(chunk_size: int = 256)

   Bases: :py:obj:`SimulationDatasetStore`


   Concrete Zarr-backed store for chunked simulation corpora.


   .. py:attribute:: chunk_size
      :value: 256



   .. py:method:: _require_zarr() -> None
      :staticmethod:



   .. py:method:: create_writer(manifest: SimulationDatasetManifest, mode: str = 'w') -> SimulationDatasetWriter

      Create a writer for a new simulation dataset.



   .. py:method:: open(manifest_or_uri: SimulationDatasetManifest | str) -> StoredSimulationDataset

      Open a stored dataset for training or evaluation.



.. py:function:: generate_simulation_dataset(task: petitRADTRANS.sbi.task.SBITask, storage_uri: str, n_simulations: int, chunk_size: int = 256, split_counts: Mapping[DatasetSplit | str, int] | None = None, split_fractions: Mapping[DatasetSplit | str, float] | None = None, split_policy: SplitSamplingPolicy | str = SplitSamplingPolicy.SEQUENTIAL, split_seed: int | None = None, simulator: Any = None, include_preprocessing_metadata: bool = True, dataset_version: str = '0.1.0', resume: bool = False, backend: str = 'hdf5', data_parallel: bool | None = None, store_covariance: bool = False) -> GeneratedSimulationDataset

   Generate and persist a simulation corpus in one call.

   Parameters
   ----------
   backend:
       Storage backend to use. ``'hdf5'`` (default) writes all data into a
       single ``.h5`` file, keeping the file count at 1 regardless of the
       number of simulations. ``'zarr'`` uses the Zarr directory store
       which produces one file per compressed chunk.
   data_parallel:
       When ``True`` (or ``None`` with multiple JAX devices), the vmapped RT
       kernel is distributed across devices using ``jax.pmap``.  The
       effective per-iteration batch size is automatically scaled by the
       device count so each device processes the configured
       ``simulation_config.batch_size`` samples.
   store_covariance:
       When ``True`` the full covariance matrix for each simulated spectrum is
       written to disk.  When ``False`` (default) only the covariance diagonal
       is stored under the ``covariance`` field to reduce storage pressure.


.. py:class:: DatasetSplit

   Bases: :py:obj:`str`, :py:obj:`enum.Enum`


   Named dataset partitions used during training and evaluation.


   .. py:attribute:: TRAIN
      :value: 'train'



   .. py:attribute:: VALIDATION
      :value: 'validation'



   .. py:attribute:: TEST
      :value: 'test'



   .. py:attribute:: BENCHMARK
      :value: 'benchmark'



.. py:class:: NormalizedObservationDatasetReader

   Lightweight reader yielding normalized ObservationBlock batches for training.


   .. py:attribute:: dataset
      :type:  StoredSimulationDataset | HDF5StoredSimulationDataset


   .. py:attribute:: preprocessing_metadata
      :type:  petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata


   .. py:attribute:: _split_cache
      :type:  dict


   .. py:method:: iter_batches(split: DatasetSplit = DatasetSplit.TRAIN, batch_size: int = 32, shuffle: bool = False, seed: int | None = None, parameter_space: str = 'physical', encoder: Any = None) -> Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]

      Yield mini-batches of normalized observations and matched parameters.

      Parameters
      ----------
      split:
          Dataset split to iterate over.
      batch_size:
          Number of samples yielded in each batch.
      shuffle:
          Whether to shuffle sample order within the requested split.
      seed:
          Optional random seed used when ``shuffle=True``.
      parameter_space:
          Parameter representation returned in each batch. Supported values
          are ``'physical'``, ``'cube'``, and ``'unconstrained'``.
      encoder:
          Optional encoder used to convert block lists into dense embedding
          arrays before batches are yielded.

      Returns
      -------
      Iterator[PosteriorBatch]
          Iterator over batches containing parameters, observations, and small
          metadata dictionaries describing the batch provenance.

      Notes
      -----
      Splits with fewer than ``_STREAMING_THRESHOLD`` rows are cached in RAM
      after the first read so that repeated epoch calls (e.g. validation) pay
      only slicing cost.  Larger splits are streamed directly from disk one
      batch at a time to avoid loading the full dataset into memory.



   .. py:method:: _iter_batches_cached(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) -> Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]

      Iterate using the full-split RAM cache (small splits).



   .. py:method:: _iter_batches_streaming(split: DatasetSplit, batch_size: int, shuffle: bool, seed: int | None, parameter_space: str, encoder: Any) -> Iterator[petitRADTRANS.sbi.posterior.PosteriorBatch]

      Iterate by streaming rows directly from disk (large splits).

      When ``shuffle=True`` a global index permutation is computed in RAM
      (one integer per simulation row — negligible memory) and used to fetch
      HDF5 rows in sorted sub-windows of ``batch_size``, satisfying h5py's
      monotonic-index requirement while still presenting shuffled order to the
      training loop.



   .. py:method:: _select_parameters(split_data: dict, parameter_space: str) -> Any
      :staticmethod:


      Return the parameter array for the requested coordinate space.



.. py:class:: HierarchicalObservationEncoder(embedding_dim: int = 128, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 128, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`, :py:obj:`petitRADTRANS.sbi.observation.ObservationEncoder`


   Learned hierarchical encoder dispatching over modalities.

   Parameters
   ----------
   embedding_dim:
       Size of the final joint observation embedding.
   spectrum_embedding_dim:
       Intermediate embedding size used by the spectral sub-encoder.
   photometry_embedding_dim:
       Intermediate embedding size used by the photometry sub-encoder.
   patch_size:
       Patch size used by the spectral encoder.
   hidden_dim:
       Hidden width shared across the component encoders and aggregator.
   key:
       Optional JAX random key used to initialize all submodules.

   Notes
   -----
   Spectrum and photometry blocks are handled by dedicated sub-encoders and
   then merged with a permutation-invariant aggregator. Unsupported modalities
   currently fall back to resized raw-value vectors.


   .. py:attribute:: spectrum_encoder
      :type:  SpectralPatchEncoder | SpectralConv1DEncoder


   .. py:attribute:: photometry_encoder
      :type:  PhotometryPointEncoder


   .. py:attribute:: aggregator
      :type:  DatasetSetAggregator


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:method:: _encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one observation block with modality-aware dispatch.

      Parameters
      ----------
      block:
          Observation block to encode.

      Returns
      -------
      jnp.ndarray
          Fixed-width embedding for the supplied block.



   .. py:method:: encode(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) -> petitRADTRANS.sbi.observation.EncodedObservation

      Encode one structured observation made of multiple blocks.

      Parameters
      ----------
      blocks:
          Observation blocks associated with one target system.

      Returns
      -------
      EncodedObservation
          Aggregated embedding and light metadata describing the block set.



   .. py:method:: encode_stacked_batch(blocks_batch: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) -> jax.numpy.ndarray

      Encode a batch of identically-structured observations using vmap.

      All observations must share the same block structure (same number of
      blocks, same array shapes per block).  This is always the case for SBI
      tasks where every observation is produced by the same forward model.

      Parameters
      ----------
      blocks_batch:
          Outer list indexes samples; inner list holds the per-block
          observation data for one sample.

      Returns
      -------
      jnp.ndarray
          Float32 array of shape ``(n_samples, embedding_dim)``.



   .. py:method:: encode_from_prestacked(obs: Any) -> jax.numpy.ndarray

      Encode a batch of observations from pre-stacked arrays.

      Accepts a :class:`~petitRADTRANS.sbi.observation.PreStackedObservations`
      instance whose array fields have already been extracted from
      ``ObservationBlock`` objects outside the JAX JIT boundary.  Only the
      vmapped XLA computation runs here, enabling the training step to be
      compiled once and reused across all batches.

      Parameters
      ----------
      obs:
          Pre-stacked observation container with ``stacked_blocks`` (one
          ``(values, uncertainties, coordinates, mask, log_scale,
          absolute_values)`` tuple per block, each of shape ``(batch_size,
          n_wl)`` except the scalar ``log_scale`` arrays) and ``modalities``
          (static tuple of modality value strings).

      Returns
      -------
      jnp.ndarray
          Float32 array of shape ``(batch_size, embedding_dim)``.



.. py:class:: PhotometryPointEncoder(embedding_dim: int = 64, hidden_dim: int = 96, key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`


   Learned photometry encoder using per-point MLP features and pooling.

   Parameters
   ----------
   embedding_dim:
       Size of the returned photometric embedding.
   hidden_dim:
       Hidden width of the per-point MLP.
   key:
       Optional JAX random key used for initialization.

   Notes
   -----
   Each photometric point is represented by value, uncertainty, coordinate,
   and an inferred width feature before permutation-invariant pooling.


   .. py:attribute:: point_mlp
      :type:  equinox.nn.MLP


   .. py:attribute:: output_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:method:: encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one photometric observation block.

      Parameters
      ----------
      block:
          Photometry-like observation block to encode.

      Returns
      -------
      jnp.ndarray
          Dense embedding for the supplied photometric block.



   .. py:method:: _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray) -> jax.numpy.ndarray

      Encode pre-processed block arrays (vmappable — no ObservationBlock input).

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays already produced by ``_as_vector``.
          ``uncertainties`` and ``coordinates`` may have length 0.
      mask:
          Boolean 1-D array already produced by ``_safe_mask``.

      Returns
      -------
      jnp.ndarray
          Dense photometric embedding of shape ``(embedding_dim,)``.



.. py:class:: SpectralConv1DEncoder(embedding_dim: int = 64, n_wavelengths: int = 233, hidden_channels: tuple[int, int] = (32, 64), key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`


   Learned spectral encoder using 1D convolutions to retain positional information.

   Parameters
   ----------
   embedding_dim:
       Size of the spectral embedding returned for one observation block.
   n_wavelengths:
       Fixed number of spectral points per observation block. Must match the
       observation schema of the SBI task.
   hidden_channels:
       Tuple of intermediate channel widths for conv layers.
   key:
       Optional JAX random key used for weight initialization.


   .. py:attribute:: conv1
      :type:  equinox.nn.Conv1d


   .. py:attribute:: conv2
      :type:  equinox.nn.Conv1d


   .. py:attribute:: conv3
      :type:  equinox.nn.Conv1d


   .. py:attribute:: amplitude_projection
      :type:  equinox.nn.MLP


   .. py:attribute:: pool_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: branch_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: output_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: scale_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:attribute:: n_wavelengths
      :type:  int


   .. py:method:: _forward_conv(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) -> jax.numpy.ndarray

      Run 1D conv stack on pre-processed arrays.

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays of length ``n_wavelengths``.
      mask:
          Boolean 1-D mask (True = invalid).

      Returns
      -------
      jnp.ndarray
          Embedding of shape ``(embedding_dim,)``.



   .. py:method:: _pad_to_fixed(arr: jax.numpy.ndarray) -> jax.numpy.ndarray

      Pad or truncate a 1-D array to ``n_wavelengths``.



   .. py:method:: encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one spectral observation block.

      Parameters
      ----------
      block:
          Spectrum-like observation block.

      Returns
      -------
      jnp.ndarray
          Dense embedding of shape ``(embedding_dim,)``.



   .. py:method:: _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) -> jax.numpy.ndarray

      Encode pre-processed block arrays (vmappable — no ObservationBlock input).

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays already produced by ``_as_vector``.
      mask:
          Boolean 1-D array already produced by ``_safe_mask``.
      log_median_flux:
          Scalar absolute-flux feature derived from the raw spectrum median.

      Returns
      -------
      jnp.ndarray
          Dense spectral embedding of shape ``(embedding_dim,)``.



.. py:class:: SpectralPatchEncoder(embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 96, key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`


   Learned spectral encoder based on patch summaries and MLP pooling.

   Parameters
   ----------
   embedding_dim:
       Size of the spectral embedding returned for one observation block.
   patch_size:
       Number of spectral points summarized together before MLP encoding.
   hidden_dim:
       Hidden width of the internal summary MLP.
   key:
       Optional JAX random key used for weight initialization.

   Notes
   -----
   The encoder summarizes values, uncertainties, and coordinates patch by
   patch instead of operating on a fully convolutional representation. This
   keeps the implementation lightweight and shape-agnostic.


   .. py:attribute:: patch_mlp
      :type:  equinox.nn.MLP


   .. py:attribute:: amplitude_projection
      :type:  equinox.nn.MLP


   .. py:attribute:: branch_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: output_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: scale_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:attribute:: patch_size
      :type:  int


   .. py:method:: encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one spectral observation block.

      Parameters
      ----------
      block:
          Spectrum-like observation block containing values and optional
          uncertainties, coordinates, and masks.

      Returns
      -------
      jnp.ndarray
          Dense fixed-width embedding for the supplied spectral block.



   .. py:method:: _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) -> jax.numpy.ndarray

      Encode pre-processed block arrays (vmappable — no ObservationBlock input).

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays already produced by ``_as_vector``.
          ``uncertainties`` and ``coordinates`` may have length 0 to indicate
          absence.
      mask:
          Boolean 1-D array already produced by ``_safe_mask``.

      Returns
      -------
      jnp.ndarray
          Dense spectral embedding of shape ``(embedding_dim,)``.



.. py:function:: load_posterior_estimator(input_directory: str) -> petitRADTRANS.sbi.posterior.PosteriorEstimator

   Load a saved posterior estimator without naming its concrete class.


.. py:class:: AmortizedRetrieval(task: petitRADTRANS.sbi.task.SBITask, posterior_estimator: petitRADTRANS.sbi.posterior.PosteriorEstimator, simulator: petitRADTRANS.sbi.simulator.RuntimeSimulator | None = None, preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | Mapping[str, Any] | None = None)

   Serve trained SBI models through a retrieval-like interface.

   Parameters
   ----------
   task:
       SBI task compatible with the trained posterior.
   posterior_estimator:
       Trained posterior estimator used for encoding and sampling.
   simulator:
       Optional simulator override used for posterior-predictive generation.
   preprocessing_metadata:
       Optional preprocessing metadata or payload. When omitted the inference
       service attempts to recover the preprocessing payload saved inside the
       posterior artifact.

   Notes
   -----
   Raw user-facing observation blocks are normalized internally before
   encoding when preprocessing metadata is available, but posterior-predictive
   comparisons remain on the original observation scale.


   .. py:attribute:: task


   .. py:attribute:: posterior_estimator


   .. py:attribute:: simulator
      :value: None



   .. py:attribute:: preprocessing_metadata


   .. py:method:: _resolve_preprocessing_metadata(preprocessing_metadata: petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | Mapping[str, Any] | None) -> petitRADTRANS.sbi.preprocessing.TaskPreprocessingMetadata | None


   .. py:method:: prepare_observation_blocks(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) -> list[petitRADTRANS.sbi.observation.ObservationBlock]

      Normalize user-facing observation blocks when preprocessing metadata is available.

      Parameters
      ----------
      observation_blocks:
          Raw observation blocks provided by the caller.

      Returns
      -------
      list[ObservationBlock]
          Either the original blocks or normalized copies, depending on
          whether preprocessing metadata is available.



   .. py:method:: infer(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], n_posterior_samples: int = 1000, include_posterior_predictive: bool = False, posterior_predictive_interval_level: float = 0.9, n_predictive_forward_model_samples: int | None = None, seed: int | None = None) -> AmortizedRetrievalResult

      Infer a posterior for one structured observation.

      Parameters
      ----------
      observation_blocks:
          Raw observation blocks for one target system.
      n_posterior_samples:
          Number of posterior draws to generate.
      include_posterior_predictive:
          Whether to also generate a posterior-predictive summary.
      posterior_predictive_interval_level:
          Predictive interval level used when generating the optional
          posterior-predictive report.
      n_predictive_forward_model_samples:
          Number of posterior draws passed through the forward model when
          generating the posterior-predictive summary.  When ``None`` all
          ``n_posterior_samples`` draws are forwarded.  Set this to a small
          value (e.g. 50–200) to avoid a multi-hour hang when
          ``n_posterior_samples`` is large.
      seed:
          Optional seed for posterior and posterior-predictive sampling.

      Returns
      -------
      AmortizedRetrievalResult
          Posterior samples and any optional predictive metadata.



   .. py:method:: posterior_predictive(posterior: petitRADTRANS.sbi.posterior.PosteriorSamples, observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], interval_level: float = 0.9, n_predictive_forward_model_samples: int | None = None, seed: int | None = None) -> petitRADTRANS.sbi.calibration.PosteriorPredictiveReport

      Generate posterior-predictive draws for a fitted observation.

      Parameters
      ----------
      posterior:
          Posterior samples associated with one observed system.
      observation_blocks:
          Original observation blocks used for user-facing comparison.
      interval_level:
          Central predictive interval level to report.
      n_predictive_forward_model_samples:
          Number of posterior draws passed through the forward model.  When
          ``None`` all samples in ``posterior`` are used.  Set this to a
          small value (e.g. 50–200) to keep the number of expensive
          petitRADTRANS calls manageable.
      seed:
          Optional seed for predictive simulation.

      Returns
      -------
      PosteriorPredictiveReport
          Predictive summary for the supplied posterior samples.



   .. py:method:: diagnose_domain(observation_blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) -> OODDiagnostic | None

      Estimate whether the observation is in-distribution for the model.

      Parameters
      ----------
      observation_blocks:
          Observation blocks to diagnose.

      Returns
      -------
      OODDiagnostic | None
          Robust support-distance diagnostic derived from preprocessing
          statistics, or ``None`` when preprocessing metadata is unavailable.



.. py:class:: AmortizedRetrievalResult

   Return type for amortized inference queries.

   Attributes
   ----------
   posterior:
       Posterior samples and any attached diagnostics.
   posterior_predictive:
       Optional posterior-predictive report for the same observation.
   ood_diagnostic:
       Optional in/out-of-distribution assessment.
   metadata:
       Additional metadata such as preprocessing usage during inference.


   .. py:attribute:: posterior
      :type:  petitRADTRANS.sbi.posterior.PosteriorSamples


   .. py:attribute:: posterior_predictive
      :type:  Any
      :value: None



   .. py:attribute:: ood_diagnostic
      :type:  OODDiagnostic | None
      :value: None



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: OODDiagnostic

   Describe whether an observation is inside the training support.

   Attributes
   ----------
   score:
       Scalar out-of-distribution score.
   threshold:
       Optional threshold used to convert the score into a pass/fail decision.
   passed:
       Optional boolean indicating whether the observation passed the OOD test.
   metadata:
       Auxiliary diagnostic metadata.


   .. py:attribute:: score
      :type:  float


   .. py:attribute:: threshold
      :type:  float | None
      :value: None



   .. py:attribute:: passed
      :type:  bool | None
      :value: None



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: ObservationBlock

   Represent one modality-specific observation block.

   Attributes:
       name:
           Stable identifier of the block within a task.
       modality:
           Semantic block type used to dispatch encoding logic.
       values:
           Observed values after any task-level preprocessing.
       uncertainties:
           Optional per-element uncertainty representation.
       coordinates:
           Optional coordinate arrays such as wavelengths or timestamps.
       mask:
           Optional mask applied to the values.
       metadata:
           Additional instrument and preprocessing metadata.


   .. py:attribute:: name
      :type:  str


   .. py:attribute:: modality
      :type:  ObservationModality


   .. py:attribute:: values
      :type:  Any


   .. py:attribute:: uncertainties
      :type:  Any
      :value: None



   .. py:attribute:: coordinates
      :type:  Any
      :value: None



   .. py:attribute:: mask
      :type:  Any
      :value: None



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: ObservationEncoder

   Bases: :py:obj:`abc.ABC`


   Transform structured observation blocks into model-ready embeddings.


   .. py:method:: encode(blocks: list[ObservationBlock]) -> EncodedObservation
      :abstractmethod:


      Encode a list of observation blocks into a shared representation.



   .. py:method:: batch_encode(observations: list[list[ObservationBlock]]) -> list[EncodedObservation]

      Encode multiple observations using repeated single-item encoding.



.. py:class:: ObservationModality

   Bases: :py:obj:`str`, :py:obj:`enum.Enum`


   Supported observation block types for SBI conditioning.


   .. py:attribute:: SPECTRUM
      :value: 'spectrum'



   .. py:attribute:: PHOTOMETRY
      :value: 'photometry'



   .. py:attribute:: TIME_SERIES
      :value: 'time_series'



   .. py:attribute:: AUXILIARY
      :value: 'auxiliary'



.. py:function:: build_observation_block(name: str, modality: ObservationModality | str, values: Any, uncertainties: Any = None, coordinates: Any = None, mask: Any = None, metadata: Mapping[str, Any] | None = None) -> ObservationBlock

   Build one observation block with modality normalization.


.. py:function:: build_observation_block_batch(observation_payloads: Mapping[str, Mapping[str, Any]], modalities: Mapping[str, str]) -> list[list[ObservationBlock]]

   Build observation blocks for each sample in a batched payload.


.. py:function:: build_observation_blocks_from_sample(observation_payloads: Mapping[str, Mapping[str, Any]], modalities: Mapping[str, str], sample_index: int) -> list[ObservationBlock]

   Build modality-aware observation blocks for one simulated sample.


.. py:class:: ConditionalAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)

   Bases: :py:obj:`ConditionalFlowPosterior`


   Posterior estimator specialized to the autoregressive flow backend.


.. py:class:: ConditionalFlowPosterior(parameter_dim: int, embedding_dim: int = 128, num_coupling_layers: int = 4, hidden_dim: int = 128, conditioner_depth: int = 2, autoregressive_transform_units: int = 16, neural_autoregressive_min_slope: float = 0.001, neural_autoregressive_min_residual: float = 0.05, neural_autoregressive_inverse_bisection_steps: int = 48, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', flow_family: str = 'spline', num_spline_bins: int = 8, spline_bound: float = 10.0, base_distribution: str = 'gaussian', use_base_affine: bool = False, training_objective: str = 'npe', early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, gradient_clip_norm: float | None = 1.0, embedding_noise_std: float = 0.0, embedding_noise_min_scale: float = 0.0, aux_scale_loss_weight: float = 0.0, aux_parameter_loss_weight: float = 0.0, parameter_noise_floor: float = 0.0, spline_small_bin_regularization_weight: float | None = None, spline_min_bin_ratio_target: float = 0.2, spline_entropy_regularization_weight: float = 0.0, spline_derivative_regularization_weight: float = 0.0, spline_entropy_floor: float = 0.6, spline_min_derivative_target: float = 0.25, spline_max_derivative_target: float = 5.0, weight_decay: float = 0.0, use_cosine_schedule: bool = False, warmup_fraction: float = 0.02, warmup_epochs: float | None = None, min_learning_rate: float = 1e-06, lr_schedule_total_epochs: float | None = None, stable_inverse_forward_max_abs_error_threshold: float | None = 0.0001, stable_inverse_forward_logdet_closure_max_abs_error_threshold: float | None = 0.0001, stable_cube_edge_hit_rate_threshold: float | None = 0.0, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, encoder_hidden_dim: int | None = None, encoder_patch_size: int = 32, seed: int = 0, task_metadata: Mapping[str, Any] | None = None, verbose_diagnostics: bool = False, diagnostics_output_directory: str | None = None, diagnostics_plot_interval: int = 1)

   Bases: :py:obj:`petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimator`


   Concrete amortized posterior using a conditional flow backend.

   Parameters
   ----------
   parameter_dim:
       Number of inferred free parameters represented by the posterior.
   embedding_dim:
       Size of the learned observation embedding consumed by the flow.
   num_coupling_layers:
       Number of conditional coupling or spline-transform layers in the flow.
   hidden_dim:
       Hidden width used by encoder-side and flow-side MLP conditioners.
   learning_rate:
       Optimizer learning rate used by :class:`SBITrainer`.
   batch_size:
       Number of simulations processed per optimization step.
   num_epochs:
       Maximum number of passes through the training split.
   parameter_space:
       Parameter coordinates learned by the posterior. Supported values are
       ``'physical'``, ``'cube'``, and ``'unconstrained'``.
   flow_family:
       Conditional density-transform family. Supported values are
       ``'spline'``, ``'affine'``, ``'autoregressive'``, and
       ``'neural_autoregressive'``.
   num_spline_bins:
       Number of rational-quadratic spline bins when ``flow_family='spline'``.
   spline_bound:
       Finite support bound of each spline transform in latent space.
   early_stopping_patience:
       Optional number of non-improving epochs tolerated before stopping.
   early_stopping_min_delta:
       Minimum improvement required for early-stopping comparisons.
   checkpoint_directory:
       Optional directory used to persist resumable trainer checkpoints.
   checkpoint_backend:
       Checkpoint persistence backend name. ``'auto'`` selects Orbax when
       available and otherwise falls back to Equinox serialization.
   resume_from_checkpoint:
       Whether ``fit`` should attempt to resume from the latest checkpoint.
   seed:
       Base random seed used for flow initialization and posterior sampling.
   task_metadata:
       Optional user-supplied metadata persisted alongside the trained model.

   Notes
   -----
   The posterior stores task fingerprinting, observation schema, and
   preprocessing payload information when those are available from the training
   dataset reader. That metadata is later reused by inference and artifact
   registration paths.


   .. py:attribute:: estimator_family
      :value: 'conditional_flow'



   .. py:attribute:: embedding_dim
      :value: 128



   .. py:attribute:: num_coupling_layers
      :value: 4



   .. py:attribute:: hidden_dim
      :value: 128



   .. py:attribute:: conditioner_depth
      :value: 2



   .. py:attribute:: autoregressive_transform_units
      :value: 16



   .. py:attribute:: neural_autoregressive_min_slope


   .. py:attribute:: neural_autoregressive_min_residual


   .. py:attribute:: neural_autoregressive_inverse_bisection_steps
      :value: 48



   .. py:attribute:: learning_rate


   .. py:attribute:: batch_size
      :value: 32



   .. py:attribute:: num_epochs
      :value: 5



   .. py:attribute:: flow_family
      :value: ''



   .. py:attribute:: effective_flow_family
      :value: ''



   .. py:attribute:: base_distribution
      :value: ''



   .. py:attribute:: use_base_affine
      :value: False



   .. py:attribute:: training_objective
      :value: ''



   .. py:attribute:: num_spline_bins
      :value: 8



   .. py:attribute:: early_stopping_patience
      :value: None



   .. py:attribute:: early_stopping_min_delta


   .. py:attribute:: checkpoint_directory
      :value: None



   .. py:attribute:: checkpoint_backend
      :value: 'auto'



   .. py:attribute:: resume_from_checkpoint
      :value: False



   .. py:attribute:: gradient_clip_norm
      :value: 1.0



   .. py:attribute:: embedding_noise_std


   .. py:attribute:: embedding_noise_min_scale


   .. py:attribute:: aux_scale_loss_weight


   .. py:attribute:: aux_parameter_loss_weight


   .. py:attribute:: parameter_noise_floor


   .. py:attribute:: spline_entropy_regularization_weight


   .. py:attribute:: spline_small_bin_regularization_weight


   .. py:attribute:: spline_derivative_regularization_weight


   .. py:attribute:: spline_min_bin_ratio_target


   .. py:attribute:: spline_entropy_floor


   .. py:attribute:: spline_min_derivative_target


   .. py:attribute:: spline_max_derivative_target


   .. py:attribute:: weight_decay


   .. py:attribute:: use_cosine_schedule
      :value: False



   .. py:attribute:: warmup_fraction


   .. py:attribute:: warmup_epochs
      :value: None



   .. py:attribute:: min_learning_rate


   .. py:attribute:: lr_schedule_total_epochs
      :value: None



   .. py:attribute:: stable_inverse_forward_max_abs_error_threshold
      :value: None



   .. py:attribute:: stable_inverse_forward_logdet_closure_max_abs_error_threshold
      :value: None



   .. py:attribute:: stable_cube_edge_hit_rate_threshold
      :value: None



   .. py:attribute:: spectrum_encoder_type
      :value: ''



   .. py:attribute:: n_wavelengths
      :value: 233



   .. py:attribute:: spectrum_embedding_dim
      :value: 64



   .. py:attribute:: photometry_embedding_dim
      :value: 64



   .. py:attribute:: encoder_hidden_dim


   .. py:attribute:: encoder_patch_size
      :value: 32



   .. py:attribute:: verbose_diagnostics
      :value: False



   .. py:attribute:: diagnostics_output_directory
      :value: None



   .. py:attribute:: diagnostics_plot_interval
      :value: 1



   .. py:attribute:: model


   .. py:method:: _build_flow(key: jax.Array) -> Any


   .. py:method:: _batch_embeddings(model: _PosteriorModel, observations: Any) -> jax.numpy.ndarray
      :staticmethod:



   .. py:method:: _loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) -> jax.numpy.ndarray
      :staticmethod:



   .. py:method:: _training_loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) -> jax.numpy.ndarray
      :staticmethod:



   .. py:method:: _loss_from_observations(model: _PosteriorModel, parameters: jax.numpy.ndarray, observations: Any, parameter_space: str, *, include_spline_regularization: bool) -> jax.numpy.ndarray
      :staticmethod:



   .. py:method:: _validation_diagnostics(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) -> dict[str, float]


   .. py:method:: _make_parameter_noise_loss(base_loss_fn: Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray], noise_floor: float, parameter_space: str) -> Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]
      :staticmethod:


      Wrap a training loss so the parameter *targets* are jittered.

      Minimizing the conditional NLL in a very-low-noise regime with a highly
      parameter-predictive embedding drives the learned conditional density
      toward a delta (NLL -> -inf), which an over-sharp flow realizes as
      exploded transforms (and, for spline flows, cube-edge collapse). A small
      Gaussian jitter on the parameter targets gives the conditional a hard
      minimum width, capping sharpness so the NLL has a finite minimum and the
      flow's inverse stays well-conditioned.

      The jitter is applied in *unconstrained* (logit) space -- the space the
      flow actually models -- not in cube space. Additive cube-space noise is
      asymmetric near the [0, 1] bounds: clipping piles jittered targets onto
      the edges, and the cube->logit Jacobian rewards edge mass, so it
      worsens the very edge-collapse it was meant to prevent. Jittering in
      logit space is symmetric and boundary-free. Applied only during
      training; evaluation and checkpoint selection use the clean targets.



   .. py:method:: _make_elbo_loss(log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray], parameter_space: str, *, num_samples: int = 1) -> Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]
      :staticmethod:


      Return an amortized-variational (ELBO) training loss.

      Maximizes, for observations ``x`` drawn from the (prior-predictive)
      dataset,

          ELBO(x) = E_{q(theta|x)}[ log p(x|theta) + log p(theta) - log q(theta|x) ]

      with a single- (or few-) sample reparameterized estimator. ``q(theta|x)``
      is the conditional flow: latents ``z`` are drawn from the flow base and
      pushed through ``flow.inverse`` to reparameterized samples, so the
      gradient flows through both the flow and the encoder.

      The ``-log q`` (entropy) term diverges to ``-inf`` as ``q`` collapses to
      a point mass, so a delta posterior is *structurally* impossible -- the
      width is set by the likelihood/entropy balance rather than by the
      (near-deterministic) embedding. ``log p(x|theta)`` is supplied by the
      injected differentiable forward-model likelihood evaluated at the
      physical parameters reconstructed from the sampled cube coordinates.
      With ``parameter_space='cube'`` the prior is uniform on the unit
      hypercube, so ``log p(theta_cube) = 0`` and is dropped.

      Parameters
      ----------
      log_likelihood_fn:
          ``(theta_cube, observations) -> (batch,)`` returning the Gaussian
          observational log-likelihood ``log p(x | theta)`` for each batch
          element. Receives unit-cube coordinates (shape ``(batch, dim)``)
          and the batch's prestacked observations; it owns the cube->physical
          transform, the differentiable forward model, and the noise model.
          Must be JAX-differentiable w.r.t. ``theta_cube``.
      parameter_space:
          Must be ``"cube"``.
      num_samples:
          Number of reparameterized posterior draws per observation used to
          estimate the ELBO expectation. ``1`` is standard for amortized VI;
          larger values reduce gradient variance at a proportional
          forward-model cost.



   .. py:method:: _make_noisy_loss(noise_std: float, preprocessing_metadata: Any = None, *, noise_min_scale: float = 0.0, include_spline_regularization: bool = True) -> Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]
      :staticmethod:


      Return a loss function that injects on-the-fly Gaussian noise into
      the observation spectra before encoding.

      For spectra preprocessed into a per-sample transformed value space,
      the uncertainty channel is assumed to already be expressed in the same
      transformed units as the value channel, so noise can be injected
      directly from that channel. When preprocessing metadata is provided,
      non-spectral blocks still reconstruct raw measurement uncertainties so
      noise is injected at the correct physical scale in normalized-value
      space. ``noise_std`` acts as a multiplier on the observational noise
      level (1.0 = exact observational noise, >1 = regularising over-noise),
      and ``noise_min_scale`` is treated as a minimum fraction of the
      typical transformed uncertainty for the block rather than as an
      absolute floor in normalized-value units.

      When preprocessing metadata is unavailable, a flat Gaussian noise with
      scale ``noise_std`` is applied as a fallback.



   .. py:method:: fit(dataset: Any, *, elbo_log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray] | None = None, elbo_num_samples: int = 1) -> petitRADTRANS.sbi.posterior_base.TrainingArtifacts

      Train the posterior on one normalized simulation dataset reader.

      Parameters
      ----------
      dataset:
          Reader-like object that yields :class:`PosteriorBatch` instances via
          ``iter_batches`` and exposes dataset manifest metadata when
          available.
      elbo_log_likelihood_fn:
          Required when ``training_objective='elbo'``. A differentiable
          ``(theta_cube, observations) -> (batch,)`` callable returning the
          Gaussian observational log-likelihood through the forward model. See
          :meth:`_make_elbo_loss`. Ignored for the default NPE objective.
      elbo_num_samples:
          Number of reparameterized posterior draws per observation for the
          ELBO estimator (default 1).

      Returns
      -------
      TrainingArtifacts
          Training history, validation metrics, and trainer metadata for the
          completed optimization run.

      Notes
      -----
      If the reader exposes preprocessing metadata or a manifest fingerprint,
      those are cached on the posterior so they can be saved and reused by the
      inference and artifact layers.



   .. py:method:: _build_estimator_config() -> dict[str, Any]

      Return backend-specific configuration for metadata persistence.



   .. py:method:: _build_serialized_metadata(artifact_metadata) -> dict[str, Any]


   .. py:method:: _resolve_estimator_config(metadata: Mapping[str, Any]) -> dict[str, Any]
      :staticmethod:



   .. py:method:: hydrate_loaded_metadata(metadata: Mapping[str, Any]) -> None


   .. py:method:: from_serialized_metadata(metadata: Mapping[str, Any]) -> ConditionalFlowPosterior
      :classmethod:


      Rebuild an estimator instance from persisted metadata only.



   .. py:method:: save_backend_state(output_path: pathlib.Path) -> None

      Persist backend-specific model state into the output directory.



   .. py:method:: load_backend_state(input_path: pathlib.Path) -> None

      Restore backend-specific model state from the input directory.



   .. py:method:: encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) -> petitRADTRANS.sbi.observation.EncodedObservation

      Encode one structured observation into the posterior context space.

      Parameters
      ----------
      blocks:
          Observation blocks describing one spectral/photometric observation.

      Returns
      -------
      EncodedObservation
          Aggregated embedding and lightweight metadata describing the input
          observation family.



   .. py:method:: batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) -> list[petitRADTRANS.sbi.observation.EncodedObservation]

      Encode a batch of observations using a single vmapped forward pass.

      Parameters
      ----------
      blocks_list:
          List of per-sample observation block lists.

      Returns
      -------
      list[EncodedObservation]
          One encoded observation per input sample.



   .. py:method:: sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) -> petitRADTRANS.sbi.posterior_base.PosteriorSamples

      Draw posterior samples conditioned on one encoded observation.

      Parameters
      ----------
      observation:
          Encoded observation produced by :meth:`encode_observation`.
      n_samples:
          Number of posterior draws to generate.
      seed:
          Optional random seed overriding the model-level default seed.

      Returns
      -------
      PosteriorSamples
          Samples in the posterior's configured parameter space. Non-finite
          outputs are clamped to large finite values for downstream stability.



   .. py:method:: batch_sample_posterior(embeddings: Any, n_samples: int, base_seed: int = 0) -> numpy.ndarray

      Draw posterior samples for a batch of encoded observations.

      Runs a single JIT-compiled vmapped call over all contexts rather than
      looping ``sample_posterior`` once per observation, eliminating the
      Python overhead of per-observation dispatch.

      Parameters
      ----------
      embeddings:
          Float32 array of shape ``(batch_size, embedding_dim)`` produced by
          stacking :attr:`EncodedObservation.embedding` vectors.
      n_samples:
          Number of posterior draws per observation.
      base_seed:
          Base random seed. Each observation in the batch receives a unique
          sub-key derived from this seed.

      Returns
      -------
      np.ndarray
          Array of shape ``(batch_size, n_samples, parameter_dim)`` with
          non-finite values clamped to large finite values.



   .. py:method:: log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) -> Any

      Evaluate posterior log-density for one or many parameter vectors.

      Parameters
      ----------
      observation:
          Encoded observation that defines the posterior context.
      parameters:
          One parameter vector or a batch of parameter vectors in the
          posterior's configured parameter space.

      Returns
      -------
      Any
          Scalar log-density or a vector of log-densities matching the input
          batch structure.



.. py:class:: ConditionalNeuralAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)

   Bases: :py:obj:`ConditionalFlowPosterior`


   Posterior estimator specialized to the neural autoregressive backend.


.. py:class:: ConditionalSplineFlowPosterior(*args: Any, **kwargs: Any)

   Bases: :py:obj:`ConditionalFlowPosterior`


   Posterior estimator specialized to the spline flow backend.

   Notes
   -----
   This convenience subclass forces ``flow_family='spline'`` while keeping the
   rest of the :class:`ConditionalFlowPosterior` API unchanged.


.. py:class:: FlowMatchingPosterior(parameter_dim: int, embedding_dim: int = 128, hidden_dim: int = 128, num_velocity_layers: int = 3, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', integration_steps: int = 32, early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, seed: int = 0, task_metadata: Mapping[str, Any] | None = None)

   Bases: :py:obj:`petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimator`


   Conditional flow-matching posterior skeleton.

   .. warning::

       This estimator is **experimental**. It does not expose ``log_prob``
       and the ODE integration scheme is a simple midpoint rule.  Expect
       the API and numerical behaviour to change in future releases.

   This estimator family trains a conditional vector field on straight-line
   interpolation paths between Gaussian noise and target parameters, then
   generates posterior samples by integrating the learned field from noise to
   the terminal parameter state.


   .. py:attribute:: estimator_family
      :value: 'flow_matching'



   .. py:attribute:: embedding_dim
      :value: 128



   .. py:attribute:: hidden_dim
      :value: 128



   .. py:attribute:: num_velocity_layers
      :value: 3



   .. py:attribute:: learning_rate


   .. py:attribute:: batch_size
      :value: 32



   .. py:attribute:: num_epochs
      :value: 5



   .. py:attribute:: integration_steps
      :value: 32



   .. py:attribute:: early_stopping_patience
      :value: None



   .. py:attribute:: early_stopping_min_delta


   .. py:attribute:: checkpoint_directory
      :value: None



   .. py:attribute:: checkpoint_backend
      :value: 'auto'



   .. py:attribute:: resume_from_checkpoint
      :value: False



   .. py:attribute:: model


   .. py:method:: _batch_embeddings(model: _FlowMatchingModel, observations: Any) -> jax.numpy.ndarray
      :staticmethod:



   .. py:method:: _loss(model: _FlowMatchingModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) -> jax.numpy.ndarray
      :staticmethod:



   .. py:method:: fit(dataset: Any) -> petitRADTRANS.sbi.posterior_base.TrainingArtifacts

      Train the posterior estimator on a simulation dataset.



   .. py:method:: _build_estimator_config() -> dict[str, Any]

      Return backend-specific configuration for metadata persistence.



   .. py:method:: _resolve_estimator_config(metadata: Mapping[str, Any]) -> dict[str, Any]
      :staticmethod:



   .. py:method:: _build_serialized_metadata(artifact_metadata) -> dict[str, Any]


   .. py:method:: from_serialized_metadata(metadata: Mapping[str, Any]) -> FlowMatchingPosterior
      :classmethod:


      Rebuild an estimator instance from persisted metadata only.



   .. py:method:: save_backend_state(output_path: pathlib.Path) -> None

      Persist backend-specific model state into the output directory.



   .. py:method:: load_backend_state(input_path: pathlib.Path) -> None

      Restore backend-specific model state from the input directory.



   .. py:method:: encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) -> petitRADTRANS.sbi.observation.EncodedObservation

      Encode a structured observation into the estimator input space.



   .. py:method:: batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) -> list[petitRADTRANS.sbi.observation.EncodedObservation]

      Encode a batch of observations using a single vmapped forward pass.



   .. py:method:: sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) -> petitRADTRANS.sbi.posterior_base.PosteriorSamples

      Sample the amortized posterior for one encoded observation.



   .. py:method:: log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) -> Any

      Evaluate posterior log-density when supported by the backend.



.. py:class:: PersistentPosteriorEstimator(parameter_dim: int, parameter_space: str = 'unconstrained', seed: int = 0, task_metadata: Mapping[str, Any] | None = None)

   Bases: :py:obj:`PosteriorEstimator`


   Shared persistence helper for estimator backends with on-disk artifacts.


   .. py:attribute:: estimator_family
      :value: 'persistent_estimator'



   .. py:attribute:: metadata_schema_version
      :value: '0.2.0'



   .. py:attribute:: parameter_dim


   .. py:attribute:: parameter_space
      :value: 'unconstrained'



   .. py:attribute:: seed
      :value: 0



   .. py:attribute:: task_metadata


   .. py:attribute:: training_artifacts
      :type:  TrainingArtifacts | None
      :value: None



   .. py:attribute:: task_name
      :type:  str | None


   .. py:attribute:: task_version
      :type:  str | None
      :value: None



   .. py:attribute:: task_fingerprint
      :type:  str | None
      :value: None



   .. py:attribute:: observation_schema
      :type:  Mapping[str, Any]


   .. py:attribute:: preprocessing_metadata_payload
      :type:  Mapping[str, Any]


   .. py:attribute:: artifact_metadata
      :type:  petitRADTRANS.sbi.artifacts.ArtifactMetadata | None
      :value: None



   .. py:method:: _build_estimator_config() -> dict[str, Any]
      :abstractmethod:


      Return backend-specific configuration for metadata persistence.



   .. py:method:: from_serialized_metadata(metadata: Mapping[str, Any]) -> PersistentPosteriorEstimator
      :classmethod:

      :abstractmethod:


      Rebuild an estimator instance from persisted metadata only.



   .. py:method:: save_backend_state(output_path: pathlib.Path) -> None
      :abstractmethod:


      Persist backend-specific model state into the output directory.



   .. py:method:: load_backend_state(input_path: pathlib.Path) -> None
      :abstractmethod:


      Restore backend-specific model state from the input directory.



   .. py:method:: _load_training_artifacts(metadata: Mapping[str, Any]) -> TrainingArtifacts | None
      :staticmethod:



   .. py:method:: hydrate_loaded_metadata(metadata: Mapping[str, Any]) -> None


   .. py:method:: _build_artifact_metadata_payload() -> dict[str, Any]


   .. py:method:: build_artifact_metadata(version: str) -> petitRADTRANS.sbi.artifacts.ArtifactMetadata

      Assemble registry metadata for the currently trained estimator.



   .. py:method:: _build_serialized_metadata(artifact_metadata: petitRADTRANS.sbi.artifacts.ArtifactMetadata) -> dict[str, Any]


   .. py:method:: save(output_directory: str, artifact_registry: petitRADTRANS.sbi.artifacts.ArtifactRegistry | None = None, artifact_version: str = '0.1.0') -> None

      Persist model weights, metadata, and optional artifact registration.



   .. py:method:: load(input_directory: str) -> PersistentPosteriorEstimator
      :classmethod:


      Restore a saved persistent estimator from disk.



.. py:class:: PosteriorBatch

   Training batch passed to amortized posterior estimators.


   .. py:attribute:: parameters
      :type:  Any


   .. py:attribute:: observations
      :type:  Any


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: PosteriorEstimator

   Bases: :py:obj:`abc.ABC`


   Backend-agnostic interface for amortized posterior models.


   .. py:method:: fit(dataset: Any) -> TrainingArtifacts
      :abstractmethod:


      Train the posterior estimator on a simulation dataset.



   .. py:method:: encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) -> petitRADTRANS.sbi.observation.EncodedObservation
      :abstractmethod:


      Encode a structured observation into the estimator input space.



   .. py:method:: sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) -> PosteriorSamples
      :abstractmethod:


      Sample the amortized posterior for one encoded observation.



   .. py:method:: log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) -> Any
      :abstractmethod:


      Evaluate posterior log-density when supported by the backend.



   .. py:method:: save(output_directory: str) -> None
      :abstractmethod:


      Persist trained model weights and metadata.



   .. py:method:: load(input_directory: str) -> PosteriorEstimator
      :classmethod:

      :abstractmethod:


      Restore a saved estimator from disk.



.. py:class:: PosteriorSamples

   Posterior samples and optional per-sample diagnostics.


   .. py:attribute:: samples
      :type:  Any


   .. py:attribute:: log_probabilities
      :type:  Any
      :value: None



   .. py:attribute:: weights
      :type:  Any
      :value: None



   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:function:: plot_local_sensitivity_fisher_correlations(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) -> tuple[Any, numpy.ndarray]

   Plot Fisher-correlation heatmaps for each representative point.


.. py:function:: plot_local_sensitivity_jacobians(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) -> tuple[Any, numpy.ndarray]

   Plot whitened Jacobian heatmaps for each representative point.


.. py:function:: plot_local_sensitivity_singular_values(report: petitRADTRANS.sbi.calibration.LocalSensitivityReport, figsize: tuple[float, float] | None = None) -> tuple[Any, numpy.ndarray]

   Plot singular spectra of the whitened Jacobian for each point.


.. py:function:: plot_posterior_corner(samples: Any, parameter_names: Sequence[str] | None = None, bins: int = 40, figsize: tuple[float, float] | None = None, max_points: int = 8192) -> tuple[Any, numpy.ndarray]

   Plot a lower-triangular corner view of posterior structure.

   Parameters
   ----------
   samples:
       Posterior samples with shape ``(n_samples, n_dim)`` or a scalar-vector
       equivalent.
   parameter_names:
       Optional display names for each posterior dimension.
   bins:
       Number of bins used for one- and two-dimensional histograms.
   figsize:
       Optional Matplotlib figure size.
   max_points:
       Maximum number of posterior draws plotted. Larger sample sets are
       evenly subsampled to keep rendering costs bounded.

   Returns
   -------
   tuple[Any, np.ndarray]
       Matplotlib figure and axes array for further customization or saving.


.. py:function:: plot_posterior_marginals(samples: Any, parameter_names: Sequence[str] | None = None, bins: int = 40, figsize: tuple[float, float] | None = None) -> tuple[Any, numpy.ndarray]

   Plot one histogram per posterior dimension.

   Parameters
   ----------
   samples:
       Posterior samples with shape ``(n_samples, n_dim)`` or a scalar-vector
       equivalent.
   parameter_names:
       Optional display names for each posterior dimension.
   bins:
       Number of histogram bins per dimension.
   figsize:
       Optional Matplotlib figure size.

   Returns
   -------
   tuple[Any, np.ndarray]
       Matplotlib figure and axes array for further customization or saving.


.. py:function:: plot_posterior_predictive_report(report: petitRADTRANS.sbi.calibration.PosteriorPredictiveReport, dataset_names: Sequence[str] | None = None, figsize: tuple[float, float] | None = None) -> tuple[Any, numpy.ndarray]

   Plot observed values against posterior-predictive means and intervals.

   Parameters
   ----------
   report:
       Posterior-predictive report to visualize.
   dataset_names:
       Optional subset of dataset names to plot. Defaults to all datasets in
       the report.
   figsize:
       Optional Matplotlib figure size.

   Returns
   -------
   tuple[Any, np.ndarray]
       Matplotlib figure and axes array.

   Notes
   -----
   For batched reports the helper currently visualizes only the first case for
   each dataset, which keeps the plot compact for quick inspection.


.. py:function:: plot_sbc_rank_histograms(report: petitRADTRANS.sbi.calibration.SimulationBasedCalibrationReport, parameter_names: Sequence[str] | None = None, figsize: tuple[float, float] | None = None) -> tuple[Any, numpy.ndarray]

   Plot one SBC rank histogram per inferred parameter.

   Parameters
   ----------
   report:
       SBC report containing per-parameter rank histogram counts.
   parameter_names:
       Optional display names for each inferred parameter.
   figsize:
       Optional Matplotlib figure size.

   Returns
   -------
   tuple[Any, np.ndarray]
       Matplotlib figure and axes array.


.. py:class:: TaskPreprocessingMetadata

   Serializable preprocessing metadata for an SBI task family.


   .. py:attribute:: version
      :type:  str


   .. py:attribute:: blocks
      :type:  Mapping[str, BlockNormalizationStats]


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


   .. py:method:: to_payload() -> dict[str, Any]


   .. py:method:: from_payload(payload: Mapping[str, Any]) -> TaskPreprocessingMetadata
      :classmethod:



.. py:function:: fit_task_preprocessing(training_observations: list[list[petitRADTRANS.sbi.observation.ObservationBlock]], version: str = '0.5.0', metadata: Mapping[str, Any] | None = None) -> TaskPreprocessingMetadata

   Fit preprocessing statistics from training observation blocks.


.. py:function:: normalize_observation_block(block: petitRADTRANS.sbi.observation.ObservationBlock, preprocessing_metadata: TaskPreprocessingMetadata) -> petitRADTRANS.sbi.observation.ObservationBlock

   Normalize one observation block with fitted preprocessing statistics.

   Uses robust (median/IQR) normalization for values and uncertainties when
   available, falling back to mean/std for backward compatibility with older
   preprocessing metadata.  Coordinates always use mean/std normalization.


.. py:function:: normalize_observation_blocks(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock], preprocessing_metadata: TaskPreprocessingMetadata) -> list[petitRADTRANS.sbi.observation.ObservationBlock]

   Normalize a list of observation blocks.


.. py:class:: ProposalSampler

   Bases: :py:obj:`abc.ABC`


   Interface for simulation proposals beyond the prior distribution.


   .. py:method:: sample(n_samples: int, task: petitRADTRANS.sbi.task.SBITask) -> Any
      :abstractmethod:


      Draw free-parameter vectors for the given task.



.. py:class:: RuntimeSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None, seed: int | None = None, data_parallel: bool | None = None)

   Bases: :py:obj:`BatchedSimulator`


   Concrete simulator backed by the retrieval runtime.

   The simulator samples from the retrieval prior using JAX random utilities
   and projects deterministic forward-model outputs into the observation space
   using ``RetrievalRuntime``.


   .. py:attribute:: _rng_key


   .. py:method:: _validate_task_support() -> None


   .. py:method:: _next_key() -> jax.Array


   .. py:method:: _row_invalid_value_mask(values: Any, mask: Any, constraint: Any) -> numpy.ndarray
      :staticmethod:



   .. py:method:: _invalid_sample_mask(observations: Mapping[str, Mapping[str, Any]], n_samples: int) -> tuple[numpy.ndarray, dict[str, int]]


   .. py:method:: _slice_batch_rows(batch: SimulationBatch, row_indices: numpy.ndarray) -> SimulationBatch
      :staticmethod:



   .. py:method:: _concatenate_simulation_batches(batches: list[SimulationBatch], diagnostics: Mapping[str, Any] | None = None) -> SimulationBatch
      :staticmethod:



   .. py:method:: _sample_prior_parameter_batch(n_samples: int) -> dict[str, jax.numpy.ndarray]


   .. py:method:: sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) -> Any

      Sample free-parameter vectors from the task prior or a proposal.



   .. py:method:: _advance_rng_for_batch(n_samples: int) -> None

      Advance the PRNG state as if ``simulate_batch(n_samples)`` were called.

      This performs only cheap key-splitting – no forward-model evaluation –
      making it suitable for fast dataset-resume skipping.

      The number of ``_next_key`` calls matches the pattern inside
      ``simulate_batch`` → ``_sample_prior_parameter_batch`` (1 key) and
      ``_simulate_parameter_matrix`` (1 key per observation that adds noise).



   .. py:method:: _apply_noise(values: Any, uncertainties: Any, covariance: Any) -> tuple[Any, Any, Any]


   .. py:method:: _apply_noise_batched(values_batch: Any, uncertainties_batch: Any, covariance_batch: Any) -> tuple[Any, Any, Any]

      Apply noise to a batch of simulation outputs.

      Args:
          values_batch: Shape ``(n_samples, n_wavelengths)``.
          uncertainties_batch: Shape ``(n_samples, n_wavelengths)`` or ``None``.
          covariance_batch: Shape ``(n_samples, n_wl, n_wl)`` or ``None``.

      Returns:
          Noisy values, uncertainties, and covariance with the same leading
          batch dimension as ``values_batch``.



   .. py:method:: _stack_observation_payloads(payloads: list[dict[str, Any]]) -> dict[str, Any]
      :staticmethod:



   .. py:method:: _simulate_parameter_matrix(parameter_matrix: Any, cube_parameters: Any = None, unconstrained_parameters: Any = None, data_parallel: bool = False) -> SimulationBatch


   .. py:method:: simulate_from_parameters(parameters: Any) -> SimulationBatch

      Run the forward model and noise pipeline for pre-specified parameters.



   .. py:method:: simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) -> SimulationBatch

      Sample parameters and preserve prior-space coordinates when available.



.. py:class:: SimulationBatch

   Container for one simulated batch.

   Attributes:
       parameters:
           Array-like free-parameter matrix with leading dimension equal to the
           number of simulated samples.
       observations:
           Task-conditioned simulated observations.
       log_likelihood:
           Optional scalar likelihood values associated with each sample.
       diagnostics:
           Additional runtime diagnostics such as clipping flags or NaN counts.


   .. py:attribute:: parameters
      :type:  Any


   .. py:attribute:: observations
      :type:  Any


   .. py:attribute:: cube_parameters
      :type:  Any
      :value: None



   .. py:attribute:: unconstrained_parameters
      :type:  Any
      :value: None



   .. py:attribute:: log_likelihood
      :type:  Any
      :value: None



   .. py:attribute:: diagnostics
      :type:  Mapping[str, Any]


   .. py:property:: n_samples
      :type: int


      Return the number of simulated samples represented by the batch.



.. py:class:: Simulator(task: petitRADTRANS.sbi.task.SBITask)

   Bases: :py:obj:`abc.ABC`


   Base simulator for SBI dataset generation and validation.


   .. py:attribute:: task


   .. py:method:: sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) -> Any
      :abstractmethod:


      Sample free-parameter vectors from the task prior or a proposal.



   .. py:method:: simulate_from_parameters(parameters: Any) -> SimulationBatch
      :abstractmethod:


      Run the forward model and noise pipeline for pre-specified parameters.



   .. py:method:: simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) -> SimulationBatch

      Sample parameters and simulate one batch in a single call.



.. py:class:: NoiseModelConfig

   Describe how observational noise is injected during simulation.

   Attributes:
       mode:
           Short identifier of the noise model implementation.
       parameters:
           Backend-specific parameters used to instantiate the noise model.
       seed:
           Optional deterministic seed for repeatable simulation pipelines.


   .. py:attribute:: mode
      :type:  str
      :value: 'observational'



   .. py:attribute:: parameters
      :type:  Mapping[str, Any]


   .. py:attribute:: seed
      :type:  int | None
      :value: None



.. py:class:: ObservationSchema

   Capture the supported observation family for an amortized task.

   Attributes:
       dataset_names:
           Dataset identifiers included in the task.
       modalities:
           Mapping from dataset name to a short modality label such as
           ``'spectrum'``, ``'photometry'``, or ``'time_series'``.
       metadata:
           Additional task-level information required by encoders or
           benchmarks, such as instrument names or wavelength coverage.


   .. py:attribute:: dataset_names
      :type:  tuple[str, Ellipsis]


   .. py:attribute:: modalities
      :type:  Mapping[str, str]


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


.. py:class:: ObservationValueConstraint

   Admissible range for simulated observation values.

   Attributes:
       min_value:
           Optional lower bound for valid simulated values.
       max_value:
           Optional upper bound for valid simulated values.
       min_inclusive:
           Whether ``min_value`` is included in the valid interval.
       max_inclusive:
           Whether ``max_value`` is included in the valid interval.


   .. py:attribute:: min_value
      :type:  float | None
      :value: None



   .. py:attribute:: max_value
      :type:  float | None
      :value: None



   .. py:attribute:: min_inclusive
      :type:  bool
      :value: True



   .. py:attribute:: max_inclusive
      :type:  bool
      :value: True



   .. py:method:: to_payload() -> dict[str, Any]


.. py:class:: SBITask

   Bundle the immutable ingredients required for an SBI problem.

   The task owns the retrieval configuration, canonical parameter layout,
   observation schema, and simulation policy needed to generate training data
   and evaluate amortized posteriors.


   .. py:attribute:: name
      :type:  str


   .. py:attribute:: retrieval_config
      :type:  petitRADTRANS.retrieval.retrieval_config.RetrievalConfig


   .. py:attribute:: parameter_layout
      :type:  petitRADTRANS.retrieval.runtime.ParameterLayout


   .. py:attribute:: observation_schema
      :type:  ObservationSchema


   .. py:attribute:: simulation_config
      :type:  SimulationConfig


   .. py:attribute:: task_version
      :type:  str
      :value: '0.1.0'



   .. py:attribute:: preprocessing_version
      :type:  str
      :value: '0.5.0'



   .. py:attribute:: forward_model_family
      :type:  str
      :value: 'unknown'



   .. py:attribute:: petitradtrans_version
      :type:  str


   .. py:attribute:: metadata
      :type:  Mapping[str, Any]


   .. py:method:: from_retrieval_config(retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig, simulation_config: SimulationConfig | None = None, metadata: Mapping[str, Any] | None = None) -> SBITask
      :classmethod:


      Create an SBI task from an existing retrieval configuration.

      Parameters
      ----------
      retrieval_config:
          Retrieval configuration describing the free parameters, observation
          datasets, and runtime-native forward model family.
      simulation_config:
          Optional simulation policy overriding the default prior-predictive
          batch configuration.
      metadata:
          Optional task-level metadata. Reserved keys such as
          ``task_version``, ``preprocessing_version``, and
          ``forward_model_family`` override the default inferred values.

      Returns
      -------
      SBITask
          Immutable SBI task description with parameter layout, observation
          schema, and deterministic task fingerprint payload.

      Notes
      -----
      The task inspects the configured retrieval datasets to infer modalities
      and a forward-model-family string that later participates in artifact
      compatibility checks.



   .. py:property:: parameter_names
      :type: tuple[str, Ellipsis]


      Return the free parameter names in canonical task order.



   .. py:property:: line_opacity_modes
      :type: Mapping[str, str]


      Return the line-opacity mode for each configured dataset.



   .. py:property:: fingerprint_payload
      :type: Mapping[str, Any]


      Return the deterministic payload used to fingerprint the task family.



   .. py:property:: task_fingerprint
      :type: petitRADTRANS.sbi.compatibility.TaskFingerprint


      Return the deterministic fingerprint of the task family.



   .. py:method:: validate_observation_schema(observation_schema: ObservationSchema) -> petitRADTRANS.sbi.compatibility.CompatibilityReport

      Compare an external observation schema with the task expectation.

      Parameters
      ----------
      observation_schema:
          Candidate schema to compare against the task's required datasets and
          modalities.

      Returns
      -------
      CompatibilityReport
          Structured compatibility report containing mismatches, if any.



   .. py:method:: validate_artifact_metadata(artifact_metadata: Any) -> petitRADTRANS.sbi.compatibility.CompatibilityReport

      Compare a task with persisted artifact metadata.

      Parameters
      ----------
      artifact_metadata:
          Artifact-like object exposing task and preprocessing provenance
          fields. ``Any`` is accepted to avoid a hard dependency cycle.

      Returns
      -------
      CompatibilityReport
          Structured report describing whether the artifact is compatible with
          the current task definition.

      The method accepts ``Any`` to avoid a hard import dependency back from
      the task module to the artifact module.



   .. py:method:: build_observation_states(model_contract: str = 'differentiable') -> dict[str, petitRADTRANS.retrieval.runtime.ObservationState]

      Materialize immutable runtime observations for the task datasets.

      Parameters
      ----------
      model_contract:
          Runtime contract requested from each retrieval dataset.

      Returns
      -------
      dict[str, ObservationState]
          Mapping from dataset name to immutable runtime observation state.



   .. py:method:: build_runtime() -> petitRADTRANS.retrieval.runtime.RetrievalRuntime

      Construct a retrieval runtime matching the task definition.

      Returns
      -------
      RetrievalRuntime
          Runtime object configured with the task's parameter layout and
          retrieval datasets.

      The default implementation mirrors the retrieval-runtime grouping logic
      used by the exact retrieval package and is sufficient for initial SBI
      simulation backends.



   .. py:method:: _infer_modality(data: petitRADTRANS.retrieval.data.Data) -> str
      :staticmethod:



   .. py:method:: _infer_forward_model_family(retrieval_config: petitRADTRANS.retrieval.retrieval_config.RetrievalConfig) -> str
      :staticmethod:



.. py:class:: SimulationConfig

   Control how a task generates prior-predictive simulations.

   Attributes:
       batch_size:
           Number of parameter points produced per simulator call.
       n_simulations:
           Optional target number of simulations for dataset generation jobs.
       use_vectorized_runtime:
           Whether to prefer the JAX-vectorized runtime path when available.
       store_per_datapoint_log_likelihood:
           Persist log-likelihood components alongside simulated observations.
       noise_model:
           Noise model configuration applied after forward-model evaluation.
       observation_value_constraints:
           Optional per-dataset admissible value ranges enforced on
           deterministic projected observations before simulations are accepted.


   .. py:attribute:: batch_size
      :type:  int
      :value: 256



   .. py:attribute:: n_simulations
      :type:  int | None
      :value: None



   .. py:attribute:: use_vectorized_runtime
      :type:  bool
      :value: True



   .. py:attribute:: store_per_datapoint_log_likelihood
      :type:  bool
      :value: False



   .. py:attribute:: noise_model
      :type:  NoiseModelConfig


   .. py:attribute:: observation_value_constraints
      :type:  Mapping[str, ObservationValueConstraint]


.. py:class:: EarlyStoppingConfig

   Early stopping policy for SBI training.


   .. py:attribute:: patience
      :type:  int


   .. py:attribute:: min_delta
      :type:  float
      :value: 0.0



.. py:class:: SBITrainer(config: TrainingConfig)

   Reusable optimization loop for amortized SBI posteriors.


   .. py:attribute:: config


   .. py:attribute:: checkpoint_directory
      :value: None



   .. py:attribute:: checkpoint_backend


   .. py:attribute:: _checkpoint_backend_fallback_reason
      :type:  str | None
      :value: None



   .. py:property:: _latest_checkpoint_directory
      :type: pathlib.Path | None



   .. py:property:: _best_checkpoint_directory
      :type: pathlib.Path | None



   .. py:method:: _checkpoint_directory_for_kind(checkpoint_kind: str) -> pathlib.Path | None


   .. py:method:: _save_checkpoint(checkpoint_kind: str, state: dict[str, Any], metadata: dict[str, Any]) -> None


   .. py:method:: _restore_checkpoint(template_state: dict[str, Any], checkpoint_kind: str) -> tuple[dict[str, Any], dict[str, Any]] | None


   .. py:method:: fit(model: Any, dataset: Any, loss_fn: Callable[[Any, Any], Any], eval_loss_fn: Callable[[Any, Any], Any] | None = None, eval_diagnostic_fn: Callable[[Any, Any], Mapping[str, float]] | None = None, selection_metric_fn: Callable[[float, Mapping[str, float] | None], float] | None = None, selection_metric_name: str | None = None, stability_metric_fn: Callable[[float, Mapping[str, float] | None], Mapping[str, float]] | None = None, stability_flag_key: str = 'checkpoint_is_stable') -> tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]

      Optimize one posterior model against a dataset reader.

      Parameters
      ----------
      model:
          Trainable Equinox-style model to optimize.
      dataset:
          Reader object exposing ``iter_batches`` for train and optional
          validation splits.
      loss_fn:
          Callable receiving ``(model, batch)`` and returning the scalar loss
          to minimize.
      eval_loss_fn:
          Optional loss used for validation. Defaults to ``loss_fn`` when not
          provided.

      Returns
      -------
      tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]
          Best model, train-loss history, validation-loss history, and a
          metadata dictionary describing checkpointing and stopping behavior.

      Notes
      -----
      The trainer monitors validation loss when a validation split exists and
      otherwise falls back to training loss for best-model selection.
      Checkpoints may be written after each epoch when enabled.



   .. py:property:: _min_delta
      :type: float



   .. py:method:: _should_stop(patience_counter: int) -> bool


.. py:class:: TrainingConfig

   Configuration for posterior optimization.


   .. py:attribute:: learning_rate
      :type:  float
      :value: 0.001



   .. py:attribute:: batch_size
      :type:  int
      :value: 32



   .. py:attribute:: num_epochs
      :type:  int
      :value: 10



   .. py:attribute:: parameter_space
      :type:  str
      :value: 'unconstrained'



   .. py:attribute:: seed
      :type:  int
      :value: 0



   .. py:attribute:: shuffle_train
      :type:  bool
      :value: True



   .. py:attribute:: early_stopping
      :type:  EarlyStoppingConfig | None
      :value: None



   .. py:attribute:: checkpoint_directory
      :type:  str | None
      :value: None



   .. py:attribute:: checkpoint_backend
      :type:  str
      :value: 'auto'



   .. py:attribute:: resume_from_checkpoint
      :type:  bool
      :value: False



   .. py:attribute:: data_parallel
      :type:  bool
      :value: False



   .. py:attribute:: gradient_clip_norm
      :type:  float | None
      :value: 1.0



   .. py:attribute:: weight_decay
      :type:  float
      :value: 0.0



   .. py:attribute:: use_cosine_schedule
      :type:  bool
      :value: False



   .. py:attribute:: warmup_fraction
      :type:  float
      :value: 0.02



   .. py:attribute:: warmup_epochs
      :type:  float | None
      :value: None



   .. py:attribute:: min_learning_rate
      :type:  float
      :value: 1e-06



   .. py:attribute:: lr_schedule_total_epochs
      :type:  float | None
      :value: None



   .. py:attribute:: verbose_diagnostics
      :type:  bool
      :value: False



   .. py:attribute:: diagnostics_output_directory
      :type:  str | None
      :value: None



   .. py:attribute:: diagnostics_plot_interval
      :type:  int
      :value: 1



   .. py:attribute:: n_validation_diagnostic_batches
      :type:  int
      :value: 4



