Retrieval Samplers#
This page summarizes the sampler backends available in the pRT retrieval package, the most important keyword arguments for each backend, and the interface you need to implement if you want to add your own sampler.
Sampler choice in pRT is controlled through RetrievalConfig.sampler_type.
Sampler-specific keyword arguments are passed to petitRADTRANS.retrieval.retrieval.Retrieval.run().
For example, to run a Dynesty retrieval you would typically write:
rc = RetrievalConfig(
retrieval_name="my_run",
run_mode="retrieval",
pressures=pressures,
sampler_type="dynesty",
)
retrieval = Retrieval(rc, model_generating_function=my_model)
retrieval.run(nlive=800, dlogz_init=0.05, use_jit=True)
All built-in samplers follow the same high-level flow:
Retrievalbuilds apetitRADTRANS.retrieval.sampler.SamplerContext.The sampler class turns that context into a backend-specific likelihood and prior.
The backend runs and returns posterior samples.
Retrievalsaves a genericout_samples/<retrieval_name>_samples.npzarchive for downstream plotting and analysis.
If a backend is inherently weighted, it should convert its posterior to equal-weight samples before handing the results back to Retrieval.
The built-in Dynesty and JAXNS integrations already do this.
Choosing a sampler#
|
Family |
Requirements |
Typical use |
|---|---|---|---|
|
Nested sampling |
MultiNest / PyMultiNest |
Established MPI nested sampler for production nested-sampling workflows. |
|
Nested sampling |
UltraNest |
Pure-Python reactive nested sampling, especially convenient when you want warm starts or step sampling. |
|
Dynamic nested sampling |
Dynesty |
Dynamic nested sampling with pure-Python control flow and optional JIT-compiled likelihood evaluation. |
|
Nested sampling |
JAXNS, JAX, TFP |
Default JAX-native nested sampler for differentiable retrievals. |
|
Nested sampling |
JAXNS, JAX, TFP |
Advanced static/sharded JAXNS backend for more specialized distributed runs. |
|
Gradient MCMC |
NumPyro, JAX |
Adaptive NUTS for differentiable posteriors when you want efficient local exploration instead of evidence estimation. |
|
Gradient MCMC |
NumPyro, JAX |
Adaptive HMC when you want tighter control over integration length than NUTS provides. |
|
Gradient MCMC |
BlackJAX, JAX |
JAX-native NUTS with explicit warmup control and a lightweight integration. |
|
Gradient MCMC |
BlackJAX, JAX |
JAX-native HMC for differentiable models with explicit integration-step control. |
|
Optimisation (MAP + Laplace) |
optax, JAX |
Fast maximum-a-posteriori fit with an optional Laplace (Gaussian) posterior. Also the engine behind the |
Backend-specific parameters#
The lists below focus on the parameters you are most likely to tune in practice. Each backend may accept additional low-level keyword arguments that are forwarded to the underlying sampler implementation.
PyMultiNest#
Key parameters:
n_live_points: number of live points.sampling_efficiency: trade-off between posterior estimation and evidence accuracy.const_efficiency_mode: constant-efficiency mode, usually only for specialized runs.evidence_tolerance: convergence tolerance on the evidence.resume: continue an existing run instead of overwriting it.max_iter: stop after a fixed number of iterations.multimodal: enable mode finding.importance_nested_sampling: enable INS mode.use_MPI: use MPI if available.likelihood_devices: list of JAX devices to dispatch likelihood evaluation to (see GPU-dispatched likelihood for CPU samplers).
Use PyMultiNest when you already rely on the MultiNest ecosystem or need its MPI-oriented execution model.
UltraNest#
Key parameters:
vectorized: evaluate batches of samples at once when the retrieval is fully differentiable.resume: reuse an existing UltraNest log directory.dlogz: evidence convergence target.frac_remain: stop when the remaining live-point weight drops below this fraction.min_num_live_points: minimum number of live points.max_iters/max_ncalls: hard stopping limits.enable_step_sampling: attach an UltraNest step sampler.warmstart_max_tau: enable warm starts from similar previous runs.likelihood_devices: list of JAX devices to dispatch likelihood evaluation to (non-vectorized path only; see GPU-dispatched likelihood for CPU samplers).
UltraNest is often the easiest pure-Python nested sampler to experiment with when you want reactive nested sampling or warm-start support.
Dynesty#
Key parameters:
use_jit: JIT-compile the scalar likelihood when the retrieval runtime is differentiable.nlive: number of live points used byDynamicNestedSampler.bound: bounding strategy.sample: proposal strategy.walks/slices/facc: proposal-tuning parameters.periodic/reflective: periodic or reflective dimensions.pool: pass an existing Dynesty-compatible worker pool object.pool_njobs: convenience option that asks pRT to build an internaldynesty.pool.Poolwith this many worker processes.queue_size: number of queued proposals. When pRT creates the pool internally, this defaults to the pool size if you do not pass it explicitly.use_pool: Dynesty per-stage pool switches such asloglikelihoodandpropose_point. If a pool is active and this is omitted, pRT defaults to using the pool for likelihood evaluation, point proposals, and bound updates, but not for the prior transform.dlogz_init: initial evidence stopping threshold.nlive_init/nlive_batch: live-point settings for the initial and follow-up batches.n_effective: stop once an effective posterior sample target is reached.resume/checkpoint_file/checkpoint_every: checkpoint and resume controls.
Dynesty is useful when you want dynamic nested sampling but prefer a Python-native backend while still reusing the pRT likelihood and prior wiring.
For parallel CPU runs, pRT can now construct dynesty.pool.Pool internally via pool_njobs and will wrap the retrieval likelihood/prior functions in pool-safe proxy callables before handing them to Dynesty.
For example, a parallel Dynesty run can be configured as:
retrieval.run(
nlive=800,
dlogz_init=0.05,
use_jit=True,
pool_njobs=8,
queue_size=8,
use_pool={
"prior_transform": False,
"loglikelihood": True,
"propose_point": True,
"update_bound": True,
},
)
In practice, queue_size should usually match the number of worker processes, and pool and pool_njobs should not be passed together.
JAXNS#
Key parameters:
num_live_points: number of live points.max_samples: hard cap on the number of nested-sampling evaluations.s,k,c: JAXNS nested-sampling control parameters.shell_fraction: shell-fragment control for proposal generation.gradient_guided: enable gradient-guided proposals.parameter_estimation: toggle estimation-oriented behavior.init_efficiency_threshold: early-efficiency threshold.devices: JAX devices to use.seed: random seed.use_jit: JIT the nested-sampler call.termination_condition,term_cond, ortermination_condition_kwargs: pass a mapping or JAXNSTerminationCondition-like object with stopping criteria.Direct termination-condition fields are also accepted as keyword arguments:
ess,evidence_uncert,live_evidence_frac,dlogZ,max_num_likelihood_evaluations,log_L_contour,efficiency_threshold,rtol,atol, andpeak_XL_frac.termination_<field>aliases are accepted for all termination-condition fields. In practice,termination_max_samplesis the important one because plainmax_samplesconfigures the nested sampler itself.
JAXNS is the default JAX-native nested sampler in pRT. It expects a differentiable retrieval runtime and now saves equal-weight posterior samples for downstream plotting.
If you do not supply any termination-condition settings, pRT leaves
term_cond=None and JAXNS falls back to its internal defaults.
For example, this run keeps the sampler-level max_samples cap at 50000
while also building a termination condition with dlogZ=0.1, ess=512,
and a termination-condition max_samples of 20000:
retrieval.run(
num_live_points=400,
max_samples=50000,
dlogZ=0.1,
ess=512,
termination_max_samples=20000,
use_jit=True,
)
You can also pass the same stopping criteria as a mapping:
retrieval.run(
num_live_points=400,
termination_condition_kwargs={
"dlogZ": 0.1,
"rtol": 1e-4,
"peak_XL_frac": 1e-2,
},
use_jit=True,
)
JAXNS sharded static sampler#
Key parameters:
num_live_points: number of live points.max_samples: hard cap on the total number of samples.devices: JAX devices to use.difficult_model: hint for harder likelihood surfaces.verbose: backend verbosity.seed: random seed.use_jit: JIT the sampler call.
This backend accepts the same termination-condition setup as jaxns:
termination_condition / term_cond / termination_condition_kwargs,
direct termination-condition field kwargs, and termination_<field> aliases.
Prefer jaxns unless you specifically need the static sharded execution strategy.
NumPyro NUTS / HMC#
Key parameters shared by both NumPyro samplers:
num_samples: number of post-warmup draws.num_warmup: warmup length.initial_step_size: initial integrator step size.target_acceptance_rate: acceptance target used during adaptation.is_mass_matrix_diagonal: diagonal or dense mass matrix.num_chains: number of chains.thinning: chain thinning factor.chain_method: NumPyro chain execution mode.seed: random seed.use_jit: accepted for API consistency; NumPyro still runs through its JIT-based execution path.
Additional HMC controls:
num_integration_stepsornum_steps: leapfrog steps.trajectory_length: alternative HMC trajectory specification.
NumPyro is a good choice when you want adaptive Hamiltonian MCMC rather than nested sampling and your retrieval is fully differentiable.
BlackJAX NUTS / HMC#
Key parameters shared by both BlackJAX samplers:
num_samples: number of post-warmup draws.num_warmup: warmup length.initial_step_size: initial integrator step size.target_acceptance_rate: adaptation target.is_mass_matrix_diagonal: diagonal or dense mass matrix.num_chains: number of chains.seed: random seed.use_jit: JIT the per-chain sampling loop.
Additional HMC control:
num_integration_steps: leapfrog steps for HMC.
BlackJAX is useful when you want a thinner wrapper around JAX-native HMC or NUTS with explicit warmup and multi-chain control.
GPU-dispatched likelihood for CPU samplers#
PyMultiNest and UltraNest (non-vectorized) are CPU-bound samplers: the sampler
kernel itself runs on the host and calls the likelihood function as a plain Python
callable. When a GPU or other JAX-supported accelerator is available, you can
offload each likelihood evaluation to that device by passing a list of JAX devices
as likelihood_devices.
import jax
retrieval.run(
n_live_points=400,
likelihood_devices=jax.devices("gpu"), # e.g. [CudaDevice(id=0)]
)
pRT wraps the likelihood in a JIT-compiled call to evaluate_jax, dispatches the
computation to the target device using JAX’s default_device context manager, and
returns a plain Python float to the sampler via jax.device_get. The sampler and
MPI orchestration layer stay on CPU throughout.
Requirements:
The retrieval must use a differentiable (JAX-compatible) model. pRT raises a
ValueErrorat setup time if any model group uses a legacy (non-JAX) forward model.JAX must be able to see the target device. For GPU runs you typically need
JAX_PLATFORMS=cuda,cpu(orgpu,cpu) in the environment so that both backends are initialised.
MPI and multi-GPU mapping:
When running with MPI (e.g. via mpirun), each rank r independently selects
likelihood_devices[r % len(likelihood_devices)]. With two GPUs and four MPI
ranks the default striping is:
rank 0 → GPU 0
rank 1 → GPU 1
rank 2 → GPU 0
rank 3 → GPU 1
For true device isolation (no shared-memory contention between MPI processes on the
same GPU), set CUDA_VISIBLE_DEVICES per process before launching MPI, or use a
job launcher that handles device assignment automatically.
UltraNest vectorized mode:
When vectorized=True is also passed, likelihood_devices is ignored with a
RuntimeWarning: the vectorized path already dispatches through JAX
(evaluate_jax_batched) and device selection there follows the standard JAX
default-device mechanism.
Optimal estimation (optax)#
The optimal_estimation backend (aliases optax and map) does not sample
the posterior with a Markov chain or nested sampling. Instead it uses optax to find the maximum-a-posteriori (MAP) point by
gradient descent on the (negative) log-posterior, and — by default — builds a
Laplace approximation of the posterior from the curvature at that optimum. This is
the classic optimal estimation idea (Rodgers): near the mode the posterior is
approximately Gaussian with covariance equal to the inverse Hessian of the negative
log-posterior. It is implemented in
petitRADTRANS.retrieval.optimal_estimation.
Like the gradient-MCMC backends, it requires a fully differentiable retrieval runtime and operates in the unconstrained parameter space (so all prior bounds are respected automatically through the change-of-variables that pRT already applies).
Key parameters:
optimizer: the optax optimiser. AcceptsNone(defaults to Adam), an optax name string such as"adam","adamw", or"adabelief", an already-builtoptax.GradientTransformation, or a callable factoryoptimizer(learning_rate).learning_rate: scalar or optax learning-rate schedule.clip_norm: global-norm gradient-clipping threshold (prepended to the optimiser). This matters in practice: at the prior median a high-resolution petitRADTRANS log-posterior can have gradients of order \(10^{11}\), and clipping keeps the first steps bounded. Set toNoneto disable.n_steps: maximum optimisation steps (aliasnum_steps).tol: optional gradient infinity-norm tolerance for early stopping.patience: stop after this many steps without improvement (Nonedisables).n_starts/start_jitter: number of random restarts and the jitter width for the extra starts (start 0 is always the prior median). The best optimum is kept, which guards against local optima.lbfgs_polish_steps: if> 0, refine the best optax optimum with this manyoptax.lbfgs(zoom line-search) iterations. This is recommended for tightening very flat or ill-conditioned directions that first-order methods converge slowly.compute_laplace: ifTrue(default), compute the inverse-Hessian covariance and draw a Gaussian posterior from it. IfFalse, the result is the single MAP point.n_posterior_samples: number of draws from the Laplace Gaussian (the first draw is the MAP itself).seed: random seed for restarts and the Laplace draws.
Outputs and caveats:
Results are written to
out_OptimalEstimation/<retrieval_name>_*(samples, parameters, covariance, and a metadata JSON with the MAP best fit), and the genericout_samplesarchive is saved as usual so the standard plotting andRetrieval.get_samples()paths work.For a Laplace run,
samplesare physical-space draws from the Gaussian posterior andlog_likelihoodis the Gaussian (Laplace) log-density of each draw — this avoids one forward-model evaluation per draw. The true model log-posterior at the MAP is reported separately asbest_logdensity.This is a local optimiser and a Gaussian posterior approximation: it is fast and ideal for a quick best fit, a sanity check, or an initial guess, but it will not capture multimodal or strongly non-Gaussian posteriors. Use a nested sampler or gradient-MCMC backend for the full posterior.
If the Hessian at the optimum is not positive definite (a poor or saddle-like optimum), its eigenvalues are floored before inversion and a warning is emitted; treat the error bars as approximate in that case.
A typical run with multi-start, an L-BFGS polish, and a Laplace posterior:
rc = RetrievalConfig(
retrieval_name="my_run",
run_mode="retrieval",
pressures=pressures,
sampler_type="optimal_estimation",
)
retrieval = Retrieval(rc, model_generating_function=my_model)
retrieval.run(
optimizer="adam",
learning_rate=1e-2,
clip_norm=1.0,
n_steps=500,
n_starts=4,
lbfgs_polish_steps=50,
compute_laplace=True,
n_posterior_samples=2000,
)
Warm-starting NUTS and HMC#
Gradient-MCMC samplers are local: they explore from wherever they are initialised. By default pRT initialises the NumPyro and BlackJAX walkers at the prior median (the origin of the unconstrained space). For a poorly-constrained starting model this can sit in a region where the log-posterior is extremely steep, so the very first leapfrog step diverges and the chain never moves — a failure mode that shows up as a near-zero acceptance rate, 100% divergences, and “samples” frozen at the prior median.
To avoid this, the NumPyro and BlackJAX backends accept an optional optax warm-up that
replaces the prior-median initial position with a MAP estimate before sampling starts.
Enable it with optimize_init=True and configure the optimiser through
optimizer_kwargs (forwarded to
petitRADTRANS.retrieval.optimal_estimation.optimize_unconstrained_position(),
so it accepts the same optimizer, learning_rate, clip_norm, n_steps,
patience, n_starts, and lbfgs_polish_steps keys as above). The optimiser
defaults to Adam.
rc = RetrievalConfig(
retrieval_name="my_run",
run_mode="retrieval",
pressures=pressures,
sampler_type="numpyro_nuts", # or "blackjaxnuts", "numpyro_hmc", "blackjaxhmc"
)
retrieval = Retrieval(rc, model_generating_function=my_model)
retrieval.run(
num_warmup=1000,
num_samples=1500,
target_acceptance_rate=0.9,
optimize_init=True,
optimizer_kwargs={
"optimizer": "adam",
"n_steps": 300,
"learning_rate": 1e-2,
"clip_norm": 1.0,
"lbfgs_polish_steps": 50,
},
)
Notes:
optimize_initis ignored if you pass your owninitial_positionexplicitly.For multi-chain runs, add
"walker_jitter": <sigma>tooptimizer_kwargsto spread the chains around the MAP (chain 0 stays exactly at the optimum, the rest get Gaussian jitter of widthsigma). This keeps multi-chain convergence diagnostics such as R-hat meaningful.The warm-up seed defaults to the sampler
seed; override it with"seed": <int>insideoptimizer_kwargs.The achieved MAP log-posterior is recorded in the run summary as
warmup_best_logdensityso you can confirm the warm-up moved the walkers into the bulk of the posterior.
Defining a custom sampler#
Custom samplers should inherit from petitRADTRANS.retrieval.sampler.Sampler.
For full integration with Retrieval.run(), implement a prepare class method that consumes a
petitRADTRANS.retrieval.sampler.SamplerContext and returns a sampler instance,
the keyword arguments for run_sampler(), and a small dictionary of summary parameters.
The minimal contract is:
prepare(context, **sampler_kwargs)returns(sampler, run_kwargs, summary_parameters).run_sampler(**run_kwargs)executes the backend and returns results.get_results()returns the in-memory results object.Optionally implement
post_run(),plot_diagnostics(),load_results(), andget_summary().
To participate cleanly in pRT posterior plotting and Retrieval.get_samples(), your results object should expose:
samples: an equal-weight posterior matrix of shape(n_samples, n_parameters).log_likelihood: one log-likelihood value per posterior sample.log_weights: optional log weights. For equal-weight posteriors this is just-log(n_samples).
If your backend is naturally weighted, keep the weighted representation in an auxiliary weighted_samples block and return equal-weight samples to Retrieval for plotting.
Here is a minimal example:
import numpy as np
from petitRADTRANS.retrieval.retrieval import SAMPLER_DISPATCH
from petitRADTRANS.retrieval.sampler import Sampler, SamplerContext
class MySampler(Sampler):
def __init__(self, log_likelihood_func, prior_func, output_directory, retrieval_name):
super().__init__(log_likelihood_func, prior_func, output_directory, retrieval_name)
self.results = None
@classmethod
def prepare(cls, context: SamplerContext, **sampler_kwargs):
sampler = cls(
log_likelihood_func=context.log_likelihood_func,
prior_func=context.prior_ultranest_func,
output_directory=context.output_directory,
retrieval_name=context.retrieval_name,
)
run_kwargs = {
"n_draws": sampler_kwargs.pop("n_draws", 512),
"seed": sampler_kwargs.pop("seed", 42654),
"parameter_names": list(context.free_parameter_names),
}
run_kwargs.update(sampler_kwargs)
summary_parameters = {
"sampler_type": "mysampler",
"n_draws": run_kwargs["n_draws"],
"seed": run_kwargs["seed"],
"parameter_names": run_kwargs["parameter_names"],
}
return sampler, run_kwargs, summary_parameters
def run_sampler(self, n_draws=512, seed=42654, parameter_names=None, **kwargs):
rng = np.random.default_rng(seed)
cube_samples = rng.random((n_draws, len(parameter_names)))
physical_samples = np.asarray([self.prior_func(sample) for sample in cube_samples])
log_likelihood = np.asarray(
[self.log_likelihood_func(sample) for sample in physical_samples],
dtype=float,
)
self.results = {
"samples": physical_samples,
"log_likelihood": log_likelihood,
"log_weights": np.full(n_draws, -np.log(n_draws), dtype=float),
}
return self.results
def get_results(self):
return self.results
SAMPLER_DISPATCH["mysampler"] = MySampler
Two integration patterns are available:
Register the class in
SAMPLER_DISPATCHand setRetrievalConfig.sampler_type="mysampler". This is the preferred route becauseRetrieval.run()will callprepare(),post_run(), and the standard summary machinery.Pass an already-built instance with
retrieval.run(sampler_object=my_sampler, ...). This is useful for one-off experiments, but it bypassesprepare()andpost_run(). In that mode you are responsible for any extra setup, summary information, and diagnostics you need.
Practical guidance#
Use
jaxns, NumPyro, or BlackJAX only with differentiable retrieval models.Use equal-weight posterior samples for downstream plotting, archive export, and best-fit selection.
Keep backend-native outputs in a sampler-specific folder such as
out_JAXNSorout_Dynestyif you need to reload diagnostics later.If you add a new sampler backend to pRT itself, update both
SAMPLER_DISPATCHand the user-facing sampler documentation on this page.