petitRADTRANS.sbi.flow_posterior#
Conditional normalizing-flow posterior estimator.
Classes#
Concrete amortized posterior using a conditional flow backend. |
|
Posterior estimator specialized to the spline flow backend. |
|
Posterior estimator specialized to the autoregressive flow backend. |
|
Posterior estimator specialized to the neural autoregressive backend. |
Module Contents#
- class petitRADTRANS.sbi.flow_posterior.ConditionalFlowPosterior(parameter_dim: int, embedding_dim: int = 128, num_coupling_layers: int = 4, hidden_dim: int = 128, conditioner_depth: int = 2, autoregressive_transform_units: int = 16, neural_autoregressive_min_slope: float = 0.001, neural_autoregressive_min_residual: float = 0.05, neural_autoregressive_inverse_bisection_steps: int = 48, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', flow_family: str = 'spline', num_spline_bins: int = 8, spline_bound: float = 10.0, base_distribution: str = 'gaussian', use_base_affine: bool = False, training_objective: str = 'npe', early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, gradient_clip_norm: float | None = 1.0, embedding_noise_std: float = 0.0, embedding_noise_min_scale: float = 0.0, aux_scale_loss_weight: float = 0.0, aux_parameter_loss_weight: float = 0.0, parameter_noise_floor: float = 0.0, spline_small_bin_regularization_weight: float | None = None, spline_min_bin_ratio_target: float = 0.2, spline_entropy_regularization_weight: float = 0.0, spline_derivative_regularization_weight: float = 0.0, spline_entropy_floor: float = 0.6, spline_min_derivative_target: float = 0.25, spline_max_derivative_target: float = 5.0, weight_decay: float = 0.0, use_cosine_schedule: bool = False, warmup_fraction: float = 0.02, warmup_epochs: float | None = None, min_learning_rate: float = 1e-06, lr_schedule_total_epochs: float | None = None, stable_inverse_forward_max_abs_error_threshold: float | None = 0.0001, stable_inverse_forward_logdet_closure_max_abs_error_threshold: float | None = 0.0001, stable_cube_edge_hit_rate_threshold: float | None = 0.0, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, encoder_hidden_dim: int | None = None, encoder_patch_size: int = 32, seed: int = 0, task_metadata: Mapping[str, Any] | None = None, verbose_diagnostics: bool = False, diagnostics_output_directory: str | None = None, diagnostics_plot_interval: int = 1)#
Bases:
petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimatorConcrete amortized posterior using a conditional flow backend.
Parameters#
- parameter_dim:
Number of inferred free parameters represented by the posterior.
- embedding_dim:
Size of the learned observation embedding consumed by the flow.
- num_coupling_layers:
Number of conditional coupling or spline-transform layers in the flow.
- hidden_dim:
Hidden width used by encoder-side and flow-side MLP conditioners.
- learning_rate:
Optimizer learning rate used by
SBITrainer.- batch_size:
Number of simulations processed per optimization step.
- num_epochs:
Maximum number of passes through the training split.
- parameter_space:
Parameter coordinates learned by the posterior. Supported values are
'physical','cube', and'unconstrained'.- flow_family:
Conditional density-transform family. Supported values are
'spline','affine','autoregressive', and'neural_autoregressive'.- num_spline_bins:
Number of rational-quadratic spline bins when
flow_family='spline'.- spline_bound:
Finite support bound of each spline transform in latent space.
- early_stopping_patience:
Optional number of non-improving epochs tolerated before stopping.
- early_stopping_min_delta:
Minimum improvement required for early-stopping comparisons.
- checkpoint_directory:
Optional directory used to persist resumable trainer checkpoints.
- checkpoint_backend:
Checkpoint persistence backend name.
'auto'selects Orbax when available and otherwise falls back to Equinox serialization.- resume_from_checkpoint:
Whether
fitshould attempt to resume from the latest checkpoint.- seed:
Base random seed used for flow initialization and posterior sampling.
- task_metadata:
Optional user-supplied metadata persisted alongside the trained model.
Notes#
The posterior stores task fingerprinting, observation schema, and preprocessing payload information when those are available from the training dataset reader. That metadata is later reused by inference and artifact registration paths.
- estimator_family = 'conditional_flow'#
- embedding_dim = 128#
- num_coupling_layers = 4#
- conditioner_depth = 2#
- autoregressive_transform_units = 16#
- neural_autoregressive_min_slope#
- neural_autoregressive_min_residual#
- neural_autoregressive_inverse_bisection_steps = 48#
- learning_rate#
- batch_size = 32#
- num_epochs = 5#
- flow_family = ''#
- effective_flow_family = ''#
- base_distribution = ''#
- use_base_affine = False#
- training_objective = ''#
- num_spline_bins = 8#
- early_stopping_patience = None#
- early_stopping_min_delta#
- checkpoint_directory = None#
- checkpoint_backend = 'auto'#
- resume_from_checkpoint = False#
- gradient_clip_norm = 1.0#
- embedding_noise_std#
- embedding_noise_min_scale#
- aux_scale_loss_weight#
- aux_parameter_loss_weight#
- parameter_noise_floor#
- spline_entropy_regularization_weight#
- spline_small_bin_regularization_weight#
- spline_derivative_regularization_weight#
- spline_min_bin_ratio_target#
- spline_entropy_floor#
- spline_min_derivative_target#
- spline_max_derivative_target#
- weight_decay#
- use_cosine_schedule = False#
- warmup_fraction#
- warmup_epochs = None#
- min_learning_rate#
- lr_schedule_total_epochs = None#
- stable_inverse_forward_max_abs_error_threshold = None#
- stable_inverse_forward_logdet_closure_max_abs_error_threshold = None#
- stable_cube_edge_hit_rate_threshold = None#
- spectrum_encoder_type = ''#
- n_wavelengths = 233#
- spectrum_embedding_dim = 64#
- photometry_embedding_dim = 64#
- encoder_patch_size = 32#
- verbose_diagnostics = False#
- diagnostics_output_directory = None#
- diagnostics_plot_interval = 1#
- model#
- _build_flow(key: jax.Array) Any#
- static _batch_embeddings(model: _PosteriorModel, observations: Any) jax.numpy.ndarray#
- static _loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
- static _training_loss(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
- static _loss_from_observations(model: _PosteriorModel, parameters: jax.numpy.ndarray, observations: Any, parameter_space: str, *, include_spline_regularization: bool) jax.numpy.ndarray#
- _validation_diagnostics(model: _PosteriorModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) dict[str, float]#
- static _make_parameter_noise_loss(base_loss_fn: Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray], noise_floor: float, parameter_space: str) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#
Wrap a training loss so the parameter targets are jittered.
Minimizing the conditional NLL in a very-low-noise regime with a highly parameter-predictive embedding drives the learned conditional density toward a delta (NLL -> -inf), which an over-sharp flow realizes as exploded transforms (and, for spline flows, cube-edge collapse). A small Gaussian jitter on the parameter targets gives the conditional a hard minimum width, capping sharpness so the NLL has a finite minimum and the flow’s inverse stays well-conditioned.
The jitter is applied in unconstrained (logit) space – the space the flow actually models – not in cube space. Additive cube-space noise is asymmetric near the [0, 1] bounds: clipping piles jittered targets onto the edges, and the cube->logit Jacobian rewards edge mass, so it worsens the very edge-collapse it was meant to prevent. Jittering in logit space is symmetric and boundary-free. Applied only during training; evaluation and checkpoint selection use the clean targets.
- static _make_elbo_loss(log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray], parameter_space: str, *, num_samples: int = 1) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#
Return an amortized-variational (ELBO) training loss.
Maximizes, for observations
xdrawn from the (prior-predictive) dataset,ELBO(x) = E_{q(theta|x)}[ log p(x|theta) + log p(theta) - log q(theta|x) ]
with a single- (or few-) sample reparameterized estimator.
q(theta|x)is the conditional flow: latentszare drawn from the flow base and pushed throughflow.inverseto reparameterized samples, so the gradient flows through both the flow and the encoder.The
-log q(entropy) term diverges to-infasqcollapses to a point mass, so a delta posterior is structurally impossible – the width is set by the likelihood/entropy balance rather than by the (near-deterministic) embedding.log p(x|theta)is supplied by the injected differentiable forward-model likelihood evaluated at the physical parameters reconstructed from the sampled cube coordinates. Withparameter_space='cube'the prior is uniform on the unit hypercube, solog p(theta_cube) = 0and is dropped.Parameters#
- log_likelihood_fn:
(theta_cube, observations) -> (batch,)returning the Gaussian observational log-likelihoodlog p(x | theta)for each batch element. Receives unit-cube coordinates (shape(batch, dim)) and the batch’s prestacked observations; it owns the cube->physical transform, the differentiable forward model, and the noise model. Must be JAX-differentiable w.r.t.theta_cube.- parameter_space:
Must be
"cube".- num_samples:
Number of reparameterized posterior draws per observation used to estimate the ELBO expectation.
1is standard for amortized VI; larger values reduce gradient variance at a proportional forward-model cost.
- static _make_noisy_loss(noise_std: float, preprocessing_metadata: Any = None, *, noise_min_scale: float = 0.0, include_spline_regularization: bool = True) Callable[[_PosteriorModel, petitRADTRANS.sbi.posterior_base.PosteriorBatch], jax.numpy.ndarray]#
Return a loss function that injects on-the-fly Gaussian noise into the observation spectra before encoding.
For spectra preprocessed into a per-sample transformed value space, the uncertainty channel is assumed to already be expressed in the same transformed units as the value channel, so noise can be injected directly from that channel. When preprocessing metadata is provided, non-spectral blocks still reconstruct raw measurement uncertainties so noise is injected at the correct physical scale in normalized-value space.
noise_stdacts as a multiplier on the observational noise level (1.0 = exact observational noise, >1 = regularising over-noise), andnoise_min_scaleis treated as a minimum fraction of the typical transformed uncertainty for the block rather than as an absolute floor in normalized-value units.When preprocessing metadata is unavailable, a flat Gaussian noise with scale
noise_stdis applied as a fallback.
- fit(dataset: Any, *, elbo_log_likelihood_fn: Callable[[jax.numpy.ndarray, Any], jax.numpy.ndarray] | None = None, elbo_num_samples: int = 1) petitRADTRANS.sbi.posterior_base.TrainingArtifacts#
Train the posterior on one normalized simulation dataset reader.
Parameters#
- dataset:
Reader-like object that yields
PosteriorBatchinstances viaiter_batchesand exposes dataset manifest metadata when available.- elbo_log_likelihood_fn:
Required when
training_objective='elbo'. A differentiable(theta_cube, observations) -> (batch,)callable returning the Gaussian observational log-likelihood through the forward model. See_make_elbo_loss(). Ignored for the default NPE objective.- elbo_num_samples:
Number of reparameterized posterior draws per observation for the ELBO estimator (default 1).
Returns#
- TrainingArtifacts
Training history, validation metrics, and trainer metadata for the completed optimization run.
Notes#
If the reader exposes preprocessing metadata or a manifest fingerprint, those are cached on the posterior so they can be saved and reused by the inference and artifact layers.
- _build_estimator_config() dict[str, Any]#
Return backend-specific configuration for metadata persistence.
- _build_serialized_metadata(artifact_metadata) dict[str, Any]#
- static _resolve_estimator_config(metadata: Mapping[str, Any]) dict[str, Any]#
- hydrate_loaded_metadata(metadata: Mapping[str, Any]) None#
- classmethod from_serialized_metadata(metadata: Mapping[str, Any]) ConditionalFlowPosterior#
Rebuild an estimator instance from persisted metadata only.
- save_backend_state(output_path: pathlib.Path) None#
Persist backend-specific model state into the output directory.
- load_backend_state(input_path: pathlib.Path) None#
Restore backend-specific model state from the input directory.
- encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#
Encode one structured observation into the posterior context space.
Parameters#
- blocks:
Observation blocks describing one spectral/photometric observation.
Returns#
- EncodedObservation
Aggregated embedding and lightweight metadata describing the input observation family.
- batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) list[petitRADTRANS.sbi.observation.EncodedObservation]#
Encode a batch of observations using a single vmapped forward pass.
Parameters#
- blocks_list:
List of per-sample observation block lists.
Returns#
- list[EncodedObservation]
One encoded observation per input sample.
- sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) petitRADTRANS.sbi.posterior_base.PosteriorSamples#
Draw posterior samples conditioned on one encoded observation.
Parameters#
- observation:
Encoded observation produced by
encode_observation().- n_samples:
Number of posterior draws to generate.
- seed:
Optional random seed overriding the model-level default seed.
Returns#
- PosteriorSamples
Samples in the posterior’s configured parameter space. Non-finite outputs are clamped to large finite values for downstream stability.
- batch_sample_posterior(embeddings: Any, n_samples: int, base_seed: int = 0) numpy.ndarray#
Draw posterior samples for a batch of encoded observations.
Runs a single JIT-compiled vmapped call over all contexts rather than looping
sample_posterioronce per observation, eliminating the Python overhead of per-observation dispatch.Parameters#
- embeddings:
Float32 array of shape
(batch_size, embedding_dim)produced by stackingEncodedObservation.embeddingvectors.- n_samples:
Number of posterior draws per observation.
- base_seed:
Base random seed. Each observation in the batch receives a unique sub-key derived from this seed.
Returns#
- np.ndarray
Array of shape
(batch_size, n_samples, parameter_dim)with non-finite values clamped to large finite values.
- log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#
Evaluate posterior log-density for one or many parameter vectors.
Parameters#
- observation:
Encoded observation that defines the posterior context.
- parameters:
One parameter vector or a batch of parameter vectors in the posterior’s configured parameter space.
Returns#
- Any
Scalar log-density or a vector of log-densities matching the input batch structure.
- class petitRADTRANS.sbi.flow_posterior.ConditionalSplineFlowPosterior(*args: Any, **kwargs: Any)#
Bases:
ConditionalFlowPosteriorPosterior estimator specialized to the spline flow backend.
Notes#
This convenience subclass forces
flow_family='spline'while keeping the rest of theConditionalFlowPosteriorAPI unchanged.
- class petitRADTRANS.sbi.flow_posterior.ConditionalAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)#
Bases:
ConditionalFlowPosteriorPosterior estimator specialized to the autoregressive flow backend.
- class petitRADTRANS.sbi.flow_posterior.ConditionalNeuralAutoregressiveFlowPosterior(*args: Any, **kwargs: Any)#
Bases:
ConditionalFlowPosteriorPosterior estimator specialized to the neural autoregressive backend.