petitRADTRANS.sbi.flows#
Conditional flow components for amortized SBI posteriors.
Classes#
Aggregate diagnostics for one neural autoregressive inverse pass. |
|
Context-conditioned network producing coupling parameters. |
|
Affine coupling layer conditioned on an observation embedding. |
|
Rational-quadratic spline coupling layer conditioned on an embedding. |
|
Context-conditioned scalar spline transform used for 1D posteriors. |
|
Context-conditioned affine autoregressive transform. |
|
Context-conditioned neural autoregressive transform. |
|
Stack of conditional affine couplings with alternating masks. |
|
Stack of context-conditioned affine autoregressive transforms. |
|
Stack of context-conditioned neural autoregressive transforms. |
|
Stack of conditional rational-quadratic spline transforms. |
Module Contents#
- class petitRADTRANS.sbi.flows.NeuralInverseDiagnostics#
Bases:
NamedTupleAggregate diagnostics for one neural autoregressive inverse pass.
- max_abs_residual: jax.numpy.ndarray#
- mean_abs_residual: jax.numpy.ndarray#
- max_bracket_width: jax.numpy.ndarray#
- mean_bracket_width: jax.numpy.ndarray#
- max_expansion_steps: jax.numpy.ndarray#
- mean_expansion_steps: jax.numpy.ndarray#
- converged_fraction: jax.numpy.ndarray#
- class petitRADTRANS.sbi.flows.CouplingConditioner(input_dim: int, context_dim: int, output_dim: int, hidden_dim: int, key: jax.Array, depth: int = 2)#
Bases:
equinox.ModuleContext-conditioned network producing coupling parameters.
Parameters#
- input_dim:
Size of the masked parameter vector presented to the conditioner.
- context_dim:
Size of the observation embedding appended to the masked input.
- output_dim:
Number of coupling parameters produced by the network.
- hidden_dim:
Hidden width of the internal MLP.
- key:
JAX random key used to initialize network weights.
Notes#
The conditioner concatenates masked parameter values with the observation context before passing them through a small MLP.
- network: equinox.nn.MLP#
- input_dim: int#
- context_dim: int#
- output_dim: int#
- __call__(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) jax.numpy.ndarray#
- class petitRADTRANS.sbi.flows.ConditionalAffineCoupling#
Bases:
equinox.ModuleAffine coupling layer conditioned on an observation embedding.
Notes#
The mask selects which parameter dimensions remain fixed and which are transformed. The conditioner predicts shift and log-scale terms only from the masked input and context embedding.
- mask: jax.numpy.ndarray#
- conditioner: CouplingConditioner#
- _parameters(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
- forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
Apply the forward affine transform.
Parameters#
- x:
Parameter vector in data space.
- context:
Observation embedding conditioning the transform.
Returns#
- tuple[jnp.ndarray, jnp.ndarray]
Transformed vector and scalar log-determinant contribution.
- inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
Apply the inverse affine transform.
Parameters#
- y:
Latent-space vector.
- context:
Observation embedding conditioning the inverse transform.
Returns#
- tuple[jnp.ndarray, jnp.ndarray]
Recovered data-space vector and inverse log-determinant.
- class petitRADTRANS.sbi.flows.ConditionalRationalQuadraticSplineCoupling#
Bases:
equinox.ModuleRational-quadratic spline coupling layer conditioned on an embedding.
Notes#
This is the main expressive transform used by spline-based SBI posteriors. The conditioner predicts per-dimension spline widths, heights, and derivatives for the unmasked subset of the parameter vector.
- mask: jax.numpy.ndarray#
- conditioner: CouplingConditioner#
- parameter_dim: int#
- num_bins: int#
- bound: float#
- min_bin_width: float#
- min_bin_height: float#
- min_derivative: float#
- max_derivative: float | None#
- _parameters(masked_x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
- forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
Apply the forward spline transform.
Parameters#
- x:
Parameter vector in data space.
- context:
Observation embedding conditioning the spline parameters.
Returns#
- tuple[jnp.ndarray, jnp.ndarray]
Transformed vector and summed log-determinant contribution.
- inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
Apply the inverse spline transform.
Parameters#
- y:
Latent-space vector.
- context:
Observation embedding conditioning the inverse transform.
Returns#
- tuple[jnp.ndarray, jnp.ndarray]
Recovered parameter vector and inverse log-determinant.
- class petitRADTRANS.sbi.flows.ConditionalScalarSplineTransform#
Bases:
equinox.ModuleContext-conditioned scalar spline transform used for 1D posteriors.
Notes#
One-dimensional posterior models cannot rely on alternating coupling masks, so they use this dedicated context-only spline transform instead.
- conditioner: _ScalarSplineConditioner#
- num_bins: int#
- bound: float#
- min_bin_width: float#
- min_bin_height: float#
- min_derivative: float#
- max_derivative: float | None#
- _parameters(context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
- static _as_scalar(value: jax.numpy.ndarray) jax.numpy.ndarray#
- forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
Apply the forward scalar spline transform.
Parameters#
- x:
One-dimensional parameter vector.
- context:
Observation embedding conditioning the spline parameters.
Returns#
- tuple[jnp.ndarray, jnp.ndarray]
Transformed scalar vector and scalar log-determinant.
- inverse(y: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
Apply the inverse scalar spline transform.
Parameters#
- y:
One-dimensional latent vector.
- context:
Observation embedding conditioning the inverse transform.
Returns#
- tuple[jnp.ndarray, jnp.ndarray]
Recovered scalar parameter vector and inverse log-determinant.
- class petitRADTRANS.sbi.flows.ConditionalAutoregressiveAffineTransform(parameter_dim: int, context_dim: int, hidden_dim: int, key: jax.Array, depth: int = 5)#
Bases:
equinox.ModuleContext-conditioned affine autoregressive transform.
Notes#
The transform maps data to latent space in a masked-autoregressive manner, using previous data dimensions and the observation embedding to predict one affine transform per parameter dimension.
- conditioners: tuple[_AutoregressiveDimensionConditioner, Ellipsis]#
- parameter_dim: int#
- forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
- inverse(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
- class petitRADTRANS.sbi.flows.ConditionalNeuralAutoregressiveTransform(parameter_dim: int, context_dim: int, hidden_dim: int, transform_units: int, key: jax.Array, depth: int = 5, min_slope: float = 0.001, min_residual: float = 0.05, inverse_base_bound: float = 4.0, inverse_bisection_steps: int = 48, inverse_expansion_steps: int = 16, inverse_newton_steps: int = 3)#
Bases:
equinox.ModuleContext-conditioned neural autoregressive transform.
Notes#
Each scalar transform uses a monotonic deep-sigmoidal-style network whose parameters are conditioned on previous dimensions and the observation embedding. Inversion is carried out with fixed-step bisection because the scalar map is strictly monotonic.
- conditioners: tuple[_NeuralAutoregressiveDimensionConditioner, Ellipsis]#
- parameter_dim: int#
- transform_units: int#
- min_slope: float#
- min_residual: float#
- inverse_base_bound: float#
- inverse_bisection_steps: int#
- inverse_expansion_steps: int#
- inverse_newton_steps: int#
- _parameters(conditioner: _NeuralAutoregressiveDimensionConditioner, prefix: jax.numpy.ndarray, context: jax.numpy.ndarray, dtype: jax.numpy.dtype) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
- _evaluate_scalar(value: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
- _inverse_scalar(target: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
- _inverse_scalar_with_diagnostics(target: jax.numpy.ndarray, positive_slopes: jax.numpy.ndarray, offsets: jax.numpy.ndarray, weight_logits: jax.numpy.ndarray, residual_fraction: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray, jax.numpy.ndarray]#
- forward(x: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
- inverse(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray]#
- inverse_with_diagnostics(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, NeuralInverseDiagnostics]#
- class petitRADTRANS.sbi.flows.ConditionalAffineFlow(parameter_dim: int, context_dim: int, num_layers: int = 4, hidden_dim: int = 128, conditioner_depth: int = 2, key: jax.Array | None = None)#
Bases:
_ConditionalFlowStackStack of conditional affine couplings with alternating masks.
Parameters#
- parameter_dim:
Number of inferred parameters transformed by the flow.
- context_dim:
Size of the observation embedding conditioning each layer.
- num_layers:
Number of affine coupling layers.
- hidden_dim:
Hidden width of each conditioner MLP.
- key:
Optional JAX random key used for layer initialization.
Notes#
For one-dimensional problems the mask degenerates to a context-only affine transform. Higher-dimensional problems alternate masks across layers.
- parameter_dim#
- context_dim#
- layers = ()#
- class petitRADTRANS.sbi.flows.ConditionalAutoregressiveFlow(parameter_dim: int, context_dim: int, num_layers: int = 3, hidden_dim: int = 512, conditioner_depth: int = 5, key: jax.Array | None = None)#
Bases:
_ConditionalFlowStackStack of context-conditioned affine autoregressive transforms.
Notes#
This flow is a lighter-weight autoregressive alternative to the spline stack. Each transform uses ELU MLPs to predict affine parameters for one dimension at a time, conditioned on previous dimensions and the observation embedding.
- parameter_dim#
- context_dim#
- layers#
- class petitRADTRANS.sbi.flows.ConditionalNeuralAutoregressiveFlow(parameter_dim: int, context_dim: int, num_layers: int = 3, hidden_dim: int = 512, conditioner_depth: int = 5, transform_units: int = 16, min_slope: float = 0.001, min_residual: float = 0.05, inverse_bisection_steps: int = 48, inverse_newton_steps: int = 3, key: jax.Array | None = None)#
Bases:
_ConditionalFlowStackStack of context-conditioned neural autoregressive transforms.
Notes#
This backend uses a deep-sigmoidal-style monotonic scalar bijector per parameter dimension. It is more expressive than the affine autoregressive stack while remaining exactly invertible via scalar bisection.
- transform_units: int#
- min_slope: float#
- min_residual: float#
- inverse_bisection_steps: int#
- inverse_newton_steps: int#
- parameter_dim#
- context_dim#
- layers#
- inverse_with_diagnostics(z: jax.numpy.ndarray, context: jax.numpy.ndarray) tuple[jax.numpy.ndarray, jax.numpy.ndarray, NeuralInverseDiagnostics]#
- class petitRADTRANS.sbi.flows.ConditionalSplineFlow(parameter_dim: int, context_dim: int, num_layers: int = 6, hidden_dim: int = 128, conditioner_depth: int = 2, num_bins: int = 8, bound: float = 10.0, min_bin_width: float = 0.001, min_bin_height: float = 0.001, min_derivative: float = 0.001, max_derivative: float | None = None, base_distribution: str = 'gaussian', use_base_affine: bool = False, key: jax.Array | None = None)#
Bases:
_ConditionalFlowStackStack of conditional rational-quadratic spline transforms.
Parameters#
- parameter_dim:
Number of inferred parameters transformed by the flow.
- context_dim:
Size of the conditioning observation embedding.
- num_layers:
Number of spline transform layers.
- hidden_dim:
Hidden width of the conditioner networks.
- num_bins:
Number of rational-quadratic bins used per transformed dimension.
- bound:
Finite support bound of the spline transform in latent coordinates.
- min_bin_width, min_bin_height, min_derivative:
Numerical-stability floors applied to spline parameters.
- key:
Optional JAX random key used for initialization.
Notes#
One-dimensional posteriors use a dedicated scalar spline stack, while higher-dimensional models use alternating masked coupling layers.
- num_bins: int#
- bound: float#
- min_bin_width: float#
- min_bin_height: float#
- min_derivative: float#
- max_derivative: float | None#
- parameter_dim#
- context_dim#
- base_distribution = ''#
- layers = ()#