petitRADTRANS.sbi.flows
=======================

.. py:module:: petitRADTRANS.sbi.flows

.. autoapi-nested-parse::

   Conditional flow components for amortized SBI posteriors.



Classes
-------

.. autoapisummary::

   petitRADTRANS.sbi.flows.NeuralInverseDiagnostics
   petitRADTRANS.sbi.flows.CouplingConditioner
   petitRADTRANS.sbi.flows.ConditionalAffineCoupling
   petitRADTRANS.sbi.flows.ConditionalRationalQuadraticSplineCoupling
   petitRADTRANS.sbi.flows.ConditionalScalarSplineTransform
   petitRADTRANS.sbi.flows.ConditionalAutoregressiveAffineTransform
   petitRADTRANS.sbi.flows.ConditionalNeuralAutoregressiveTransform
   petitRADTRANS.sbi.flows.ConditionalAffineFlow
   petitRADTRANS.sbi.flows.ConditionalAutoregressiveFlow
   petitRADTRANS.sbi.flows.ConditionalNeuralAutoregressiveFlow
   petitRADTRANS.sbi.flows.ConditionalSplineFlow


Module Contents
---------------

.. py:class:: NeuralInverseDiagnostics

   Bases: :py:obj:`NamedTuple`


   Aggregate diagnostics for one neural autoregressive inverse pass.


   .. py:attribute:: max_abs_residual
      :type:  jax.numpy.ndarray


   .. py:attribute:: mean_abs_residual
      :type:  jax.numpy.ndarray


   .. py:attribute:: max_bracket_width
      :type:  jax.numpy.ndarray


   .. py:attribute:: mean_bracket_width
      :type:  jax.numpy.ndarray


   .. py:attribute:: max_expansion_steps
      :type:  jax.numpy.ndarray


   .. py:attribute:: mean_expansion_steps
      :type:  jax.numpy.ndarray


   .. py:attribute:: converged_fraction
      :type:  jax.numpy.ndarray


.. py:class:: CouplingConditioner(input_dim: int, context_dim: int, output_dim: int, hidden_dim: int, key: jax.Array, depth: int = 2)

   Bases: :py:obj:`equinox.Module`


   Context-conditioned network producing coupling parameters.

   Parameters
   ----------
   input_dim:
       Size of the masked parameter vector presented to the conditioner.
   context_dim:
       Size of the observation embedding appended to the masked input.
   output_dim:
       Number of coupling parameters produced by the network.
   hidden_dim:
       Hidden width of the internal MLP.
   key:
       JAX random key used to initialize network weights.

   Notes
   -----
   The conditioner concatenates masked parameter values with the observation
   context before passing them through a small MLP.


   .. py:attribute:: network
      :type:  equinox.nn.MLP


   .. py:attribute:: input_dim
      :type:  int


   .. py:attribute:: context_dim
      :type:  int


   .. py:attribute:: output_dim
      :type:  int


   .. py:method:: __call__(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> jax.numpy.ndarray


.. py:class:: ConditionalAffineCoupling

   Bases: :py:obj:`equinox.Module`


   Affine coupling layer conditioned on an observation embedding.

   Notes
   -----
   The mask selects which parameter dimensions remain fixed and which are
   transformed. The conditioner predicts shift and log-scale terms only from
   the masked input and context embedding.


   .. py:attribute:: mask
      :type:  jax.numpy.ndarray


   .. py:attribute:: conditioner
      :type:  CouplingConditioner


   .. py:method:: _parameters(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]

      Apply the forward affine transform.

      Parameters
      ----------
      x:
          Parameter vector in data space.
      context:
          Observation embedding conditioning the transform.

      Returns
      -------
      tuple[jnp.ndarray, jnp.ndarray]
          Transformed vector and scalar log-determinant contribution.



   .. py:method:: inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]

      Apply the inverse affine transform.

      Parameters
      ----------
      y:
          Latent-space vector.
      context:
          Observation embedding conditioning the inverse transform.

      Returns
      -------
      tuple[jnp.ndarray, jnp.ndarray]
          Recovered data-space vector and inverse log-determinant.



.. py:class:: ConditionalRationalQuadraticSplineCoupling

   Bases: :py:obj:`equinox.Module`


   Rational-quadratic spline coupling layer conditioned on an embedding.

   Notes
   -----
   This is the main expressive transform used by spline-based SBI posteriors.
   The conditioner predicts per-dimension spline widths, heights, and
   derivatives for the unmasked subset of the parameter vector.


   .. py:attribute:: mask
      :type:  jax.numpy.ndarray


   .. py:attribute:: conditioner
      :type:  CouplingConditioner


   .. py:attribute:: parameter_dim
      :type:  int


   .. py:attribute:: num_bins
      :type:  int


   .. py:attribute:: bound
      :type:  float


   .. py:attribute:: min_bin_width
      :type:  float


   .. py:attribute:: min_bin_height
      :type:  float


   .. py:attribute:: min_derivative
      :type:  float


   .. py:attribute:: max_derivative
      :type:  float | None


   .. py:method:: _parameters(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]

      Apply the forward spline transform.

      Parameters
      ----------
      x:
          Parameter vector in data space.
      context:
          Observation embedding conditioning the spline parameters.

      Returns
      -------
      tuple[jnp.ndarray, jnp.ndarray]
          Transformed vector and summed log-determinant contribution.



   .. py:method:: inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]

      Apply the inverse spline transform.

      Parameters
      ----------
      y:
          Latent-space vector.
      context:
          Observation embedding conditioning the inverse transform.

      Returns
      -------
      tuple[jnp.ndarray, jnp.ndarray]
          Recovered parameter vector and inverse log-determinant.



.. py:class:: ConditionalScalarSplineTransform

   Bases: :py:obj:`equinox.Module`


   Context-conditioned scalar spline transform used for 1D posteriors.

   Notes
   -----
   One-dimensional posterior models cannot rely on alternating coupling masks,
   so they use this dedicated context-only spline transform instead.


   .. py:attribute:: conditioner
      :type:  _ScalarSplineConditioner


   .. py:attribute:: num_bins
      :type:  int


   .. py:attribute:: bound
      :type:  float


   .. py:attribute:: min_bin_width
      :type:  float


   .. py:attribute:: min_bin_height
      :type:  float


   .. py:attribute:: min_derivative
      :type:  float


   .. py:attribute:: max_derivative
      :type:  float | None


   .. py:method:: _parameters(context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: _as_scalar(value: jax.numpy.ndarray) -> jax.numpy.ndarray
      :staticmethod:



   .. py:method:: forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]

      Apply the forward scalar spline transform.

      Parameters
      ----------
      x:
          One-dimensional parameter vector.
      context:
          Observation embedding conditioning the spline parameters.

      Returns
      -------
      tuple[jnp.ndarray, jnp.ndarray]
          Transformed scalar vector and scalar log-determinant.



   .. py:method:: inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]

      Apply the inverse scalar spline transform.

      Parameters
      ----------
      y:
          One-dimensional latent vector.
      context:
          Observation embedding conditioning the inverse transform.

      Returns
      -------
      tuple[jnp.ndarray, jnp.ndarray]
          Recovered scalar parameter vector and inverse log-determinant.



.. py:class:: ConditionalAutoregressiveAffineTransform(parameter_dim: int, context_dim: int, hidden_dim: int, key: jax.Array, depth: int = 5)

   Bases: :py:obj:`equinox.Module`


   Context-conditioned affine autoregressive transform.

   Notes
   -----
   The transform maps data to latent space in a masked-autoregressive manner,
   using previous data dimensions and the observation embedding to predict one
   affine transform per parameter dimension.


   .. py:attribute:: conditioners
      :type:  tuple[_AutoregressiveDimensionConditioner, Ellipsis]


   .. py:attribute:: parameter_dim
      :type:  int


   .. py:method:: forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: inverse(z: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]


.. py:class:: ConditionalNeuralAutoregressiveTransform(parameter_dim: int, context_dim: int, hidden_dim: int, transform_units: int, key: jax.Array, depth: int = 5, min_slope: float = 0.001, min_residual: float = 0.05, inverse_base_bound: float = 4.0, inverse_bisection_steps: int = 48, inverse_expansion_steps: int = 16, inverse_newton_steps: int = 3)

   Bases: :py:obj:`equinox.Module`


   Context-conditioned neural autoregressive transform.

   Notes
   -----
   Each scalar transform uses a monotonic deep-sigmoidal-style network whose
   parameters are conditioned on previous dimensions and the observation
   embedding. Inversion is carried out with fixed-step bisection because the
   scalar map is strictly monotonic.


   .. py:attribute:: conditioners
      :type:  tuple[_NeuralAutoregressiveDimensionConditioner, Ellipsis]


   .. py:attribute:: parameter_dim
      :type:  int


   .. py:attribute:: transform_units
      :type:  int


   .. py:attribute:: min_slope
      :type:  float


   .. py:attribute:: min_residual
      :type:  float


   .. py:attribute:: inverse_base_bound
      :type:  float


   .. py:attribute:: inverse_bisection_steps
      :type:  int


   .. py:attribute:: inverse_expansion_steps
      :type:  int


   .. py:attribute:: inverse_newton_steps
      :type:  int


   .. py:method:: _parameters(conditioner: _NeuralAutoregressiveDimensionConditioner, prefix: jax.numpy.ndarray, context: jax.numpy.ndarray, dtype: jax.numpy.dtype) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: _evaluate_scalar(value: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: _inverse_scalar(target: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: _inverse_scalar_with_diagnostics(target: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: inverse(z: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]


   .. py:method:: inverse_with_diagnostics(z: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray, NeuralInverseDiagnostics]


.. py:class:: ConditionalAffineFlow(parameter_dim: int, context_dim: int, num_layers: int = 4, hidden_dim: int = 128, conditioner_depth: int = 2, key: jax.Array | None = None)

   Bases: :py:obj:`_ConditionalFlowStack`


   Stack of conditional affine couplings with alternating masks.

   Parameters
   ----------
   parameter_dim:
       Number of inferred parameters transformed by the flow.
   context_dim:
       Size of the observation embedding conditioning each layer.
   num_layers:
       Number of affine coupling layers.
   hidden_dim:
       Hidden width of each conditioner MLP.
   key:
       Optional JAX random key used for layer initialization.

   Notes
   -----
   For one-dimensional problems the mask degenerates to a context-only affine
   transform. Higher-dimensional problems alternate masks across layers.


   .. py:attribute:: parameter_dim


   .. py:attribute:: context_dim


   .. py:attribute:: layers
      :value: ()



.. py:class:: ConditionalAutoregressiveFlow(parameter_dim: int, context_dim: int, num_layers: int = 3, hidden_dim: int = 512, conditioner_depth: int = 5, key: jax.Array | None = None)

   Bases: :py:obj:`_ConditionalFlowStack`


   Stack of context-conditioned affine autoregressive transforms.

   Notes
   -----
   This flow is a lighter-weight autoregressive alternative to the spline
   stack. Each transform uses ELU MLPs to predict affine parameters for one
   dimension at a time, conditioned on previous dimensions and the
   observation embedding.


   .. py:attribute:: parameter_dim


   .. py:attribute:: context_dim


   .. py:attribute:: layers


.. py:class:: ConditionalNeuralAutoregressiveFlow(parameter_dim: int, context_dim: int, num_layers: int = 3, hidden_dim: int = 512, conditioner_depth: int = 5, transform_units: int = 16, min_slope: float = 0.001, min_residual: float = 0.05, inverse_bisection_steps: int = 48, inverse_newton_steps: int = 3, key: jax.Array | None = None)

   Bases: :py:obj:`_ConditionalFlowStack`


   Stack of context-conditioned neural autoregressive transforms.

   Notes
   -----
   This backend uses a deep-sigmoidal-style monotonic scalar bijector per
   parameter dimension. It is more expressive than the affine autoregressive
   stack while remaining exactly invertible via scalar bisection.


   .. py:attribute:: transform_units
      :type:  int


   .. py:attribute:: min_slope
      :type:  float


   .. py:attribute:: min_residual
      :type:  float


   .. py:attribute:: inverse_bisection_steps
      :type:  int


   .. py:attribute:: inverse_newton_steps
      :type:  int


   .. py:attribute:: parameter_dim


   .. py:attribute:: context_dim


   .. py:attribute:: layers


   .. py:method:: inverse_with_diagnostics(z: jax.numpy.ndarray, context: jax.numpy.ndarray) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray, NeuralInverseDiagnostics]


.. py:class:: ConditionalSplineFlow(parameter_dim: int, context_dim: int, num_layers: int = 6, hidden_dim: int = 128, conditioner_depth: int = 2, num_bins: int = 8, bound: float = 10.0, min_bin_width: float = 0.001, min_bin_height: float = 0.001, min_derivative: float = 0.001, max_derivative: float | None = None, base_distribution: str = 'gaussian', use_base_affine: bool = False, key: jax.Array | None = None)

   Bases: :py:obj:`_ConditionalFlowStack`


   Stack of conditional rational-quadratic spline transforms.

   Parameters
   ----------
   parameter_dim:
       Number of inferred parameters transformed by the flow.
   context_dim:
       Size of the conditioning observation embedding.
   num_layers:
       Number of spline transform layers.
   hidden_dim:
       Hidden width of the conditioner networks.
   num_bins:
       Number of rational-quadratic bins used per transformed dimension.
   bound:
       Finite support bound of the spline transform in latent coordinates.
   min_bin_width, min_bin_height, min_derivative:
       Numerical-stability floors applied to spline parameters.
   key:
       Optional JAX random key used for initialization.

   Notes
   -----
   One-dimensional posteriors use a dedicated scalar spline stack, while
   higher-dimensional models use alternating masked coupling layers.


   .. py:attribute:: num_bins
      :type:  int


   .. py:attribute:: bound
      :type:  float


   .. py:attribute:: min_bin_width
      :type:  float


   .. py:attribute:: min_bin_height
      :type:  float


   .. py:attribute:: min_derivative
      :type:  float


   .. py:attribute:: max_derivative
      :type:  float | None


   .. py:attribute:: parameter_dim


   .. py:attribute:: context_dim


   .. py:attribute:: base_distribution
      :value: ''



   .. py:attribute:: layers
      :value: ()



