petitRADTRANS.sbi.training#
Training and checkpoint orchestration for SBI posterior models.
Classes#
Early stopping policy for SBI training. |
|
Configuration for posterior optimization. |
|
Persistence backend for trainer checkpoints. |
|
Checkpoint backend backed by Equinox tree serialization. |
|
Checkpoint backend that uses Orbax when available. |
|
Reusable optimization loop for amortized SBI posteriors. |
Functions#
|
Resolve the requested checkpoint backend with graceful fallback. |
|
Load one persisted trainer checkpoint state and its metadata. |
Module Contents#
- class petitRADTRANS.sbi.training.EarlyStoppingConfig#
Early stopping policy for SBI training.
- patience: int#
- min_delta: float = 0.0#
- class petitRADTRANS.sbi.training.TrainingConfig#
Configuration for posterior optimization.
- learning_rate: float = 0.001#
- batch_size: int = 32#
- num_epochs: int = 10#
- parameter_space: str = 'unconstrained'#
- seed: int = 0#
- shuffle_train: bool = True#
- early_stopping: EarlyStoppingConfig | None = None#
- checkpoint_directory: str | None = None#
- checkpoint_backend: str = 'auto'#
- resume_from_checkpoint: bool = False#
- data_parallel: bool = False#
- gradient_clip_norm: float | None = 1.0#
- weight_decay: float = 0.0#
- use_cosine_schedule: bool = False#
- warmup_fraction: float = 0.02#
- warmup_epochs: float | None = None#
- min_learning_rate: float = 1e-06#
- lr_schedule_total_epochs: float | None = None#
- verbose_diagnostics: bool = False#
- diagnostics_output_directory: str | None = None#
- diagnostics_plot_interval: int = 1#
- n_validation_diagnostic_batches: int = 4#
- class petitRADTRANS.sbi.training.CheckpointBackend#
Bases:
abc.ABCPersistence backend for trainer checkpoints.
- name: str#
- abstractmethod save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) None#
Persist checkpoint state and metadata.
Parameters#
- state:
Serializable optimizer/model state payload.
- output_directory:
Directory that should receive the checkpoint files.
- metadata:
Small JSON-serializable metadata dictionary written next to the checkpoint payload.
Returns#
None
- abstractmethod restore(template_state: Any, output_directory: pathlib.Path) tuple[Any, dict[str, Any]]#
Restore persisted checkpoint state into the supplied template.
Parameters#
- template_state:
Structure used by some backends to describe the expected state layout during restoration.
- output_directory:
Directory containing the serialized checkpoint payload.
Returns#
- tuple[Any, dict[str, Any]]
Restored state object together with the deserialized checkpoint metadata dictionary.
- class petitRADTRANS.sbi.training.EquinoxCheckpointBackend#
Bases:
CheckpointBackendCheckpoint backend backed by Equinox tree serialization.
- name = 'equinox'#
- save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) None#
Serialize checkpoint state with dill and save JSON metadata.
- restore(template_state: Any, output_directory: pathlib.Path) tuple[Any, dict[str, Any]]#
Load a dill-serialized checkpoint and its metadata.
- class petitRADTRANS.sbi.training.OrbaxCheckpointBackend#
Bases:
CheckpointBackendCheckpoint backend that uses Orbax when available.
- name = 'orbax'#
- _ocp#
- _checkpointer#
- save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) None#
Persist checkpoint state with Orbax and write JSON metadata.
- restore(template_state: Any, output_directory: pathlib.Path) tuple[Any, dict[str, Any]]#
Restore Orbax checkpoint state and metadata.
- petitRADTRANS.sbi.training.resolve_checkpoint_backend(preferred: str = 'auto') CheckpointBackend#
Resolve the requested checkpoint backend with graceful fallback.
Parameters#
- preferred:
Backend name. Supported values are
'auto','equinox', and'orbax'.
Returns#
- CheckpointBackend
Concrete checkpoint backend implementation.
Notes#
'auto'prefers Orbax when installed and otherwise falls back to the dill-based Equinox backend.
- petitRADTRANS.sbi.training.load_trainer_checkpoint_state(checkpoint_directory: str | pathlib.Path, checkpoint_kind: str = 'best') tuple[Any, dict[str, Any]] | None#
Load one persisted trainer checkpoint state and its metadata.
Parameters#
- checkpoint_directory:
Root directory containing trainer checkpoint subdirectories such as
best_selectionandbest_loss.- checkpoint_kind:
Checkpoint subdirectory to restore.
Returns#
- tuple[Any, dict[str, Any]] | None
Restored state payload and metadata, or
Nonewhen the requested checkpoint does not exist.
- class petitRADTRANS.sbi.training.SBITrainer(config: TrainingConfig)#
Reusable optimization loop for amortized SBI posteriors.
- config#
- checkpoint_directory = None#
- checkpoint_backend#
- _checkpoint_backend_fallback_reason: str | None = None#
- property _latest_checkpoint_directory: pathlib.Path | None#
- property _best_checkpoint_directory: pathlib.Path | None#
- _checkpoint_directory_for_kind(checkpoint_kind: str) pathlib.Path | None#
- _save_checkpoint(checkpoint_kind: str, state: dict[str, Any], metadata: dict[str, Any]) None#
- _restore_checkpoint(template_state: dict[str, Any], checkpoint_kind: str) tuple[dict[str, Any], dict[str, Any]] | None#
- fit(model: Any, dataset: Any, loss_fn: Callable[[Any, Any], Any], eval_loss_fn: Callable[[Any, Any], Any] | None = None, eval_diagnostic_fn: Callable[[Any, Any], Mapping[str, float]] | None = None, selection_metric_fn: Callable[[float, Mapping[str, float] | None], float] | None = None, selection_metric_name: str | None = None, stability_metric_fn: Callable[[float, Mapping[str, float] | None], Mapping[str, float]] | None = None, stability_flag_key: str = 'checkpoint_is_stable') tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]#
Optimize one posterior model against a dataset reader.
Parameters#
- model:
Trainable Equinox-style model to optimize.
- dataset:
Reader object exposing
iter_batchesfor train and optional validation splits.- loss_fn:
Callable receiving
(model, batch)and returning the scalar loss to minimize.- eval_loss_fn:
Optional loss used for validation. Defaults to
loss_fnwhen not provided.
Returns#
- tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]
Best model, train-loss history, validation-loss history, and a metadata dictionary describing checkpointing and stopping behavior.
Notes#
The trainer monitors validation loss when a validation split exists and otherwise falls back to training loss for best-model selection. Checkpoints may be written after each epoch when enabled.
- property _min_delta: float#
- _should_stop(patience_counter: int) bool#