petitRADTRANS.sbi.training

Contents

petitRADTRANS.sbi.training#

Training and checkpoint orchestration for SBI posterior models.

Classes#

EarlyStoppingConfig

Early stopping policy for SBI training.

TrainingConfig

Configuration for posterior optimization.

CheckpointBackend

Persistence backend for trainer checkpoints.

EquinoxCheckpointBackend

Checkpoint backend backed by Equinox tree serialization.

OrbaxCheckpointBackend

Checkpoint backend that uses Orbax when available.

SBITrainer

Reusable optimization loop for amortized SBI posteriors.

Functions#

resolve_checkpoint_backend(→ CheckpointBackend)

Resolve the requested checkpoint backend with graceful fallback.

load_trainer_checkpoint_state(→ tuple[Any, dict[str, ...)

Load one persisted trainer checkpoint state and its metadata.

Module Contents#

class petitRADTRANS.sbi.training.EarlyStoppingConfig#

Early stopping policy for SBI training.

patience: int#
min_delta: float = 0.0#
class petitRADTRANS.sbi.training.TrainingConfig#

Configuration for posterior optimization.

learning_rate: float = 0.001#
batch_size: int = 32#
num_epochs: int = 10#
parameter_space: str = 'unconstrained'#
seed: int = 0#
shuffle_train: bool = True#
early_stopping: EarlyStoppingConfig | None = None#
checkpoint_directory: str | None = None#
checkpoint_backend: str = 'auto'#
resume_from_checkpoint: bool = False#
data_parallel: bool = False#
gradient_clip_norm: float | None = 1.0#
weight_decay: float = 0.0#
use_cosine_schedule: bool = False#
warmup_fraction: float = 0.02#
warmup_epochs: float | None = None#
min_learning_rate: float = 1e-06#
lr_schedule_total_epochs: float | None = None#
verbose_diagnostics: bool = False#
diagnostics_output_directory: str | None = None#
diagnostics_plot_interval: int = 1#
n_validation_diagnostic_batches: int = 4#
class petitRADTRANS.sbi.training.CheckpointBackend#

Bases: abc.ABC

Persistence backend for trainer checkpoints.

name: str#
abstractmethod save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) None#

Persist checkpoint state and metadata.

Parameters#

state:

Serializable optimizer/model state payload.

output_directory:

Directory that should receive the checkpoint files.

metadata:

Small JSON-serializable metadata dictionary written next to the checkpoint payload.

Returns#

None

abstractmethod restore(template_state: Any, output_directory: pathlib.Path) tuple[Any, dict[str, Any]]#

Restore persisted checkpoint state into the supplied template.

Parameters#

template_state:

Structure used by some backends to describe the expected state layout during restoration.

output_directory:

Directory containing the serialized checkpoint payload.

Returns#

tuple[Any, dict[str, Any]]

Restored state object together with the deserialized checkpoint metadata dictionary.

class petitRADTRANS.sbi.training.EquinoxCheckpointBackend#

Bases: CheckpointBackend

Checkpoint backend backed by Equinox tree serialization.

name = 'equinox'#
save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) None#

Serialize checkpoint state with dill and save JSON metadata.

restore(template_state: Any, output_directory: pathlib.Path) tuple[Any, dict[str, Any]]#

Load a dill-serialized checkpoint and its metadata.

class petitRADTRANS.sbi.training.OrbaxCheckpointBackend#

Bases: CheckpointBackend

Checkpoint backend that uses Orbax when available.

name = 'orbax'#
_ocp#
_checkpointer#
save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) None#

Persist checkpoint state with Orbax and write JSON metadata.

restore(template_state: Any, output_directory: pathlib.Path) tuple[Any, dict[str, Any]]#

Restore Orbax checkpoint state and metadata.

petitRADTRANS.sbi.training.resolve_checkpoint_backend(preferred: str = 'auto') CheckpointBackend#

Resolve the requested checkpoint backend with graceful fallback.

Parameters#

preferred:

Backend name. Supported values are 'auto', 'equinox', and 'orbax'.

Returns#

CheckpointBackend

Concrete checkpoint backend implementation.

Notes#

'auto' prefers Orbax when installed and otherwise falls back to the dill-based Equinox backend.

petitRADTRANS.sbi.training.load_trainer_checkpoint_state(checkpoint_directory: str | pathlib.Path, checkpoint_kind: str = 'best') tuple[Any, dict[str, Any]] | None#

Load one persisted trainer checkpoint state and its metadata.

Parameters#

checkpoint_directory:

Root directory containing trainer checkpoint subdirectories such as best_selection and best_loss.

checkpoint_kind:

Checkpoint subdirectory to restore.

Returns#

tuple[Any, dict[str, Any]] | None

Restored state payload and metadata, or None when the requested checkpoint does not exist.

class petitRADTRANS.sbi.training.SBITrainer(config: TrainingConfig)#

Reusable optimization loop for amortized SBI posteriors.

config#
checkpoint_directory = None#
checkpoint_backend#
_checkpoint_backend_fallback_reason: str | None = None#
property _latest_checkpoint_directory: pathlib.Path | None#
property _best_checkpoint_directory: pathlib.Path | None#
_checkpoint_directory_for_kind(checkpoint_kind: str) pathlib.Path | None#
_save_checkpoint(checkpoint_kind: str, state: dict[str, Any], metadata: dict[str, Any]) None#
_restore_checkpoint(template_state: dict[str, Any], checkpoint_kind: str) tuple[dict[str, Any], dict[str, Any]] | None#
fit(model: Any, dataset: Any, loss_fn: Callable[[Any, Any], Any], eval_loss_fn: Callable[[Any, Any], Any] | None = None, eval_diagnostic_fn: Callable[[Any, Any], Mapping[str, float]] | None = None, selection_metric_fn: Callable[[float, Mapping[str, float] | None], float] | None = None, selection_metric_name: str | None = None, stability_metric_fn: Callable[[float, Mapping[str, float] | None], Mapping[str, float]] | None = None, stability_flag_key: str = 'checkpoint_is_stable') tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]#

Optimize one posterior model against a dataset reader.

Parameters#

model:

Trainable Equinox-style model to optimize.

dataset:

Reader object exposing iter_batches for train and optional validation splits.

loss_fn:

Callable receiving (model, batch) and returning the scalar loss to minimize.

eval_loss_fn:

Optional loss used for validation. Defaults to loss_fn when not provided.

Returns#

tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]

Best model, train-loss history, validation-loss history, and a metadata dictionary describing checkpointing and stopping behavior.

Notes#

The trainer monitors validation loss when a validation split exists and otherwise falls back to training loss for best-model selection. Checkpoints may be written after each epoch when enabled.

property _min_delta: float#
_should_stop(patience_counter: int) bool#