petitRADTRANS.retrieval.sampler#
Attributes#
Classes#
Lightweight container that carries everything a sampler needs from a Retrieval. |
|
Abstract base class for samplers. |
|
Abstract base class for samplers. |
|
Abstract base class for samplers. |
|
Abstract base class for samplers. |
|
Abstract base class for samplers. |
|
Abstract base class for samplers. |
|
Abstract base class for samplers. |
|
Sampler class for pymultinest. |
|
Abstract base class for samplers. |
|
Abstract base class for samplers. |
|
Sampler class for JAXNS. |
|
Sampler class for JAXNS. |
Functions#
|
Raise if the retrieval has only legacy (non-differentiable) model groups. |
|
Return a PyMultiNest-compatible log-likelihood that dispatches to a JAX device. |
|
Build the unconstrained-space logdensity and transforms for gradient-based samplers. |
Module Contents#
- petitRADTRANS.retrieval.sampler.lax = None#
- petitRADTRANS.retrieval.sampler.random = None#
- petitRADTRANS.retrieval.sampler.tree = None#
- petitRADTRANS.retrieval.sampler.vmap = None#
- petitRADTRANS.retrieval.sampler.tree_util = None#
- petitRADTRANS.retrieval.sampler.jit = None#
- petitRADTRANS.retrieval.sampler.local_device_count = None#
- petitRADTRANS.retrieval.sampler.make_mesh = None#
- petitRADTRANS.retrieval.sampler.shard_map = None#
- petitRADTRANS.retrieval.sampler.set_mesh = None#
- petitRADTRANS.retrieval.sampler.devices = None#
- petitRADTRANS.retrieval.sampler.default_device = None#
- petitRADTRANS.retrieval.sampler.device_get = None#
- petitRADTRANS.retrieval.sampler.clear_caches = None#
- petitRADTRANS.retrieval.sampler.jnp = None#
- petitRADTRANS.retrieval.sampler.NamedSharding = None#
- petitRADTRANS.retrieval.sampler.P = None#
- petitRADTRANS.retrieval.sampler.tfp = None#
- petitRADTRANS.retrieval.sampler._require_jax()#
- petitRADTRANS.retrieval.sampler.jaxns = None#
- petitRADTRANS.retrieval.sampler.Model = None#
- petitRADTRANS.retrieval.sampler.NestedSampler = None#
- petitRADTRANS.retrieval.sampler.TerminationCondition = None#
- petitRADTRANS.retrieval.sampler.save_results = None#
- petitRADTRANS.retrieval.sampler.load_results = None#
- petitRADTRANS.retrieval.sampler.jaxns_resample = None#
- petitRADTRANS.retrieval.sampler.ShardedStaticNestedSampler = None#
- petitRADTRANS.retrieval.sampler.UniformSampler = None#
- petitRADTRANS.retrieval.sampler._JAXNS_TERMINATION_FIELDS = ('ess', 'evidence_uncert', 'live_evidence_frac', 'dlogZ', 'max_samples',...#
- petitRADTRANS.retrieval.sampler._require_jaxns()#
- petitRADTRANS.retrieval.sampler.blackjax = None#
- petitRADTRANS.retrieval.sampler._require_blackjax()#
- petitRADTRANS.retrieval.sampler.numpyro = None#
- petitRADTRANS.retrieval.sampler.NumPyroHMC = None#
- petitRADTRANS.retrieval.sampler.NumPyroNUTS = None#
- petitRADTRANS.retrieval.sampler.NumPyroMCMC = None#
- petitRADTRANS.retrieval.sampler.az = None#
- petitRADTRANS.retrieval.sampler.plt = None#
- petitRADTRANS.retrieval.sampler._require_numpyro()#
- petitRADTRANS.retrieval.sampler._require_arviz()#
- petitRADTRANS.retrieval.sampler._require_matplotlib_pyplot()#
- petitRADTRANS.retrieval.sampler.pymultinest = None#
- petitRADTRANS.retrieval.sampler._require_pymultinest()#
- petitRADTRANS.retrieval.sampler.ultranest = None#
- petitRADTRANS.retrieval.sampler.MLFriends = None#
- petitRADTRANS.retrieval.sampler.RobustEllipsoidRegion = None#
- petitRADTRANS.retrieval.sampler.dynesty = None#
- petitRADTRANS.retrieval.sampler.DynamicNestedSampler = None#
- petitRADTRANS.retrieval.sampler.DynestyPool = None#
- petitRADTRANS.retrieval.sampler.dyplot = None#
- petitRADTRANS.retrieval.sampler.dyfunc = None#
- petitRADTRANS.retrieval.sampler._require_ultranest()#
- petitRADTRANS.retrieval.sampler._require_dynesty()#
- class petitRADTRANS.retrieval.sampler.SamplerContext#
Lightweight container that carries everything a sampler needs from a Retrieval.
Built once by
Retrieval._build_sampler_context()and passed to each sampler’sprepare()class method so that sampler implementations never depend on the fullRetrievalobject.- parameter_layout: petitRADTRANS.retrieval.runtime.ParameterLayout#
- configuration: Any#
- output_directory: str#
- retrieval_name: str#
- uncertainties_mode: str = 'default'#
- print_log_likelihood_for_debugging: bool = False#
- log_likelihood_func: Any = None#
- prior_func: Any = None#
- prior_ultranest_func: Any = None#
- free_parameter_names: list = []#
- n_free_parameters: int = 0#
- likelihood_devices: list | None = None#
- petitRADTRANS.retrieval.sampler._check_differentiable_support(context: SamplerContext, sampler_label: str) None#
Raise if the retrieval has only legacy (non-differentiable) model groups.
- petitRADTRANS.retrieval.sampler._build_device_likelihood_wrapper(context: SamplerContext)#
Return a PyMultiNest-compatible log-likelihood that dispatches to a JAX device.
Each MPI rank r selects context.likelihood_devices[r % len(likelihood_devices)], so N ranks stripe naturally across M devices. Requires differentiable model groups.
Returns
(device_likelihood_func, target_device).Notes#
The user is responsible for ensuring that JAX processes do not unexpectedly contend on the same device (e.g. set
JAX_PLATFORMS=cuda,cpuorCUDA_VISIBLE_DEVICESper MPI process for true isolation when using multiple GPUs).
- petitRADTRANS.retrieval.sampler._build_unconstrained_mcmc_interface(context: SamplerContext)#
Build the unconstrained-space logdensity and transforms for gradient-based samplers.
Returns
(free_parameter_names, initial_position, logdensity, transform_positions, log_jacobian).
- class petitRADTRANS.retrieval.sampler.Sampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
abc.ABCAbstract base class for samplers.
- log_likelihood_func#
- prior_func#
- output_directory#
- retrieval_name#
- abstractmethod run_sampler(**kwargs)#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
- Abstractmethod:
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- post_run(results)#
Optional hook called after
run_sampler.Handles pretty-printing, results extraction, and any other sampler-specific finalization.
- Returns:
(final_results, extra_finalize_kwargs)where the extra kwargs are forwarded toRetrieval._finalize_sampler_run.
- plot_diagnostics(results=None)#
Optional hook to persist sampler-specific diagnostics plots.
- _get_samples_archive_path()#
- _load_sample_archive_results()#
- load_results()#
- static _extract_figure(plot_output)#
- classmethod _save_plot_figure(plot_output, output_file)#
- static _summary_scalar(value)#
- get_summary(free_parameter_names=None)#
- static _posterior_summary_from_matrix(samples, parameter_names)#
- static _posterior_summary_from_mapping(samples_by_parameter, free_parameter_names=None)#
- class petitRADTRANS.retrieval.sampler.BlackJAXSamplingResults#
- algorithm: str#
- samples: numpy.ndarray#
- unconstrained_samples: numpy.ndarray#
- logdensity: numpy.ndarray#
- model_log_likelihood: numpy.ndarray | None#
- warmup_parameters: dict#
- acceptance_rate: numpy.ndarray | None = None#
- is_divergent: numpy.ndarray | None = None#
- num_integration_steps: numpy.ndarray | None = None#
- output_file: str | None = None#
- metadata_file: str | None = None#
- class petitRADTRANS.retrieval.sampler.NumPyroSamplingResults#
- algorithm: str#
- samples: numpy.ndarray#
- unconstrained_samples: numpy.ndarray#
- logdensity: numpy.ndarray#
- model_log_likelihood: numpy.ndarray | None#
- warmup_parameters: dict#
- acceptance_rate: numpy.ndarray | None = None#
- is_divergent: numpy.ndarray | None = None#
- num_integration_steps: numpy.ndarray | None = None#
- output_file: str | None = None#
- metadata_file: str | None = None#
- class petitRADTRANS.retrieval.sampler.BlackJAXSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
SamplerAbstract base class for samplers.
- results = None#
- classmethod _prepare_blackjax(context: SamplerContext, algorithm_name: str, **sampler_kwargs)#
Shared preparation logic for all BlackJAX sampler variants.
- abstract property algorithm_name#
- abstract property algorithm_factory#
- static _to_serializable(value)#
- _get_output_prefix()#
- _save_results(parameter_names, samples, unconstrained_samples, logdensity, model_log_likelihood, acceptance_rate, is_divergent, num_integration_steps, warmup_parameters, metadata)#
- run_sampler(initial_position, parameter_names, num_samples=1000, num_warmup=1000, initial_step_size=1.0, target_acceptance_rate=0.8, is_mass_matrix_diagonal=True, progress_bar=False, seed=42654, use_jit=True, log_jacobian_fn=None, num_chains=1, **blackjax_kwargs)#
- get_results()#
- load_results()#
- pretty_print_results()#
- get_summary(free_parameter_names=None)#
- class petitRADTRANS.retrieval.sampler.BlackJAXHMCSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
BlackJAXSamplerAbstract base class for samplers.
- property algorithm_name#
- property algorithm_factory#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- class petitRADTRANS.retrieval.sampler.BlackJAXNUTSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
BlackJAXSamplerAbstract base class for samplers.
- property algorithm_name#
- property algorithm_factory#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- class petitRADTRANS.retrieval.sampler.NumPyroSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
SamplerAbstract base class for samplers.
- results = None#
- mcmc = None#
- parameter_names = None#
- classmethod _prepare_numpyro(context: SamplerContext, algorithm_name: str, **sampler_kwargs)#
Shared preparation logic for all NumPyro sampler variants.
- abstract property algorithm_name#
- abstract property kernel_class#
- static _to_serializable(value)#
- _get_output_prefix()#
- static _prepare_kernel_kwargs(algorithm_name, log_likelihood_func, initial_step_size, is_mass_matrix_diagonal, target_acceptance_rate, numpyro_kwargs)#
- _save_results(parameter_names, samples, unconstrained_samples, logdensity, model_log_likelihood, acceptance_rate, is_divergent, num_integration_steps, warmup_parameters, metadata)#
- run_sampler(initial_position, parameter_names, num_samples=1000, num_warmup=1000, initial_step_size=1.0, target_acceptance_rate=0.8, is_mass_matrix_diagonal=True, progress_bar=False, seed=42654, use_jit=True, log_jacobian_fn=None, num_chains=1, thinning=1, chain_method='vectorized', **numpyro_kwargs)#
- get_results()#
- load_results()#
- pretty_print_results()#
- plot_diagnostics(results=None)#
Optional hook to persist sampler-specific diagnostics plots.
- get_summary(free_parameter_names=None)#
- class petitRADTRANS.retrieval.sampler.NumPyroHMCSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
NumPyroSamplerAbstract base class for samplers.
- property algorithm_name#
- property kernel_class#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- class petitRADTRANS.retrieval.sampler.NumPyroNUTSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
NumPyroSamplerAbstract base class for samplers.
- property algorithm_name#
- property kernel_class#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- class petitRADTRANS.retrieval.sampler.PymultinestSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
SamplerSampler class for pymultinest.
- seed = -1#
- analyzer = None#
- _n_dims = None#
- outputfiles_basename#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- post_run(results)#
Optional hook called after
run_sampler.Handles pretty-printing, results extraction, and any other sampler-specific finalization.
- Returns:
(final_results, extra_finalize_kwargs)where the extra kwargs are forwarded toRetrieval._finalize_sampler_run.
- run_sampler(n_dims, **kwargs)#
- Args:
- sampling_efficiencyFloat
pymultinest sampling efficiency. If const efficiency mode is true, should be set to around 0.05. Otherwise, it should be around 0.8 for parameter estimation and 0.3 for evidence comparison.
- const_efficiency_modeBool
pymultinest constant efficiency mode
- n_live_pointsInt
Number of live points to use in pymultinest, or the minimum number of live points to use for the Ultranest reactive sampler.
- log_z_convergencefloat
If ultranest is being used, the convergence criterion on log z.
- step_samplerbool
Use a step sampler to improve the efficiency in ultranest.
- warmstart_max_taufloat
Warm start allows accelerated computation based on a different but similar UltraNest run.
- n_iter_before_updateint
Number of live point replacements before printing an update to a log file.
- max_itersint
Maximum number of sampling iterations. If 0, will continue until convergence criteria are satisfied.
- frac_remainfloat
Ultranest convergence criterion. Halts integration if live point weights are below the specified value.
- l_epsilonfloat
Ultranest convergence criterion. Use with noisy likelihoods. Halts integration if live points are within l_epsilon.
- resumebool
Continue existing retrieval. If FALSE THIS WILL OVERWRITE YOUR EXISTING RETRIEVAL.
- error_checkingbool
Test the model generating function for typical errors. ONLY TURN THIS OFF IF YOU KNOW WHAT YOU’RE DOING!
- force_serial_error_checkingbool
If True, error checking will be performed process-by-process, instead of with all processes at once. This can prevent memory overflow.
- seedint
Random number generator seed, -ve value for seed from the system clock (for reproducibility)
- Returns:
analyzer
- load_results()#
- pretty_print(free_parameter_names, prefix)#
- get_summary(free_parameter_names=None)#
- class petitRADTRANS.retrieval.sampler.UltranestSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
SamplerAbstract base class for samplers.
- results = None#
- _get_log_dir()#
- classmethod _resolve_vectorized(context: SamplerContext, requested: bool) bool#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- load_results()#
- run_sampler(parameter_names, **kwargs)#
- get_results()#
- pretty_print_results()#
- get_summary(free_parameter_names=None)#
- class petitRADTRANS.retrieval.sampler.DynestySampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
SamplerAbstract base class for samplers.
- sampler = None#
- results = None#
- raw_results = None#
- parameter_names = []#
- _run_summary#
- _dynesty_pool = None#
- _owns_dynesty_pool = False#
- static _default_use_pool() dict[str, bool]#
- classmethod _configure_dynesty_pool(log_likelihood_func, prior_func, *, pool, pool_njobs)#
- _close_dynesty_pool()#
- classmethod _build_dynesty_interface(context: SamplerContext, *, use_jit: bool = True, pool_safe: bool = False)#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- post_run(results)#
Optional hook called after
run_sampler.Handles pretty-printing, results extraction, and any other sampler-specific finalization.
- Returns:
(final_results, extra_finalize_kwargs)where the extra kwargs are forwarded toRetrieval._finalize_sampler_run.
- static _normalize_weights(weights)#
- _get_output_prefix()#
- _save_results(parameter_names, samples, log_likelihood, weighted_samples, weighted_log_likelihood, weighted_log_weights, metadata)#
- _build_equal_weight_posterior(raw_results, seed)#
- static _summarize_ncall(raw_results)#
- run_sampler(init_kwargs, **run_kwargs)#
- get_results()#
- load_results()#
- pretty_print_results()#
- plot_diagnostics(results=None)#
Optional hook to persist sampler-specific diagnostics plots.
- get_summary(free_parameter_names=None)#
- class petitRADTRANS.retrieval.sampler.JAXNSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
SamplerSampler class for JAXNS.
- jaxns_parameters#
- seed = -1#
- sampler = None#
- results = None#
- raw_results = None#
- parameter_names = []#
- offload_results_to_cpu = True#
- classmethod _build_jaxns_interface(context: SamplerContext)#
Build JAXNS prior_model generator and log_likelihood wrapper.
- static _coerce_termination_condition_mapping(values, source_name: str) dict[str, Any]#
- classmethod _build_termination_condition(sampler_kwargs: dict[str, Any]) tuple[Any, dict[str, Any]]#
- classmethod prepare(context: SamplerContext, **sampler_kwargs)#
Build a sampler instance, run-kwargs, and summary parameters.
Every concrete sampler should override this to assemble its own run kwargs from the common
SamplerContextinstead of relying on theRetrievalclass.- Returns:
(sampler_instance, run_kwargs, summary_parameters)
- _get_output_prefix()#
- _get_raw_results_path()#
- _get_equal_weight_samples_path()#
- _get_parameter_names_path()#
- _get_cpu_device()#
- _materialize_raw_results(termination_reason, state)#
- _to_numpy_tree(pytree)#
- static _get_named_samples_mapping(raw_results)#
- static _samples_mapping_to_matrix(samples_by_parameter, parameter_names=None)#
- _save_equal_weight_samples(parameter_names, samples, log_likelihood)#
- _build_processed_results(raw_results, *, equal_weight_samples=None, equal_weight_log_likelihood=None, output_file=None, params_file=None, raw_results_file=None)#
- _build_equal_weight_results(raw_results)#
- post_run(results)#
JAXNS post-run: get_results is already called in run_sampler, pass stats=results.
- run_sampler(use_jit=True, parameters={}, **jaxns_kwargs)#
Run the JAXNS sampler.
- get_results()#
- load_results()#
- pretty_print_results()#
- plot_diagnostics(results=None)#
Optional hook to persist sampler-specific diagnostics plots.
- get_summary(free_parameter_names=None)#
- class petitRADTRANS.retrieval.sampler.JAXNSShardedStaticNestedSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)#
Bases:
JAXNSSamplerSampler class for JAXNS.
- log_likelihood_func#
- prior_function#
- output_directory#
- retrieval_name#
- jaxns_parameters#
- seed = -1#
- sampler = None#
- results = None#
- run_sampler(use_jit=True, **jaxns_kwargs)#
Run the JAXNS sampler.
- get_results()#
- load_results()#
- pretty_print_results()#