petitRADTRANS.sbi.encoders

Contents

petitRADTRANS.sbi.encoders#

Learned observation encoders for SBI conditioning.

Classes#

SpectralPatchEncoder

Learned spectral encoder based on patch summaries and MLP pooling.

SpectralConv1DEncoder

Learned spectral encoder using 1D convolutions to retain positional information.

PhotometryPointEncoder

Learned photometry encoder using per-point MLP features and pooling.

DatasetSetAggregator

Learned permutation-invariant aggregator over block embeddings.

HierarchicalObservationEncoder

Learned hierarchical encoder dispatching over modalities.

Module Contents#

class petitRADTRANS.sbi.encoders.SpectralPatchEncoder(embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 96, key: jax.Array | None = None)#

Bases: equinox.Module

Learned spectral encoder based on patch summaries and MLP pooling.

Parameters#

embedding_dim:

Size of the spectral embedding returned for one observation block.

patch_size:

Number of spectral points summarized together before MLP encoding.

hidden_dim:

Hidden width of the internal summary MLP.

key:

Optional JAX random key used for weight initialization.

Notes#

The encoder summarizes values, uncertainties, and coordinates patch by patch instead of operating on a fully convolutional representation. This keeps the implementation lightweight and shape-agnostic.

patch_mlp: equinox.nn.MLP#
amplitude_projection: equinox.nn.MLP#
branch_projection: equinox.nn.Linear#
output_projection: equinox.nn.Linear#
scale_projection: equinox.nn.Linear#
embedding_dim: int#
patch_size: int#
encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one spectral observation block.

Parameters#

block:

Spectrum-like observation block containing values and optional uncertainties, coordinates, and masks.

Returns#

jnp.ndarray

Dense fixed-width embedding for the supplied spectral block.

_encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#

Encode pre-processed block arrays (vmappable — no ObservationBlock input).

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays already produced by _as_vector. uncertainties and coordinates may have length 0 to indicate absence.

mask:

Boolean 1-D array already produced by _safe_mask.

Returns#

jnp.ndarray

Dense spectral embedding of shape (embedding_dim,).

class petitRADTRANS.sbi.encoders.SpectralConv1DEncoder(embedding_dim: int = 64, n_wavelengths: int = 233, hidden_channels: tuple[int, int] = (32, 64), key: jax.Array | None = None)#

Bases: equinox.Module

Learned spectral encoder using 1D convolutions to retain positional information.

Parameters#

embedding_dim:

Size of the spectral embedding returned for one observation block.

n_wavelengths:

Fixed number of spectral points per observation block. Must match the observation schema of the SBI task.

hidden_channels:

Tuple of intermediate channel widths for conv layers.

key:

Optional JAX random key used for weight initialization.

conv1: equinox.nn.Conv1d#
conv2: equinox.nn.Conv1d#
conv3: equinox.nn.Conv1d#
amplitude_projection: equinox.nn.MLP#
pool_projection: equinox.nn.Linear#
branch_projection: equinox.nn.Linear#
output_projection: equinox.nn.Linear#
scale_projection: equinox.nn.Linear#
embedding_dim: int#
n_wavelengths: int#
_forward_conv(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#

Run 1D conv stack on pre-processed arrays.

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays of length n_wavelengths.

mask:

Boolean 1-D mask (True = invalid).

Returns#

jnp.ndarray

Embedding of shape (embedding_dim,).

_pad_to_fixed(arr: jax.numpy.ndarray) jax.numpy.ndarray#

Pad or truncate a 1-D array to n_wavelengths.

encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one spectral observation block.

Parameters#

block:

Spectrum-like observation block.

Returns#

jnp.ndarray

Dense embedding of shape (embedding_dim,).

_encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#

Encode pre-processed block arrays (vmappable — no ObservationBlock input).

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays already produced by _as_vector.

mask:

Boolean 1-D array already produced by _safe_mask.

log_median_flux:

Scalar absolute-flux feature derived from the raw spectrum median.

Returns#

jnp.ndarray

Dense spectral embedding of shape (embedding_dim,).

class petitRADTRANS.sbi.encoders.PhotometryPointEncoder(embedding_dim: int = 64, hidden_dim: int = 96, key: jax.Array | None = None)#

Bases: equinox.Module

Learned photometry encoder using per-point MLP features and pooling.

Parameters#

embedding_dim:

Size of the returned photometric embedding.

hidden_dim:

Hidden width of the per-point MLP.

key:

Optional JAX random key used for initialization.

Notes#

Each photometric point is represented by value, uncertainty, coordinate, and an inferred width feature before permutation-invariant pooling.

point_mlp: equinox.nn.MLP#
output_projection: equinox.nn.Linear#
embedding_dim: int#
encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one photometric observation block.

Parameters#

block:

Photometry-like observation block to encode.

Returns#

jnp.ndarray

Dense embedding for the supplied photometric block.

_encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray) jax.numpy.ndarray#

Encode pre-processed block arrays (vmappable — no ObservationBlock input).

Parameters#

values, uncertainties, coordinates:

Float32 1-D arrays already produced by _as_vector. uncertainties and coordinates may have length 0.

mask:

Boolean 1-D array already produced by _safe_mask.

Returns#

jnp.ndarray

Dense photometric embedding of shape (embedding_dim,).

class petitRADTRANS.sbi.encoders.DatasetSetAggregator(embedding_dim: int = 128, hidden_dim: int = 128, key: jax.Array | None = None)#

Bases: equinox.Module

Learned permutation-invariant aggregator over block embeddings.

Parameters#

embedding_dim:

Target dimensionality of the aggregated observation embedding.

hidden_dim:

Hidden width of the block projection MLP.

key:

Optional JAX random key used for initialization.

block_projection: equinox.nn.MLP#
output_projection: equinox.nn.Linear#
embedding_dim: int#
aggregate(block_embeddings: list[jax.numpy.ndarray]) jax.numpy.ndarray#

Aggregate a list of block embeddings into one observation embedding.

Parameters#

block_embeddings:

Encoded observation blocks for one retrieval target.

Returns#

jnp.ndarray

Permutation-invariant aggregated embedding.

class petitRADTRANS.sbi.encoders.HierarchicalObservationEncoder(embedding_dim: int = 128, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 128, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, key: jax.Array | None = None)#

Bases: equinox.Module, petitRADTRANS.sbi.observation.ObservationEncoder

Learned hierarchical encoder dispatching over modalities.

Parameters#

embedding_dim:

Size of the final joint observation embedding.

spectrum_embedding_dim:

Intermediate embedding size used by the spectral sub-encoder.

photometry_embedding_dim:

Intermediate embedding size used by the photometry sub-encoder.

patch_size:

Patch size used by the spectral encoder.

hidden_dim:

Hidden width shared across the component encoders and aggregator.

key:

Optional JAX random key used to initialize all submodules.

Notes#

Spectrum and photometry blocks are handled by dedicated sub-encoders and then merged with a permutation-invariant aggregator. Unsupported modalities currently fall back to resized raw-value vectors.

spectrum_encoder: SpectralPatchEncoder | SpectralConv1DEncoder#
photometry_encoder: PhotometryPointEncoder#
aggregator: DatasetSetAggregator#
embedding_dim: int#
_encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#

Encode one observation block with modality-aware dispatch.

Parameters#

block:

Observation block to encode.

Returns#

jnp.ndarray

Fixed-width embedding for the supplied block.

encode(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#

Encode one structured observation made of multiple blocks.

Parameters#

blocks:

Observation blocks associated with one target system.

Returns#

EncodedObservation

Aggregated embedding and light metadata describing the block set.

encode_stacked_batch(blocks_batch: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) jax.numpy.ndarray#

Encode a batch of identically-structured observations using vmap.

All observations must share the same block structure (same number of blocks, same array shapes per block). This is always the case for SBI tasks where every observation is produced by the same forward model.

Parameters#

blocks_batch:

Outer list indexes samples; inner list holds the per-block observation data for one sample.

Returns#

jnp.ndarray

Float32 array of shape (n_samples, embedding_dim).

encode_from_prestacked(obs: Any) jax.numpy.ndarray#

Encode a batch of observations from pre-stacked arrays.

Accepts a PreStackedObservations instance whose array fields have already been extracted from ObservationBlock objects outside the JAX JIT boundary. Only the vmapped XLA computation runs here, enabling the training step to be compiled once and reused across all batches.

Parameters#

obs:

Pre-stacked observation container with stacked_blocks (one (values, uncertainties, coordinates, mask, log_scale, absolute_values) tuple per block, each of shape (batch_size, n_wl) except the scalar log_scale arrays) and modalities (static tuple of modality value strings).

Returns#

jnp.ndarray

Float32 array of shape (batch_size, embedding_dim).