petitRADTRANS.retrieval.optimal_estimation#

Gradient-based optimisation (“optimal estimation”) for petitRADTRANS retrievals.

This module provides an optax-powered maximum-a-posteriori (MAP) optimiser that operates on the unconstrained-space log-posterior built by petitRADTRANS.retrieval.sampler._build_unconstrained_mcmc_interface. It is used in two ways:

  • As a stand-alone “sampler” selectable with sampler_type="optimal_estimation" (alias "optax" / "map"): OptimalEstimationSampler runs the optimiser and, by default, draws a Laplace (Rodgers optimal estimation) posterior from the inverse Hessian at the optimum so it emits a genuine sample distribution and 1-sigma error bars rather than a bare point estimate.

  • As an optional warm-up step for the NumPyro and BlackJAX samplers (optimize_init=True): the same optimiser finds a MAP point that replaces the prior-median initial_position, so the gradient sampler starts inside the high-probability region instead of in a stiff, divergence-prone tail.

The optimiser defaults to Adam with global-norm gradient clipping, but any optax optimiser can be supplied.

Attributes#

Classes#

OptimizeResult

Outcome of optimize_unconstrained_position().

OptimalEstimationResults

Results emitted by OptimalEstimationSampler.

OptimalEstimationSampler

Run an optax MAP optimisation (plus optional Laplace posterior) as a sampler.

Functions#

_require_jax_local()

Populate the module-level jax / jnp handles.

_require_optax()

Populate the module-level optax handle.

build_optimizer([optimizer, learning_rate, clip_norm])

Return an optax.GradientTransformation.

optimize_unconstrained_position(→ OptimizeResult)

Maximise an unconstrained-space log-posterior with optax.

_lbfgs_polish(logdensity, position, n_steps[, verbose])

Refine position with optax.lbfgs (zoom line-search) minimising -logdensity.

make_walker_inits(position, num_chains, *[, jitter, seed])

Broadcast a single MAP point into per-chain initial positions.

report_and_save_map_warmup(optimize_result, ...[, ...])

Print and persist the MAP point found by an optimize_init warm-up.

laplace_covariance(logdensity, position, *[, ...])

Posterior covariance from the inverse Hessian of -logdensity at position.

_multivariate_normal_logpdf(samples, mean, covariance)

Gaussian log-density of each row of samples (unconstrained space).

_sample_gaussian(mean, covariance, n_samples, seed)

Draw n_samples from N(mean, covariance); row 0 is the mean (the MAP).

Module Contents#

petitRADTRANS.retrieval.optimal_estimation.jax = None#
petitRADTRANS.retrieval.optimal_estimation.jnp = None#
petitRADTRANS.retrieval.optimal_estimation.optax = None#
petitRADTRANS.retrieval.optimal_estimation._require_jax_local()#

Populate the module-level jax / jnp handles.

petitRADTRANS.retrieval.optimal_estimation._require_optax()#

Populate the module-level optax handle.

petitRADTRANS.retrieval.optimal_estimation.build_optimizer(optimizer: Any = None, learning_rate: Any = 0.01, clip_norm: float | None = 1.0)#

Return an optax.GradientTransformation.

Args:
optimizer: selects the optax optimiser. May be
  • None -> optax.adam (the default);

  • a str -> getattr(optax, name)(learning_rate), e.g. "adamw", "adabelief", "sgd";

  • an already-built optax.GradientTransformation (used verbatim);

  • a callable factory -> optimizer(learning_rate).

learning_rate: scalar or optax schedule, passed to the factory. Ignored

when optimizer is a pre-built GradientTransformation.

clip_norm: if not None, prepend optax.clip_by_global_norm(clip_norm)

so the huge gradients near the prior median produce bounded steps.

Returns:

An optax.GradientTransformation.

class petitRADTRANS.retrieval.optimal_estimation.OptimizeResult#

Outcome of optimize_unconstrained_position().

best_position: Any#
best_logdensity: float#
trace: numpy.ndarray#
n_steps_run: int#
converged: bool#
n_starts: int = 1#
start_logdensities: numpy.ndarray | None = None#
petitRADTRANS.retrieval.optimal_estimation.optimize_unconstrained_position(logdensity: Callable, initial_position: Any, *, optimizer: Any = None, learning_rate: Any = 0.01, clip_norm: float | None = 1.0, n_steps: int = 300, tol: float | None = None, patience: int | None = 60, n_starts: int = 1, start_jitter: float = 1.0, seed: int = 0, lbfgs_polish_steps: int = 0, verbose: bool = True) OptimizeResult#

Maximise an unconstrained-space log-posterior with optax.

The optimisation minimises -logdensity in the unconstrained space (so all parameter bounds are automatically respected via the change-of-variables that logdensity already includes). Non-finite gradient components are zeroed so a single bad coordinate cannot poison a step, and the best finite point seen is returned.

Args:
logdensity: callable mapping an unconstrained position (shape

(n_params,)) to a scalar log-posterior. Must be JAX-differentiable.

initial_position: starting unconstrained position, shape (n_params,). optimizer / learning_rate / clip_norm: see build_optimizer(). n_steps: maximum optimisation steps per start. tol: optional gradient-infinity-norm tolerance for early stopping. patience: stop a start after this many steps without improvement

(None disables).

n_starts: number of random restarts (start 0 is initial_position;

the rest are initial_position plus Gaussian jitter). The best optimum across starts is returned.

start_jitter: standard deviation of the Gaussian jitter for extra starts. seed: RNG seed for the restart jitter. lbfgs_polish_steps: if > 0, refine the best optax optimum with this many

optax.lbfgs (zoom line-search) iterations.

verbose: print progress.

Returns:

An OptimizeResult.

petitRADTRANS.retrieval.optimal_estimation._lbfgs_polish(logdensity, position, n_steps, verbose=True)#

Refine position with optax.lbfgs (zoom line-search) minimising -logdensity.

petitRADTRANS.retrieval.optimal_estimation.make_walker_inits(position, num_chains, *, jitter=0.0, seed=0)#

Broadcast a single MAP point into per-chain initial positions.

Returns a (n_params,) array for a single chain, or (num_chains, n_params) for several chains (the first chain stays exactly at the optimum; the rest get Gaussian jitter of width jitter so multi-chain convergence diagnostics such as R-hat remain meaningful).

petitRADTRANS.retrieval.optimal_estimation.report_and_save_map_warmup(optimize_result: OptimizeResult, transform_positions: Callable, parameter_names, output_directory: str, retrieval_name: str, *, output_subdirectory: str, label: str = 'optimal_estimation', verbose: bool = True)#

Print and persist the MAP point found by an optimize_init warm-up.

Used by the NumPyro and BlackJAX samplers when optimize_init=True: once the optax optimiser has located the MAP, this transforms the best unconstrained position into physical (constrained) parameter space, prints each parameter’s MAP value, and writes the MAP sample to <output_directory>/<output_subdirectory>/<retrieval_name>_map_warmup.{npz,json}.

Args:
optimize_result: the OptimizeResult returned by

optimize_unconstrained_position().

transform_positions: maps an unconstrained position to physical space. parameter_names: ordered free-parameter names matching the position. output_directory: retrieval output directory. retrieval_name: retrieval name used as the file-name stem. output_subdirectory: sampler-specific sub-directory (e.g. "out_BlackJAX"

or "out_NumPyro").

label: short tag used in the printed lines. verbose: print the MAP values and the saved-file path.

Returns:

The path to the saved .npz file.

petitRADTRANS.retrieval.optimal_estimation.laplace_covariance(logdensity, position, *, eigenvalue_floor=1e-08)#

Posterior covariance from the inverse Hessian of -logdensity at position.

This is the Laplace approximation underpinning Rodgers optimal estimation: near the MAP the posterior is approximately Gaussian with covariance equal to the inverse of the curvature (Fisher-information-like) matrix.

Returns (covariance, hessian, is_positive_definite). If the raw Hessian is not positive definite (a poor or saddle-like optimum), its eigenvalues are floored at eigenvalue_floor before inversion and is_positive_definite is False.

petitRADTRANS.retrieval.optimal_estimation._multivariate_normal_logpdf(samples, mean, covariance)#

Gaussian log-density of each row of samples (unconstrained space).

petitRADTRANS.retrieval.optimal_estimation._sample_gaussian(mean, covariance, n_samples, seed)#

Draw n_samples from N(mean, covariance); row 0 is the mean (the MAP).

class petitRADTRANS.retrieval.optimal_estimation.OptimalEstimationResults#

Results emitted by OptimalEstimationSampler.

samples (physical/constrained space, shape (n_draw, n_params)) and log_likelihood (shape (n_draw,)) are the fields consumed by Retrieval._build_samples_from_sampler_results. For a Laplace run the rows are draws from the Gaussian posterior (row 0 is the MAP) and log_likelihood is the Gaussian (Laplace) log-density; for a point estimate there is a single row at the MAP. The true model log-posterior at the MAP is best_logdensity.

samples: numpy.ndarray#
log_likelihood: numpy.ndarray#
best_fit: numpy.ndarray#
best_logdensity: float#
parameter_names: list#
covariance: numpy.ndarray | None = None#
covariance_is_positive_definite: bool | None = None#
unconstrained_samples: numpy.ndarray | None = None#
trace: numpy.ndarray | None = None#
converged: bool = False#
output_file: str | None = None#
metadata_file: str | None = None#
class petitRADTRANS.retrieval.optimal_estimation.OptimalEstimationSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: petitRADTRANS.retrieval.sampler.Sampler

Run an optax MAP optimisation (plus optional Laplace posterior) as a sampler.

Selectable via RetrievalConfig(sampler_type="optimal_estimation") (aliases "optax", "map"). With compute_laplace=True (default) it additionally draws a Gaussian posterior from the inverse Hessian at the optimum, so corner plots and 1-sigma intervals are available; with compute_laplace=False it returns the single MAP point.

results = None#
logdensity = None#
transform_positions = None#
parameter_names = None#
classmethod prepare(context: petitRADTRANS.retrieval.sampler.SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

run_sampler(initial_position, parameter_names, *, optimizer=None, learning_rate=0.01, clip_norm=1.0, n_steps=300, tol=None, patience=60, n_starts=1, start_jitter=1.0, lbfgs_polish_steps=0, compute_laplace=True, n_posterior_samples=2000, seed=0, verbose=True, **_ignored)#
_get_output_prefix()#
_save_results(results: OptimalEstimationResults)#
pretty_print_results()#
get_summary(free_parameter_names=None)#