petitRADTRANS.radtrans_core.linalg
==================================

.. py:module:: petitRADTRANS.radtrans_core.linalg


Attributes
----------

.. autoapisummary::

   petitRADTRANS.radtrans_core.linalg.TINIEST
   petitRADTRANS.radtrans_core.linalg.HUGE_SQRT
   petitRADTRANS.radtrans_core.linalg.GRAD_SAFE_DENOM
   petitRADTRANS.radtrans_core.linalg._USE_GPU


Functions
---------

.. autoapisummary::

   petitRADTRANS.radtrans_core.linalg.solve_tridiagonal_thomas
   petitRADTRANS.radtrans_core.linalg.solve_tridiagonal_thomas_batched
   petitRADTRANS.radtrans_core.linalg.factor_feautrier_thomas
   petitRADTRANS.radtrans_core.linalg.solve_feautrier_thomas_prefactored
   petitRADTRANS.radtrans_core.linalg.solve_tridiagonal_thomas_batched_feautrier_fused
   petitRADTRANS.radtrans_core.linalg.solve_tridiagonal_thomas_batched_fused
   petitRADTRANS.radtrans_core.linalg.solve_tridiagonal_pcr
   petitRADTRANS.radtrans_core.linalg.solve_tridiagonal_pcr_batched_feautrier_fused
   petitRADTRANS.radtrans_core.linalg.solve_feautrier_tridiagonal
   petitRADTRANS.radtrans_core.linalg.factor_feautrier
   petitRADTRANS.radtrans_core.linalg.solve_feautrier_prefactored
   petitRADTRANS.radtrans_core.linalg.linear_fit


Module Contents
---------------

.. py:data:: TINIEST

.. py:data:: HUGE_SQRT

.. py:data:: GRAD_SAFE_DENOM

.. py:data:: _USE_GPU

.. py:function:: solve_tridiagonal_thomas(a, b, c, d)

   Solves a tridiagonal system of equations using the Thomas algorithm.
   The system is of the form Ax = d, where A is a tridiagonal matrix.
   The inputs are Jax arrays.
   The implementation is a Jax port of the naive implementation of the Thomas algorithm.
   See https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm
   WARNING: This implementation is not stable for all matrices. It is recommended to use it only for
   diagonally dominant matrices.
   See https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm#Numerical_stability

   Args:
       a (jnp.ndarray): The lower diagonal of the matrix A.
       b (jnp.ndarray): The main diagonal of the matrix A.
       c (jnp.ndarray): The upper diagonal of the matrix A.
       d (jnp.ndarray): The right-hand side of the equation.

   Returns:
       jnp.ndarray: The solution of the system.


.. py:function:: solve_tridiagonal_thomas_batched(a, b, c, d)

   Solves n independent tridiagonal systems simultaneously.

   Parameters
   ----------
   a, b, c, d : (batch, n_layers)
       Lower diagonal (a[:,0] unused), main diagonal,
       upper diagonal (c[:,-1] unused), RHS.

   Returns
   -------
   x : (batch, n_layers)


.. py:function:: factor_feautrier_thomas(a, b, c, n_g, n_freq_bins, n_angles, n_layers)

   Pre-factor the LHS of the Feautrier tridiagonal system.

   Performs the forward elimination that depends only on (a, b, c),
   caching the modified upper diagonal and reciprocal denominators so
   that subsequent solves with different RHS vectors only need the
   cheaper RHS-only forward sweep and backward substitution.

   Parameters
   ----------
   a, b, c : jax.Array, shape (batch, n_layers)
       Tridiagonal coefficients where batch = n_g * n_freq_bins * n_angles.
   n_g, n_freq_bins, n_angles, n_layers : int
       Static shape parameters.

   Returns
   -------
   c_prime : jax.Array, shape (n_layers, batch)
       Modified upper diagonal from Thomas forward elimination.
       ``c_prime[n_layers-1]`` is unused (padding) but kept for shape
       consistency.
   inv_denom : jax.Array, shape (n_layers, batch)
       Reciprocal of the elimination denominators.  Includes the surface
       row at index ``n_layers - 1``.
   a_times_inv_denom : jax.Array, shape (n_layers, batch)
       ``a[:, k] * inv_denom[k]`` pre-multiplied for the RHS sweep.


.. py:function:: solve_feautrier_thomas_prefactored(c_prime, inv_denom, a_times_inv_denom, d_inner, d_surf, angle_weight_a, n_g, n_freq_bins, n_angles, n_layers)

   Solve the Feautrier tridiagonal system using pre-factored LHS arrays.

   Uses the output of :func:`factor_feautrier_thomas` to skip the LHS
   forward elimination, performing only the RHS forward sweep and backward
   substitution.

   Parameters
   ----------
   c_prime : jax.Array, shape (n_layers, batch)
       Modified upper diagonal from :func:`factor_feautrier_thomas`.
   inv_denom : jax.Array, shape (n_layers, batch)
       Reciprocal denominators from :func:`factor_feautrier_thomas`.
   a_times_inv_denom : jax.Array, shape (n_layers, batch)
       Pre-multiplied ``a * inv_denom`` from :func:`factor_feautrier_thomas`.
   d_inner : jax.Array, shape (n_g, n_freq_bins, n_layers - 1)
       Angle-independent RHS entries for layers 0 .. n_layers-2.
   d_surf : jax.Array, shape (n_g, n_freq_bins, n_angles)
       Boundary RHS entry for the surface layer.
   angle_weight_a : jax.Array, shape (n_angles,)
       Angular quadrature weights for mean-intensity accumulation.
   n_g, n_freq_bins, n_angles, n_layers : int
       Static shape parameters.

   Returns
   -------
   I_H_top : jax.Array, shape (n_g, n_freq_bins, n_angles)
   x_last : jax.Array, shape (n_g, n_freq_bins, n_angles)
   x_nm2 : jax.Array, shape (n_g, n_freq_bins, n_angles)
   J_bol_gfl : jax.Array, shape (n_g, n_freq_bins, n_layers)


.. py:function:: solve_tridiagonal_thomas_batched_feautrier_fused(a, b, c, d_inner, d_surf, angle_weight_a, n_g, n_freq_bins, n_angles, n_layers)

   Solves Feautrier tridiagonal systems without materialising the full
   (g, f, a, l) solution tensor.

   Parameters
   ----------
   a, b, c : (batch, n_layers)
       Tridiagonal coefficients where batch = n_g * n_freq_bins * n_angles.
   d_inner : (n_g, n_freq_bins, n_layers - 1)
       Angle-independent RHS entries for layers 0 .. n_layers-2.
   d_surf : (n_g, n_freq_bins, n_angles)
       Boundary RHS entry for the surface layer.
   angle_weight_a : (n_angles,)
       Angular quadrature weights for mean-intensity accumulation.

   Returns
   -------
   I_H_top : (n_g, n_freq_bins, n_angles)
   x_last : (n_g, n_freq_bins, n_angles)
   x_nm2 : (n_g, n_freq_bins, n_angles)
   J_bol_gfl : (n_g, n_freq_bins, n_layers)


.. py:function:: solve_tridiagonal_thomas_batched_fused(a, b, c, d, mu_weight_flat, n_g, n_freq_bins, n_angles, n_layers)

   Solves n_batch independent tridiagonal systems simultaneously and fuses
   three downstream post-processing steps into the backward substitution pass,
   eliminating the full (batch, n_layers) solution array. More memory efficient,
   but slowe on CPU.

   Parameters
   ----------
   a, b, c, d : (batch, n_layers)
       Lower diagonal (a[:,0] unused), main diagonal,
       upper diagonal (c[:,-1] unused), RHS.
   mu_weight_flat : (batch,)
       Per-batch-element quadrature weight for J_bol accumulation.
       Equals mu_weight_a[angle_index] for each batch element, tiled over
       (g, f) so that mu_weight_flat[g*n_freq_bins*n_angles + f*n_angles + a]
       = mu_weight_a[a].
   n_g, n_freq_bins : int
       Used to reshape the J_bol accumulator.

   Returns
   -------
   I_H_top_flat      : (batch,)   = -x[0],   the top-boundary intensity.
   x_nm1_flat        : (batch,)   = x[n-1],  solution at the surface layer.
   x_nm2_flat        : (batch,)   = x[n-2],  solution at the penultimate layer.
   J_bol_gfl         : (n_g, n_freq_bins, n_layers)
                       Angle-quadrature-weighted mean intensity per (g,f,l):
                       J_bol_gfl[g,f,l] = sum_a( x[g,f,a,l] * mu_weight_a[a] ).

   Design
   ------
   Forward sweep  : identical to the original solver.
   Backward pass  :
     Rather than emitting x[k] at every scan step (which would materialise the
     full (n-1, batch) solution), we carry only what is needed:
       • x_next      (batch,)        — rolling solution value, used to compute x[k]
       • J_bol_acc   (n_g*n_freq_bins, n_layers) — per-(g,f,l) weighted partial sum;
                     updated at each step via a dynamic .at[:, k].add() scatter.

     x[n-1] and x[n-2] are extracted before the scan; x[0] is the final
     value of x_next in the carry after the scan.

   Memory saved vs. the original solver + downstream uses
   -------------------------------------------------------
     Eliminated : I_J_flat  (batch, n_layers)  — full Thomas solution
                  x_rest    (n-1, batch)        — backward scan output
                  I_J_gfal  (g, f, a, l)        — reshape of I_J_flat
                  Total ≈ 2 × n_g × n_freq_bins × n_angles × n_layers float64
     Added      : J_bol_acc (n_g*n_freq_bins, n_layers) in backward scan carry
                  ≈ n_g × n_freq_bins × n_layers float64  (factor n_angles smaller)
     Net saving ≈ (2*n_angles − 1) × n_g × n_freq_bins × n_layers float64.


.. py:function:: solve_tridiagonal_pcr(a, b, c, d)

   Solves a batch of tridiagonal systems using the parallel cyclic reduction (PCR) algorithm.

   This implementation is based on a parallel CUDA implementation and is suitable for GPU execution.
   For the method to be numerically stable, the matrix should be diagonally dominant.

   The shapes of the input arrays should be (..., N), where (...) is one or more batch dimensions
   and N is the size of each tridiagonal system.

   Source:
   https://github.com/tanim72/15418-final-project

   Args:
       a (jnp.ndarray): The lower diagonal of the matrix A, shape (..., N).
                        a[..., 0] should be 0.
       b (jnp.ndarray): The main diagonal of the matrix A, shape (..., N).
       c (jnp.ndarray): The upper diagonal of the matrix A, shape (..., N).
                        c[..., -1] should be 0.
       d (jnp.ndarray): The right-hand side of the equation, shape (..., N).

   Returns:
       jnp.ndarray: The solution of the system, shape (..., N).


.. py:function:: solve_tridiagonal_pcr_batched_feautrier_fused(a, b, c, d_inner, d_surf, angle_weight_a, n_g, n_freq_bins, n_angles, n_layers)

   PCR-based Feautrier tridiagonal solver optimised for GPU.

   Mirrors the interface of ``solve_tridiagonal_thomas_batched_feautrier_fused``
   but uses the O(log N) parallel cyclic reduction algorithm internally.

   Parameters
   ----------
   a, b, c : (batch, n_layers)
       Tridiagonal coefficients where batch = n_g * n_freq_bins * n_angles.
   d_inner : (n_g, n_freq_bins, n_layers - 1)
       Angle-independent RHS entries for layers 0 .. n_layers-2.
   d_surf : (n_g, n_freq_bins, n_angles)
       Boundary RHS entry for the surface layer.
   angle_weight_a : (n_angles,)
       Angular quadrature weights for mean-intensity accumulation.
   n_g, n_freq_bins, n_angles, n_layers : int
       Static shape parameters.

   Returns
   -------
   I_H_top : (n_g, n_freq_bins, n_angles)
   x_last : (n_g, n_freq_bins, n_angles)
   x_nm2 : (n_g, n_freq_bins, n_angles)
   J_bol_gfl : (n_g, n_freq_bins, n_layers)


.. py:function:: solve_feautrier_tridiagonal(a, b, c, d_inner, d_surf, angle_weight_a, n_g, n_freq_bins, n_angles, n_layers)

   Dispatch to the optimal Feautrier tridiagonal solver for the current device.

   Uses the sequential Thomas algorithm on CPU (good sequential throughput)
   and the parallel cyclic reduction algorithm on GPU (O(log N) depth).

   Parameters are identical to ``solve_tridiagonal_thomas_batched_feautrier_fused``.


.. py:function:: factor_feautrier(a, b, c, n_g, n_freq_bins, n_angles, n_layers)

   Pre-factor the Feautrier tridiagonal LHS.

   On CPU, delegates to :func:`factor_feautrier_thomas`.
   On GPU, returns ``None`` (PCR has no separable factorisation step).

   Returns
   -------
   factors : tuple or None
       ``(c_prime, inv_denom, a_times_inv_denom)`` on CPU, ``None`` on GPU.


.. py:function:: solve_feautrier_prefactored(factors, a, b, c, d_inner, d_surf, angle_weight_a, n_g, n_freq_bins, n_angles, n_layers)

   Solve the Feautrier system, using pre-factored LHS when available.

   On CPU (``factors is not None``), uses :func:`solve_feautrier_thomas_prefactored`
   which skips the LHS forward elimination.  On GPU (``factors is None``), falls
   back to the full PCR solver.

   Parameters
   ----------
   factors : tuple or None
       Output of :func:`factor_feautrier`.  ``None`` triggers the GPU path.
   a, b, c, d_inner, d_surf, angle_weight_a, n_g, n_freq_bins, n_angles, n_layers
       Same as :func:`solve_feautrier_tridiagonal`.


.. py:function:: linear_fit(x, y)

   Calculate slope and y-axis intercept of x,y data, assuming zero error on data.
   Translated from Fortran linear_fit.

   Args:
       x: 1D array of x values.
       y: 1D array of y values.

   Returns:
       Tuple (a, b) where a is the y-intercept and b is the slope.


