petitRADTRANS.sbi.flow_matching_posterior#
Flow-matching posterior estimator (experimental).
Classes#
Conditional flow-matching posterior skeleton. |
Module Contents#
- class petitRADTRANS.sbi.flow_matching_posterior.FlowMatchingPosterior(parameter_dim: int, embedding_dim: int = 128, hidden_dim: int = 128, num_velocity_layers: int = 3, learning_rate: float = 0.001, batch_size: int = 32, num_epochs: int = 5, parameter_space: str = 'unconstrained', integration_steps: int = 32, early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, checkpoint_directory: str | None = None, checkpoint_backend: str = 'auto', resume_from_checkpoint: bool = False, seed: int = 0, task_metadata: Mapping[str, Any] | None = None)#
Bases:
petitRADTRANS.sbi.posterior_base.PersistentPosteriorEstimatorConditional flow-matching posterior skeleton.
Warning
This estimator is experimental. It does not expose
log_proband the ODE integration scheme is a simple midpoint rule. Expect the API and numerical behaviour to change in future releases.This estimator family trains a conditional vector field on straight-line interpolation paths between Gaussian noise and target parameters, then generates posterior samples by integrating the learned field from noise to the terminal parameter state.
- estimator_family = 'flow_matching'#
- embedding_dim = 128#
- num_velocity_layers = 3#
- learning_rate#
- batch_size = 32#
- num_epochs = 5#
- integration_steps = 32#
- early_stopping_patience = None#
- early_stopping_min_delta#
- checkpoint_directory = None#
- checkpoint_backend = 'auto'#
- resume_from_checkpoint = False#
- model#
- static _batch_embeddings(model: _FlowMatchingModel, observations: Any) jax.numpy.ndarray#
- static _loss(model: _FlowMatchingModel, batch: petitRADTRANS.sbi.posterior_base.PosteriorBatch) jax.numpy.ndarray#
- fit(dataset: Any) petitRADTRANS.sbi.posterior_base.TrainingArtifacts#
Train the posterior estimator on a simulation dataset.
- _build_estimator_config() dict[str, Any]#
Return backend-specific configuration for metadata persistence.
- static _resolve_estimator_config(metadata: Mapping[str, Any]) dict[str, Any]#
- _build_serialized_metadata(artifact_metadata) dict[str, Any]#
- classmethod from_serialized_metadata(metadata: Mapping[str, Any]) FlowMatchingPosterior#
Rebuild an estimator instance from persisted metadata only.
- save_backend_state(output_path: pathlib.Path) None#
Persist backend-specific model state into the output directory.
- load_backend_state(input_path: pathlib.Path) None#
Restore backend-specific model state from the input directory.
- encode_observation(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#
Encode a structured observation into the estimator input space.
- batch_encode_observation(blocks_list: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) list[petitRADTRANS.sbi.observation.EncodedObservation]#
Encode a batch of observations using a single vmapped forward pass.
- sample_posterior(observation: petitRADTRANS.sbi.observation.EncodedObservation, n_samples: int, seed: int | None = None) petitRADTRANS.sbi.posterior_base.PosteriorSamples#
Sample the amortized posterior for one encoded observation.
- log_prob(observation: petitRADTRANS.sbi.observation.EncodedObservation, parameters: Any) Any#
Evaluate posterior log-density when supported by the backend.