Retrievals: Emission Spectra Retrieval with NumPyro NUTS

[ ]:
import os
# To not have numpy start parallelizing on its own
os.environ["OMP_NUM_THREADS"] = "1"

from jax import config
config.update("jax_enable_x64", True)

import tensorflow_probability.substrates.jax as tfp
from petitRADTRANS.retrieval import Retrieval, RetrievalConfig
from petitRADTRANS.retrieval.models import molliere_2020_emission
from petitRADTRANS import physical_constants as cst

tfpd = tfp.distributions

Retrievals: Emission Spectra Retrieval with NumPyro NUTS#

Written by Evert Nasedkin. Please cite pRT’s retrieval package (Nasedkin et al. 2024) in addition to pRT (Mollière et al. 2019) if you make use of the retrieval package for your work.

This retrieval is based on the forward model used in Mollière et al. (2020) for HR8799e, and shows a more realistic example of how to set up a retrieval. This notebook uses the differentiable molliere_2020_emission model together with the NumPyro NUTS sampler so that pRT can evaluate gradients through the runtime-native retrieval path. Emission spectra retrievals, particularly when multiple datasets are included and scattering is taken into account, can still take substantial computational resources to run. Therefore, this notebook outlines the setup for such a retrieval, but we still advise running production chains on a workstation or cluster.

For this example, the module chemistry.pre_calculated_chemistry is used to solve for disequilibrium chemistry (actually equilibrium chemistry with a simple quenching treatment), and the chemistry.clouds module is used for condensation. Note that you can import other models from petitRADTRANS/retrieval/models.py, see the API documentation. This file is included in the pRT package folder. Alternatively you can access them through git (just click the link). Of course you can also write your own model. For this we recommend using existing models as template. Remember to respect the input and output format of the model functions described in “Basic Retrieval Tutorial”

The model here uses a simple adaptive mesh refinement (AMR) algorithm to improve the pressure resolution around the location of the cloud bases.

Getting started#

Please make sure to have worked through the“Basic Retrieval Tutorial”before looking at the material below.

In this tutorial, we will outline the process of setting up a RetrievalConfig object, which is the class used to set up a pRT retrieval. The basic process is always to set up the configuration, and then pass it to the Retrieval class to run the retrieval with the sampler backend selected in RetrievalConfig. In this notebook we use sampler_type="numpyro_nuts". Like mentioned in the “Basic Retrieval Tutorial” several standard plotting outputs will also be produced by the retrieval class. NUTS is a gradient-based MCMC method: during warmup it adapts the Hamiltonian trajectory length and step size so it can explore smooth differentiable posteriors efficiently, but it does not compute the Bayesian evidence. Nested sampling backends such as PyMultiNest, Dynesty, or UltraNest are designed to estimate both the posterior and the evidence, which makes them better suited for evidence-based model comparison and some strongly multimodal posteriors. Most of the classes and functions used in this tutorial have more advanced features than what will be explained here, so it’s highly recommended to take a look at the code and API documentation. There should be enough flexibility built in to cover most typical retrieval studies, but if you have feature requests please get in touch, or open an issue on gitlab.

[ ]:
# Define the pRT run setup
retrieval_config = RetrievalConfig(
    retrieval_name="HR8799e_example",  # give a useful name for your retrieval
    run_mode="retrieval",  # 'retrieve' to run, or 'evaluate' to make plots
    sampler_type="numpyro_nuts",  # use NumPyro NUTS for this differentiable retrieval
    adaptive_mesh_refinement=True,  # adaptive mesh refinement, slower if True
    scattering_in_emission=True,  # add scattering for emission spectra clouds
)

For this example we include the GRAVITY data as published in Mollière et al. (2020). To reproduce the published results, please also include the archival SPHERE and GPI data from Zurlo et al. (2015) and Greenbaum et al. (2016).

[ ]:
# Read in Data

# Here we import petitRADTRANS to find the path of the example files on your machine
# In general this is not required, you just put the files in the folder that you are running
# Your script in, for example
path_to_data = "./"

retrieval_config.add_data(
    "GRAVITY",
    f"{path_to_data}/retrievals/emission/observations/HR8799e_Spectra.fits",
    data_resolution=500,
    model_resolution=1000,
    model_generating_function=molliere_2020_emission,
)

# Note that in Mollière et al. (2020), additional data sets from SPHERE and GPI were used

Photometric data#

If we want to add photometry, we can do that as well! The photometry file should have the format:

# Name, lower wavelength edge [um], upper wavelength edge [um], flux density [W/m2/micron], flux error [W/m2/micron]

You are required to provide a model function for calculating the spectrum, as with spectral data, but also a photometric transformation function, which is used to convert the model spectrum into synthetic photometry. This would typically make use of the transmission function for a particular filter. We recommend the use of the species package (https://species.readthedocs.io/), in particular the SyntheticPhotometry module to provide these functions. If no function is provided, the RetrievalConfig will attempt to import species to use this module, using the name provided as the filter name.

If you are using transmission spectra, your photometric transformation function should model the difference between the clear and occulted stellar spectrum, returning the difference in (planet radius/stellar radius)^2.

[ ]:
retrieval_config.add_photometry(
    path_to_data + "retrievals/emission/observations/hr8799e_photometry.txt", molliere_2020_emission, model_resolution=40
)
==============
species v0.8.3
==============

 -> A new version (0.9.0) is available!
 -> It is recommended to update to the latest version
 -> See https://github.com/tomasstolker/species for details

Working folder: /mnt/c/Users/doria/OneDrive/Documents/programs/Python/petitRADTRANS/docs/content/notebooks

Configuration file: /mnt/c/Users/doria/OneDrive/Documents/programs/Python/petitRADTRANS/docs/content/notebooks/species_config.ini
Database file: /mnt/c/Users/doria/OneDrive/Documents/programs/Python/petitRADTRANS/docs/content/notebooks/species_database.hdf5
Data folder: /mnt/c/Users/doria/OneDrive/Documents/programs/Python/petitRADTRANS/docs/content/notebooks/data

Configuration settings:
   - Database: /mnt/c/Users/doria/OneDrive/Documents/programs/Python/petitRADTRANS/docs/content/notebooks/species_database.hdf5
   - Data folder: /mnt/c/Users/doria/OneDrive/Documents/programs/Python/petitRADTRANS/docs/content/notebooks/data
   - Magnitude of Vega: 0.03

Multiprocessing: mpi4py installed
Process number 1 out of 1...

Parameters and Priors#

Here we add all of the parameters used in the retrieval for HR 8799 e, following the prescription of Mollière 2020. The forward model used below is the differentiable molliere_2020_emission implementation, so the same parameter set can be sampled with NumPyro NUTS through the runtime-native contract. For NumPyro NUTS, the recommended JAX-native way to define priors is to pass distribution= with TensorFlow Probability distributions. The lower-level transform_prior_cube_coordinate interface is still accepted by the generic JAX MCMC path, but distribution-backed priors are the clearer and more consistent choice for NumPyro, and they are required by JAXNS. There are many other approaches we could take: varying the temperature structure parameterisation, retrieving different cloud properties, adding in a blackbody CPD and so on. It is highly recommended to look at the API documentation of models.py, or at the Retrieval Models Tutorial to get a better idea of what options are available.

[ ]:
# Add parameters and priors for free parameters

# This run uses the model of Molliere (2020) for HR8799e
# For NumPyro NUTS we define JAX-native priors with TensorFlow Probability distributions.

# Distance to the planet in cm
retrieval_config.add_parameter(name="system_distance", is_free_parameter=False, value=41.2925 * cst.pc)

# Log of the surface gravity in cgs units.
retrieval_config.add_parameter(
    "log_g",
    True,
    distribution=tfpd.Uniform(low=2.0, high=5.5),
)

# Planet radius in cm
retrieval_config.add_parameter(
    "planet_radius",
    True,
    distribution=tfpd.Uniform(low=0.7 * cst.r_jup_mean, high=2.0 * cst.r_jup_mean),
)

# Temperature in Kelvin
retrieval_config.add_parameter(
    "T_int",
    True,
    distribution=tfpd.Uniform(low=300.0, high=2300.0),
    value=0.0,
 )

# Spline temperature structure parameters. T1 < T2 < T3
# As these priors depend on each other, they are implemented in the model function
retrieval_config.add_parameter("T3", True, distribution=tfpd.Uniform(low=0.0, high=1.0), value=0.0)
retrieval_config.add_parameter("T2", True, distribution=tfpd.Uniform(low=0.0, high=1.0), value=0.0)
retrieval_config.add_parameter("T1", True, distribution=tfpd.Uniform(low=0.0, high=1.0))
# Optical depth model
# power law index in tau = delta * press_cgs**alpha
retrieval_config.add_parameter("alpha", True, distribution=tfpd.Uniform(low=1.0, high=2.0))
# proportionality factor in tau = delta * press_cgs**alpha
retrieval_config.add_parameter("log_delta", True, distribution=tfpd.Uniform(low=0.0, high=1.0))
# Chemistry
# A 'free retrieval' would have each line species as a parameter
# Using a (dis)equilibrium model, we only supply bulk parameters.
# Carbon quench pressure
retrieval_config.add_parameter("log_pquench", True, distribution=tfpd.Uniform(low=-6.0, high=3.0))
# Metallicity [Fe/H]
retrieval_config.add_parameter("Fe/H", True, distribution=tfpd.Uniform(low=-1.5, high=1.5))
# C/O ratio
retrieval_config.add_parameter("C/O", True, distribution=tfpd.Uniform(low=0.1, high=1.6))
# Clouds
# Based on an Ackermann-Marley (2001) cloud model
# Width of particle size distribution
retrieval_config.add_parameter("sigma_lnorm", True, distribution=tfpd.Uniform(low=1.05, high=3.0))
# Vertical mixing parameters
retrieval_config.add_parameter("log_kzz", True, distribution=tfpd.Uniform(low=5.0, high=13.0))
# Sedimentation parameter
retrieval_config.add_parameter("fsed", True, distribution=tfpd.Uniform(low=1.0, high=11.0))
[ ]:
# Define opacity species to be included
retrieval_config.set_rayleigh_species(("H2", "He"))
retrieval_config.set_continuum_opacities(("H2--H2", "H2--He"))
retrieval_config.set_line_species(
    (
        "H2O__POKAZATEL",
        "CO-NatAbund",
        "CH4",
        "CO2",
        "HCN",
        "FeH",
        "H2S",
        "NH3",
        "PH3",
        "Na__NewAllard",
        "K__Allard",
        "TiO",
        "VO",
        "SiO",
    ),
    use_equilibrium_chemistry=True,
)

retrieval_config.add_cloud_species(
    "Fe(s)_crystalline_000__DHS",
    use_equilibrium_chemistry=True,
    equilbrium_mass_fraction_scaling_factor=(-3.5, 1.0),
)

retrieval_config.add_cloud_species(
    "MgSiO3(s)_crystalline_000__DHS",
    use_equilibrium_chemistry=True,
    equilbrium_mass_fraction_scaling_factor=(-3.5, 1.0),
)
[ ]:
# Before we run the retrieval, let's set up plotting.

# Define what to put into corner plot if run_mode == 'evaluate'
retrieval_config.parameters["planet_radius"].plot_in_corner = True
retrieval_config.parameters["planet_radius"].corner_label = r"$R_{\rm P}$ ($\rm R_{Jup}$)"
retrieval_config.parameters["planet_radius"].corner_transform = lambda x: x / cst.r_jup_mean
retrieval_config.parameters["log_g"].plot_in_corner = True
retrieval_config.parameters["log_g"].corner_ranges = [2.0, 5.0]
retrieval_config.parameters["log_g"].corner_label = "log g"
retrieval_config.parameters["fsed"].plot_in_corner = True
retrieval_config.parameters["log_kzz"].plot_in_corner = True
retrieval_config.parameters["log_kzz"].corner_label = "log Kzz"
retrieval_config.parameters["C/O"].plot_in_corner = True
retrieval_config.parameters["Fe/H"].plot_in_corner = True
retrieval_config.parameters["log_pquench"].plot_in_corner = True
retrieval_config.parameters["log_pquench"].corner_label = "log pquench"

for spec in retrieval_config.cloud_species:
    cname = spec.split("_")[0]
    retrieval_config.parameters["eq_scaling_" + cname].plot_in_corner = True
    retrieval_config.parameters["eq_scaling_" + cname].corner_label = cname



def configure_retrieval_plotter(retrieval):
    retrieval.plotter.spec_xlabel = "Wavelength [micron]"
    retrieval.plotter.spec_ylabel = "Flux [W/m2/micron]"
    retrieval.plotter.y_axis_scaling = 1.0
    retrieval.plotter.xscale = "log"
    retrieval.plotter.yscale = "linear"
    retrieval.plotter.resolution = 100.0  # maximum resolution, will rebin the data
    retrieval.plotter.nsample = 100  # if we want a plot with many sampled spectra
    retrieval.plotter.reference_data_name = "GRAVITY"
    retrieval.plotter.temp_limits = [150, 3000]
    retrieval.plotter.press_limits = [1e2, 1e-5]

[9]:
retrieval = Retrieval(retrieval_config, output_directory="./", evaluate_sample_spectra=False)
Using provided Radtrans object for data 'GRAVITY'...
Using provided Radtrans object for data 'Keck/NIRC2.Ks'...
Using provided Radtrans object for data 'Paranal/NACO.Lp'...
Using provided Radtrans object for data 'Paranal/NACO.NB405'...
Using provided Radtrans object for data 'Paranal/SPHERE.IRDIS_B_J'...
Using provided Radtrans object for data 'Paranal/SPHERE.IRDIS_D_H23_2'...
Using provided Radtrans object for data 'Paranal/SPHERE.IRDIS_D_H23_3'...
Using provided Radtrans object for data 'Paranal/SPHERE.IRDIS_D_K12_1'...
Using provided Radtrans object for data 'Paranal/SPHERE.IRDIS_D_K12_2'...
[ ]:
configure_retrieval_plotter(retrieval)

As mentioned at the beginning of this tutorial, this retrieval is still expensive and benefits from running on a workstation or cluster. For NumPyro NUTS, the main controls are now the warmup length, posterior sample count, and number of chains rather than nested-sampling live points. The example below uses vectorized chains and a higher target acceptance rate for a more conservative NUTS setup. Reduce num_chains if you are limited by memory.

To try to run the retrieval, set run_retrieval below to True, then execute the cells below. It is usually worth starting with a shorter smoke test before launching a longer production run.

[10]:
run_retrieval = False

NUTS is gradient-based MCMC: it uses local gradient information and warmup adaptation to explore a smooth posterior efficiently, but it does not estimate the Bayesian evidence. Nested sampling instead is designed to estimate both the posterior and the evidence, which makes it more natural for model comparison and often more robust for strongly multimodal posteriors.

[ ]:
if run_retrieval:
    retrieval.run(
        num_warmup=1000,
        num_samples=2000,
        num_chains=4,
        chain_method="vectorized",
        target_acceptance_rate=0.9,
        progress_bar=True,
        use_jit=True,
        seed=12345,  # remove or randomize for a production retrieval
    )

Once the retrieval is complete, the easiest way to generate standard output plots is to use the plot_all function. The NumPyro backend also stores chain diagnostics in out_NumPyro, so you can inspect trace information alongside the usual pRT summary products.

[12]:
if run_retrieval:
    retrieval.plot_all(contribution=True)

Contact

If you need any additional help, don’t hesitate to contact Evert Nasedkin.