petitRADTRANS.retrieval.optimal_estimation
==========================================

.. py:module:: petitRADTRANS.retrieval.optimal_estimation

.. autoapi-nested-parse::

   Gradient-based optimisation ("optimal estimation") for petitRADTRANS retrievals.

   This module provides an optax-powered maximum-a-posteriori (MAP) optimiser that
   operates on the unconstrained-space log-posterior built by
   ``petitRADTRANS.retrieval.sampler._build_unconstrained_mcmc_interface``. It is
   used in two ways:

   * As a stand-alone "sampler" selectable with ``sampler_type="optimal_estimation"``
     (alias ``"optax"`` / ``"map"``): :class:`OptimalEstimationSampler` runs the
     optimiser and, by default, draws a Laplace (Rodgers optimal estimation)
     posterior from the inverse Hessian at the optimum so it emits a genuine sample
     distribution and 1-sigma error bars rather than a bare point estimate.

   * As an optional warm-up step for the NumPyro and BlackJAX samplers
     (``optimize_init=True``): the same optimiser finds a MAP point that replaces the
     prior-median ``initial_position``, so the gradient sampler starts inside the
     high-probability region instead of in a stiff, divergence-prone tail.

   The optimiser defaults to Adam with global-norm gradient clipping, but
   any optax optimiser can be supplied.



Attributes
----------

.. autoapisummary::

   petitRADTRANS.retrieval.optimal_estimation.jax
   petitRADTRANS.retrieval.optimal_estimation.jnp
   petitRADTRANS.retrieval.optimal_estimation.optax


Classes
-------

.. autoapisummary::

   petitRADTRANS.retrieval.optimal_estimation.OptimizeResult
   petitRADTRANS.retrieval.optimal_estimation.OptimalEstimationResults
   petitRADTRANS.retrieval.optimal_estimation.OptimalEstimationSampler


Functions
---------

.. autoapisummary::

   petitRADTRANS.retrieval.optimal_estimation._require_jax_local
   petitRADTRANS.retrieval.optimal_estimation._require_optax
   petitRADTRANS.retrieval.optimal_estimation.build_optimizer
   petitRADTRANS.retrieval.optimal_estimation.optimize_unconstrained_position
   petitRADTRANS.retrieval.optimal_estimation._lbfgs_polish
   petitRADTRANS.retrieval.optimal_estimation.make_walker_inits
   petitRADTRANS.retrieval.optimal_estimation.report_and_save_map_warmup
   petitRADTRANS.retrieval.optimal_estimation.laplace_covariance
   petitRADTRANS.retrieval.optimal_estimation._multivariate_normal_logpdf
   petitRADTRANS.retrieval.optimal_estimation._sample_gaussian


Module Contents
---------------

.. py:data:: jax
   :value: None


.. py:data:: jnp
   :value: None


.. py:data:: optax
   :value: None


.. py:function:: _require_jax_local()

   Populate the module-level ``jax`` / ``jnp`` handles.


.. py:function:: _require_optax()

   Populate the module-level ``optax`` handle.


.. py:function:: build_optimizer(optimizer: Any = None, learning_rate: Any = 0.01, clip_norm: Optional[float] = 1.0)

   Return an ``optax.GradientTransformation``.

   Args:
       optimizer: selects the optax optimiser. May be
           * ``None``  -> ``optax.adam`` (the default);
           * a ``str`` -> ``getattr(optax, name)(learning_rate)``, e.g. ``"adamw"``,
             ``"adabelief"``, ``"sgd"``;
           * an already-built ``optax.GradientTransformation`` (used verbatim);
           * a callable factory -> ``optimizer(learning_rate)``.
       learning_rate: scalar or optax schedule, passed to the factory. Ignored
           when ``optimizer`` is a pre-built ``GradientTransformation``.
       clip_norm: if not ``None``, prepend ``optax.clip_by_global_norm(clip_norm)``
           so the huge gradients near the prior median produce bounded steps.

   Returns:
       An ``optax.GradientTransformation``.


.. py:class:: OptimizeResult

   Outcome of :func:`optimize_unconstrained_position`.


   .. py:attribute:: best_position
      :type:  Any


   .. py:attribute:: best_logdensity
      :type:  float


   .. py:attribute:: trace
      :type:  numpy.ndarray


   .. py:attribute:: n_steps_run
      :type:  int


   .. py:attribute:: converged
      :type:  bool


   .. py:attribute:: n_starts
      :type:  int
      :value: 1



   .. py:attribute:: start_logdensities
      :type:  Optional[numpy.ndarray]
      :value: None



.. py:function:: optimize_unconstrained_position(logdensity: Callable, initial_position: Any, *, optimizer: Any = None, learning_rate: Any = 0.01, clip_norm: Optional[float] = 1.0, n_steps: int = 300, tol: Optional[float] = None, patience: Optional[int] = 60, n_starts: int = 1, start_jitter: float = 1.0, seed: int = 0, lbfgs_polish_steps: int = 0, verbose: bool = True) -> OptimizeResult

   Maximise an unconstrained-space log-posterior with optax.

   The optimisation minimises ``-logdensity`` in the unconstrained space (so all
   parameter bounds are automatically respected via the change-of-variables that
   ``logdensity`` already includes). Non-finite gradient components are zeroed so
   a single bad coordinate cannot poison a step, and the best finite point seen
   is returned.

   Args:
       logdensity: callable mapping an unconstrained position (shape
           ``(n_params,)``) to a scalar log-posterior. Must be JAX-differentiable.
       initial_position: starting unconstrained position, shape ``(n_params,)``.
       optimizer / learning_rate / clip_norm: see :func:`build_optimizer`.
       n_steps: maximum optimisation steps per start.
       tol: optional gradient-infinity-norm tolerance for early stopping.
       patience: stop a start after this many steps without improvement
           (``None`` disables).
       n_starts: number of random restarts (start 0 is ``initial_position``;
           the rest are ``initial_position`` plus Gaussian jitter). The best
           optimum across starts is returned.
       start_jitter: standard deviation of the Gaussian jitter for extra starts.
       seed: RNG seed for the restart jitter.
       lbfgs_polish_steps: if > 0, refine the best optax optimum with this many
           ``optax.lbfgs`` (zoom line-search) iterations.
       verbose: print progress.

   Returns:
       An :class:`OptimizeResult`.


.. py:function:: _lbfgs_polish(logdensity, position, n_steps, verbose=True)

   Refine ``position`` with optax.lbfgs (zoom line-search) minimising -logdensity.


.. py:function:: make_walker_inits(position, num_chains, *, jitter=0.0, seed=0)

   Broadcast a single MAP point into per-chain initial positions.

   Returns a ``(n_params,)`` array for a single chain, or ``(num_chains, n_params)``
   for several chains (the first chain stays exactly at the optimum; the rest get
   Gaussian jitter of width ``jitter`` so multi-chain convergence diagnostics
   such as R-hat remain meaningful).


.. py:function:: report_and_save_map_warmup(optimize_result: OptimizeResult, transform_positions: Callable, parameter_names, output_directory: str, retrieval_name: str, *, output_subdirectory: str, label: str = 'optimal_estimation', verbose: bool = True)

   Print and persist the MAP point found by an ``optimize_init`` warm-up.

   Used by the NumPyro and BlackJAX samplers when ``optimize_init=True``: once the
   optax optimiser has located the MAP, this transforms the best unconstrained
   position into physical (constrained) parameter space, prints each parameter's
   MAP value, and writes the MAP sample to
   ``<output_directory>/<output_subdirectory>/<retrieval_name>_map_warmup.{npz,json}``.

   Args:
       optimize_result: the :class:`OptimizeResult` returned by
           :func:`optimize_unconstrained_position`.
       transform_positions: maps an unconstrained position to physical space.
       parameter_names: ordered free-parameter names matching the position.
       output_directory: retrieval output directory.
       retrieval_name: retrieval name used as the file-name stem.
       output_subdirectory: sampler-specific sub-directory (e.g. ``"out_BlackJAX"``
           or ``"out_NumPyro"``).
       label: short tag used in the printed lines.
       verbose: print the MAP values and the saved-file path.

   Returns:
       The path to the saved ``.npz`` file.


.. py:function:: laplace_covariance(logdensity, position, *, eigenvalue_floor=1e-08)

   Posterior covariance from the inverse Hessian of ``-logdensity`` at ``position``.

   This is the Laplace approximation underpinning Rodgers optimal estimation: near
   the MAP the posterior is approximately Gaussian with covariance equal to the
   inverse of the curvature (Fisher-information-like) matrix.

   Returns ``(covariance, hessian, is_positive_definite)``. If the raw Hessian is
   not positive definite (a poor or saddle-like optimum), its eigenvalues are
   floored at ``eigenvalue_floor`` before inversion and ``is_positive_definite``
   is ``False``.


.. py:function:: _multivariate_normal_logpdf(samples, mean, covariance)

   Gaussian log-density of each row of ``samples`` (unconstrained space).


.. py:function:: _sample_gaussian(mean, covariance, n_samples, seed)

   Draw ``n_samples`` from N(mean, covariance); row 0 is the mean (the MAP).


.. py:class:: OptimalEstimationResults

   Results emitted by :class:`OptimalEstimationSampler`.

   ``samples`` (physical/constrained space, shape ``(n_draw, n_params)``) and
   ``log_likelihood`` (shape ``(n_draw,)``) are the fields consumed by
   ``Retrieval._build_samples_from_sampler_results``. For a Laplace run the rows
   are draws from the Gaussian posterior (row 0 is the MAP) and ``log_likelihood``
   is the Gaussian (Laplace) log-density; for a point estimate there is a single
   row at the MAP. The true model log-posterior at the MAP is ``best_logdensity``.


   .. py:attribute:: samples
      :type:  numpy.ndarray


   .. py:attribute:: log_likelihood
      :type:  numpy.ndarray


   .. py:attribute:: best_fit
      :type:  numpy.ndarray


   .. py:attribute:: best_logdensity
      :type:  float


   .. py:attribute:: parameter_names
      :type:  list


   .. py:attribute:: covariance
      :type:  Optional[numpy.ndarray]
      :value: None



   .. py:attribute:: covariance_is_positive_definite
      :type:  Optional[bool]
      :value: None



   .. py:attribute:: unconstrained_samples
      :type:  Optional[numpy.ndarray]
      :value: None



   .. py:attribute:: trace
      :type:  Optional[numpy.ndarray]
      :value: None



   .. py:attribute:: converged
      :type:  bool
      :value: False



   .. py:attribute:: output_file
      :type:  Optional[str]
      :value: None



   .. py:attribute:: metadata_file
      :type:  Optional[str]
      :value: None



.. py:class:: OptimalEstimationSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`petitRADTRANS.retrieval.sampler.Sampler`


   Run an optax MAP optimisation (plus optional Laplace posterior) as a sampler.

   Selectable via ``RetrievalConfig(sampler_type="optimal_estimation")`` (aliases
   ``"optax"``, ``"map"``). With ``compute_laplace=True`` (default) it additionally
   draws a Gaussian posterior from the inverse Hessian at the optimum, so corner
   plots and 1-sigma intervals are available; with ``compute_laplace=False`` it
   returns the single MAP point.


   .. py:attribute:: results
      :value: None



   .. py:attribute:: logdensity
      :value: None



   .. py:attribute:: transform_positions
      :value: None



   .. py:attribute:: parameter_names
      :value: None



   .. py:method:: prepare(context: petitRADTRANS.retrieval.sampler.SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



   .. py:method:: run_sampler(initial_position, parameter_names, *, optimizer=None, learning_rate=0.01, clip_norm=1.0, n_steps=300, tol=None, patience=60, n_starts=1, start_jitter=1.0, lbfgs_polish_steps=0, compute_laplace=True, n_posterior_samples=2000, seed=0, verbose=True, **_ignored)


   .. py:method:: _get_output_prefix()


   .. py:method:: _save_results(results: OptimalEstimationResults)


   .. py:method:: pretty_print_results()


   .. py:method:: get_summary(free_parameter_names=None)


