petitRADTRANS.sbi.flows

Contents

petitRADTRANS.sbi.flows#

Conditional flow components for amortized SBI posteriors.

Classes#

NeuralInverseDiagnostics

Aggregate diagnostics for one neural autoregressive inverse pass.

CouplingConditioner

Context-conditioned network producing coupling parameters.

ConditionalAffineCoupling

Affine coupling layer conditioned on an observation embedding.

ConditionalRationalQuadraticSplineCoupling

Rational-quadratic spline coupling layer conditioned on an embedding.

ConditionalScalarSplineTransform

Context-conditioned scalar spline transform used for 1D posteriors.

ConditionalAutoregressiveAffineTransform

Context-conditioned affine autoregressive transform.

ConditionalNeuralAutoregressiveTransform

Context-conditioned neural autoregressive transform.

ConditionalAffineFlow

Stack of conditional affine couplings with alternating masks.

ConditionalAutoregressiveFlow

Stack of context-conditioned affine autoregressive transforms.

ConditionalNeuralAutoregressiveFlow

Stack of context-conditioned neural autoregressive transforms.

ConditionalSplineFlow

Stack of conditional rational-quadratic spline transforms.

Module Contents#

class petitRADTRANS.sbi.flows.NeuralInverseDiagnostics#

Bases: NamedTuple

Aggregate diagnostics for one neural autoregressive inverse pass.

max_abs_residual: jax.numpy.ndarray#
mean_abs_residual: jax.numpy.ndarray#
max_bracket_width: jax.numpy.ndarray#
mean_bracket_width: jax.numpy.ndarray#
max_expansion_steps: jax.numpy.ndarray#
mean_expansion_steps: jax.numpy.ndarray#
converged_fraction: jax.numpy.ndarray#
class petitRADTRANS.sbi.flows.CouplingConditioner(input_dim: int, context_dim: int, output_dim: int, hidden_dim: int, key: jax.Array, depth: int = 2)#

Bases: equinox.Module

Context-conditioned network producing coupling parameters.

Parameters#

input_dim:

Size of the masked parameter vector presented to the conditioner.

context_dim:

Size of the observation embedding appended to the masked input.

output_dim:

Number of coupling parameters produced by the network.

hidden_dim:

Hidden width of the internal MLP.

key:

JAX random key used to initialize network weights.

Notes#

The conditioner concatenates masked parameter values with the observation context before passing them through a small MLP.

network: equinox.nn.MLP#
input_dim: int#
context_dim: int#
output_dim: int#
__call__(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) jax.numpy.ndarray#
class petitRADTRANS.sbi.flows.ConditionalAffineCoupling#

Bases: equinox.Module

Affine coupling layer conditioned on an observation embedding.

Notes#

The mask selects which parameter dimensions remain fixed and which are transformed. The conditioner predicts shift and log-scale terms only from the masked input and context embedding.

mask: jax.numpy.ndarray#
conditioner: CouplingConditioner#
_parameters(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#

Apply the forward affine transform.

Parameters#

x:

Parameter vector in data space.

context:

Observation embedding conditioning the transform.

Returns#

tuple[jnp.ndarray, jnp.ndarray]

Transformed vector and scalar log-determinant contribution.

inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#

Apply the inverse affine transform.

Parameters#

y:

Latent-space vector.

context:

Observation embedding conditioning the inverse transform.

Returns#

tuple[jnp.ndarray, jnp.ndarray]

Recovered data-space vector and inverse log-determinant.

class petitRADTRANS.sbi.flows.ConditionalRationalQuadraticSplineCoupling#

Bases: equinox.Module

Rational-quadratic spline coupling layer conditioned on an embedding.

Notes#

This is the main expressive transform used by spline-based SBI posteriors. The conditioner predicts per-dimension spline widths, heights, and derivatives for the unmasked subset of the parameter vector.

mask: jax.numpy.ndarray#
conditioner: CouplingConditioner#
parameter_dim: int#
num_bins: int#
bound: float#
min_bin_width: float#
min_bin_height: float#
min_derivative: float#
max_derivative: float | None#
_parameters(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#

Apply the forward spline transform.

Parameters#

x:

Parameter vector in data space.

context:

Observation embedding conditioning the spline parameters.

Returns#

tuple[jnp.ndarray, jnp.ndarray]

Transformed vector and summed log-determinant contribution.

inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#

Apply the inverse spline transform.

Parameters#

y:

Latent-space vector.

context:

Observation embedding conditioning the inverse transform.

Returns#

tuple[jnp.ndarray, jnp.ndarray]

Recovered parameter vector and inverse log-determinant.

class petitRADTRANS.sbi.flows.ConditionalScalarSplineTransform#

Bases: equinox.Module

Context-conditioned scalar spline transform used for 1D posteriors.

Notes#

One-dimensional posterior models cannot rely on alternating coupling masks, so they use this dedicated context-only spline transform instead.

conditioner: _ScalarSplineConditioner#
num_bins: int#
bound: float#
min_bin_width: float#
min_bin_height: float#
min_derivative: float#
max_derivative: float | None#
_parameters(context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
static _as_scalar(value: jax.numpy.ndarray) jax.numpy.ndarray#
forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#

Apply the forward scalar spline transform.

Parameters#

x:

One-dimensional parameter vector.

context:

Observation embedding conditioning the spline parameters.

Returns#

tuple[jnp.ndarray, jnp.ndarray]

Transformed scalar vector and scalar log-determinant.

inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#

Apply the inverse scalar spline transform.

Parameters#

y:

One-dimensional latent vector.

context:

Observation embedding conditioning the inverse transform.

Returns#

tuple[jnp.ndarray, jnp.ndarray]

Recovered scalar parameter vector and inverse log-determinant.

class petitRADTRANS.sbi.flows.ConditionalAutoregressiveAffineTransform(parameter_dim: int, context_dim: int, hidden_dim: int, key: jax.Array, depth: int = 5)#

Bases: equinox.Module

Context-conditioned affine autoregressive transform.

Notes#

The transform maps data to latent space in a masked-autoregressive manner, using previous data dimensions and the observation embedding to predict one affine transform per parameter dimension.

conditioners: tuple[_AutoregressiveDimensionConditioner, Ellipsis]#
parameter_dim: int#
forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
inverse(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
class petitRADTRANS.sbi.flows.ConditionalNeuralAutoregressiveTransform(parameter_dim: int, context_dim: int, hidden_dim: int, transform_units: int, key: jax.Array, depth: int = 5, min_slope: float = 0.001, min_residual: float = 0.05, inverse_base_bound: float = 4.0, inverse_bisection_steps: int = 48, inverse_expansion_steps: int = 16, inverse_newton_steps: int = 3)#

Bases: equinox.Module

Context-conditioned neural autoregressive transform.

Notes#

Each scalar transform uses a monotonic deep-sigmoidal-style network whose parameters are conditioned on previous dimensions and the observation embedding. Inversion is carried out with fixed-step bisection because the scalar map is strictly monotonic.

conditioners: tuple[_NeuralAutoregressiveDimensionConditioner, Ellipsis]#
parameter_dim: int#
transform_units: int#
min_slope: float#
min_residual: float#
inverse_base_bound: float#
inverse_bisection_steps: int#
inverse_expansion_steps: int#
inverse_newton_steps: int#
_parameters(conditioner: _NeuralAutoregressiveDimensionConditioner, prefix: jax.numpy.ndarray, context: jax.numpy.ndarray, dtype: jax.numpy.dtype) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
_evaluate_scalar(value: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
_inverse_scalar(target: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
_inverse_scalar_with_diagnostics(target: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
inverse(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
inverse_with_diagnostics(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, NeuralInverseDiagnostics]#
class petitRADTRANS.sbi.flows.ConditionalAffineFlow(parameter_dim: int, context_dim: int, num_layers: int = 4, hidden_dim: int = 128, conditioner_depth: int = 2, key: jax.Array | None = None)#

Bases: _ConditionalFlowStack

Stack of conditional affine couplings with alternating masks.

Parameters#

parameter_dim:

Number of inferred parameters transformed by the flow.

context_dim:

Size of the observation embedding conditioning each layer.

num_layers:

Number of affine coupling layers.

hidden_dim:

Hidden width of each conditioner MLP.

key:

Optional JAX random key used for layer initialization.

Notes#

For one-dimensional problems the mask degenerates to a context-only affine transform. Higher-dimensional problems alternate masks across layers.

parameter_dim#
context_dim#
layers = ()#
class petitRADTRANS.sbi.flows.ConditionalAutoregressiveFlow(parameter_dim: int, context_dim: int, num_layers: int = 3, hidden_dim: int = 512, conditioner_depth: int = 5, key: jax.Array | None = None)#

Bases: _ConditionalFlowStack

Stack of context-conditioned affine autoregressive transforms.

Notes#

This flow is a lighter-weight autoregressive alternative to the spline stack. Each transform uses ELU MLPs to predict affine parameters for one dimension at a time, conditioned on previous dimensions and the observation embedding.

parameter_dim#
context_dim#
layers#
class petitRADTRANS.sbi.flows.ConditionalNeuralAutoregressiveFlow(parameter_dim: int, context_dim: int, num_layers: int = 3, hidden_dim: int = 512, conditioner_depth: int = 5, transform_units: int = 16, min_slope: float = 0.001, min_residual: float = 0.05, inverse_bisection_steps: int = 48, inverse_newton_steps: int = 3, key: jax.Array | None = None)#

Bases: _ConditionalFlowStack

Stack of context-conditioned neural autoregressive transforms.

Notes#

This backend uses a deep-sigmoidal-style monotonic scalar bijector per parameter dimension. It is more expressive than the affine autoregressive stack while remaining exactly invertible via scalar bisection.

transform_units: int#
min_slope: float#
min_residual: float#
inverse_bisection_steps: int#
inverse_newton_steps: int#
parameter_dim#
context_dim#
layers#
inverse_with_diagnostics(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, NeuralInverseDiagnostics]#
class petitRADTRANS.sbi.flows.ConditionalSplineFlow(parameter_dim: int, context_dim: int, num_layers: int = 6, hidden_dim: int = 128, conditioner_depth: int = 2, num_bins: int = 8, bound: float = 10.0, min_bin_width: float = 0.001, min_bin_height: float = 0.001, min_derivative: float = 0.001, max_derivative: float | None = None, base_distribution: str = 'gaussian', use_base_affine: bool = False, key: jax.Array | None = None)#

Bases: _ConditionalFlowStack

Stack of conditional rational-quadratic spline transforms.

Parameters#

parameter_dim:

Number of inferred parameters transformed by the flow.

context_dim:

Size of the conditioning observation embedding.

num_layers:

Number of spline transform layers.

hidden_dim:

Hidden width of the conditioner networks.

num_bins:

Number of rational-quadratic bins used per transformed dimension.

bound:

Finite support bound of the spline transform in latent coordinates.

min_bin_width, min_bin_height, min_derivative:

Numerical-stability floors applied to spline parameters.

key:

Optional JAX random key used for initialization.

Notes#

One-dimensional posteriors use a dedicated scalar spline stack, while higher-dimensional models use alternating masked coupling layers.

num_bins: int#
bound: float#
min_bin_width: float#
min_bin_height: float#
min_derivative: float#
max_derivative: float | None#
parameter_dim#
context_dim#
base_distribution = ''#
layers = ()#