petitRADTRANS.retrieval.sampler
===============================

.. py:module:: petitRADTRANS.retrieval.sampler


Attributes
----------

.. autoapisummary::

   petitRADTRANS.retrieval.sampler.lax
   petitRADTRANS.retrieval.sampler.random
   petitRADTRANS.retrieval.sampler.tree
   petitRADTRANS.retrieval.sampler.vmap
   petitRADTRANS.retrieval.sampler.tree_util
   petitRADTRANS.retrieval.sampler.jit
   petitRADTRANS.retrieval.sampler.local_device_count
   petitRADTRANS.retrieval.sampler.make_mesh
   petitRADTRANS.retrieval.sampler.shard_map
   petitRADTRANS.retrieval.sampler.set_mesh
   petitRADTRANS.retrieval.sampler.devices
   petitRADTRANS.retrieval.sampler.default_device
   petitRADTRANS.retrieval.sampler.device_get
   petitRADTRANS.retrieval.sampler.clear_caches
   petitRADTRANS.retrieval.sampler.jnp
   petitRADTRANS.retrieval.sampler.NamedSharding
   petitRADTRANS.retrieval.sampler.P
   petitRADTRANS.retrieval.sampler.tfp
   petitRADTRANS.retrieval.sampler.jaxns
   petitRADTRANS.retrieval.sampler.Model
   petitRADTRANS.retrieval.sampler.NestedSampler
   petitRADTRANS.retrieval.sampler.TerminationCondition
   petitRADTRANS.retrieval.sampler.save_results
   petitRADTRANS.retrieval.sampler.load_results
   petitRADTRANS.retrieval.sampler.jaxns_resample
   petitRADTRANS.retrieval.sampler.ShardedStaticNestedSampler
   petitRADTRANS.retrieval.sampler.UniformSampler
   petitRADTRANS.retrieval.sampler._JAXNS_TERMINATION_FIELDS
   petitRADTRANS.retrieval.sampler.blackjax
   petitRADTRANS.retrieval.sampler.numpyro
   petitRADTRANS.retrieval.sampler.NumPyroHMC
   petitRADTRANS.retrieval.sampler.NumPyroNUTS
   petitRADTRANS.retrieval.sampler.NumPyroMCMC
   petitRADTRANS.retrieval.sampler.az
   petitRADTRANS.retrieval.sampler.plt
   petitRADTRANS.retrieval.sampler.pymultinest
   petitRADTRANS.retrieval.sampler.ultranest
   petitRADTRANS.retrieval.sampler.MLFriends
   petitRADTRANS.retrieval.sampler.RobustEllipsoidRegion
   petitRADTRANS.retrieval.sampler.dynesty
   petitRADTRANS.retrieval.sampler.DynamicNestedSampler
   petitRADTRANS.retrieval.sampler.DynestyPool
   petitRADTRANS.retrieval.sampler.dyplot
   petitRADTRANS.retrieval.sampler.dyfunc


Classes
-------

.. autoapisummary::

   petitRADTRANS.retrieval.sampler.SamplerContext
   petitRADTRANS.retrieval.sampler.Sampler
   petitRADTRANS.retrieval.sampler.BlackJAXSamplingResults
   petitRADTRANS.retrieval.sampler.NumPyroSamplingResults
   petitRADTRANS.retrieval.sampler.BlackJAXSampler
   petitRADTRANS.retrieval.sampler.BlackJAXHMCSampler
   petitRADTRANS.retrieval.sampler.BlackJAXNUTSSampler
   petitRADTRANS.retrieval.sampler.NumPyroSampler
   petitRADTRANS.retrieval.sampler.NumPyroHMCSampler
   petitRADTRANS.retrieval.sampler.NumPyroNUTSSampler
   petitRADTRANS.retrieval.sampler.PymultinestSampler
   petitRADTRANS.retrieval.sampler.UltranestSampler
   petitRADTRANS.retrieval.sampler.DynestySampler
   petitRADTRANS.retrieval.sampler.JAXNSSampler
   petitRADTRANS.retrieval.sampler.JAXNSShardedStaticNestedSampler


Functions
---------

.. autoapisummary::

   petitRADTRANS.retrieval.sampler._require_jax
   petitRADTRANS.retrieval.sampler._require_jaxns
   petitRADTRANS.retrieval.sampler._require_blackjax
   petitRADTRANS.retrieval.sampler._require_numpyro
   petitRADTRANS.retrieval.sampler._require_arviz
   petitRADTRANS.retrieval.sampler._require_matplotlib_pyplot
   petitRADTRANS.retrieval.sampler._require_pymultinest
   petitRADTRANS.retrieval.sampler._require_ultranest
   petitRADTRANS.retrieval.sampler._require_dynesty
   petitRADTRANS.retrieval.sampler._check_differentiable_support
   petitRADTRANS.retrieval.sampler._build_device_likelihood_wrapper
   petitRADTRANS.retrieval.sampler._build_unconstrained_mcmc_interface


Module Contents
---------------

.. py:data:: lax
   :value: None


.. py:data:: random
   :value: None


.. py:data:: tree
   :value: None


.. py:data:: vmap
   :value: None


.. py:data:: tree_util
   :value: None


.. py:data:: jit
   :value: None


.. py:data:: local_device_count
   :value: None


.. py:data:: make_mesh
   :value: None


.. py:data:: shard_map
   :value: None


.. py:data:: set_mesh
   :value: None


.. py:data:: devices
   :value: None


.. py:data:: default_device
   :value: None


.. py:data:: device_get
   :value: None


.. py:data:: clear_caches
   :value: None


.. py:data:: jnp
   :value: None


.. py:data:: NamedSharding
   :value: None


.. py:data:: P
   :value: None


.. py:data:: tfp
   :value: None


.. py:function:: _require_jax()

.. py:data:: jaxns
   :value: None


.. py:data:: Model
   :value: None


.. py:data:: NestedSampler
   :value: None


.. py:data:: TerminationCondition
   :value: None


.. py:data:: save_results
   :value: None


.. py:data:: load_results
   :value: None


.. py:data:: jaxns_resample
   :value: None


.. py:data:: ShardedStaticNestedSampler
   :value: None


.. py:data:: UniformSampler
   :value: None


.. py:data:: _JAXNS_TERMINATION_FIELDS
   :value: ('ess', 'evidence_uncert', 'live_evidence_frac', 'dlogZ', 'max_samples',...


.. py:function:: _require_jaxns()

.. py:data:: blackjax
   :value: None


.. py:function:: _require_blackjax()

.. py:data:: numpyro
   :value: None


.. py:data:: NumPyroHMC
   :value: None


.. py:data:: NumPyroNUTS
   :value: None


.. py:data:: NumPyroMCMC
   :value: None


.. py:data:: az
   :value: None


.. py:data:: plt
   :value: None


.. py:function:: _require_numpyro()

.. py:function:: _require_arviz()

.. py:function:: _require_matplotlib_pyplot()

.. py:data:: pymultinest
   :value: None


.. py:function:: _require_pymultinest()

.. py:data:: ultranest
   :value: None


.. py:data:: MLFriends
   :value: None


.. py:data:: RobustEllipsoidRegion
   :value: None


.. py:data:: dynesty
   :value: None


.. py:data:: DynamicNestedSampler
   :value: None


.. py:data:: DynestyPool
   :value: None


.. py:data:: dyplot
   :value: None


.. py:data:: dyfunc
   :value: None


.. py:function:: _require_ultranest()

.. py:function:: _require_dynesty()

.. py:class:: SamplerContext

   Lightweight container that carries everything a sampler needs from a Retrieval.

   Built once by ``Retrieval._build_sampler_context()`` and passed to each
   sampler's ``prepare()`` class method so that sampler implementations never
   depend on the full ``Retrieval`` object.


   .. py:attribute:: parameter_layout
      :type:  petitRADTRANS.retrieval.runtime.ParameterLayout


   .. py:attribute:: runtime
      :type:  petitRADTRANS.retrieval.runtime.RetrievalRuntime


   .. py:attribute:: configuration
      :type:  Any


   .. py:attribute:: output_directory
      :type:  str


   .. py:attribute:: retrieval_name
      :type:  str


   .. py:attribute:: uncertainties_mode
      :type:  str
      :value: 'default'



   .. py:attribute:: print_log_likelihood_for_debugging
      :type:  bool
      :value: False



   .. py:attribute:: log_likelihood_func
      :type:  Any
      :value: None



   .. py:attribute:: prior_func
      :type:  Any
      :value: None



   .. py:attribute:: prior_ultranest_func
      :type:  Any
      :value: None



   .. py:attribute:: free_parameter_names
      :type:  list
      :value: []



   .. py:attribute:: n_free_parameters
      :type:  int
      :value: 0



   .. py:attribute:: likelihood_devices
      :type:  list | None
      :value: None



.. py:function:: _check_differentiable_support(context: SamplerContext, sampler_label: str) -> None

   Raise if the retrieval has only legacy (non-differentiable) model groups.


.. py:function:: _build_device_likelihood_wrapper(context: SamplerContext)

   Return a PyMultiNest-compatible log-likelihood that dispatches to a JAX device.

   Each MPI rank r selects context.likelihood_devices[r % len(likelihood_devices)],
   so N ranks stripe naturally across M devices. Requires differentiable model groups.

   Returns ``(device_likelihood_func, target_device)``.

   Notes
   -----
   The user is responsible for ensuring that JAX processes do not unexpectedly contend
   on the same device (e.g. set ``JAX_PLATFORMS=cuda,cpu`` or ``CUDA_VISIBLE_DEVICES``
   per MPI process for true isolation when using multiple GPUs).


.. py:function:: _build_unconstrained_mcmc_interface(context: SamplerContext)

   Build the unconstrained-space logdensity and transforms for gradient-based samplers.

   Returns ``(free_parameter_names, initial_position, logdensity, transform_positions, log_jacobian)``.


.. py:class:: Sampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`abc.ABC`


   Abstract base class for samplers.


   .. py:attribute:: log_likelihood_func


   .. py:attribute:: prior_func


   .. py:attribute:: output_directory


   .. py:attribute:: retrieval_name


   .. py:method:: run_sampler(**kwargs)
      :abstractmethod:



   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:

      :abstractmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



   .. py:method:: post_run(results)

      Optional hook called after ``run_sampler``.

      Handles pretty-printing, results extraction, and any other
      sampler-specific finalization.

      Returns:
          ``(final_results, extra_finalize_kwargs)`` where the extra kwargs
          are forwarded to ``Retrieval._finalize_sampler_run``.



   .. py:method:: plot_diagnostics(results=None)

      Optional hook to persist sampler-specific diagnostics plots.



   .. py:method:: _get_samples_archive_path()


   .. py:method:: _load_sample_archive_results()


   .. py:method:: load_results()


   .. py:method:: _extract_figure(plot_output)
      :staticmethod:



   .. py:method:: _save_plot_figure(plot_output, output_file)
      :classmethod:



   .. py:method:: _summary_scalar(value)
      :staticmethod:



   .. py:method:: get_summary(free_parameter_names=None)


   .. py:method:: _posterior_summary_from_matrix(samples, parameter_names)
      :staticmethod:



   .. py:method:: _posterior_summary_from_mapping(samples_by_parameter, free_parameter_names=None)
      :staticmethod:



.. py:class:: BlackJAXSamplingResults

   .. py:attribute:: algorithm
      :type:  str


   .. py:attribute:: samples
      :type:  numpy.ndarray


   .. py:attribute:: unconstrained_samples
      :type:  numpy.ndarray


   .. py:attribute:: logdensity
      :type:  numpy.ndarray


   .. py:attribute:: model_log_likelihood
      :type:  numpy.ndarray | None


   .. py:attribute:: warmup_parameters
      :type:  dict


   .. py:attribute:: acceptance_rate
      :type:  numpy.ndarray | None
      :value: None



   .. py:attribute:: is_divergent
      :type:  numpy.ndarray | None
      :value: None



   .. py:attribute:: num_integration_steps
      :type:  numpy.ndarray | None
      :value: None



   .. py:attribute:: output_file
      :type:  str | None
      :value: None



   .. py:attribute:: metadata_file
      :type:  str | None
      :value: None



.. py:class:: NumPyroSamplingResults

   .. py:attribute:: algorithm
      :type:  str


   .. py:attribute:: samples
      :type:  numpy.ndarray


   .. py:attribute:: unconstrained_samples
      :type:  numpy.ndarray


   .. py:attribute:: logdensity
      :type:  numpy.ndarray


   .. py:attribute:: model_log_likelihood
      :type:  numpy.ndarray | None


   .. py:attribute:: warmup_parameters
      :type:  dict


   .. py:attribute:: acceptance_rate
      :type:  numpy.ndarray | None
      :value: None



   .. py:attribute:: is_divergent
      :type:  numpy.ndarray | None
      :value: None



   .. py:attribute:: num_integration_steps
      :type:  numpy.ndarray | None
      :value: None



   .. py:attribute:: output_file
      :type:  str | None
      :value: None



   .. py:attribute:: metadata_file
      :type:  str | None
      :value: None



.. py:class:: BlackJAXSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`Sampler`


   Abstract base class for samplers.


   .. py:attribute:: results
      :value: None



   .. py:method:: _prepare_blackjax(context: SamplerContext, algorithm_name: str, **sampler_kwargs)
      :classmethod:


      Shared preparation logic for all BlackJAX sampler variants.



   .. py:property:: algorithm_name
      :abstractmethod:



   .. py:property:: algorithm_factory
      :abstractmethod:



   .. py:method:: _to_serializable(value)
      :staticmethod:



   .. py:method:: _get_output_prefix()


   .. py:method:: _save_results(parameter_names, samples, unconstrained_samples, logdensity, model_log_likelihood, acceptance_rate, is_divergent, num_integration_steps, warmup_parameters, metadata)


   .. py:method:: run_sampler(initial_position, parameter_names, num_samples=1000, num_warmup=1000, initial_step_size=1.0, target_acceptance_rate=0.8, is_mass_matrix_diagonal=True, progress_bar=False, seed=42654, use_jit=True, log_jacobian_fn=None, num_chains=1, **blackjax_kwargs)


   .. py:method:: get_results()


   .. py:method:: load_results()


   .. py:method:: pretty_print_results()


   .. py:method:: get_summary(free_parameter_names=None)


.. py:class:: BlackJAXHMCSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`BlackJAXSampler`


   Abstract base class for samplers.


   .. py:property:: algorithm_name


   .. py:property:: algorithm_factory


   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



.. py:class:: BlackJAXNUTSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`BlackJAXSampler`


   Abstract base class for samplers.


   .. py:property:: algorithm_name


   .. py:property:: algorithm_factory


   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



.. py:class:: NumPyroSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`Sampler`


   Abstract base class for samplers.


   .. py:attribute:: results
      :value: None



   .. py:attribute:: mcmc
      :value: None



   .. py:attribute:: parameter_names
      :value: None



   .. py:method:: _prepare_numpyro(context: SamplerContext, algorithm_name: str, **sampler_kwargs)
      :classmethod:


      Shared preparation logic for all NumPyro sampler variants.



   .. py:property:: algorithm_name
      :abstractmethod:



   .. py:property:: kernel_class
      :abstractmethod:



   .. py:method:: _to_serializable(value)
      :staticmethod:



   .. py:method:: _get_output_prefix()


   .. py:method:: _prepare_kernel_kwargs(algorithm_name, log_likelihood_func, initial_step_size, is_mass_matrix_diagonal, target_acceptance_rate, numpyro_kwargs)
      :staticmethod:



   .. py:method:: _save_results(parameter_names, samples, unconstrained_samples, logdensity, model_log_likelihood, acceptance_rate, is_divergent, num_integration_steps, warmup_parameters, metadata)


   .. py:method:: run_sampler(initial_position, parameter_names, num_samples=1000, num_warmup=1000, initial_step_size=1.0, target_acceptance_rate=0.8, is_mass_matrix_diagonal=True, progress_bar=False, seed=42654, use_jit=True, log_jacobian_fn=None, num_chains=1, thinning=1, chain_method='vectorized', **numpyro_kwargs)


   .. py:method:: get_results()


   .. py:method:: load_results()


   .. py:method:: pretty_print_results()


   .. py:method:: plot_diagnostics(results=None)

      Optional hook to persist sampler-specific diagnostics plots.



   .. py:method:: get_summary(free_parameter_names=None)


.. py:class:: NumPyroHMCSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`NumPyroSampler`


   Abstract base class for samplers.


   .. py:property:: algorithm_name


   .. py:property:: kernel_class


   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



.. py:class:: NumPyroNUTSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`NumPyroSampler`


   Abstract base class for samplers.


   .. py:property:: algorithm_name


   .. py:property:: kernel_class


   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



.. py:class:: PymultinestSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`Sampler`


   Sampler class for pymultinest.


   .. py:attribute:: seed
      :value: -1



   .. py:attribute:: analyzer
      :value: None



   .. py:attribute:: _n_dims
      :value: None



   .. py:attribute:: outputfiles_basename


   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



   .. py:method:: post_run(results)

      Optional hook called after ``run_sampler``.

      Handles pretty-printing, results extraction, and any other
      sampler-specific finalization.

      Returns:
          ``(final_results, extra_finalize_kwargs)`` where the extra kwargs
          are forwarded to ``Retrieval._finalize_sampler_run``.



   .. py:method:: run_sampler(n_dims, **kwargs)

      Args:
          sampling_efficiency : Float
              pymultinest sampling efficiency. If const efficiency mode is true, should be set to around
              0.05. Otherwise, it should be around 0.8 for parameter estimation and 0.3 for evidence
              comparison.
          const_efficiency_mode : Bool
              pymultinest constant efficiency mode
          n_live_points : Int
              Number of live points to use in pymultinest, or the minimum number of live points to
              use for the Ultranest reactive sampler.
          log_z_convergence : float
              If ultranest is being used, the convergence criterion on log z.
          step_sampler : bool
              Use a step sampler to improve the efficiency in ultranest.
          warmstart_max_tau : float
              Warm start allows accelerated computation based on a different but similar UltraNest run.
          n_iter_before_update : int
              Number of live point replacements before printing an update to a log file.
          max_iters : int
              Maximum number of sampling iterations. If 0, will continue until convergence criteria are satisfied.
          frac_remain : float
              Ultranest convergence criterion. Halts integration if live point weights are below the specified value.
          l_epsilon : float
              Ultranest convergence criterion. Use with noisy likelihoods. Halts integration if live points are
              within l_epsilon.
          resume : bool
              Continue existing retrieval. If FALSE THIS WILL OVERWRITE YOUR EXISTING RETRIEVAL.
          error_checking : bool
              Test the model generating function for typical errors. ONLY TURN THIS OFF IF YOU KNOW WHAT YOU'RE DOING!
          force_serial_error_checking : bool
              If True, error checking will be performed process-by-process, instead of with all processes at once.
              This can prevent memory overflow.
          seed : int
              Random number generator seed, -ve value for seed from the system clock (for reproducibility)
      Returns:
          analyzer



   .. py:method:: load_results()


   .. py:method:: pretty_print(free_parameter_names, prefix)


   .. py:method:: get_summary(free_parameter_names=None)


.. py:class:: UltranestSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`Sampler`


   Abstract base class for samplers.


   .. py:attribute:: results
      :value: None



   .. py:method:: _get_log_dir()


   .. py:method:: _resolve_vectorized(context: SamplerContext, requested: bool) -> bool
      :classmethod:



   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



   .. py:method:: load_results()


   .. py:method:: run_sampler(parameter_names, **kwargs)


   .. py:method:: get_results()


   .. py:method:: pretty_print_results()


   .. py:method:: get_summary(free_parameter_names=None)


.. py:class:: DynestySampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`Sampler`


   Abstract base class for samplers.


   .. py:attribute:: sampler
      :value: None



   .. py:attribute:: results
      :value: None



   .. py:attribute:: raw_results
      :value: None



   .. py:attribute:: parameter_names
      :value: []



   .. py:attribute:: _run_summary


   .. py:attribute:: _dynesty_pool
      :value: None



   .. py:attribute:: _owns_dynesty_pool
      :value: False



   .. py:method:: _default_use_pool() -> dict[str, bool]
      :staticmethod:



   .. py:method:: _configure_dynesty_pool(log_likelihood_func, prior_func, *, pool, pool_njobs)
      :classmethod:



   .. py:method:: _close_dynesty_pool()


   .. py:method:: _build_dynesty_interface(context: SamplerContext, *, use_jit: bool = True, pool_safe: bool = False)
      :classmethod:



   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



   .. py:method:: post_run(results)

      Optional hook called after ``run_sampler``.

      Handles pretty-printing, results extraction, and any other
      sampler-specific finalization.

      Returns:
          ``(final_results, extra_finalize_kwargs)`` where the extra kwargs
          are forwarded to ``Retrieval._finalize_sampler_run``.



   .. py:method:: _normalize_weights(weights)
      :staticmethod:



   .. py:method:: _get_output_prefix()


   .. py:method:: _save_results(parameter_names, samples, log_likelihood, weighted_samples, weighted_log_likelihood, weighted_log_weights, metadata)


   .. py:method:: _build_equal_weight_posterior(raw_results, seed)


   .. py:method:: _summarize_ncall(raw_results)
      :staticmethod:



   .. py:method:: run_sampler(init_kwargs, **run_kwargs)


   .. py:method:: get_results()


   .. py:method:: load_results()


   .. py:method:: pretty_print_results()


   .. py:method:: plot_diagnostics(results=None)

      Optional hook to persist sampler-specific diagnostics plots.



   .. py:method:: get_summary(free_parameter_names=None)


.. py:class:: JAXNSSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`Sampler`


   Sampler class for JAXNS.


   .. py:attribute:: jaxns_parameters


   .. py:attribute:: seed
      :value: -1



   .. py:attribute:: sampler
      :value: None



   .. py:attribute:: results
      :value: None



   .. py:attribute:: raw_results
      :value: None



   .. py:attribute:: parameter_names
      :value: []



   .. py:attribute:: offload_results_to_cpu
      :value: True



   .. py:method:: _build_jaxns_interface(context: SamplerContext)
      :classmethod:


      Build JAXNS prior_model generator and log_likelihood wrapper.



   .. py:method:: _coerce_termination_condition_mapping(values, source_name: str) -> dict[str, Any]
      :staticmethod:



   .. py:method:: _build_termination_condition(sampler_kwargs: dict[str, Any]) -> tuple[Any, dict[str, Any]]
      :classmethod:



   .. py:method:: prepare(context: SamplerContext, **sampler_kwargs)
      :classmethod:


      Build a sampler instance, run-kwargs, and summary parameters.

      Every concrete sampler should override this to assemble its own
      run kwargs from the common ``SamplerContext`` instead of relying on
      the ``Retrieval`` class.

      Returns:
          ``(sampler_instance, run_kwargs, summary_parameters)``



   .. py:method:: _get_output_prefix()


   .. py:method:: _get_raw_results_path()


   .. py:method:: _get_equal_weight_samples_path()


   .. py:method:: _get_parameter_names_path()


   .. py:method:: _get_cpu_device()


   .. py:method:: _materialize_raw_results(termination_reason, state)


   .. py:method:: _to_numpy_tree(pytree)


   .. py:method:: _get_named_samples_mapping(raw_results)
      :staticmethod:



   .. py:method:: _samples_mapping_to_matrix(samples_by_parameter, parameter_names=None)
      :staticmethod:



   .. py:method:: _save_equal_weight_samples(parameter_names, samples, log_likelihood)


   .. py:method:: _build_processed_results(raw_results, *, equal_weight_samples=None, equal_weight_log_likelihood=None, output_file=None, params_file=None, raw_results_file=None)


   .. py:method:: _build_equal_weight_results(raw_results)


   .. py:method:: post_run(results)

      JAXNS post-run: get_results is already called in run_sampler, pass stats=results.



   .. py:method:: run_sampler(use_jit=True, parameters={}, **jaxns_kwargs)

      Run the JAXNS sampler.



   .. py:method:: get_results()


   .. py:method:: load_results()


   .. py:method:: pretty_print_results()


   .. py:method:: plot_diagnostics(results=None)

      Optional hook to persist sampler-specific diagnostics plots.



   .. py:method:: get_summary(free_parameter_names=None)


.. py:class:: JAXNSShardedStaticNestedSampler(log_likelihood_func, prior_func, output_directory, retrieval_name)

   Bases: :py:obj:`JAXNSSampler`


   Sampler class for JAXNS.


   .. py:attribute:: log_likelihood_func


   .. py:attribute:: prior_function


   .. py:attribute:: output_directory


   .. py:attribute:: retrieval_name


   .. py:attribute:: jaxns_parameters


   .. py:attribute:: seed
      :value: -1



   .. py:attribute:: sampler
      :value: None



   .. py:attribute:: results
      :value: None



   .. py:method:: run_sampler(use_jit=True, **jaxns_kwargs)

      Run the JAXNS sampler.



   .. py:method:: get_results()


   .. py:method:: load_results()


   .. py:method:: pretty_print_results()


