petitRADTRANS.sbi.training
==========================

.. py:module:: petitRADTRANS.sbi.training

.. autoapi-nested-parse::

   Training and checkpoint orchestration for SBI posterior models.



Classes
-------

.. autoapisummary::

   petitRADTRANS.sbi.training.EarlyStoppingConfig
   petitRADTRANS.sbi.training.TrainingConfig
   petitRADTRANS.sbi.training.CheckpointBackend
   petitRADTRANS.sbi.training.EquinoxCheckpointBackend
   petitRADTRANS.sbi.training.OrbaxCheckpointBackend
   petitRADTRANS.sbi.training.SBITrainer


Functions
---------

.. autoapisummary::

   petitRADTRANS.sbi.training.resolve_checkpoint_backend
   petitRADTRANS.sbi.training.load_trainer_checkpoint_state


Module Contents
---------------

.. py:class:: EarlyStoppingConfig

   Early stopping policy for SBI training.


   .. py:attribute:: patience
      :type:  int


   .. py:attribute:: min_delta
      :type:  float
      :value: 0.0



.. py:class:: TrainingConfig

   Configuration for posterior optimization.


   .. py:attribute:: learning_rate
      :type:  float
      :value: 0.001



   .. py:attribute:: batch_size
      :type:  int
      :value: 32



   .. py:attribute:: num_epochs
      :type:  int
      :value: 10



   .. py:attribute:: parameter_space
      :type:  str
      :value: 'unconstrained'



   .. py:attribute:: seed
      :type:  int
      :value: 0



   .. py:attribute:: shuffle_train
      :type:  bool
      :value: True



   .. py:attribute:: early_stopping
      :type:  EarlyStoppingConfig | None
      :value: None



   .. py:attribute:: checkpoint_directory
      :type:  str | None
      :value: None



   .. py:attribute:: checkpoint_backend
      :type:  str
      :value: 'auto'



   .. py:attribute:: resume_from_checkpoint
      :type:  bool
      :value: False



   .. py:attribute:: data_parallel
      :type:  bool
      :value: False



   .. py:attribute:: gradient_clip_norm
      :type:  float | None
      :value: 1.0



   .. py:attribute:: weight_decay
      :type:  float
      :value: 0.0



   .. py:attribute:: use_cosine_schedule
      :type:  bool
      :value: False



   .. py:attribute:: warmup_fraction
      :type:  float
      :value: 0.02



   .. py:attribute:: warmup_epochs
      :type:  float | None
      :value: None



   .. py:attribute:: min_learning_rate
      :type:  float
      :value: 1e-06



   .. py:attribute:: lr_schedule_total_epochs
      :type:  float | None
      :value: None



   .. py:attribute:: verbose_diagnostics
      :type:  bool
      :value: False



   .. py:attribute:: diagnostics_output_directory
      :type:  str | None
      :value: None



   .. py:attribute:: diagnostics_plot_interval
      :type:  int
      :value: 1



   .. py:attribute:: n_validation_diagnostic_batches
      :type:  int
      :value: 4



.. py:class:: CheckpointBackend

   Bases: :py:obj:`abc.ABC`


   Persistence backend for trainer checkpoints.


   .. py:attribute:: name
      :type:  str


   .. py:method:: save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) -> None
      :abstractmethod:


      Persist checkpoint state and metadata.

      Parameters
      ----------
      state:
          Serializable optimizer/model state payload.
      output_directory:
          Directory that should receive the checkpoint files.
      metadata:
          Small JSON-serializable metadata dictionary written next to the
          checkpoint payload.

      Returns
      -------
      None



   .. py:method:: restore(template_state: Any, output_directory: pathlib.Path) -> tuple[Any, dict[str, Any]]
      :abstractmethod:


      Restore persisted checkpoint state into the supplied template.

      Parameters
      ----------
      template_state:
          Structure used by some backends to describe the expected state
          layout during restoration.
      output_directory:
          Directory containing the serialized checkpoint payload.

      Returns
      -------
      tuple[Any, dict[str, Any]]
          Restored state object together with the deserialized checkpoint
          metadata dictionary.



.. py:class:: EquinoxCheckpointBackend

   Bases: :py:obj:`CheckpointBackend`


   Checkpoint backend backed by Equinox tree serialization.


   .. py:attribute:: name
      :value: 'equinox'



   .. py:method:: save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) -> None

      Serialize checkpoint state with dill and save JSON metadata.



   .. py:method:: restore(template_state: Any, output_directory: pathlib.Path) -> tuple[Any, dict[str, Any]]

      Load a dill-serialized checkpoint and its metadata.



.. py:class:: OrbaxCheckpointBackend

   Bases: :py:obj:`CheckpointBackend`


   Checkpoint backend that uses Orbax when available.


   .. py:attribute:: name
      :value: 'orbax'



   .. py:attribute:: _ocp


   .. py:attribute:: _checkpointer


   .. py:method:: save(state: Any, output_directory: pathlib.Path, metadata: dict[str, Any]) -> None

      Persist checkpoint state with Orbax and write JSON metadata.



   .. py:method:: restore(template_state: Any, output_directory: pathlib.Path) -> tuple[Any, dict[str, Any]]

      Restore Orbax checkpoint state and metadata.



.. py:function:: resolve_checkpoint_backend(preferred: str = 'auto') -> CheckpointBackend

   Resolve the requested checkpoint backend with graceful fallback.

   Parameters
   ----------
   preferred:
       Backend name. Supported values are ``'auto'``, ``'equinox'``, and
       ``'orbax'``.

   Returns
   -------
   CheckpointBackend
       Concrete checkpoint backend implementation.

   Notes
   -----
   ``'auto'`` prefers Orbax when installed and otherwise falls back to the
   dill-based Equinox backend.


.. py:function:: load_trainer_checkpoint_state(checkpoint_directory: str | pathlib.Path, checkpoint_kind: str = 'best') -> tuple[Any, dict[str, Any]] | None

   Load one persisted trainer checkpoint state and its metadata.

   Parameters
   ----------
   checkpoint_directory:
       Root directory containing trainer checkpoint subdirectories such as
       ``best_selection`` and ``best_loss``.
   checkpoint_kind:
       Checkpoint subdirectory to restore.

   Returns
   -------
   tuple[Any, dict[str, Any]] | None
       Restored state payload and metadata, or ``None`` when the requested
       checkpoint does not exist.


.. py:class:: SBITrainer(config: TrainingConfig)

   Reusable optimization loop for amortized SBI posteriors.


   .. py:attribute:: config


   .. py:attribute:: checkpoint_directory
      :value: None



   .. py:attribute:: checkpoint_backend


   .. py:attribute:: _checkpoint_backend_fallback_reason
      :type:  str | None
      :value: None



   .. py:property:: _latest_checkpoint_directory
      :type: pathlib.Path | None



   .. py:property:: _best_checkpoint_directory
      :type: pathlib.Path | None



   .. py:method:: _checkpoint_directory_for_kind(checkpoint_kind: str) -> pathlib.Path | None


   .. py:method:: _save_checkpoint(checkpoint_kind: str, state: dict[str, Any], metadata: dict[str, Any]) -> None


   .. py:method:: _restore_checkpoint(template_state: dict[str, Any], checkpoint_kind: str) -> tuple[dict[str, Any], dict[str, Any]] | None


   .. py:method:: fit(model: Any, dataset: Any, loss_fn: Callable[[Any, Any], Any], eval_loss_fn: Callable[[Any, Any], Any] | None = None, eval_diagnostic_fn: Callable[[Any, Any], Mapping[str, float]] | None = None, selection_metric_fn: Callable[[float, Mapping[str, float] | None], float] | None = None, selection_metric_name: str | None = None, stability_metric_fn: Callable[[float, Mapping[str, float] | None], Mapping[str, float]] | None = None, stability_flag_key: str = 'checkpoint_is_stable') -> tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]

      Optimize one posterior model against a dataset reader.

      Parameters
      ----------
      model:
          Trainable Equinox-style model to optimize.
      dataset:
          Reader object exposing ``iter_batches`` for train and optional
          validation splits.
      loss_fn:
          Callable receiving ``(model, batch)`` and returning the scalar loss
          to minimize.
      eval_loss_fn:
          Optional loss used for validation. Defaults to ``loss_fn`` when not
          provided.

      Returns
      -------
      tuple[Any, dict[str, list[float]], dict[str, list[float]], dict[str, Any]]
          Best model, train-loss history, validation-loss history, and a
          metadata dictionary describing checkpointing and stopping behavior.

      Notes
      -----
      The trainer monitors validation loss when a validation split exists and
      otherwise falls back to training loss for best-model selection.
      Checkpoints may be written after each epoch when enabled.



   .. py:property:: _min_delta
      :type: float



   .. py:method:: _should_stop(patience_counter: int) -> bool


