petitRADTRANS.sbi.encoders#
Learned observation encoders for SBI conditioning.
Classes#
Learned spectral encoder based on patch summaries and MLP pooling. |
|
Learned spectral encoder using 1D convolutions to retain positional information. |
|
Learned photometry encoder using per-point MLP features and pooling. |
|
Learned permutation-invariant aggregator over block embeddings. |
|
Learned hierarchical encoder dispatching over modalities. |
Module Contents#
- class petitRADTRANS.sbi.encoders.SpectralPatchEncoder(embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 96, key: jax.Array | None = None)#
Bases:
equinox.ModuleLearned spectral encoder based on patch summaries and MLP pooling.
Parameters#
- embedding_dim:
Size of the spectral embedding returned for one observation block.
- patch_size:
Number of spectral points summarized together before MLP encoding.
- hidden_dim:
Hidden width of the internal summary MLP.
- key:
Optional JAX random key used for weight initialization.
Notes#
The encoder summarizes values, uncertainties, and coordinates patch by patch instead of operating on a fully convolutional representation. This keeps the implementation lightweight and shape-agnostic.
- patch_mlp: equinox.nn.MLP#
- amplitude_projection: equinox.nn.MLP#
- branch_projection: equinox.nn.Linear#
- output_projection: equinox.nn.Linear#
- scale_projection: equinox.nn.Linear#
- embedding_dim: int#
- patch_size: int#
- encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one spectral observation block.
Parameters#
- block:
Spectrum-like observation block containing values and optional uncertainties, coordinates, and masks.
Returns#
- jnp.ndarray
Dense fixed-width embedding for the supplied spectral block.
- _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#
Encode pre-processed block arrays (vmappable — no ObservationBlock input).
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays already produced by
_as_vector.uncertaintiesandcoordinatesmay have length 0 to indicate absence.- mask:
Boolean 1-D array already produced by
_safe_mask.
Returns#
- jnp.ndarray
Dense spectral embedding of shape
(embedding_dim,).
- class petitRADTRANS.sbi.encoders.SpectralConv1DEncoder(embedding_dim: int = 64, n_wavelengths: int = 233, hidden_channels: tuple[int, int] = (32, 64), key: jax.Array | None = None)#
Bases:
equinox.ModuleLearned spectral encoder using 1D convolutions to retain positional information.
Parameters#
- embedding_dim:
Size of the spectral embedding returned for one observation block.
- n_wavelengths:
Fixed number of spectral points per observation block. Must match the observation schema of the SBI task.
- hidden_channels:
Tuple of intermediate channel widths for conv layers.
- key:
Optional JAX random key used for weight initialization.
- conv1: equinox.nn.Conv1d#
- conv2: equinox.nn.Conv1d#
- conv3: equinox.nn.Conv1d#
- amplitude_projection: equinox.nn.MLP#
- pool_projection: equinox.nn.Linear#
- branch_projection: equinox.nn.Linear#
- output_projection: equinox.nn.Linear#
- scale_projection: equinox.nn.Linear#
- embedding_dim: int#
- n_wavelengths: int#
- _forward_conv(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#
Run 1D conv stack on pre-processed arrays.
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays of length
n_wavelengths.- mask:
Boolean 1-D mask (True = invalid).
Returns#
- jnp.ndarray
Embedding of shape
(embedding_dim,).
- _pad_to_fixed(arr: jax.numpy.ndarray) jax.numpy.ndarray#
Pad or truncate a 1-D array to
n_wavelengths.
- encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one spectral observation block.
Parameters#
- block:
Spectrum-like observation block.
Returns#
- jnp.ndarray
Dense embedding of shape
(embedding_dim,).
- _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) jax.numpy.ndarray#
Encode pre-processed block arrays (vmappable — no ObservationBlock input).
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays already produced by
_as_vector.- mask:
Boolean 1-D array already produced by
_safe_mask.- log_median_flux:
Scalar absolute-flux feature derived from the raw spectrum median.
Returns#
- jnp.ndarray
Dense spectral embedding of shape
(embedding_dim,).
- class petitRADTRANS.sbi.encoders.PhotometryPointEncoder(embedding_dim: int = 64, hidden_dim: int = 96, key: jax.Array | None = None)#
Bases:
equinox.ModuleLearned photometry encoder using per-point MLP features and pooling.
Parameters#
- embedding_dim:
Size of the returned photometric embedding.
- hidden_dim:
Hidden width of the per-point MLP.
- key:
Optional JAX random key used for initialization.
Notes#
Each photometric point is represented by value, uncertainty, coordinate, and an inferred width feature before permutation-invariant pooling.
- point_mlp: equinox.nn.MLP#
- output_projection: equinox.nn.Linear#
- embedding_dim: int#
- encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one photometric observation block.
Parameters#
- block:
Photometry-like observation block to encode.
Returns#
- jnp.ndarray
Dense embedding for the supplied photometric block.
- _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray) jax.numpy.ndarray#
Encode pre-processed block arrays (vmappable — no ObservationBlock input).
Parameters#
- values, uncertainties, coordinates:
Float32 1-D arrays already produced by
_as_vector.uncertaintiesandcoordinatesmay have length 0.- mask:
Boolean 1-D array already produced by
_safe_mask.
Returns#
- jnp.ndarray
Dense photometric embedding of shape
(embedding_dim,).
- class petitRADTRANS.sbi.encoders.DatasetSetAggregator(embedding_dim: int = 128, hidden_dim: int = 128, key: jax.Array | None = None)#
Bases:
equinox.ModuleLearned permutation-invariant aggregator over block embeddings.
Parameters#
- embedding_dim:
Target dimensionality of the aggregated observation embedding.
- hidden_dim:
Hidden width of the block projection MLP.
- key:
Optional JAX random key used for initialization.
- block_projection: equinox.nn.MLP#
- output_projection: equinox.nn.Linear#
- embedding_dim: int#
- class petitRADTRANS.sbi.encoders.HierarchicalObservationEncoder(embedding_dim: int = 128, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 128, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, key: jax.Array | None = None)#
Bases:
equinox.Module,petitRADTRANS.sbi.observation.ObservationEncoderLearned hierarchical encoder dispatching over modalities.
Parameters#
- embedding_dim:
Size of the final joint observation embedding.
- spectrum_embedding_dim:
Intermediate embedding size used by the spectral sub-encoder.
- photometry_embedding_dim:
Intermediate embedding size used by the photometry sub-encoder.
- patch_size:
Patch size used by the spectral encoder.
- hidden_dim:
Hidden width shared across the component encoders and aggregator.
- key:
Optional JAX random key used to initialize all submodules.
Notes#
Spectrum and photometry blocks are handled by dedicated sub-encoders and then merged with a permutation-invariant aggregator. Unsupported modalities currently fall back to resized raw-value vectors.
- spectrum_encoder: SpectralPatchEncoder | SpectralConv1DEncoder#
- photometry_encoder: PhotometryPointEncoder#
- aggregator: DatasetSetAggregator#
- embedding_dim: int#
- _encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) jax.numpy.ndarray#
Encode one observation block with modality-aware dispatch.
Parameters#
- block:
Observation block to encode.
Returns#
- jnp.ndarray
Fixed-width embedding for the supplied block.
- encode(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) petitRADTRANS.sbi.observation.EncodedObservation#
Encode one structured observation made of multiple blocks.
Parameters#
- blocks:
Observation blocks associated with one target system.
Returns#
- EncodedObservation
Aggregated embedding and light metadata describing the block set.
- encode_stacked_batch(blocks_batch: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) jax.numpy.ndarray#
Encode a batch of identically-structured observations using vmap.
All observations must share the same block structure (same number of blocks, same array shapes per block). This is always the case for SBI tasks where every observation is produced by the same forward model.
Parameters#
- blocks_batch:
Outer list indexes samples; inner list holds the per-block observation data for one sample.
Returns#
- jnp.ndarray
Float32 array of shape
(n_samples, embedding_dim).
- encode_from_prestacked(obs: Any) jax.numpy.ndarray#
Encode a batch of observations from pre-stacked arrays.
Accepts a
PreStackedObservationsinstance whose array fields have already been extracted fromObservationBlockobjects outside the JAX JIT boundary. Only the vmapped XLA computation runs here, enabling the training step to be compiled once and reused across all batches.Parameters#
- obs:
Pre-stacked observation container with
stacked_blocks(one(values, uncertainties, coordinates, mask, log_scale, absolute_values)tuple per block, each of shape(batch_size, n_wl)except the scalarlog_scalearrays) andmodalities(static tuple of modality value strings).
Returns#
- jnp.ndarray
Float32 array of shape
(batch_size, embedding_dim).