petitRADTRANS.sbi.simulator
===========================

.. py:module:: petitRADTRANS.sbi.simulator

.. autoapi-nested-parse::

   Simulation interfaces for prior-predictive and proposal-driven generation.



Attributes
----------

.. autoapisummary::

   petitRADTRANS.sbi.simulator.logger


Classes
-------

.. autoapisummary::

   petitRADTRANS.sbi.simulator.SimulationBatch
   petitRADTRANS.sbi.simulator.SimulationRuntime
   petitRADTRANS.sbi.simulator.ProposalSampler
   petitRADTRANS.sbi.simulator.Simulator
   petitRADTRANS.sbi.simulator.BatchedSimulator
   petitRADTRANS.sbi.simulator.RuntimeSimulator


Functions
---------

.. autoapisummary::

   petitRADTRANS.sbi.simulator._stack_simulation_output_payloads


Module Contents
---------------

.. py:data:: logger

.. py:class:: SimulationBatch

   Container for one simulated batch.

   Attributes:
       parameters:
           Array-like free-parameter matrix with leading dimension equal to the
           number of simulated samples.
       observations:
           Task-conditioned simulated observations.
       log_likelihood:
           Optional scalar likelihood values associated with each sample.
       diagnostics:
           Additional runtime diagnostics such as clipping flags or NaN counts.


   .. py:attribute:: parameters
      :type:  Any


   .. py:attribute:: observations
      :type:  Any


   .. py:attribute:: cube_parameters
      :type:  Any
      :value: None



   .. py:attribute:: unconstrained_parameters
      :type:  Any
      :value: None



   .. py:attribute:: log_likelihood
      :type:  Any
      :value: None



   .. py:attribute:: diagnostics
      :type:  Mapping[str, Any]


   .. py:property:: n_samples
      :type: int


      Return the number of simulated samples represented by the batch.



.. py:function:: _stack_simulation_output_payloads(payloads: list[dict]) -> dict

   Stack per-sample observation payloads from simulate_observations_scalar.

   Array fields are stacked into a leading (n_samples, ...) dimension.
   The ``"metadata"`` field is kept as a ``list`` (one dict per sample).
   Fields that are ``None`` for every sample remain ``None``.


.. py:class:: SimulationRuntime

   Bases: :py:obj:`petitRADTRANS.retrieval.runtime.RetrievalRuntime`


   Extension of :class:`~petitRADTRANS.retrieval.runtime.RetrievalRuntime`
   that adds deterministic forward-model simulation methods.

   ``RetrievalRuntime`` covers likelihood evaluation.  This subclass layers
   the *simulation* surface on top — projecting parameters into observation
   space without noise so that :class:`RuntimeSimulator` can inject noise
   afterwards and assemble :class:`SimulationBatch` objects.


   .. py:method:: simulate_observations_scalar(physical_params: petitRADTRANS.retrieval.runtime.PhysicalParams) -> dict[str, dict[str, Any]]

      Return deterministic observation-space model outputs for one parameter point.



   .. py:method:: _run_batched_rt_kernel(parameter_matrix: Any) -> tuple[Any, list[tuple[Any, Any]]]

      Run the expensive vmapped RT kernel and return raw results.

      Returns
      -------
      tuple[batched_physical_params, list[(group, batched_model_result)]]
          The batched physical parameters and per-group raw model outputs
          before observation projection.



   .. py:method:: _run_batched_rt_kernel_multi_device(parameter_matrix: Any) -> tuple[Any, list[tuple[Any, Any]]]

      Multi-device version of :meth:`_run_batched_rt_kernel`.

      Distributes the batch across all available JAX devices by placing
      each chunk on a separate device with ``jax.device_put`` and then
      calling the single-device ``_run_batched_rt_kernel`` per chunk.
      JAX's asynchronous GPU dispatch allows the per-device computations
      to overlap.  Results are gathered and concatenated before return.

      The parameter matrix **must** have its leading dimension equal to a
      multiple of the device count (caller is responsible for padding).



   .. py:method:: _project_batched_rt_results(batched_physical_params: Any, group_batched_results: list[tuple[Any, Any]], n_samples: int) -> dict[str, dict[str, Any]]

      Project raw RT kernel outputs into observation space.

      This is the cheap projection step (interpolation, scaling, binning).
      When possible the projection and flux-transform functions are
      vectorized with ``jax.vmap`` so that the Python-level loop over
      samples is replaced by a single traced call.



   .. py:method:: simulate_observations_batched(parameter_matrix: Any, multi_device: bool = False) -> dict[str, dict[str, Any]]

      Return deterministic batched observation outputs for a parameter matrix.

      For model groups using ``MODEL_CONTRACT_DIFFERENTIABLE`` this method
      vmaps the expensive RT kernel over the whole batch in one XLA call,
      then performs the cheap projection step (interpolation, scaling, etc.)
      in a Python loop.  Non-native groups fall back to a scalar-loop path
      equivalent to calling :meth:`simulate_observations_scalar` per row.

      Args:
          parameter_matrix: Shape ``(n_samples, n_free_params)``.
          multi_device: When ``True`` and more than one JAX device is
              available, distribute the vmapped RT kernel across devices
              using ``jax.pmap``.

      Returns:
          Mapping from observation name to a payload dict where ``"values"``
          has shape ``(n_samples, n_wavelengths)`` and ``"mask"``,
          ``"coordinates"``, and ``"metadata"`` are static / per-sample lists.



.. py:class:: ProposalSampler

   Bases: :py:obj:`abc.ABC`


   Interface for simulation proposals beyond the prior distribution.


   .. py:method:: sample(n_samples: int, task: petitRADTRANS.sbi.task.SBITask) -> Any
      :abstractmethod:


      Draw free-parameter vectors for the given task.



.. py:class:: Simulator(task: petitRADTRANS.sbi.task.SBITask)

   Bases: :py:obj:`abc.ABC`


   Base simulator for SBI dataset generation and validation.


   .. py:attribute:: task


   .. py:method:: sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) -> Any
      :abstractmethod:


      Sample free-parameter vectors from the task prior or a proposal.



   .. py:method:: simulate_from_parameters(parameters: Any) -> SimulationBatch
      :abstractmethod:


      Run the forward model and noise pipeline for pre-specified parameters.



   .. py:method:: simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) -> SimulationBatch

      Sample parameters and simulate one batch in a single call.



.. py:class:: BatchedSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None)

   Bases: :py:obj:`Simulator`


   Runtime-backed simulator that prefers vectorized task execution.

   The class intentionally exposes only interfaces at this stage. Concrete
   implementations can back this with ``RetrievalRuntime.evaluate_vectorized_jax``
   or a fallback scalar loop when JAX batching is unavailable.


   .. py:attribute:: runtime
      :value: None



   .. py:method:: sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) -> Any

      Sample free-parameter vectors from the task prior or a proposal.



   .. py:method:: simulate_from_parameters(parameters: Any) -> SimulationBatch
      :abstractmethod:


      Run the forward model and noise pipeline for pre-specified parameters.



.. py:class:: RuntimeSimulator(task: petitRADTRANS.sbi.task.SBITask, runtime: Any | None = None, seed: int | None = None, data_parallel: bool | None = None)

   Bases: :py:obj:`BatchedSimulator`


   Concrete simulator backed by the retrieval runtime.

   The simulator samples from the retrieval prior using JAX random utilities
   and projects deterministic forward-model outputs into the observation space
   using ``RetrievalRuntime``.


   .. py:attribute:: _rng_key


   .. py:method:: _validate_task_support() -> None


   .. py:method:: _next_key() -> jax.Array


   .. py:method:: _row_invalid_value_mask(values: Any, mask: Any, constraint: Any) -> numpy.ndarray
      :staticmethod:



   .. py:method:: _invalid_sample_mask(observations: Mapping[str, Mapping[str, Any]], n_samples: int) -> tuple[numpy.ndarray, dict[str, int]]


   .. py:method:: _slice_batch_rows(batch: SimulationBatch, row_indices: numpy.ndarray) -> SimulationBatch
      :staticmethod:



   .. py:method:: _concatenate_simulation_batches(batches: list[SimulationBatch], diagnostics: Mapping[str, Any] | None = None) -> SimulationBatch
      :staticmethod:



   .. py:method:: _sample_prior_parameter_batch(n_samples: int) -> dict[str, jax.numpy.ndarray]


   .. py:method:: sample_parameters(n_samples: int, proposal: ProposalSampler | None = None) -> Any

      Sample free-parameter vectors from the task prior or a proposal.



   .. py:method:: _advance_rng_for_batch(n_samples: int) -> None

      Advance the PRNG state as if ``simulate_batch(n_samples)`` were called.

      This performs only cheap key-splitting – no forward-model evaluation –
      making it suitable for fast dataset-resume skipping.

      The number of ``_next_key`` calls matches the pattern inside
      ``simulate_batch`` → ``_sample_prior_parameter_batch`` (1 key) and
      ``_simulate_parameter_matrix`` (1 key per observation that adds noise).



   .. py:method:: _apply_noise(values: Any, uncertainties: Any, covariance: Any) -> tuple[Any, Any, Any]


   .. py:method:: _apply_noise_batched(values_batch: Any, uncertainties_batch: Any, covariance_batch: Any) -> tuple[Any, Any, Any]

      Apply noise to a batch of simulation outputs.

      Args:
          values_batch: Shape ``(n_samples, n_wavelengths)``.
          uncertainties_batch: Shape ``(n_samples, n_wavelengths)`` or ``None``.
          covariance_batch: Shape ``(n_samples, n_wl, n_wl)`` or ``None``.

      Returns:
          Noisy values, uncertainties, and covariance with the same leading
          batch dimension as ``values_batch``.



   .. py:method:: _stack_observation_payloads(payloads: list[dict[str, Any]]) -> dict[str, Any]
      :staticmethod:



   .. py:method:: _simulate_parameter_matrix(parameter_matrix: Any, cube_parameters: Any = None, unconstrained_parameters: Any = None, data_parallel: bool = False) -> SimulationBatch


   .. py:method:: simulate_from_parameters(parameters: Any) -> SimulationBatch

      Run the forward model and noise pipeline for pre-specified parameters.



   .. py:method:: simulate_batch(n_samples: int, proposal: ProposalSampler | None = None) -> SimulationBatch

      Sample parameters and preserve prior-space coordinates when available.



