petitRADTRANS.sbi.encoders
==========================

.. py:module:: petitRADTRANS.sbi.encoders

.. autoapi-nested-parse::

   Learned observation encoders for SBI conditioning.



Classes
-------

.. autoapisummary::

   petitRADTRANS.sbi.encoders.SpectralPatchEncoder
   petitRADTRANS.sbi.encoders.SpectralConv1DEncoder
   petitRADTRANS.sbi.encoders.PhotometryPointEncoder
   petitRADTRANS.sbi.encoders.DatasetSetAggregator
   petitRADTRANS.sbi.encoders.HierarchicalObservationEncoder


Module Contents
---------------

.. py:class:: SpectralPatchEncoder(embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 96, key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`


   Learned spectral encoder based on patch summaries and MLP pooling.

   Parameters
   ----------
   embedding_dim:
       Size of the spectral embedding returned for one observation block.
   patch_size:
       Number of spectral points summarized together before MLP encoding.
   hidden_dim:
       Hidden width of the internal summary MLP.
   key:
       Optional JAX random key used for weight initialization.

   Notes
   -----
   The encoder summarizes values, uncertainties, and coordinates patch by
   patch instead of operating on a fully convolutional representation. This
   keeps the implementation lightweight and shape-agnostic.


   .. py:attribute:: patch_mlp
      :type:  equinox.nn.MLP


   .. py:attribute:: amplitude_projection
      :type:  equinox.nn.MLP


   .. py:attribute:: branch_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: output_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: scale_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:attribute:: patch_size
      :type:  int


   .. py:method:: encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one spectral observation block.

      Parameters
      ----------
      block:
          Spectrum-like observation block containing values and optional
          uncertainties, coordinates, and masks.

      Returns
      -------
      jnp.ndarray
          Dense fixed-width embedding for the supplied spectral block.



   .. py:method:: _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) -> jax.numpy.ndarray

      Encode pre-processed block arrays (vmappable — no ObservationBlock input).

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays already produced by ``_as_vector``.
          ``uncertainties`` and ``coordinates`` may have length 0 to indicate
          absence.
      mask:
          Boolean 1-D array already produced by ``_safe_mask``.

      Returns
      -------
      jnp.ndarray
          Dense spectral embedding of shape ``(embedding_dim,)``.



.. py:class:: SpectralConv1DEncoder(embedding_dim: int = 64, n_wavelengths: int = 233, hidden_channels: tuple[int, int] = (32, 64), key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`


   Learned spectral encoder using 1D convolutions to retain positional information.

   Parameters
   ----------
   embedding_dim:
       Size of the spectral embedding returned for one observation block.
   n_wavelengths:
       Fixed number of spectral points per observation block. Must match the
       observation schema of the SBI task.
   hidden_channels:
       Tuple of intermediate channel widths for conv layers.
   key:
       Optional JAX random key used for weight initialization.


   .. py:attribute:: conv1
      :type:  equinox.nn.Conv1d


   .. py:attribute:: conv2
      :type:  equinox.nn.Conv1d


   .. py:attribute:: conv3
      :type:  equinox.nn.Conv1d


   .. py:attribute:: amplitude_projection
      :type:  equinox.nn.MLP


   .. py:attribute:: pool_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: branch_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: output_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: scale_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:attribute:: n_wavelengths
      :type:  int


   .. py:method:: _forward_conv(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) -> jax.numpy.ndarray

      Run 1D conv stack on pre-processed arrays.

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays of length ``n_wavelengths``.
      mask:
          Boolean 1-D mask (True = invalid).

      Returns
      -------
      jnp.ndarray
          Embedding of shape ``(embedding_dim,)``.



   .. py:method:: _pad_to_fixed(arr: jax.numpy.ndarray) -> jax.numpy.ndarray

      Pad or truncate a 1-D array to ``n_wavelengths``.



   .. py:method:: encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one spectral observation block.

      Parameters
      ----------
      block:
          Spectrum-like observation block.

      Returns
      -------
      jnp.ndarray
          Dense embedding of shape ``(embedding_dim,)``.



   .. py:method:: _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray, log_median_flux: jax.numpy.ndarray | None = None, absolute_values: jax.numpy.ndarray | None = None) -> jax.numpy.ndarray

      Encode pre-processed block arrays (vmappable — no ObservationBlock input).

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays already produced by ``_as_vector``.
      mask:
          Boolean 1-D array already produced by ``_safe_mask``.
      log_median_flux:
          Scalar absolute-flux feature derived from the raw spectrum median.

      Returns
      -------
      jnp.ndarray
          Dense spectral embedding of shape ``(embedding_dim,)``.



.. py:class:: PhotometryPointEncoder(embedding_dim: int = 64, hidden_dim: int = 96, key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`


   Learned photometry encoder using per-point MLP features and pooling.

   Parameters
   ----------
   embedding_dim:
       Size of the returned photometric embedding.
   hidden_dim:
       Hidden width of the per-point MLP.
   key:
       Optional JAX random key used for initialization.

   Notes
   -----
   Each photometric point is represented by value, uncertainty, coordinate,
   and an inferred width feature before permutation-invariant pooling.


   .. py:attribute:: point_mlp
      :type:  equinox.nn.MLP


   .. py:attribute:: output_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:method:: encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one photometric observation block.

      Parameters
      ----------
      block:
          Photometry-like observation block to encode.

      Returns
      -------
      jnp.ndarray
          Dense embedding for the supplied photometric block.



   .. py:method:: _encode_block_raw(values: jax.numpy.ndarray, uncertainties: jax.numpy.ndarray, coordinates: jax.numpy.ndarray, mask: jax.numpy.ndarray) -> jax.numpy.ndarray

      Encode pre-processed block arrays (vmappable — no ObservationBlock input).

      Parameters
      ----------
      values, uncertainties, coordinates:
          Float32 1-D arrays already produced by ``_as_vector``.
          ``uncertainties`` and ``coordinates`` may have length 0.
      mask:
          Boolean 1-D array already produced by ``_safe_mask``.

      Returns
      -------
      jnp.ndarray
          Dense photometric embedding of shape ``(embedding_dim,)``.



.. py:class:: DatasetSetAggregator(embedding_dim: int = 128, hidden_dim: int = 128, key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`


   Learned permutation-invariant aggregator over block embeddings.

   Parameters
   ----------
   embedding_dim:
       Target dimensionality of the aggregated observation embedding.
   hidden_dim:
       Hidden width of the block projection MLP.
   key:
       Optional JAX random key used for initialization.


   .. py:attribute:: block_projection
      :type:  equinox.nn.MLP


   .. py:attribute:: output_projection
      :type:  equinox.nn.Linear


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:method:: aggregate(block_embeddings: list[jax.numpy.ndarray]) -> jax.numpy.ndarray

      Aggregate a list of block embeddings into one observation embedding.

      Parameters
      ----------
      block_embeddings:
          Encoded observation blocks for one retrieval target.

      Returns
      -------
      jnp.ndarray
          Permutation-invariant aggregated embedding.



.. py:class:: HierarchicalObservationEncoder(embedding_dim: int = 128, spectrum_embedding_dim: int = 64, photometry_embedding_dim: int = 64, patch_size: int = 32, hidden_dim: int = 128, spectrum_encoder_type: str = 'conv1d', n_wavelengths: int = 233, key: jax.Array | None = None)

   Bases: :py:obj:`equinox.Module`, :py:obj:`petitRADTRANS.sbi.observation.ObservationEncoder`


   Learned hierarchical encoder dispatching over modalities.

   Parameters
   ----------
   embedding_dim:
       Size of the final joint observation embedding.
   spectrum_embedding_dim:
       Intermediate embedding size used by the spectral sub-encoder.
   photometry_embedding_dim:
       Intermediate embedding size used by the photometry sub-encoder.
   patch_size:
       Patch size used by the spectral encoder.
   hidden_dim:
       Hidden width shared across the component encoders and aggregator.
   key:
       Optional JAX random key used to initialize all submodules.

   Notes
   -----
   Spectrum and photometry blocks are handled by dedicated sub-encoders and
   then merged with a permutation-invariant aggregator. Unsupported modalities
   currently fall back to resized raw-value vectors.


   .. py:attribute:: spectrum_encoder
      :type:  SpectralPatchEncoder | SpectralConv1DEncoder


   .. py:attribute:: photometry_encoder
      :type:  PhotometryPointEncoder


   .. py:attribute:: aggregator
      :type:  DatasetSetAggregator


   .. py:attribute:: embedding_dim
      :type:  int


   .. py:method:: _encode_block(block: petitRADTRANS.sbi.observation.ObservationBlock) -> jax.numpy.ndarray

      Encode one observation block with modality-aware dispatch.

      Parameters
      ----------
      block:
          Observation block to encode.

      Returns
      -------
      jnp.ndarray
          Fixed-width embedding for the supplied block.



   .. py:method:: encode(blocks: list[petitRADTRANS.sbi.observation.ObservationBlock]) -> petitRADTRANS.sbi.observation.EncodedObservation

      Encode one structured observation made of multiple blocks.

      Parameters
      ----------
      blocks:
          Observation blocks associated with one target system.

      Returns
      -------
      EncodedObservation
          Aggregated embedding and light metadata describing the block set.



   .. py:method:: encode_stacked_batch(blocks_batch: list[list[petitRADTRANS.sbi.observation.ObservationBlock]]) -> jax.numpy.ndarray

      Encode a batch of identically-structured observations using vmap.

      All observations must share the same block structure (same number of
      blocks, same array shapes per block).  This is always the case for SBI
      tasks where every observation is produced by the same forward model.

      Parameters
      ----------
      blocks_batch:
          Outer list indexes samples; inner list holds the per-block
          observation data for one sample.

      Returns
      -------
      jnp.ndarray
          Float32 array of shape ``(n_samples, embedding_dim)``.



   .. py:method:: encode_from_prestacked(obs: Any) -> jax.numpy.ndarray

      Encode a batch of observations from pre-stacked arrays.

      Accepts a :class:`~petitRADTRANS.sbi.observation.PreStackedObservations`
      instance whose array fields have already been extracted from
      ``ObservationBlock`` objects outside the JAX JIT boundary.  Only the
      vmapped XLA computation runs here, enabling the training step to be
      compiled once and reused across all batches.

      Parameters
      ----------
      obs:
          Pre-stacked observation container with ``stacked_blocks`` (one
          ``(values, uncertainties, coordinates, mask, log_scale,
          absolute_values)`` tuple per block, each of shape ``(batch_size,
          n_wl)`` except the scalar ``log_scale`` arrays) and ``modalities``
          (static tuple of modality value strings).

      Returns
      -------
      jnp.ndarray
          Float32 array of shape ``(batch_size, embedding_dim)``.



