petitRADTRANS.retrieval.sampler

Contents

petitRADTRANS.retrieval.sampler#

Attributes#

Classes#

SamplerContext

Lightweight container that carries everything a sampler needs from a Retrieval.

Sampler

Abstract base class for samplers.

BlackJAXSamplingResults

NumPyroSamplingResults

BlackJAXSampler

Abstract base class for samplers.

BlackJAXHMCSampler

Abstract base class for samplers.

BlackJAXNUTSSampler

Abstract base class for samplers.

NumPyroSampler

Abstract base class for samplers.

NumPyroHMCSampler

Abstract base class for samplers.

NumPyroNUTSSampler

Abstract base class for samplers.

PymultinestSampler

Sampler class for pymultinest.

UltranestSampler

Abstract base class for samplers.

DynestySampler

Abstract base class for samplers.

JAXNSSampler

Sampler class for JAXNS.

JAXNSShardedStaticNestedSampler

Sampler class for JAXNS.

Functions#

_require_jax()

_require_jaxns()

_require_blackjax()

_require_numpyro()

_require_arviz()

_require_matplotlib_pyplot()

_require_pymultinest()

_require_ultranest()

_require_dynesty()

_check_differentiable_support(→ None)

Raise if the retrieval has only legacy (non-differentiable) model groups.

_build_device_likelihood_wrapper(context)

Return a PyMultiNest-compatible log-likelihood that dispatches to a JAX device.

_build_unconstrained_mcmc_interface(context)

Build the unconstrained-space logdensity and transforms for gradient-based samplers.

Module Contents#

petitRADTRANS.retrieval.sampler.lax = None#
petitRADTRANS.retrieval.sampler.random = None#
petitRADTRANS.retrieval.sampler.tree = None#
petitRADTRANS.retrieval.sampler.vmap = None#
petitRADTRANS.retrieval.sampler.tree_util = None#
petitRADTRANS.retrieval.sampler.jit = None#
petitRADTRANS.retrieval.sampler.local_device_count = None#
petitRADTRANS.retrieval.sampler.make_mesh = None#
petitRADTRANS.retrieval.sampler.shard_map = None#
petitRADTRANS.retrieval.sampler.set_mesh = None#
petitRADTRANS.retrieval.sampler.devices = None#
petitRADTRANS.retrieval.sampler.default_device = None#
petitRADTRANS.retrieval.sampler.device_get = None#
petitRADTRANS.retrieval.sampler.clear_caches = None#
petitRADTRANS.retrieval.sampler.jnp = None#
petitRADTRANS.retrieval.sampler.NamedSharding = None#
petitRADTRANS.retrieval.sampler.P = None#
petitRADTRANS.retrieval.sampler.tfp = None#
petitRADTRANS.retrieval.sampler._require_jax()#
petitRADTRANS.retrieval.sampler.jaxns = None#
petitRADTRANS.retrieval.sampler.Model = None#
petitRADTRANS.retrieval.sampler.NestedSampler = None#
petitRADTRANS.retrieval.sampler.TerminationCondition = None#
petitRADTRANS.retrieval.sampler.save_results = None#
petitRADTRANS.retrieval.sampler.load_results = None#
petitRADTRANS.retrieval.sampler.jaxns_resample = None#
petitRADTRANS.retrieval.sampler.ShardedStaticNestedSampler = None#
petitRADTRANS.retrieval.sampler.UniformSampler = None#
petitRADTRANS.retrieval.sampler._JAXNS_TERMINATION_FIELDS = ('ess', 'evidence_uncert', 'live_evidence_frac', 'dlogZ', 'max_samples',...#
petitRADTRANS.retrieval.sampler._require_jaxns()#
petitRADTRANS.retrieval.sampler.blackjax = None#
petitRADTRANS.retrieval.sampler._require_blackjax()#
petitRADTRANS.retrieval.sampler.numpyro = None#
petitRADTRANS.retrieval.sampler.NumPyroHMC = None#
petitRADTRANS.retrieval.sampler.NumPyroNUTS = None#
petitRADTRANS.retrieval.sampler.NumPyroMCMC = None#
petitRADTRANS.retrieval.sampler.az = None#
petitRADTRANS.retrieval.sampler.plt = None#
petitRADTRANS.retrieval.sampler._require_numpyro()#
petitRADTRANS.retrieval.sampler._require_arviz()#
petitRADTRANS.retrieval.sampler._require_matplotlib_pyplot()#
petitRADTRANS.retrieval.sampler.pymultinest = None#
petitRADTRANS.retrieval.sampler._require_pymultinest()#
petitRADTRANS.retrieval.sampler.ultranest = None#
petitRADTRANS.retrieval.sampler.MLFriends = None#
petitRADTRANS.retrieval.sampler.RobustEllipsoidRegion = None#
petitRADTRANS.retrieval.sampler.dynesty = None#
petitRADTRANS.retrieval.sampler.DynamicNestedSampler = None#
petitRADTRANS.retrieval.sampler.DynestyPool = None#
petitRADTRANS.retrieval.sampler.dyplot = None#
petitRADTRANS.retrieval.sampler.dyfunc = None#
petitRADTRANS.retrieval.sampler._require_ultranest()#
petitRADTRANS.retrieval.sampler._require_dynesty()#
class petitRADTRANS.retrieval.sampler.SamplerContext#

Lightweight container that carries everything a sampler needs from a Retrieval.

Built once by Retrieval._build_sampler_context() and passed to each sampler’s prepare() class method so that sampler implementations never depend on the full Retrieval object.

parameter_layout: petitRADTRANS.retrieval.runtime.ParameterLayout#
runtime: petitRADTRANS.retrieval.runtime.RetrievalRuntime#
configuration: Any#
output_directory: str#
retrieval_name: str#
uncertainties_mode: str = 'default'#
print_log_likelihood_for_debugging: bool = False#
log_likelihood_func: Any = None#
prior_func: Any = None#
prior_ultranest_func: Any = None#
free_parameter_names: list = []#
n_free_parameters: int = 0#
likelihood_devices: list | None = None#
petitRADTRANS.retrieval.sampler._check_differentiable_support(context: SamplerContext, sampler_label: str) None#

Raise if the retrieval has only legacy (non-differentiable) model groups.

petitRADTRANS.retrieval.sampler._build_device_likelihood_wrapper(context: SamplerContext)#

Return a PyMultiNest-compatible log-likelihood that dispatches to a JAX device.

Each MPI rank r selects context.likelihood_devices[r % len(likelihood_devices)], so N ranks stripe naturally across M devices. Requires differentiable model groups.

Returns (device_likelihood_func, target_device).

Notes#

The user is responsible for ensuring that JAX processes do not unexpectedly contend on the same device (e.g. set JAX_PLATFORMS=cuda,cpu or CUDA_VISIBLE_DEVICES per MPI process for true isolation when using multiple GPUs).

petitRADTRANS.retrieval.sampler._build_unconstrained_mcmc_interface(context: SamplerContext)#

Build the unconstrained-space logdensity and transforms for gradient-based samplers.

Returns (free_parameter_names, initial_position, logdensity, transform_positions, log_jacobian).

class petitRADTRANS.retrieval.sampler.Sampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: abc.ABC

Abstract base class for samplers.

log_likelihood_func#
prior_func#
output_directory#
retrieval_name#
abstractmethod run_sampler(**kwargs)#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Abstractmethod:

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

post_run(results)#

Optional hook called after run_sampler.

Handles pretty-printing, results extraction, and any other sampler-specific finalization.

Returns:

(final_results, extra_finalize_kwargs) where the extra kwargs are forwarded to Retrieval._finalize_sampler_run.

plot_diagnostics(results=None)#

Optional hook to persist sampler-specific diagnostics plots.

_get_samples_archive_path()#
_load_sample_archive_results()#
load_results()#
static _extract_figure(plot_output)#
classmethod _save_plot_figure(plot_output, output_file)#
static _summary_scalar(value)#
get_summary(free_parameter_names=None)#
static _posterior_summary_from_matrix(samples, parameter_names)#
static _posterior_summary_from_mapping(samples_by_parameter, free_parameter_names=None)#
class petitRADTRANS.retrieval.sampler.BlackJAXSamplingResults#
algorithm: str#
samples: numpy.ndarray#
unconstrained_samples: numpy.ndarray#
logdensity: numpy.ndarray#
model_log_likelihood: numpy.ndarray | None#
warmup_parameters: dict#
acceptance_rate: numpy.ndarray | None = None#
is_divergent: numpy.ndarray | None = None#
num_integration_steps: numpy.ndarray | None = None#
output_file: str | None = None#
metadata_file: str | None = None#
class petitRADTRANS.retrieval.sampler.NumPyroSamplingResults#
algorithm: str#
samples: numpy.ndarray#
unconstrained_samples: numpy.ndarray#
logdensity: numpy.ndarray#
model_log_likelihood: numpy.ndarray | None#
warmup_parameters: dict#
acceptance_rate: numpy.ndarray | None = None#
is_divergent: numpy.ndarray | None = None#
num_integration_steps: numpy.ndarray | None = None#
output_file: str | None = None#
metadata_file: str | None = None#
class petitRADTRANS.retrieval.sampler.BlackJAXSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: Sampler

Abstract base class for samplers.

results = None#
classmethod _prepare_blackjax(context: SamplerContext, algorithm_name: str, **sampler_kwargs)#

Shared preparation logic for all BlackJAX sampler variants.

abstract property algorithm_name#
abstract property algorithm_factory#
static _to_serializable(value)#
_get_output_prefix()#
_save_results(parameter_names, samples, unconstrained_samples, logdensity, model_log_likelihood, acceptance_rate, is_divergent, num_integration_steps, warmup_parameters, metadata)#
run_sampler(initial_position, parameter_names, num_samples=1000, num_warmup=1000, initial_step_size=1.0, target_acceptance_rate=0.8, is_mass_matrix_diagonal=True, progress_bar=False, seed=42654, use_jit=True, log_jacobian_fn=None, num_chains=1, **blackjax_kwargs)#
get_results()#
load_results()#
pretty_print_results()#
get_summary(free_parameter_names=None)#
class petitRADTRANS.retrieval.sampler.BlackJAXHMCSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: BlackJAXSampler

Abstract base class for samplers.

property algorithm_name#
property algorithm_factory#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

class petitRADTRANS.retrieval.sampler.BlackJAXNUTSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: BlackJAXSampler

Abstract base class for samplers.

property algorithm_name#
property algorithm_factory#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

class petitRADTRANS.retrieval.sampler.NumPyroSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: Sampler

Abstract base class for samplers.

results = None#
mcmc = None#
parameter_names = None#
classmethod _prepare_numpyro(context: SamplerContext, algorithm_name: str, **sampler_kwargs)#

Shared preparation logic for all NumPyro sampler variants.

abstract property algorithm_name#
abstract property kernel_class#
static _to_serializable(value)#
_get_output_prefix()#
static _prepare_kernel_kwargs(algorithm_name, log_likelihood_func, initial_step_size, is_mass_matrix_diagonal, target_acceptance_rate, numpyro_kwargs)#
_save_results(parameter_names, samples, unconstrained_samples, logdensity, model_log_likelihood, acceptance_rate, is_divergent, num_integration_steps, warmup_parameters, metadata)#
run_sampler(initial_position, parameter_names, num_samples=1000, num_warmup=1000, initial_step_size=1.0, target_acceptance_rate=0.8, is_mass_matrix_diagonal=True, progress_bar=False, seed=42654, use_jit=True, log_jacobian_fn=None, num_chains=1, thinning=1, chain_method='vectorized', **numpyro_kwargs)#
get_results()#
load_results()#
pretty_print_results()#
plot_diagnostics(results=None)#

Optional hook to persist sampler-specific diagnostics plots.

get_summary(free_parameter_names=None)#
class petitRADTRANS.retrieval.sampler.NumPyroHMCSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: NumPyroSampler

Abstract base class for samplers.

property algorithm_name#
property kernel_class#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

class petitRADTRANS.retrieval.sampler.NumPyroNUTSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: NumPyroSampler

Abstract base class for samplers.

property algorithm_name#
property kernel_class#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

class petitRADTRANS.retrieval.sampler.PymultinestSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: Sampler

Sampler class for pymultinest.

seed = -1#
analyzer = None#
_n_dims = None#
outputfiles_basename#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

post_run(results)#

Optional hook called after run_sampler.

Handles pretty-printing, results extraction, and any other sampler-specific finalization.

Returns:

(final_results, extra_finalize_kwargs) where the extra kwargs are forwarded to Retrieval._finalize_sampler_run.

run_sampler(n_dims, **kwargs)#
Args:
sampling_efficiencyFloat

pymultinest sampling efficiency. If const efficiency mode is true, should be set to around 0.05. Otherwise, it should be around 0.8 for parameter estimation and 0.3 for evidence comparison.

const_efficiency_modeBool

pymultinest constant efficiency mode

n_live_pointsInt

Number of live points to use in pymultinest, or the minimum number of live points to use for the Ultranest reactive sampler.

log_z_convergencefloat

If ultranest is being used, the convergence criterion on log z.

step_samplerbool

Use a step sampler to improve the efficiency in ultranest.

warmstart_max_taufloat

Warm start allows accelerated computation based on a different but similar UltraNest run.

n_iter_before_updateint

Number of live point replacements before printing an update to a log file.

max_itersint

Maximum number of sampling iterations. If 0, will continue until convergence criteria are satisfied.

frac_remainfloat

Ultranest convergence criterion. Halts integration if live point weights are below the specified value.

l_epsilonfloat

Ultranest convergence criterion. Use with noisy likelihoods. Halts integration if live points are within l_epsilon.

resumebool

Continue existing retrieval. If FALSE THIS WILL OVERWRITE YOUR EXISTING RETRIEVAL.

error_checkingbool

Test the model generating function for typical errors. ONLY TURN THIS OFF IF YOU KNOW WHAT YOU’RE DOING!

force_serial_error_checkingbool

If True, error checking will be performed process-by-process, instead of with all processes at once. This can prevent memory overflow.

seedint

Random number generator seed, -ve value for seed from the system clock (for reproducibility)

Returns:

analyzer

load_results()#
pretty_print(free_parameter_names, prefix)#
get_summary(free_parameter_names=None)#
class petitRADTRANS.retrieval.sampler.UltranestSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: Sampler

Abstract base class for samplers.

results = None#
_get_log_dir()#
classmethod _resolve_vectorized(context: SamplerContext, requested: bool) bool#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

load_results()#
run_sampler(parameter_names, **kwargs)#
get_results()#
pretty_print_results()#
get_summary(free_parameter_names=None)#
class petitRADTRANS.retrieval.sampler.DynestySampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: Sampler

Abstract base class for samplers.

sampler = None#
results = None#
raw_results = None#
parameter_names = []#
_run_summary#
_dynesty_pool = None#
_owns_dynesty_pool = False#
static _default_use_pool() dict[str, bool]#
classmethod _configure_dynesty_pool(log_likelihood_func, prior_func, *, pool, pool_njobs)#
_close_dynesty_pool()#
classmethod _build_dynesty_interface(context: SamplerContext, *, use_jit: bool = True, pool_safe: bool = False)#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

post_run(results)#

Optional hook called after run_sampler.

Handles pretty-printing, results extraction, and any other sampler-specific finalization.

Returns:

(final_results, extra_finalize_kwargs) where the extra kwargs are forwarded to Retrieval._finalize_sampler_run.

static _normalize_weights(weights)#
_get_output_prefix()#
_save_results(parameter_names, samples, log_likelihood, weighted_samples, weighted_log_likelihood, weighted_log_weights, metadata)#
_build_equal_weight_posterior(raw_results, seed)#
static _summarize_ncall(raw_results)#
run_sampler(init_kwargs, **run_kwargs)#
get_results()#
load_results()#
pretty_print_results()#
plot_diagnostics(results=None)#

Optional hook to persist sampler-specific diagnostics plots.

get_summary(free_parameter_names=None)#
class petitRADTRANS.retrieval.sampler.JAXNSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: Sampler

Sampler class for JAXNS.

jaxns_parameters#
seed = -1#
sampler = None#
results = None#
raw_results = None#
parameter_names = []#
offload_results_to_cpu = True#
classmethod _build_jaxns_interface(context: SamplerContext)#

Build JAXNS prior_model generator and log_likelihood wrapper.

static _coerce_termination_condition_mapping(values, source_name: str) dict[str, Any]#
classmethod _build_termination_condition(sampler_kwargs: dict[str, Any]) tuple[Any, dict[str, Any]]#
classmethod prepare(context: SamplerContext, **sampler_kwargs)#

Build a sampler instance, run-kwargs, and summary parameters.

Every concrete sampler should override this to assemble its own run kwargs from the common SamplerContext instead of relying on the Retrieval class.

Returns:

(sampler_instance, run_kwargs, summary_parameters)

_get_output_prefix()#
_get_raw_results_path()#
_get_equal_weight_samples_path()#
_get_parameter_names_path()#
_get_cpu_device()#
_materialize_raw_results(termination_reason, state)#
_to_numpy_tree(pytree)#
static _get_named_samples_mapping(raw_results)#
static _samples_mapping_to_matrix(samples_by_parameter, parameter_names=None)#
_save_equal_weight_samples(parameter_names, samples, log_likelihood)#
_build_processed_results(raw_results, *, equal_weight_samples=None, equal_weight_log_likelihood=None, output_file=None, params_file=None, raw_results_file=None)#
_build_equal_weight_results(raw_results)#
post_run(results)#

JAXNS post-run: get_results is already called in run_sampler, pass stats=results.

run_sampler(use_jit=True, parameters={}, **jaxns_kwargs)#

Run the JAXNS sampler.

get_results()#
load_results()#
pretty_print_results()#
plot_diagnostics(results=None)#

Optional hook to persist sampler-specific diagnostics plots.

get_summary(free_parameter_names=None)#
class petitRADTRANS.retrieval.sampler.JAXNSShardedStaticNestedSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#

Bases: JAXNSSampler

Sampler class for JAXNS.

log_likelihood_func#
prior_function#
output_directory#
retrieval_name#
jaxns_parameters#
seed = -1#
sampler = None#
results = None#
run_sampler(use_jit=True, **jaxns_kwargs)#

Run the JAXNS sampler.

get_results()#
load_results()#
pretty_print_results()#