petitRADTRANS.sbi.flow_matching_posterior#

Flow-matching posterior estimator (experimental).

Classes#

FlowMatchingPosterior

Conditional flow-matching posterior skeleton.

Module Contents#

class petitRADTRANS.sbi.flow_matching_posterior.FlowMatchingPosterior(parameter_dim: int, embedding_dim: int = 128, hidden_dim: int = 128, num_velocity_layers: int = 3, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', integration_steps: int = 32, early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, seed: int = 0, task_metadata: Mapping[str, Any] | None = None)#

Bases: petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimator

Conditional flow-matching posterior skeleton.

Warning

This estimator is experimental. It does not expose log_prob and the ODE integration scheme is a simple midpoint rule. Expect the API and numerical behaviour to change in future releases.

This estimator family trains a conditional vector field on straight-line interpolation paths between Gaussian noise and target parameters, then generates posterior samples by integrating the learned field from noise to the terminal parameter state.

estimator_family = 'flow_matching'#
embedding_dim = 128#
hidden_dim = 128#
num_velocity_layers = 3#
learning_rate#
batch_size = 32#
num_epochs = 5#
integration_steps = 32#
early_stopping_patience = None#
early_stopping_min_delta#
checkpoint_directory = None#
checkpoint_backend = 'auto'#
resume_from_checkpoint = False#
model#
static _batch_embeddings(model: _FlowMatchingModel, observations: Any) jax.numpy.ndarray#
static _loss(model: _FlowMatchingModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
fit(dataset: Any) petitRADTRANS.sbi.posterior_base.TrainingArtifacts#

Train the posterior estimator on a simulation dataset.

_build_estimator_config() dict[str, Any]#

Return backend-specific configuration for metadata persistence.

static _resolve_estimator_config(metadata: Mapping[str, Any]) dict[str, Any]#
_build_serialized_metadata(artifact_metadata) dict[str, Any]#
classmethod from_serialized_metadata(metadata: Mapping[str, Any]) FlowMatchingPosterior#

Rebuild an estimator instance from persisted metadata only.

save_backend_state(output_path: pathlib.Path) None#

Persist backend-specific model state into the output directory.

load_backend_state(input_path: pathlib.Path) None#

Restore backend-specific model state from the input directory.

encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#

Encode a structured observation into the estimator input space.

batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) list[petitRADTRANS.sbi.observation.EncodedObservation]#

Encode a batch of observations using a single vmapped forward pass.

sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) petitRADTRANS.sbi.posterior_base.PosteriorSamples#

Sample the amortized posterior for one encoded observation.

log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#

Evaluate posterior log-density when supported by the backend.