petitRADTRANS.retrieval.optimal_estimation#
Gradient-based optimisation (“optimal estimation”) for petitRADTRANS retrievals.
This module provides an optax-powered maximum-a-posteriori (MAP) optimiser that
operates on the unconstrained-space log-posterior built by
petitRADTRANS.retrieval.sampler._build_unconstrained_mcmc_interface. It is
used in two ways:
As a stand-alone “sampler” selectable with
sampler_type="optimal_estimation"(alias"optax"/"map"):OptimalEstimationSamplerruns the optimiser and, by default, draws a Laplace (Rodgers optimal estimation) posterior from the inverse Hessian at the optimum so it emits a genuine sample distribution and 1-sigma error bars rather than a bare point estimate.As an optional warm-up step for the NumPyro and BlackJAX samplers (
optimize_init=True): the same optimiser finds a MAP point that replaces the prior-medianinitial_position, so the gradient sampler starts inside the high-probability region instead of in a stiff, divergence-prone tail.
The optimiser defaults to Adam with global-norm gradient clipping, but any optax optimiser can be supplied.
Attributes#
Classes#
Outcome of |
|
Results emitted by |
|
Run an optax MAP optimisation (plus optional Laplace posterior) as a sampler. |
Functions#
Populate the module-level |
|
Populate the module-level |
|
|
Return an |
|
Maximise an unconstrained-space log-posterior with optax. |
|
Refine |
|
Broadcast a single MAP point into per-chain initial positions. |
|
Print and persist the MAP point found by an |
|
Posterior covariance from the inverse Hessian of |
|
Gaussian log-density of each row of |
|
Draw |
Module Contents#
- petitRADTRANS.retrieval.optimal_estimation.jax = None#
- petitRADTRANS.retrieval.optimal_estimation.jnp = None#
- petitRADTRANS.retrieval.optimal_estimation.optax = None#
- petitRADTRANS.retrieval.optimal_estimation._require_jax_local()#
Populate the module-level
jax/jnphandles.
- petitRADTRANS.retrieval.optimal_estimation._require_optax()#
Populate the module-level
optaxhandle.
- petitRADTRANS.retrieval.optimal_estimation.build_optimizer(optimizer: Any = None, learning_rate: Any = 0.01, clip_norm: float | None = 1.0)#
Return an
optax.GradientTransformation.- Args:
- optimizer: selects the optax optimiser. May be
None->optax.adam(the default);a
str->getattr(optax, name)(learning_rate), e.g."adamw","adabelief","sgd";an already-built
optax.GradientTransformation(used verbatim);a callable factory ->
optimizer(learning_rate).
- learning_rate: scalar or optax schedule, passed to the factory. Ignored
when
optimizeris a pre-builtGradientTransformation.- clip_norm: if not
None, prependoptax.clip_by_global_norm(clip_norm) so the huge gradients near the prior median produce bounded steps.
- Returns:
An
optax.GradientTransformation.
- class petitRADTRANS.retrieval.optimal_estimation.OptimizeResult#
Outcome of
optimize_unconstrained_position().- best_position: Any#
- best_logdensity: float#
- trace: numpy.ndarray#
- n_steps_run: int#
- converged: bool#
- n_starts: int = 1#
- start_logdensities: numpy.ndarray | None = None#
- petitRADTRANS.retrieval.optimal_estimation.optimize_unconstrained_position(logdensity: Callable, initial_position: Any, *, optimizer: Any = None, learning_rate: Any = 0.01, clip_norm: float | None = 1.0, n_steps: int = 300, tol: float | None = None, patience: int | None = 60, n_starts: int = 1, start_jitter: float = 1.0, seed: int = 0, lbfgs_polish_steps: int = 0, verbose: bool = True) OptimizeResult#
Maximise an unconstrained-space log-posterior with optax.
The optimisation minimises
-logdensityin the unconstrained space (so all parameter bounds are automatically respected via the change-of-variables thatlogdensityalready includes). Non-finite gradient components are zeroed so a single bad coordinate cannot poison a step, and the best finite point seen is returned.- Args:
- logdensity: callable mapping an unconstrained position (shape
(n_params,)) to a scalar log-posterior. Must be JAX-differentiable.
initial_position: starting unconstrained position, shape
(n_params,). optimizer / learning_rate / clip_norm: seebuild_optimizer(). n_steps: maximum optimisation steps per start. tol: optional gradient-infinity-norm tolerance for early stopping. patience: stop a start after this many steps without improvement(
Nonedisables).- n_starts: number of random restarts (start 0 is
initial_position; the rest are
initial_positionplus Gaussian jitter). The best optimum across starts is returned.
start_jitter: standard deviation of the Gaussian jitter for extra starts. seed: RNG seed for the restart jitter. lbfgs_polish_steps: if > 0, refine the best optax optimum with this many
optax.lbfgs(zoom line-search) iterations.verbose: print progress.
- Returns:
An
OptimizeResult.
- petitRADTRANS.retrieval.optimal_estimation._lbfgs_polish(logdensity, position, n_steps, verbose=True)#
Refine
positionwith optax.lbfgs (zoom line-search) minimising -logdensity.
- petitRADTRANS.retrieval.optimal_estimation.make_walker_inits(position, num_chains, *, jitter=0.0, seed=0)#
Broadcast a single MAP point into per-chain initial positions.
Returns a
(n_params,)array for a single chain, or(num_chains, n_params)for several chains (the first chain stays exactly at the optimum; the rest get Gaussian jitter of widthjitterso multi-chain convergence diagnostics such as R-hat remain meaningful).
- petitRADTRANS.retrieval.optimal_estimation.report_and_save_map_warmup(optimize_result: OptimizeResult, transform_positions: Callable, parameter_names, output_directory: str, retrieval_name: str, *, output_subdirectory: str, label: str = 'optimal_estimation', verbose: bool = True)#
Print and persist the MAP point found by an
optimize_initwarm-up.Used by the NumPyro and BlackJAX samplers when
optimize_init=True: once the optax optimiser has located the MAP, this transforms the best unconstrained position into physical (constrained) parameter space, prints each parameter’s MAP value, and writes the MAP sample to<output_directory>/<output_subdirectory>/<retrieval_name>_map_warmup.{npz,json}.- Args:
- optimize_result: the
OptimizeResultreturned by
transform_positions: maps an unconstrained position to physical space. parameter_names: ordered free-parameter names matching the position. output_directory: retrieval output directory. retrieval_name: retrieval name used as the file-name stem. output_subdirectory: sampler-specific sub-directory (e.g.
"out_BlackJAX"or
"out_NumPyro").label: short tag used in the printed lines. verbose: print the MAP values and the saved-file path.
- optimize_result: the
- Returns:
The path to the saved
.npzfile.
- petitRADTRANS.retrieval.optimal_estimation.laplace_covariance(logdensity, position, *, eigenvalue_floor=1e-08)#
Posterior covariance from the inverse Hessian of
-logdensityatposition.This is the Laplace approximation underpinning Rodgers optimal estimation: near the MAP the posterior is approximately Gaussian with covariance equal to the inverse of the curvature (Fisher-information-like) matrix.
Returns
(covariance, hessian, is_positive_definite). If the raw Hessian is not positive definite (a poor or saddle-like optimum), its eigenvalues are floored ateigenvalue_floorbefore inversion andis_positive_definiteisFalse.
- petitRADTRANS.retrieval.optimal_estimation._multivariate_normal_logpdf(samples, mean, covariance)#
Gaussian log-density of each row of
samples(unconstrained space).
- petitRADTRANS.retrieval.optimal_estimation._sample_gaussian(mean, covariance, n_samples, seed)#
Draw
n_samplesfrom N(mean, covariance); row 0 is the mean (the MAP).
- class petitRADTRANS.retrieval.optimal_estimation.OptimalEstimationResults#
Results emitted by
OptimalEstimationSampler.samples(physical/constrained space, shape(n_draw, n_params)) andlog_likelihood(shape(n_draw,)) are the fields consumed byRetrieval._build_samples_from_sampler_results. For a Laplace run the rows are draws from the Gaussian posterior (row 0 is the MAP) andlog_likelihoodis the Gaussian (Laplace) log-density; for a point estimate there is a single row at the MAP. The true model log-posterior at the MAP isbest_logdensity.- samples: numpy.ndarray#
- log_likelihood: numpy.ndarray#
- best_fit: numpy.ndarray#
- best_logdensity: float#
- parameter_names: list#
- covariance: numpy.ndarray | None = None#
- covariance_is_positive_definite: bool | None = None#
- unconstrained_samples: numpy.ndarray | None = None#
- trace: numpy.ndarray | None = None#
- converged: bool = False#
- output_file: str | None = None#
- metadata_file: str | None = None#
- class petitRADTRANS.retrieval.optimal_estimation.OptimalEstimationSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
petitRADTRANS.retrieval.sampler.SamplerRun an optax MAP optimisation (plus optional Laplace posterior) as a sampler.
Selectable via
RetrievalConfig(sampler_type="optimal_estimation")(aliases"optax","map"). Withcompute_laplace=True(default) it additionally draws a Gaussian posterior from the inverse Hessian at the optimum, so corner plots and 1-sigma intervals are available; withcompute_laplace=Falseit returns the single MAP point.- results = None#
- logdensity = None#
- transform_positions = None#
- parameter_names = None#
- classmethod prepare(context: petitRADTRANS.retrieval.sampler.SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- run_sampler(initial_position, parameter_names, *, optimizer=None, learning_rate=0.01, clip_norm=1.0, n_steps=300, tol=None, patience=60, n_starts=1, start_jitter=1.0, lbfgs_polish_steps=0, compute_laplace=True, n_posterior_samples=2000, seed=0, verbose=True, **_ignored)#
- _get_output_prefix()#
- _save_results(results: OptimalEstimationResults)#
- pretty_print_results()#
- get_summary(free_parameter_names=None)#