Source code for diffct.differentiable

import math
import numpy as np
import torch
from numba import cuda

# ---------------------------------------------------------------------------
# Global settings & helpers
# ---------------------------------------------------------------------------

_DTYPE              = np.float32
# CUDA thread block configurations optimized for different dimensionalities
# 2D blocks: 16x16 = 256 threads per block, optimal for 2D ray-tracing kernels
# Balances occupancy with shared memory usage for parallel/fan beam projections
_TPB_2D             = (16, 16)
# 3D blocks: 8x8x8 = 512 threads per block, optimal for 3D cone beam kernels  
# Smaller per-dimension size accommodates higher register usage in 3D algorithms
_TPB_3D             = (8,  8,  8)
# CUDA fastmath optimization: enables aggressive floating-point optimizations
# Trades numerical precision for performance in ray-tracing calculations
# Safe for CT reconstruction where slight precision loss is acceptable for speed gains
_FASTMATH_DECORATOR = cuda.jit(cache=True, fastmath=True)

_INF                = _DTYPE(np.inf)
_EPSILON            = _DTYPE(1e-6)
# === Device Management Utilities ===
class DeviceManager:
    @staticmethod
    def get_device(tensor):
        """Get the device of a PyTorch tensor.

        Parameters
        ----------
        tensor : torch.Tensor
            Tensor whose device to determine.

        Returns
        -------
        torch.device
            Device of the tensor or CPU if unavailable.

        Examples
        --------
        >>> DeviceManager.get_device(torch.tensor([1, 2, 3]))
        device(type='cpu')
        """
        return tensor.device if hasattr(tensor, "device") else torch.device("cpu")

    @staticmethod
    def ensure_device(tensor, device):
        """Ensure a tensor resides on a given device.

        Parameters
        ----------
        tensor : torch.Tensor
            Tensor to move.
        device : torch.device
            Desired device.

        Returns
        -------
        torch.Tensor
            Tensor on the specified device. Unchanged if already on it.

        Examples
        --------
        >>> DeviceManager.ensure_device(
        ...     torch.tensor([1, 2, 3]),
        ...     torch.device('cuda')
        ... )
        tensor([1, 2, 3], device='cuda:0')
        """
        if hasattr(tensor, "to") and tensor.device != device:
            return tensor.to(device)
        return tensor

# === PyTorch-CUDA Bridge ===
class TorchCUDABridge:
    @staticmethod
    def tensor_to_cuda_array(tensor):
        """Convert a PyTorch CUDA tensor to a Numba CUDA DeviceNDArray.

        Provides a zero-copy view of a detached PyTorch tensor as a Numba CUDA array,
        avoiding CPU data transfers. The returned array shares memory with the
        original tensor.

        Parameters
        ----------
        tensor : torch.Tensor
            PyTorch tensor on a CUDA device.

        Returns
        -------
        numba.cuda.cudadrv.devicearray.DeviceNDArray
            Numba CUDA array view sharing memory with `tensor`.

        Raises
        ------
        ValueError
            If `tensor` is not on a CUDA device.

        Examples
        --------
        >>> t = torch.randn(10, device='cuda')
        >>> arr = TorchCUDABridge.tensor_to_cuda_array(t)
        """
        if not tensor.is_cuda:
            raise ValueError("Tensor must be on CUDA device")
        return cuda.as_cuda_array(tensor.detach())

# ---------------------------------------------------------------------------
# Stream helper (cached external Numba stream)
# ---------------------------------------------------------------------------
_cached_stream_ptr = None
_cached_numba_stream = None

def _get_numba_external_stream_for(pt_stream=None):
    """
    Return a cached numba.cuda.external_stream for the current PyTorch CUDA stream.
    Caches by the underlying CUDA stream pointer to avoid repeated construction.
    """
    global _cached_stream_ptr, _cached_numba_stream
    if pt_stream is None:
        pt_stream = torch.cuda.current_stream()
    # Torch exposes an underlying CUDA stream handle via .cuda_stream
    ptr = int(pt_stream.cuda_stream)
    if _cached_stream_ptr == ptr and _cached_numba_stream is not None:
        return _cached_numba_stream
    numba_stream = cuda.external_stream(pt_stream.cuda_stream)
    _cached_stream_ptr = ptr
    _cached_numba_stream = numba_stream
    return numba_stream

# === GPU-aware Trigonometric Table Generation ===
# Caching removed: torch.Tensor is unhashable for lru_cache
def _trig_tables(angles, dtype=_DTYPE, device=None):
    """Compute cosine and sine tables for input angles.

    Precompute cosine and sine values and return as torch tensors on the
    same device as `angles`.

    Parameters
    ----------
    angles : array-like or torch.Tensor
        Projection angles in radians. Can be a NumPy array or a PyTorch tensor on CPU or CUDA.
    dtype : numpy.dtype or torch.dtype, optional
        Desired data type for output tables. Default is `_DTYPE`.

    Returns
    -------
    cos : torch.Tensor
        Cosine values of `angles` on the same device.
    sin : torch.Tensor
        Sine values of `angles` on the same device.

    Examples
    --------
    >>> angles = torch.linspace(0, torch.pi, 180, device='cuda')
    >>> cos, sin = _trig_tables(angles)
    >>> cos.device
    device(type='cuda', index=0)
    """
    if isinstance(angles, torch.Tensor):
        device = angles.device if device is None else device
        # Compute both cos and sin in one call to avoid redundant kernel launches
        angles_device = angles.to(dtype=dtype, device=device)
        cos = torch.cos(angles_device)
        sin = torch.sin(angles_device)
        return cos, sin
    else:
        # fallback for non-tensor inputs: compute via PyTorch on CPU for consistency
        # Determine desired torch dtype
        if isinstance(dtype, torch.dtype):
            torch_dtype = dtype
        else:
            _NP_TO_TORCH = {
                np.float32: torch.float32,
                np.float64: torch.float64,
            }
            torch_dtype = _NP_TO_TORCH.get(dtype, torch.float32)
        # Convert input angles to a CPU torch tensor and compute both simultaneously
        angles_cpu = torch.tensor(angles, dtype=torch_dtype)
        cos_cpu = torch.cos(angles_cpu)
        sin_cpu = torch.sin(angles_cpu)
        if device is not None:
            return cos_cpu.to(device), sin_cpu.to(device)
        else:
            return cos_cpu, sin_cpu


# ############################################################################
# MEMORY LAYOUT VALIDATION
# ############################################################################

def _validate_3d_memory_layout(tensor, expected_order='DHW'):
    """Validate 3D tensor memory layout to prevent coordinate system inconsistencies.

    Parameters
    ----------
    tensor : torch.Tensor
        3D tensor to validate
    expected_order : str, optional
        Expected memory order ('DHW', 'VHW', etc.). Default is 'DHW'.

    Raises
    ------
    ValueError
        If tensor has unexpected memory layout or is non-contiguous
    """
    shape = tensor.shape
    if len(shape) != 3:
        raise ValueError(f"Expected 3D tensor, got {len(shape)}D")

    # Early return for common case - contiguous tensor with expected ordering
    if tensor.is_contiguous() and expected_order in ('DHW', 'VHW'):
        # For DHW and VHW, the expected order matches memory layout when contiguous
        return
    
    # Only check memory order for DHW and VHW, not for internal WHD layout
    if expected_order in ('DHW', 'VHW'):
        if not tensor.is_contiguous():
            raise ValueError(
                "Input tensor must be contiguous. Call .contiguous() before passing to "
                "cone beam functions to avoid memory duplication and ensure correct results."
            )

        strides = tensor.stride()
        order_mapping = {
            'DHW': (0, 1, 2),  # Depth, Height, Width
            'VHW': (0, 1, 2),  # Views, Height, Width (for sinograms)
        }
        if expected_order not in order_mapping:
            raise ValueError(f"Unsupported expected_order: {expected_order}")

        expected_stride_order = order_mapping[expected_order]
        # Check if actual strides match expected order
        sorted_strides = sorted(enumerate(strides), key=lambda x: x[1], reverse=True)
        actual_order = tuple(idx for idx, _ in sorted_strides)

        if actual_order != expected_stride_order:
            # Create appropriate error message based on context
            if expected_order == 'VHW':
                actual_str = f"({shape[0]}, {shape[1]}, {shape[2]})"
                expected_str = "(Views, Height, Width)"
                fix_str = "ensure your sinogram has shape (num_views, det_v, det_u)"
            elif expected_order == 'DHW':
                actual_str = f"({shape[0]}, {shape[1]}, {shape[2]})"
                expected_str = "(Depth, Height, Width)"
                fix_str = "ensure your volume has shape (D, H, W)"
            else:
                actual_str = str(tuple(shape))
                expected_str = expected_order
                fix_str = "check tensor dimensions"

            raise ValueError(
                f"Memory layout mismatch: expected {expected_str} order, "
                f"but tensor has shape {actual_str}. Please {fix_str} and ensure "
                f"the tensor is contiguous (.contiguous()) before passing to the function."
            )
    # For 'WHD' (internal layout), skip stride check entirely


def _grid_2d(n1, n2, tpb=_TPB_2D):
    """Compute 2D CUDA grid and block dimensions.

    Determine optimal grid and block sizes for 2D CUDA ray-tracing kernels.

    Parameters
    ----------
    n1 : int
        Number of elements along the first dimension (e.g., projection angles).
    n2 : int
        Number of elements along the second dimension (e.g., detector elements).
    tpb : tuple of int, optional
        Threads per block (default is `_TPB_2D`) to balance occupancy and memory.

    Returns
    -------
    grid : tuple of int
        Blocks count per axis.
    tpb : tuple of int
        Threads per block per axis.

    Examples
    --------
    >>> grid, tpb = _grid_2d(180, 256)
    >>> grid
    (12, 16)
    >>> tpb
    (16, 16)
    """
    return (math.ceil(n1 / tpb[0]), math.ceil(n2 / tpb[1])), tpb


def _grid_3d(n1, n2, n3, tpb=_TPB_3D):
    """Compute 3D CUDA grid and block dimensions.

    Determine optimal grid and block sizes for 3D CUDA cone-beam kernels.

    Parameters
    ----------
    n1 : int
        Number of elements along the first dimension (e.g., projection views).
    n2 : int
        Number of elements along the second dimension (e.g., detector u-axis).
    n3 : int
        Number of elements along the third dimension (e.g., detector v-axis).
    tpb : tuple of int, optional
        Threads per block (default is `_TPB_3D`) to balance occupancy and registers.

    Returns
    -------
    grid : tuple of int
        Blocks count per axis.
    tpb : tuple of int
        Threads per block per axis.

    Examples
    --------
    >>> grid, tpb = _grid_3d(360, 256, 256)
    >>> grid
    (45, 32, 32)
    >>> tpb
    (8, 8, 8)
    """
    return (
        math.ceil(n1 / tpb[0]),
        math.ceil(n2 / tpb[1]),
        math.ceil(n3 / tpb[2]),
    ), tpb


def detector_coordinates_1d(num_detectors, detector_spacing, detector_offset=0.0, device=None, dtype=torch.float32):
    """Return centered detector coordinates in physical units.

    Coordinates follow the convention ``(i - (N-1)/2) * spacing + offset``.
    This avoids a half-pixel center bias for even detector counts.
    """
    idx = torch.arange(num_detectors, device=device, dtype=dtype)
    return (idx - (num_detectors - 1) * 0.5) * detector_spacing + detector_offset


def angular_integration_weights(angles, redundant_full_scan=True):
    """Compute per-view integration weights from the provided angle samples.

    Parameters
    ----------
    angles : torch.Tensor
        1D projection angles in radians.
    redundant_full_scan : bool, optional
        If ``True``, applies a 0.5 factor for near-``2*pi`` scans to account for
        view redundancy in reconstruction formulas using full circular data.
    """
    if not isinstance(angles, torch.Tensor):
        angles = torch.tensor(angles, dtype=torch.float32)
    a = angles.to(dtype=torch.float32)
    if a.ndim != 1 or a.numel() < 2:
        raise ValueError("angles must be a 1D tensor with at least 2 elements")

    coverage = torch.abs(a[-1] - a[0]) + torch.abs(a[1] - a[0])

    if coverage >= (2.0 * torch.pi - 1e-3):
        # Periodic boundary integration weights for full circular sampling.
        d = torch.diff(a, append=(a[:1] + 2.0 * torch.pi))
        d = torch.abs(d)
        w = 0.5 * (d + torch.roll(d, shifts=1, dims=0))
        if redundant_full_scan:
            w = 0.5 * w
    else:
        # Non-periodic trapezoidal weights for partial scans.
        w = torch.empty_like(a)
        w[1:-1] = 0.5 * (a[2:] - a[:-2])
        w[0] = 0.5 * (a[1] - a[0])
        w[-1] = 0.5 * (a[-1] - a[-2])
        w = torch.abs(w)
    return w


def fan_cosine_weights(num_detectors, detector_spacing, sdd, detector_offset=0.0, device=None, dtype=torch.float32):
    """Return fan-beam cosine pre-weights ``cos(gamma)`` for each detector bin."""
    u = detector_coordinates_1d(num_detectors, detector_spacing, detector_offset, device=device, dtype=dtype)
    gamma = torch.atan(u / sdd)
    return torch.cos(gamma)


def cone_cosine_weights(det_u, det_v, du, dv, sdd, detector_offset_u=0.0, detector_offset_v=0.0, device=None, dtype=torch.float32):
    """Return FDK cosine pre-weights ``D/sqrt(D^2 + u^2 + v^2)`` on a 2D detector."""
    u = detector_coordinates_1d(det_u, du, detector_offset_u, device=device, dtype=dtype).view(det_u, 1)
    v = detector_coordinates_1d(det_v, dv, detector_offset_v, device=device, dtype=dtype).view(1, det_v)
    return sdd / torch.sqrt(sdd * sdd + u * u + v * v)


def parker_weights(angles, num_detectors, detector_spacing, sdd, detector_offset=0.0):
    """Return Parker redundancy weights for fan/cone short-scan geometries.

    For full scans (near ``2*pi``), this returns all ones.
    """
    if not isinstance(angles, torch.Tensor):
        angles = torch.tensor(angles, dtype=torch.float32)
    a = angles.to(dtype=torch.float32)
    if a.ndim != 1 or a.numel() < 2:
        raise ValueError("angles must be a 1D tensor with at least 2 elements")

    # Approximate covered range including the final sample interval.
    coverage = torch.abs(a[-1] - a[0]) + torch.abs(a[1] - a[0])
    if coverage >= (2.0 * torch.pi - 1e-3):
        return torch.ones((a.numel(), num_detectors), dtype=a.dtype, device=a.device)

    u = detector_coordinates_1d(
        num_detectors,
        detector_spacing,
        detector_offset=detector_offset,
        device=a.device,
        dtype=a.dtype,
    )
    gamma = torch.atan(u / sdd).view(1, num_detectors)
    gamma_max = torch.max(torch.abs(gamma))
    min_short_scan = torch.pi + 2.0 * gamma_max
    if coverage < (min_short_scan - 1e-3):
        raise ValueError(
            "Insufficient angular coverage for Parker weighting. "
            f"Need at least {float(min_short_scan):.6f} rad, got {float(coverage):.6f} rad."
        )

    beta = (a - a[0]).view(-1, 1)
    eps = 1e-6

    # Exact Parker form for minimal short scan (pi + 2*gamma_max).
    if coverage <= (min_short_scan + 1e-3):
        t1 = 2.0 * (gamma_max - gamma)
        t2 = torch.pi - 2.0 * gamma
        t3 = torch.pi + 2.0 * gamma_max
        t4 = 2.0 * (gamma_max + gamma)

        w1 = 0.5 * (1.0 - torch.cos(torch.pi * beta / torch.clamp(t1, min=eps)))
        w3 = 0.5 * (1.0 - torch.cos(torch.pi * (t3 - beta) / torch.clamp(t4, min=eps)))

        cond1 = (beta >= 0.0) & (beta < t1)
        cond2 = (beta >= t1) & (beta <= t2)
        cond3 = (beta > t2) & (beta <= t3)
        ones = torch.ones_like(w1)
        zeros = torch.zeros_like(w1)
        return torch.where(cond1, w1, torch.where(cond2, ones, torch.where(cond3, w3, zeros)))

    # Fallback for over-scan (<2*pi): cosine taper at both ends.
    ramp = 2.0 * gamma_max
    wb = torch.ones_like(a)
    if ramp > eps:
        b = beta[:, 0]
        lead = 0.5 * (1.0 - torch.cos(torch.pi * torch.clamp(b / ramp, min=0.0, max=1.0)))
        trail = 0.5 * (1.0 - torch.cos(torch.pi * torch.clamp((coverage - b) / ramp, min=0.0, max=1.0)))
        wb = torch.minimum(torch.maximum(lead, trail), torch.ones_like(lead))
    return wb.view(-1, 1).expand(-1, num_detectors)


def ramp_filter_1d(sinogram_tensor, dim=-1):
    """Apply a 1D ramp filter along ``dim`` using FFT."""
    n = sinogram_tensor.shape[dim]
    freqs = torch.fft.fftfreq(n, device=sinogram_tensor.device)
    ramp = torch.abs(2.0 * torch.pi * freqs)
    shape = [1] * sinogram_tensor.ndim
    shape[dim] = n
    ramp = ramp.reshape(shape)
    sino_fft = torch.fft.fft(sinogram_tensor, dim=dim)
    return torch.real(torch.fft.ifft(sino_fft * ramp, dim=dim))


# ############################################################################
# SHARED CUDA KERNELS
# ############################################################################

# ------------------------------------------------------------------
# 2-D PARALLEL BEAM KERNELS
# ------------------------------------------------------------------

@_FASTMATH_DECORATOR
def _parallel_2d_forward_kernel(
    d_image, Nx, Ny,
    d_sino, n_ang, n_det,
    det_spacing, d_cos, d_sin, cx, cy, voxel_spacing,
    det_offset, center_offset_x, center_offset_y
):
    """Compute the 2D parallel beam forward projection.

    This CUDA kernel implements the Siddon ray-tracing method with interpolation for
    2D parallel beam forward projection.

    Parameters
    ----------
    d_image : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Input 2D image array on CUDA.
    Nx : int
        Number of voxels along the x-axis.
    Ny : int
        Number of voxels along the y-axis.
    d_sino : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Output sinogram array on CUDA.
    n_ang : int
        Number of projection angles.
    n_det : int
        Number of detector elements.
    det_spacing : float
        Physical spacing between detector elements.
    d_cos : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed cosine values of projection angles.
    d_sin : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed sine values of projection angles.
    cx : float
        Half of image width in voxels.
    cy : float
        Half of image height in voxels.
    voxel_spacing : float
        Physical size of one voxel (in same units as det_spacing, sid, sdd).

    Notes
    -----
    The Siddon method with interpolation provides accurate ray-volume intersection by:
      - Calculating ray-volume boundary intersections to define traversal limits.
      - Iterating through voxels along the ray path via parametric equations.
      - Determining bilinear interpolation weights for sub-voxel sampling.
      - Aggregating weighted voxel values based on ray segment lengths.
    """
    # CUDA THREAD ORGANIZATION: 2D grid maps directly to ray geometry
    # Each thread processes one ray defined by (projection_angle, detector_element) pair
    # Thread indexing: iang = projection angle index, idet = detector element index
    # Memory access pattern: Threads in same warp access consecutive detector elements (coalesced)
    iang, idet = cuda.grid(2)
    if iang >= n_ang or idet >= n_det:
        return

    # === RAY GEOMETRY SETUP ===
    # Extract projection angle and compute detector position
    cos_a = d_cos[iang]  # Precomputed cosine of projection angle
    sin_a = d_sin[iang]  # Precomputed sine of projection angle
    # Normalize all physical distances to voxel units
    u = (idet - (n_det - 1) * 0.5) * det_spacing / voxel_spacing + det_offset

    # Define ray direction and starting point for parallel beam geometry
    # Ray direction is perpendicular to detector array (cos_a, sin_a)
    # Ray starting point is offset along detector by distance u in voxel units
    dir_x, dir_y = cos_a, sin_a
    pnt_x = u * -sin_a + center_offset_x
    pnt_y = u * cos_a + center_offset_y

    # === RAY-VOLUME INTERSECTION CALCULATION ===
    # Compute parametric intersection points with volume boundaries using ray equation r(t) = pnt + t*dir
    # Volume extends from [-cx, cx] x [-cy, cy] in voxel coordinate system
    # Mathematical basis: For ray r(t) = origin + t*direction, solve r(t) = boundary for parameter t
    t_min, t_max = -_INF, _INF  # Initialize ray parameter range to unbounded
    
    # X-direction boundary intersections
    # Handle non-parallel rays: compute intersection parameters with left (-cx) and right (+cx) boundaries
    if abs(dir_x) > _EPSILON:  # Ray not parallel to x-axis (avoid division by zero)
        tx1, tx2 = (-cx - pnt_x) / dir_x, (cx - pnt_x) / dir_x  # Left and right boundary intersections
        # Update valid parameter range: intersection of current range with x-boundary constraints
        # min/max operations ensure we get the entry/exit points correctly regardless of ray direction
        t_min, t_max = max(t_min, min(tx1, tx2)), min(t_max, max(tx1, tx2))  # Update valid parameter range
    elif pnt_x < -cx or pnt_x > cx:  # Ray parallel to x-axis but outside volume bounds
        # Edge case: ray never intersects volume if parallel and outside boundaries
        d_sino[iang, idet] = 0.0; return

    # Y-direction boundary intersections (identical logic to x-direction)
    # Handle non-parallel rays: compute intersection parameters with bottom (-cy) and top (+cy) boundaries
    if abs(dir_y) > _EPSILON:  # Ray not parallel to y-axis (avoid division by zero)
        ty1, ty2 = (-cy - pnt_y) / dir_y, (cy - pnt_y) / dir_y  # Bottom and top boundary intersections
        # Intersect y-boundary constraints with existing parameter range from x-boundaries
        t_min, t_max = max(t_min, min(ty1, ty2)), min(t_max, max(ty1, ty2))  # Intersect with x-range
    elif pnt_y < -cy or pnt_y > cy:  # Ray parallel to y-axis but outside volume bounds
        # Edge case: ray never intersects volume if parallel and outside boundaries
        d_sino[iang, idet] = 0.0; return

    # Boundary intersection validation: check if ray actually intersects the volume
    # If t_min >= t_max, the ray misses the volume entirely (no valid intersection interval)
    if t_min >= t_max:
        d_sino[iang, idet] = 0.0; return

    # === SIDDON METHOD VOXEL TRAVERSAL INITIALIZATION ===
    accum = 0.0  # Accumulated projection value along ray
    t = t_min    # Current ray parameter (distance from ray start)
    
    # Convert ray entry point to voxel indices (image coordinate system)
    ix = int(math.floor(pnt_x + t * dir_x + cx))  # Current voxel x-index
    iy = int(math.floor(pnt_y + t * dir_y + cy))  # Current voxel y-index

    # Determine traversal direction and step sizes for each axis
    step_x, step_y = (1 if dir_x >= 0 else -1), (1 if dir_y >= 0 else -1)  # Voxel stepping direction
    # Hoist inverse directions to reduce divisions and branches
    inv_dir_x = (1.0 / dir_x) if abs(dir_x) > _EPSILON else 0.0
    inv_dir_y = (1.0 / dir_y) if abs(dir_y) > _EPSILON else 0.0
    dt_x = abs(inv_dir_x) if abs(dir_x) > _EPSILON else _INF
    dt_y = abs(inv_dir_y) if abs(dir_y) > _EPSILON else _INF

    # Calculate parameter values for next voxel boundary crossings using inv_dir_*
    tx = ((ix + (step_x > 0)) - cx - pnt_x) * inv_dir_x if abs(dir_x) > _EPSILON else _INF
    ty = ((iy + (step_y > 0)) - cy - pnt_y) * inv_dir_y if abs(dir_y) > _EPSILON else _INF

    # === MAIN RAY TRAVERSAL LOOP ===
    # Step through voxels along ray path, accumulating weighted contributions
    while t < t_max:
        # Check if current voxel indices are within valid interpolation bounds
        if 0 <= ix < Nx and 0 <= iy < Ny:
            # Determine next voxel boundary crossing (minimum of x, y boundaries or ray exit)
            t_next = min(tx, ty, t_max)
            seg_len = t_next - t  # Length of ray segment within current voxel region
            
            if seg_len > _EPSILON:  # Only process segments with meaningful length (avoid numerical noise)
                # === BILINEAR INTERPOLATION SAMPLING ===
                # Sample volume at ray segment midpoint for accurate integration
                # Mathematical basis: Midpoint rule for numerical integration along ray segments
                t_mid = t + seg_len * 0.5
                mid_x = pnt_x + t_mid * dir_x + cx  # Midpoint x-coordinate in image space
                mid_y = pnt_y + t_mid * dir_y + cy  # Midpoint y-coordinate in image space

                # Convert continuous coordinates to discrete voxel indices and fractional weights
                # Floor operation gives base voxel index, fractional part gives interpolation weights
                ix0, iy0 = int(math.floor(mid_x)), int(math.floor(mid_y))  # Base voxel indices (bottom-left corner)
                dx, dy = mid_x - ix0, mid_y - iy0  # Fractional parts: distance from base voxel center [0,1]
                
                # Clamp indices to stay in-bounds during interpolation
                ix0 = max(0, min(ix0, Nx - 2))
                iy0 = max(0, min(iy0, Ny - 2))
                
                # === BILINEAR INTERPOLATION WEIGHT CALCULATION ===
                # Mathematical basis: Bilinear interpolation formula f(x,y) = Σ f(xi,yi) * wi(x,y)
                # where wi(x,y) are the bilinear basis functions for each corner voxel
                # Weights are products of 1D linear interpolation weights: (1-dx) or dx, (1-dy) or dy
                one_minus_dx = 1.0 - dx
                one_minus_dy = 1.0 - dy
                v00 = d_image[iy0, ix0]
                v10 = d_image[iy0, ix0 + 1]
                v01 = d_image[iy0 + 1, ix0]
                v11 = d_image[iy0 + 1, ix0 + 1]
                row0 = (v00 * one_minus_dx + v10 * dx) * one_minus_dy
                row1 = (v01 * one_minus_dx + v11 * dx) * dy
                val = row0 + row1
                # Accumulate contribution weighted by ray segment length (discrete line integral approximation)
                # This implements the Radon transform: integral of f(x,y) along the ray path
                accum += val * seg_len
        
        # === VOXEL BOUNDARY CROSSING LOGIC ===
        # Advance to next voxel based on which boundary is crossed first
        if tx <= ty:  # X-boundary crossed first
            t = tx
            ix += step_x  # Move to next voxel in x-direction
            tx += dt_x    # Update next x-boundary crossing parameter
        else:         # Y-boundary crossed first
            t = ty
            iy += step_y  # Move to next voxel in y-direction
            ty += dt_y    # Update next y-boundary crossing parameter
    
    d_sino[iang, idet] = accum

@_FASTMATH_DECORATOR
def _parallel_2d_backward_kernel(
    d_sino, n_ang, n_det,
    d_image, Nx, Ny,
    det_spacing, d_cos, d_sin, cx, cy, voxel_spacing,
    det_offset, center_offset_x, center_offset_y
):
    """Compute the 2D parallel beam backprojection.

    This CUDA kernel implements the Siddon ray-tracing method with interpolation for
    2D parallel beam backprojection.

    Parameters
    ----------
    d_sino : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Input sinogram array on CUDA.
    n_ang : int
        Number of projection angles.
    n_det : int
        Number of detector elements.
    d_image : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Output image gradient array on CUDA.
    Nx : int
        Number of voxels along the x-axis.
    Ny : int
        Number of voxels along the y-axis.
    det_spacing : float
        Physical spacing between detector elements.
    d_cos : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed cosine values of projection angles.
    d_sin : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed sine values of projection angles.
    cx : float
        Half of image width in voxels.
    cy : float
        Half of image height in voxels.
    voxel_spacing : float
        Physical size of one voxel (in same units as det_spacing, sid, sdd).

    Notes
    -----
    This operation is the adjoint of the forward projection. Sinogram values
    are distributed back into the volume along identical ray paths using
    atomic operations to ensure thread-safe accumulation.
    """
    iang, idet = cuda.grid(2)
    if iang >= n_ang or idet >= n_det:
        return

    # === RAY GEOMETRY SETUP (identical to forward projection) ===
    val   = d_sino[iang, idet]  # Sinogram value to backproject
    cos_a = d_cos[iang]         # Precomputed cosine of projection angle
    sin_a = d_sin[iang]         # Precomputed sine of projection angle
    # Normalize all physical distances to voxel units
    u = (idet - (n_det - 1) * 0.5) * det_spacing / voxel_spacing + det_offset

    # Define ray direction and starting point for parallel beam geometry
    dir_x, dir_y = cos_a, sin_a
    pnt_x = u * -sin_a + center_offset_x
    pnt_y = u * cos_a + center_offset_y

    # === RAY-VOLUME INTERSECTION CALCULATION (identical to forward) ===
    t_min, t_max = -_INF, _INF
    if abs(dir_x) > _EPSILON:
        tx1, tx2 = (-cx - pnt_x) / dir_x, (cx - pnt_x) / dir_x
        t_min, t_max = max(t_min, min(tx1, tx2)), min(t_max, max(tx1, tx2))
    elif pnt_x < -cx or pnt_x > cx: return

    if abs(dir_y) > _EPSILON:
        ty1, ty2 = (-cy - pnt_y) / dir_y, (cy - pnt_y) / dir_y
        t_min, t_max = max(t_min, min(ty1, ty2)), min(t_max, max(ty1, ty2))
    elif pnt_y < -cy or pnt_y > cy: return

    if t_min >= t_max: return

    # === SIDDON METHOD TRAVERSAL INITIALIZATION ===
    t = t_min
    ix = int(math.floor(pnt_x + t * dir_x + cx))
    iy = int(math.floor(pnt_y + t * dir_y + cy))

    step_x, step_y = (1 if dir_x >= 0 else -1), (1 if dir_y >= 0 else -1)
    inv_dir_x = (1.0 / dir_x) if abs(dir_x) > _EPSILON else 0.0
    inv_dir_y = (1.0 / dir_y) if abs(dir_y) > _EPSILON else 0.0
    dt_x = abs(inv_dir_x) if abs(dir_x) > _EPSILON else _INF
    dt_y = abs(inv_dir_y) if abs(dir_y) > _EPSILON else _INF
    tx = ((ix + (step_x > 0)) - cx - pnt_x) * inv_dir_x if abs(dir_x) > _EPSILON else _INF
    ty = ((iy + (step_y > 0)) - cy - pnt_y) * inv_dir_y if abs(dir_y) > _EPSILON else _INF

    # === BACKPROJECTION TRAVERSAL LOOP ===
    # Distribute sinogram value along ray path using bilinear interpolation
    while t < t_max:
        if 0 <= ix < Nx and 0 <= iy < Ny:
            t_next = min(tx, ty, t_max)
            seg_len = t_next - t
            if seg_len > _EPSILON:
                # Sample at ray segment midpoint (same as forward projection)
                t_mid = t + seg_len * 0.5
                mid_x = pnt_x + t_mid * dir_x + cx
                mid_y = pnt_y + t_mid * dir_y + cy
                ix0, iy0 = int(math.floor(mid_x)), int(math.floor(mid_y))
                dx, dy = mid_x - ix0, mid_y - iy0
                
                # Clamp indices to stay in-bounds during interpolation
                ix0 = max(0, min(ix0, Nx - 2))
                iy0 = max(0, min(iy0, Ny - 2))
                
                # === ATOMIC BACKPROJECTION WITH BILINEAR WEIGHTS ===
                # Distribute contribution weighted by segment length and interpolation weights
                # CUDA ATOMIC OPERATIONS: Essential for thread safety in backprojection
                # Multiple threads (rays) can write to the same voxel simultaneously, causing race conditions
                # Atomic add operations serialize these writes, ensuring correct accumulation of contributions
                # Performance impact: Atomic operations are slower than regular writes but necessary for correctness
                # Memory access pattern: Global memory atomics with potential bank conflicts, but unavoidable
                cval = val * seg_len  # Contribution value for this ray segment
                one_minus_dx = 1.0 - dx
                one_minus_dy = 1.0 - dy
                cuda.atomic.add(d_image, (iy0,     ix0),     cval * one_minus_dx * one_minus_dy)
                cuda.atomic.add(d_image, (iy0,     ix0 + 1), cval * dx          * one_minus_dy)
                cuda.atomic.add(d_image, (iy0 + 1, ix0),     cval * one_minus_dx * dy)
                cuda.atomic.add(d_image, (iy0 + 1, ix0 + 1), cval * dx          * dy)

        # Advance to next voxel (identical logic to forward projection)
        if tx <= ty:
            t = tx
            ix += step_x
            tx += dt_x
        else:
            t = ty
            iy += step_y
            ty += dt_y

# ------------------------------------------------------------------
# 2-D FAN BEAM KERNELS
# ------------------------------------------------------------------

@_FASTMATH_DECORATOR
def _fan_2d_forward_kernel(
    d_image, Nx, Ny,
    d_sino, n_ang, n_det,
    det_spacing, d_cos, d_sin,
    sdd, sid, cx, cy, voxel_spacing,
    det_offset, center_offset_x, center_offset_y
):
    """Compute the 2D fan beam forward projection.

    This CUDA kernel implements the Siddon ray-tracing method with interpolation for
    2D fan beam forward projection.

    Parameters
    ----------
    d_image : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Input 2D image array on CUDA.
    Nx : int
        Number of voxels along the x-axis.
    Ny : int
        Number of voxels along the y-axis.
    d_sino : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Output fan beam sinogram array on CUDA.
    n_ang : int
        Number of projection angles.
    n_det : int
        Number of detector elements.
    det_spacing : float
        Physical spacing between detector elements.
    d_cos : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed cosine values of projection angles.
    d_sin : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed sine values of projection angles.
    sdd : float
        Source-to-Detector Distance (SDD), total distance from source to detector.
    sid : float
        Source-to-Isocenter Distance (SID), distance from source to isocenter.
    cx : float
        Half of image width in voxels.
    cy : float
        Half of image height in voxels.
    voxel_spacing : float
        Physical size of one voxel (in same units as det_spacing, sid, sdd).

    Notes
    -----
    Fan beam geometry diverges from parallel beam in that its rays originate
    from a single point source to a linear detector array. Rays connect the
    rotated source position around the isocenter to each detector pixel.
    """
    iang, idet = cuda.grid(2)
    if iang >= n_ang or idet >= n_det:
        return

    # === FAN BEAM GEOMETRY SETUP ===
    cos_a = d_cos[iang]  # Precomputed cosine of projection angle
    sin_a = d_sin[iang]  # Precomputed sine of projection angle
    # Normalize all physical distances to voxel units
    u = (idet - (n_det - 1) * 0.5) * det_spacing / voxel_spacing + det_offset
    sid_v = sid / voxel_spacing  # Source-to-isocenter distance in voxel units
    sdd_v = sdd / voxel_spacing  # Source-to-detector distance in voxel units

    # Calculate source and detector positions for current projection angle
    # Source position: rotated by angle around isocenter at distance sid (SID)
    src_x = -sid_v * sin_a + center_offset_x  # Source x-coordinate in voxel units
    src_y = sid_v * cos_a + center_offset_y  # Source y-coordinate in voxel units
    
    # Detector element position: IDD = SDD - SID (Isocenter-to-Detector Distance)
    idd = sdd_v - sid_v
    det_x = idd * sin_a + u * cos_a + center_offset_x  # Detector x-coordinate in voxel units
    det_y = -idd * cos_a + u * sin_a + center_offset_y  # Detector y-coordinate in voxel units

    # === RAY DIRECTION CALCULATION ===
    # Ray direction vector from source to detector element
    dir_x, dir_y = det_x - src_x, det_y - src_y
    length = math.sqrt(dir_x * dir_x + dir_y * dir_y)  # Ray length
    if length < _EPSILON:  # Degenerate ray case
        d_sino[iang, idet] = 0.0; return
    
    # Normalize ray direction vector for parametric traversal
    inv_len = 1.0 / length
    dir_x, dir_y = dir_x * inv_len, dir_y * inv_len

    # === RAY-VOLUME INTERSECTION CALCULATION ===
    # Compute intersection with volume boundaries using source position as ray origin
    t_min, t_max = -_INF, _INF
    if abs(dir_x) > _EPSILON:
        tx1, tx2 = (-cx - src_x) / dir_x, (cx - src_x) / dir_x  # Volume boundary intersections
        t_min, t_max = max(t_min, min(tx1, tx2)), min(t_max, max(tx1, tx2))
    elif src_x < -cx or src_x > cx:  # Source outside volume bounds
        d_sino[iang, idet] = 0.0; return

    if abs(dir_y) > _EPSILON:
        ty1, ty2 = (-cy - src_y) / dir_y, (cy - src_y) / dir_y
        t_min, t_max = max(t_min, min(ty1, ty2)), min(t_max, max(ty1, ty2))
    elif src_y < -cy or src_y > cy:
        d_sino[iang, idet] = 0.0; return

    if t_min >= t_max:  # No valid intersection
        d_sino[iang, idet] = 0.0; return

    # === SIDDON METHOD TRAVERSAL (same algorithm as parallel beam) ===
    accum = 0.0  # Accumulated projection value
    t = t_min    # Current ray parameter
    
    # Convert ray entry point to voxel indices (using source as ray origin)
    ix = int(math.floor(src_x + t * dir_x + cx))
    iy = int(math.floor(src_y + t * dir_y + cy))

    # Traversal parameters (identical to parallel beam implementation)
    step_x, step_y = (1 if dir_x >= 0 else -1), (1 if dir_y >= 0 else -1)
    inv_dir_x = (1.0 / dir_x) if abs(dir_x) > _EPSILON else 0.0
    inv_dir_y = (1.0 / dir_y) if abs(dir_y) > _EPSILON else 0.0
    dt_x = abs(inv_dir_x) if abs(dir_x) > _EPSILON else _INF
    dt_y = abs(inv_dir_y) if abs(dir_y) > _EPSILON else _INF
    tx = ((ix + (step_x > 0)) - cx - src_x) * inv_dir_x if abs(dir_x) > _EPSILON else _INF
    ty = ((iy + (step_y > 0)) - cy - src_y) * inv_dir_y if abs(dir_y) > _EPSILON else _INF

    # Main traversal loop with bilinear interpolation (identical to parallel beam)
    while t < t_max:
        if 0 <= ix < Nx and 0 <= iy < Ny:
            t_next = min(tx, ty, t_max)
            seg_len = t_next - t
            if seg_len > _EPSILON:
                # Sample at midpoint using source as ray origin
                t_mid = t + seg_len * 0.5
                mid_x = src_x + t_mid * dir_x + cx
                mid_y = src_y + t_mid * dir_y + cy
                ix0, iy0 = int(math.floor(mid_x)), int(math.floor(mid_y))
                dx, dy = mid_x - ix0, mid_y - iy0
                
                # Clamp indices to stay in-bounds during interpolation
                ix0 = max(0, min(ix0, Nx - 2))
                iy0 = max(0, min(iy0, Ny - 2))
                
                # Bilinear interpolation (identical to parallel beam)
                one_minus_dx = 1.0 - dx
                one_minus_dy = 1.0 - dy
                v00 = d_image[iy0, ix0]
                v10 = d_image[iy0, ix0 + 1]
                v01 = d_image[iy0 + 1, ix0]
                v11 = d_image[iy0 + 1, ix0 + 1]
                row0 = (v00 * one_minus_dx + v10 * dx) * one_minus_dy
                row1 = (v01 * one_minus_dx + v11 * dx) * dy
                val = row0 + row1
                accum += val * seg_len
        
        # Voxel boundary crossing logic (identical to parallel beam)
        if tx <= ty:
            t = tx
            ix += step_x
            tx += dt_x
        else:
            t = ty
            iy += step_y
            ty += dt_y
    
    d_sino[iang, idet] = accum

@_FASTMATH_DECORATOR
def _fan_2d_backward_kernel(
    d_sino, n_ang, n_det,
    d_image, Nx, Ny,
    det_spacing, d_cos, d_sin,
    sdd, sid, cx, cy, voxel_spacing,
    det_offset, center_offset_x, center_offset_y,
    distance_weight
):
    """Compute the 2D fan beam backprojection.

    This CUDA kernel implements the Siddon ray-tracing method with interpolation for
    2D fan beam backprojection.

    Parameters
    ----------
    d_sino : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Input fan beam sinogram array on CUDA.
    n_ang : int
        Number of projection angles.
    n_det : int
        Number of detector elements.
    d_image : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Output image gradient array on CUDA.
    Nx : int
        Number of voxels along the x-axis.
    Ny : int
        Number of voxels along the y-axis.
    det_spacing : float
        Physical spacing between detector elements.
    d_cos : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed cosine values of projection angles.
    d_sin : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed sine values of projection angles.
    sdd : float
        Source-to-Detector Distance (SDD), total distance from source to detector.
    sid : float
        Source-to-Isocenter Distance (SID), distance from source to isocenter.
    cx : float
        Half of image width in voxels.
    cy : float
        Half of image height in voxels.
    voxel_spacing : float
        Physical size of one voxel (in same units as det_spacing, sid, sdd).

    Notes
    -----
    As the adjoint to the fan beam forward projection, this operation
    distributes sinogram values back into the volume along divergent ray
    paths using atomic operations for thread-safe accumulation.
    """
    iang, idet = cuda.grid(2)
    if iang >= n_ang or idet >= n_det:
        return

    # === BACKPROJECTION VALUE AND GEOMETRY SETUP ===
    val   = d_sino[iang, idet]  # Sinogram value to backproject along this ray
    cos_a = d_cos[iang]         # Precomputed cosine of projection angle
    sin_a = d_sin[iang]         # Precomputed sine of projection angle
    # Normalize all physical distances to voxel units
    u = (idet - (n_det - 1) * 0.5) * det_spacing / voxel_spacing + det_offset
    sid_v = sid / voxel_spacing  # Source-to-isocenter distance in voxel units
    sdd_v = sdd / voxel_spacing  # Source-to-detector distance in voxel units

    # Calculate source and detector positions for current projection angle
    # Source position: rotated by angle around isocenter at distance sid (SID)
    src_x = -sid_v * sin_a + center_offset_x  # Source x-coordinate in voxel units
    src_y = sid_v * cos_a + center_offset_y  # Source y-coordinate in voxel units
    
    # Detector element position: IDD = SDD - SID (Isocenter-to-Detector Distance)
    idd = sdd_v - sid_v
    det_x = idd * sin_a + u * cos_a + center_offset_x  # Detector x-coordinate in voxel units
    det_y = -idd * cos_a + u * sin_a + center_offset_y  # Detector y-coordinate in voxel units

    # === RAY DIRECTION CALCULATION ===
    # Ray direction vector from source to detector element
    dir_x, dir_y = det_x - src_x, det_y - src_y
    length = math.sqrt(dir_x * dir_x + dir_y * dir_y)  # Ray length
    if length < _EPSILON: return  # Skip degenerate rays
    inv_len = 1.0 / length        # Normalization factor for ray direction
    dir_x, dir_y = dir_x * inv_len, dir_y * inv_len  # Normalized ray direction vector

    # === RAY-VOLUME INTERSECTION CALCULATION ===
    # Compute intersection with volume boundaries using source position as ray origin
    t_min, t_max = -_INF, _INF
    if abs(dir_x) > _EPSILON:
        tx1, tx2 = (-cx - src_x) / dir_x, (cx - src_x) / dir_x
        t_min, t_max = max(t_min, min(tx1, tx2)), min(t_max, max(tx1, tx2))
    elif src_x < -cx or src_x > cx: return

    if abs(dir_y) > _EPSILON:
        ty1, ty2 = (-cy - src_y) / dir_y, (cy - src_y) / dir_y
        t_min, t_max = max(t_min, min(ty1, ty2)), min(t_max, max(ty1, ty2))
    elif src_y < -cy or src_y > cy: return

    if t_min >= t_max: return

    # === SIDDON METHOD TRAVERSAL INITIALIZATION ===
    t = t_min
    ix = int(math.floor(src_x + t * dir_x + cx))
    iy = int(math.floor(src_y + t * dir_y + cy))

    step_x, step_y = (1 if dir_x >= 0 else -1), (1 if dir_y >= 0 else -1)
    inv_dir_x = (1.0 / dir_x) if abs(dir_x) > _EPSILON else 0.0
    inv_dir_y = (1.0 / dir_y) if abs(dir_y) > _EPSILON else 0.0
    dt_x = abs(inv_dir_x) if abs(dir_x) > _EPSILON else _INF
    dt_y = abs(inv_dir_y) if abs(dir_y) > _EPSILON else _INF
    tx = ((ix + (step_x > 0)) - cx - src_x) * inv_dir_x if abs(dir_x) > _EPSILON else _INF
    ty = ((iy + (step_y > 0)) - cy - src_y) * inv_dir_y if abs(dir_y) > _EPSILON else _INF

    # === FAN BEAM BACKPROJECTION TRAVERSAL LOOP ===
    # Distribute sinogram value along divergent ray path using bilinear interpolation
    while t < t_max:
        if 0 <= ix < Nx and 0 <= iy < Ny:
            t_next = min(tx, ty, t_max)
            seg_len = t_next - t
            if seg_len > _EPSILON:
                # Sample at ray segment midpoint using source as ray origin
                t_mid = t + seg_len * 0.5
                mid_x = src_x + t_mid * dir_x + cx
                mid_y = src_y + t_mid * dir_y + cy
                ix0, iy0 = int(math.floor(mid_x)), int(math.floor(mid_y))
                dx, dy = mid_x - ix0, mid_y - iy0
                
                # Clamp indices to stay in-bounds during interpolation
                ix0 = max(0, min(ix0, Nx - 2))
                iy0 = max(0, min(iy0, Ny - 2))
                
                # === ATOMIC BACKPROJECTION WITH BILINEAR WEIGHTS ===
                # Distribute contribution weighted by segment length and interpolation weights
                # CUDA ATOMIC OPERATIONS: Critical for fan beam backprojection thread safety
                # Fan beam rays converge at source, creating higher probability of voxel write conflicts
                # Atomic operations prevent race conditions when multiple divergent rays write to same voxel
                # Performance consideration: Fan beam geometry may have more atomic contention than parallel beam
                cval = val * seg_len  # Contribution value for this ray segment
                if distance_weight > 0.5:
                    # Fan-beam FBP distance weighting in rotated coordinates.
                    x_rel = (src_x + t_mid * dir_x) - center_offset_x
                    y_rel = (src_y + t_mid * dir_y) - center_offset_y
                    den = sdd_v + x_rel * sin_a - y_rel * cos_a
                    if abs(den) <= _EPSILON:
                        cval = 0.0
                    else:
                        scale = sdd_v / den
                        cval = cval * scale * scale
                one_minus_dx = 1.0 - dx
                one_minus_dy = 1.0 - dy
                cuda.atomic.add(d_image, (iy0,     ix0),     cval * one_minus_dx * one_minus_dy)
                cuda.atomic.add(d_image, (iy0,     ix0 + 1), cval * dx          * one_minus_dy)
                cuda.atomic.add(d_image, (iy0 + 1, ix0),     cval * one_minus_dx * dy)
                cuda.atomic.add(d_image, (iy0 + 1, ix0 + 1), cval * dx          * dy)

        # === VOXEL BOUNDARY CROSSING LOGIC ===
        # Advance to next voxel based on which boundary is crossed first
        if tx <= ty:
            t = tx
            ix += step_x
            tx += dt_x
        else:
            t = ty
            iy += step_y
            ty += dt_y

# ------------------------------------------------------------------
# 3-D CONE BEAM KERNELS
# ------------------------------------------------------------------

@_FASTMATH_DECORATOR
def _cone_3d_forward_kernel(
    d_vol, Nx, Ny, Nz,
    d_sino, n_views, n_u, n_v,
    du, dv, d_cos, d_sin,
    sdd, sid, cx, cy, cz, voxel_spacing,
    det_offset_u, det_offset_v,
    center_offset_x, center_offset_y, center_offset_z
):
    """Compute the 3D cone-beam forward projection.

    This CUDA kernel implements the Siddon ray-tracing method with interpolation for
    3D cone-beam forward projection.

    Parameters
    ----------
    d_vol : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Input 3D volume array on CUDA.
    Nx : int
        Number of voxels along the x-axis.
    Ny : int
        Number of voxels along the y-axis.
    Nz : int
        Number of voxels along the z-axis.
    d_sino : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Output cone-beam sinogram array on CUDA.
    n_views : int
        Number of projection views.
    n_u : int
        Number of detector elements along the u-axis.
    n_v : int
        Number of detector elements along the v-axis.
    du : float
        Physical spacing between detector elements along the u-axis.
    dv : float
        Physical spacing between detector elements along the v-axis.
    d_cos : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed cosine values of projection angles.
    d_sin : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed sine values of projection angles.
    sdd : float
        Source-to-Detector Distance (SDD), total distance from source to detector.
    sid : float
        Source-to-Isocenter Distance (SID), distance from source to isocenter.
    cx : float
        Half of volume width along x-axis (in voxels).
    cy : float
        Half of volume height along y-axis (in voxels).
    cz : float
        Half of volume depth along z-axis (in voxels).
    voxel_spacing : float
        Physical size of one voxel (in same units as du, dv, sid, sdd).

    Notes
    -----
    Cone-beam geometry extends the fan-beam configuration to 3D by employing
    a 2D detector array and trilinear interpolation for accurate volumetric
    sampling.
    """
    iview, iu, iv = cuda.grid(3)
    if iview >= n_views or iu >= n_u or iv >= n_v:
        return

    # === 3D CONE BEAM GEOMETRY SETUP ===
    cos_a, sin_a = d_cos[iview], d_sin[iview]  # Projection angle trigonometry
    # Normalize all physical distances to voxel units
    u = (iu - (n_u - 1) * 0.5) * du / voxel_spacing + det_offset_u
    v = (iv - (n_v - 1) * 0.5) * dv / voxel_spacing + det_offset_v
    sid_v = sid / voxel_spacing  # Source-to-isocenter distance in voxel units
    sdd_v = sdd / voxel_spacing  # Source-to-detector distance in voxel units

    # Calculate 3D source and detector positions
    # Source rotates in xy-plane around isocenter, z-coordinate is zero
    src_x = -sid_v * sin_a + center_offset_x
    src_y = sid_v * cos_a + center_offset_y
    src_z = center_offset_z
    
    # Detector element position: IDD = SDD - SID (Isocenter-to-Detector Distance)
    # u-coordinate is in-plane offset, v-coordinate is vertical (z-direction)
    idd = sdd_v - sid_v
    det_x = idd * sin_a + u * cos_a + center_offset_x  # In-plane x-coordinate in voxel units
    det_y = -idd * cos_a + u * sin_a + center_offset_y  # In-plane y-coordinate in voxel units
    det_z = v + center_offset_z  # Vertical z-coordinate in voxel units

    # === 3D RAY DIRECTION CALCULATION ===
    # Ray direction vector from source to detector element in 3D space
    dir_x, dir_y, dir_z = det_x - src_x, det_y - src_y, det_z - src_z
    length = math.sqrt(dir_x*dir_x + dir_y*dir_y + dir_z*dir_z)  # 3D ray length
    if length < _EPSILON:  # Degenerate ray case
        d_sino[iview, iu, iv] = 0.0; return
    
    # Normalize 3D ray direction vector for parametric traversal
    inv_len = 1.0 / length
    dir_x, dir_y, dir_z = dir_x*inv_len, dir_y*inv_len, dir_z*inv_len

    # === 3D RAY-VOLUME INTERSECTION CALCULATION ===
    # Compute intersection with 3D volume boundaries using source position as ray origin
    t_min, t_max = -_INF, _INF
    
    # X-direction boundary intersections
    if abs(dir_x) > _EPSILON:
        tx1, tx2 = (-cx - src_x) / dir_x, (cx - src_x) / dir_x
        t_min, t_max = max(t_min, min(tx1, tx2)), min(t_max, max(tx1, tx2))
    elif src_x < -cx or src_x > cx:  # Source outside x-bounds
        d_sino[iview, iu, iv] = 0.0; return
    
    # Y-direction boundary intersections
    if abs(dir_y) > _EPSILON:
        ty1, ty2 = (-cy - src_y) / dir_y, (cy - src_y) / dir_y
        t_min, t_max = max(t_min, min(ty1, ty2)), min(t_max, max(ty1, ty2))
    elif src_y < -cy or src_y > cy:  # Source outside y-bounds
        d_sino[iview, iu, iv] = 0.0; return
    
    # Z-direction boundary intersections (extends 2D algorithm to 3D)
    if abs(dir_z) > _EPSILON:
        tz1, tz2 = (-cz - src_z) / dir_z, (cz - src_z) / dir_z
        t_min, t_max = max(t_min, min(tz1, tz2)), min(t_max, max(tz1, tz2))
    elif src_z < -cz or src_z > cz:  # Source outside z-bounds
        d_sino[iview, iu, iv] = 0.0; return

    if t_min >= t_max:  # No valid 3D intersection
        d_sino[iview, iu, iv] = 0.0; return

    # === 3D SIDDON METHOD TRAVERSAL INITIALIZATION ===
    accum = 0.0  # Accumulated projection value
    t = t_min    # Current ray parameter
    
    # Convert 3D ray entry point to voxel indices
    ix = int(math.floor(src_x + t * dir_x + cx))  # Current voxel x-index
    iy = int(math.floor(src_y + t * dir_y + cy))  # Current voxel y-index
    iz = int(math.floor(src_z + t * dir_z + cz))  # Current voxel z-index

    # 3D traversal parameters (extends 2D algorithm)
    step_x, step_y, step_z = (1 if dir_x >= 0 else -1), (1 if dir_y >= 0 else -1), (1 if dir_z >= 0 else -1)
    inv_dir_x = (1.0 / dir_x) if abs(dir_x) > _EPSILON else 0.0
    inv_dir_y = (1.0 / dir_y) if abs(dir_y) > _EPSILON else 0.0
    inv_dir_z = (1.0 / dir_z) if abs(dir_z) > _EPSILON else 0.0
    dt_x = abs(inv_dir_x) if abs(dir_x) > _EPSILON else _INF  # Parameter increment per x-voxel
    dt_y = abs(inv_dir_y) if abs(dir_y) > _EPSILON else _INF  # Parameter increment per y-voxel
    dt_z = abs(inv_dir_z) if abs(dir_z) > _EPSILON else _INF  # Parameter increment per z-voxel

    # Calculate parameter values for next 3D voxel boundary crossings
    tx = ((ix + (step_x > 0)) - cx - src_x) * inv_dir_x if abs(dir_x) > _EPSILON else _INF
    ty = ((iy + (step_y > 0)) - cy - src_y) * inv_dir_y if abs(dir_y) > _EPSILON else _INF
    tz = ((iz + (step_z > 0)) - cz - src_z) * inv_dir_z if abs(dir_z) > _EPSILON else _INF

    # === 3D TRAVERSAL LOOP WITH TRILINEAR INTERPOLATION ===
    while t < t_max:
        # Check if current 3D voxel indices are within valid interpolation bounds
        if 0 <= ix < Nx and 0 <= iy < Ny and 0 <= iz < Nz:
            # Determine next 3D voxel boundary crossing (minimum of x, y, z boundaries or ray exit)
            t_next = min(tx, ty, tz, t_max)
            seg_len = t_next - t
            if seg_len > _EPSILON:
                # === TRILINEAR INTERPOLATION SAMPLING ===
                # Sample 3D volume at ray segment midpoint for accurate integration
                # Mathematical basis: Midpoint rule for numerical integration along 3D ray segments
                t_mid = t + seg_len * 0.5
                mid_x = src_x + t_mid * dir_x + cx  # Midpoint x-coordinate in volume space
                mid_y = src_y + t_mid * dir_y + cy  # Midpoint y-coordinate in volume space
                mid_z = src_z + t_mid * dir_z + cz  # Midpoint z-coordinate in volume space

                # Convert continuous 3D coordinates to discrete voxel indices and fractional weights
                ix0, iy0, iz0 = int(math.floor(mid_x)), int(math.floor(mid_y)), int(math.floor(mid_z))
                dx, dy, dz = mid_x - ix0, mid_y - iy0, mid_z - iz0

                # Clamp indices to stay in-bounds during interpolation
                ix0 = max(0, min(ix0, Nx - 2))
                iy0 = max(0, min(iy0, Ny - 2))
                iz0 = max(0, min(iz0, Nz - 2))

                # Precompute complements
                omdx = 1.0 - dx
                omdy = 1.0 - dy
                omdz = 1.0 - dz

                # === TRILINEAR INTERPOLATION WEIGHT CALCULATION ===
                val = (
                    d_vol[ix0,     iy0,     iz0]     * omdx*omdy*omdz +
                    d_vol[ix0 + 1, iy0,     iz0]     * dx  *omdy*omdz +
                    d_vol[ix0,     iy0 + 1, iz0]     * omdx*dy  *omdz +
                    d_vol[ix0,     iy0,     iz0 + 1] * omdx*omdy*dz   +
                    d_vol[ix0 + 1, iy0 + 1, iz0]     * dx  *dy  *omdz +
                    d_vol[ix0 + 1, iy0,     iz0 + 1] * dx  *omdy*dz   +
                    d_vol[ix0,     iy0 + 1, iz0 + 1] * omdx*dy  *dz   +
                    d_vol[ix0 + 1, iy0 + 1, iz0 + 1] * dx  *dy  *dz
                )
                # Accumulate contribution weighted by 3D ray segment length (discrete line integral approximation)
                # This implements the 3D Radon transform: integral of f(x,y,z) along the ray path
                accum += val * seg_len

        # === 3D VOXEL BOUNDARY CROSSING LOGIC ===
        # Advance to next voxel based on which boundary is crossed first in 3D
        if tx <= ty and tx <= tz:      # X-boundary crossed first
            t = tx
            ix += step_x
            tx += dt_x
        elif ty <= tx and ty <= tz:    # Y-boundary crossed first
            t = ty
            iy += step_y
            ty += dt_y
        else:                          # Z-boundary crossed first
            t = tz
            iz += step_z
            tz += dt_z
    
    d_sino[iview, iu, iv] = accum

@_FASTMATH_DECORATOR
def _cone_3d_backward_kernel(
    d_sino, n_views, n_u, n_v,
    d_vol, Nx, Ny, Nz,
    du, dv, d_cos, d_sin,
    sdd, sid, cx, cy, cz, voxel_spacing,
    det_offset_u, det_offset_v,
    center_offset_x, center_offset_y, center_offset_z,
    distance_weight
):
    """Compute the 3D cone-beam backprojection.

    This CUDA kernel implements the Siddon ray-tracing method with interpolation for
    3D cone-beam backprojection.

    Parameters
    ----------
    d_sino : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Input cone-beam sinogram array on CUDA.
    n_views : int
        Number of projection views.
    n_u : int
        Number of detector elements along the u-axis.
    n_v : int
        Number of detector elements along the v-axis.
    d_vol : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Output 3D volume gradient array on CUDA.
    Nx : int
        Number of voxels along the x-axis.
    Ny : int
        Number of voxels along the y-axis.
    Nz : int
        Number of voxels along the z-axis.
    du : float
        Physical spacing between detector elements along the u-axis.
    dv : float
        Physical spacing between detector elements along the v-axis.
    d_cos : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed cosine values of projection angles.
    d_sin : numba.cuda.cudadrv.devicearray.DeviceNDArray
        Precomputed sine values of projection angles.
    sdd : float
        Source-to-Detector Distance (SDD), total distance from source to detector.
    sid : float
        Source-to-Isocenter Distance (SID), distance from source to isocenter.
    cx : float
        Half of volume width along x-axis (in voxels).
    cy : float
        Half of volume height along y-axis (in voxels).
    cz : float
        Half of volume depth along z-axis (in voxels).
    voxel_spacing : float
        Physical size of one voxel (in same units as du, dv, sid, sdd).

    Notes
    -----
    As the adjoint to the cone-beam forward projection, this operation
    distributes sinogram values back into the 3D volume along ray paths using
    atomic operations for thread-safe accumulation.
    """
    iview, iu, iv = cuda.grid(3)
    if iview >= n_views or iu >= n_u or iv >= n_v:
        return

    # === 3D BACKPROJECTION VALUE AND GEOMETRY SETUP ===
    g = d_sino[iview, iu, iv]  # Sinogram value to backproject along this ray
    cos_a, sin_a = d_cos[iview], d_sin[iview]  # Projection angle trigonometry
    # Normalize all physical distances to voxel units
    u = (iu - (n_u - 1) * 0.5) * du / voxel_spacing + det_offset_u
    v = (iv - (n_v - 1) * 0.5) * dv / voxel_spacing + det_offset_v
    sid_v = sid / voxel_spacing  # Source-to-isocenter distance in voxel units
    sdd_v = sdd / voxel_spacing  # Source-to-detector distance in voxel units

    # Calculate 3D source and detector positions
    # Source rotates in xy-plane around isocenter, z-coordinate is zero
    src_x = -sid_v * sin_a + center_offset_x
    src_y = sid_v * cos_a + center_offset_y
    src_z = center_offset_z
    
    # Detector element position: IDD = SDD - SID (Isocenter-to-Detector Distance)
    # u-coordinate is in-plane offset, v-coordinate is vertical (z-direction)
    idd = sdd_v - sid_v
    det_x = idd * sin_a + u * cos_a + center_offset_x  # In-plane x-coordinate in voxel units
    det_y = -idd * cos_a + u * sin_a + center_offset_y  # In-plane y-coordinate in voxel units
    det_z = v + center_offset_z  # Vertical z-coordinate in voxel units

    # === 3D RAY DIRECTION CALCULATION ===
    # Ray direction vector from source to detector element in 3D space
    dir_x, dir_y, dir_z = det_x - src_x, det_y - src_y, det_z - src_z
    length = math.sqrt(dir_x*dir_x + dir_y*dir_y + dir_z*dir_z)  # 3D ray length
    if length < _EPSILON: return  # Skip degenerate rays
    inv_len = 1.0 / length        # Normalization factor for ray direction
    dir_x, dir_y, dir_z = dir_x*inv_len, dir_y*inv_len, dir_z*inv_len  # Normalized 3D ray direction vector

    # === 3D RAY-VOLUME INTERSECTION CALCULATION ===
    # Compute intersection with 3D volume boundaries using source position as ray origin
    t_min, t_max = -_INF, _INF
    
    # X-direction boundary intersections
    if abs(dir_x) > _EPSILON:
        tx1, tx2 = (-cx - src_x) / dir_x, (cx - src_x) / dir_x
        t_min, t_max = max(t_min, min(tx1, tx2)), min(t_max, max(tx1, tx2))
    elif src_x < -cx or src_x > cx: return
    
    # Y-direction boundary intersections
    if abs(dir_y) > _EPSILON:
        ty1, ty2 = (-cy - src_y) / dir_y, (cy - src_y) / dir_y
        t_min, t_max = max(t_min, min(ty1, ty2)), min(t_max, max(ty1, ty2))
    elif src_y < -cy or src_y > cy: return
    
    # Z-direction boundary intersections (extends 2D algorithm to 3D)
    if abs(dir_z) > _EPSILON:
        tz1, tz2 = (-cz - src_z) / dir_z, (cz - src_z) / dir_z
        t_min, t_max = max(t_min, min(tz1, tz2)), min(t_max, max(tz1, tz2))
    elif src_z < -cz or src_z > cz: return

    if t_min >= t_max: return

    # === 3D SIDDON METHOD TRAVERSAL INITIALIZATION ===
    t = t_min
    ix = int(math.floor(src_x + t * dir_x + cx))  # Current voxel x-index
    iy = int(math.floor(src_y + t * dir_y + cy))  # Current voxel y-index
    iz = int(math.floor(src_z + t * dir_z + cz))  # Current voxel z-index

    # 3D traversal parameters (extends 2D algorithm)
    step_x, step_y, step_z = (1 if dir_x >= 0 else -1), (1 if dir_y >= 0 else -1), (1 if dir_z >= 0 else -1)
    inv_dir_x = (1.0 / dir_x) if abs(dir_x) > _EPSILON else 0.0
    inv_dir_y = (1.0 / dir_y) if abs(dir_y) > _EPSILON else 0.0
    inv_dir_z = (1.0 / dir_z) if abs(dir_z) > _EPSILON else 0.0
    dt_x = abs(inv_dir_x) if abs(dir_x) > _EPSILON else _INF  # Parameter increment per x-voxel
    dt_y = abs(inv_dir_y) if abs(dir_y) > _EPSILON else _INF  # Parameter increment per y-voxel
    dt_z = abs(inv_dir_z) if abs(dir_z) > _EPSILON else _INF  # Parameter increment per z-voxel

    # Calculate parameter values for next 3D voxel boundary crossings
    tx = ((ix + (step_x > 0)) - cx - src_x) * inv_dir_x if abs(dir_x) > _EPSILON else _INF
    ty = ((iy + (step_y > 0)) - cy - src_y) * inv_dir_y if abs(dir_y) > _EPSILON else _INF
    tz = ((iz + (step_z > 0)) - cz - src_z) * inv_dir_z if abs(dir_z) > _EPSILON else _INF

    # === 3D CONE BEAM BACKPROJECTION TRAVERSAL LOOP ===
    # Distribute sinogram value along divergent 3D ray path using trilinear interpolation
    while t < t_max:
        # Check if current 3D voxel indices are within valid interpolation bounds
        if 0 <= ix < Nx and 0 <= iy < Ny and 0 <= iz < Nz:
            # Determine next 3D voxel boundary crossing (minimum of x, y, z boundaries or ray exit)
            t_next = min(tx, ty, tz, t_max)
            seg_len = t_next - t
            if seg_len > _EPSILON:
                # === TRILINEAR INTERPOLATION SAMPLING ===
                # Sample 3D volume at ray segment midpoint using source as ray origin
                t_mid = t + seg_len * 0.5
                mid_x = src_x + t_mid * dir_x + cx
                mid_y = src_y + t_mid * dir_y + cy
                mid_z = src_z + t_mid * dir_z + cz

                # Convert continuous 3D coordinates to voxel indices and interpolation weights
                ix0, iy0, iz0 = int(math.floor(mid_x)), int(math.floor(mid_y)), int(math.floor(mid_z))
                dx, dy, dz = mid_x - ix0, mid_y - iy0, mid_z - iz0

                # Clamp indices to stay in-bounds during interpolation
                ix0 = max(0, min(ix0, Nx - 2))
                iy0 = max(0, min(iy0, Ny - 2))
                iz0 = max(0, min(iz0, Nz - 2))

                # Precompute complements and contribution
                omdx = 1.0 - dx
                omdy = 1.0 - dy
                omdz = 1.0 - dz
                cval = g * seg_len
                if distance_weight > 0.5:
                    # FDK distance weighting term in rotated coordinates.
                    x_rel = (src_x + t_mid * dir_x) - center_offset_x
                    y_rel = (src_y + t_mid * dir_y) - center_offset_y
                    den = sdd_v + x_rel * sin_a - y_rel * cos_a
                    if abs(den) <= _EPSILON:
                        cval = 0.0
                    else:
                        scale = sdd_v / den
                        cval = cval * scale * scale

                # === ATOMIC BACKPROJECTION WITH TRILINEAR WEIGHTS ===
                cuda.atomic.add(d_vol, (ix0,     iy0,     iz0),     cval * omdx*omdy*omdz)
                cuda.atomic.add(d_vol, (ix0 + 1, iy0,     iz0),     cval * dx  *omdy*omdz)
                cuda.atomic.add(d_vol, (ix0,     iy0 + 1, iz0),     cval * omdx*dy  *omdz)
                cuda.atomic.add(d_vol, (ix0,     iy0,     iz0 + 1), cval * omdx*omdy*dz)
                cuda.atomic.add(d_vol, (ix0 + 1, iy0 + 1, iz0),     cval * dx  *dy  *omdz)
                cuda.atomic.add(d_vol, (ix0 + 1, iy0,     iz0 + 1), cval * dx  *omdy*dz)
                cuda.atomic.add(d_vol, (ix0,     iy0 + 1, iz0 + 1), cval * omdx*dy  *dz)
                cuda.atomic.add(d_vol, (ix0 + 1, iy0 + 1, iz0 + 1), cval * dx  *dy  *dz)

        # === 3D VOXEL BOUNDARY CROSSING LOGIC ===
        # Advance to next voxel based on which boundary is crossed first in 3D
        if tx <= ty and tx <= tz:      # X-boundary crossed first
            t = tx
            ix += step_x
            tx += dt_x
        elif ty <= tx and ty <= tz:    # Y-boundary crossed first
            t = ty
            iy += step_y
            ty += dt_y
        else:                          # Z-boundary crossed first
            t = tz
            iz += step_z
            tz += dt_z


# ############################################################################
# DIFFERENTIABLE TORCH FUNCTIONS
# ############################################################################

[docs] class ParallelProjectorFunction(torch.autograd.Function): """ Summary ------- PyTorch autograd function for differentiable 2D parallel beam forward projection. Notes ----- Provides a differentiable interface to the CUDA-accelerated Siddon ray-tracing method with interpolation for parallel beam CT geometry. The forward pass computes the sinogram from a 2D image using parallel beam geometry. The backward pass computes gradients using the adjoint backprojection operation. Requires CUDA-capable hardware and a properly configured CUDA environment; all input tensors must reside on the same CUDA device. Examples -------- >>> import torch >>> from diffct.differentiable import ParallelProjectorFunction >>> >>> # Create a 2D image with gradient tracking >>> image = torch.randn(128, 128, device='cuda', requires_grad=True) >>> # Define projection parameters >>> angles = torch.linspace(0, torch.pi, 180, device='cuda') >>> num_detectors = 128 >>> detector_spacing = 1.0 >>> # Compute forward projection >>> projector = ParallelProjectorFunction.apply >>> sinogram = projector(image, angles, num_detectors, detector_spacing) >>> # Compute loss and gradients >>> loss = sinogram.sum() >>> loss.backward() >>> print(f"Gradient shape: {image.grad.shape}") # (128, 128) """
[docs] @staticmethod def forward( ctx, image, angles, num_detectors, detector_spacing=1.0, voxel_spacing=1.0, detector_offset=0.0, center_offset_x=0.0, center_offset_y=0.0, ): """Compute the 2D parallel beam forward projection (Radon transform) of an image using CUDA acceleration. Parameters ---------- image : torch.Tensor 2D input image tensor of shape (H, W), must be on a CUDA device and of type float32. angles : torch.Tensor 1D tensor of projection angles in radians, shape (num_angles,), must be on the same CUDA device as `image`. num_detectors : int Number of detector elements in the sinogram (columns). detector_spacing : float, optional Physical spacing between detector elements (default: 1.0). voxel_spacing : float, optional Physical size of one voxel (in same units as detector_spacing, default: 1.0). Returns ------- sinogram : torch.Tensor 2D tensor of shape (num_angles, num_detectors) containing the forward projection (sinogram) on the same device as `image`. Notes ----- - All input tensors must be on the same CUDA device. - The operation is fully differentiable and supports autograd. - Uses the Siddon method with interpolation for accurate ray tracing and bilinear interpolation. Examples -------- >>> image = torch.randn(128, 128, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, torch.pi, 180, device='cuda') >>> sinogram = ParallelProjectorFunction.apply( ... image, angles, 128, 1.0 ... ) """ device = DeviceManager.get_device(image) image = DeviceManager.ensure_device(image, device) angles = DeviceManager.ensure_device(angles, device) # Ensure input is float32 for kernel compatibility image = image.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() Ny, Nx = image.shape n_angles = angles.shape[0] # Allocate output tensor on the same device sinogram = torch.zeros((n_angles, num_detectors), dtype=image.dtype, device=device) # Prepare trigonometric tables on the correct device d_cos, d_sin = _trig_tables(angles, dtype=image.dtype, device=device) # Get Numba CUDA array views for kernel d_image = TorchCUDABridge.tensor_to_cuda_array(image) d_sino = TorchCUDABridge.tensor_to_cuda_array(sinogram) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_angles, num_detectors) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _parallel_2d_forward_kernel[grid, tpb, numba_stream]( d_image, Nx, Ny, d_sino, n_angles, num_detectors, _DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v ) ctx.save_for_backward(angles) ctx.intermediate = ( num_detectors, detector_spacing, Ny, Nx, voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) return sinogram
[docs] @staticmethod def backward(ctx, grad_sinogram): angles, = ctx.saved_tensors ( num_detectors, detector_spacing, Ny, Nx, voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) = ctx.intermediate device = DeviceManager.get_device(grad_sinogram) grad_sinogram = DeviceManager.ensure_device(grad_sinogram, device) angles = DeviceManager.ensure_device(angles, device) grad_sinogram = grad_sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() n_angles = angles.shape[0] grad_image = torch.zeros((Ny, Nx), dtype=grad_sinogram.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=grad_sinogram.dtype, device=device) d_grad_sino = TorchCUDABridge.tensor_to_cuda_array(grad_sinogram) d_img_grad = TorchCUDABridge.tensor_to_cuda_array(grad_image) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_angles, num_detectors) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _parallel_2d_backward_kernel[grid, tpb, numba_stream]( d_grad_sino, n_angles, num_detectors, d_img_grad, Nx, Ny, _DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v ) return grad_image, None, None, None, None, None, None, None
[docs] class ParallelBackprojectorFunction(torch.autograd.Function): """ Summary ------- PyTorch autograd function for differentiable 2D parallel beam backprojection. Notes ----- Provides a differentiable interface to the CUDA-accelerated Siddon ray-tracing method with interpolation for parallel beam backprojection. The forward pass computes a 2D reconstruction from sinogram data using parallel beam backprojection, and the backward pass computes gradients via forward projection as the adjoint operation. Requires CUDA-capable hardware and consistent device placements. Examples -------- >>> import torch >>> from diffct.differentiable import ParallelBackprojectorFunction >>> >>> sinogram = torch.randn(180, 128, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, torch.pi, 180, device='cuda') >>> recon = ParallelBackprojectorFunction.apply(sinogram, angles, 1.0, 128, 128) >>> loss = recon.sum() >>> loss.backward() >>> print(sinogram.grad.shape) # (180, 128) """
[docs] @staticmethod def forward( ctx, sinogram, angles, detector_spacing=1.0, H=128, W=128, voxel_spacing=1.0, detector_offset=0.0, center_offset_x=0.0, center_offset_y=0.0, ): """Compute the 2D parallel beam backprojection (adjoint Radon transform) of a sinogram using CUDA acceleration. Parameters ---------- sinogram : torch.Tensor 2D input sinogram tensor of shape (num_angles, num_detectors), must be on a CUDA device and of type float32. angles : torch.Tensor 1D tensor of projection angles in radians, shape (num_angles,), must be on the same CUDA device as `sinogram`. detector_spacing : float, optional Physical spacing between detector elements (default: 1.0). H : int, optional Height of the output reconstruction image (default: 128). W : int, optional Width of the output reconstruction image (default: 128). voxel_spacing : float, optional Physical size of one voxel (in same units as detector_spacing, default: 1.0). Returns ------- reco : torch.Tensor 2D tensor of shape (H, W) containing the reconstructed image on the same device as `sinogram`. Notes ----- - All input tensors must be on the same CUDA device. - The operation is fully differentiable and supports autograd. - Uses the Siddon method with interpolation for accurate ray tracing and bilinear interpolation. Examples -------- >>> sinogram = torch.randn(180, 128, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, torch.pi, 180, device='cuda') >>> reco = ParallelBackprojectorFunction.apply( ... sinogram, angles, 1.0, 128, 128 ... ) """ device = DeviceManager.get_device(sinogram) sinogram = DeviceManager.ensure_device(sinogram, device) angles = DeviceManager.ensure_device(angles, device) # Ensure input is float32 for kernel compatibility sinogram = sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() n_ang, n_det = sinogram.shape Ny, Nx = H, W # Allocate output tensor on the same device reco = torch.zeros((Ny, Nx), dtype=sinogram.dtype, device=device) # Prepare trigonometric tables on the correct device d_cos, d_sin = _trig_tables(angles, dtype=sinogram.dtype, device=device) # Get Numba CUDA array views for kernel d_sino = TorchCUDABridge.tensor_to_cuda_array(sinogram) d_reco = TorchCUDABridge.tensor_to_cuda_array(reco) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_ang, n_det) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _parallel_2d_backward_kernel[grid, tpb, numba_stream]( d_sino, n_ang, n_det, d_reco, Nx, Ny, _DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v ) ctx.save_for_backward(angles) ctx.intermediate = ( H, W, detector_spacing, sinogram.shape[0], sinogram.shape[1], voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) return reco
[docs] @staticmethod def backward(ctx, grad_output): angles, = ctx.saved_tensors ( H, W, detector_spacing, n_ang, n_det, voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) = ctx.intermediate device = DeviceManager.get_device(grad_output) grad_output = DeviceManager.ensure_device(grad_output, device) angles = DeviceManager.ensure_device(angles, device) grad_output = grad_output.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() Ny, Nx = grad_output.shape # Allocate output tensor on the same device grad_sino = torch.zeros((n_ang, n_det), dtype=grad_output.dtype, device=device) # Prepare trigonometric tables on the correct device d_cos, d_sin = _trig_tables(angles, dtype=grad_output.dtype, device=device) # Get Numba CUDA array views for kernel d_grad_out = TorchCUDABridge.tensor_to_cuda_array(grad_output) d_sino_grad = TorchCUDABridge.tensor_to_cuda_array(grad_sino) d_cos = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_ang, n_det) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _parallel_2d_forward_kernel[grid, tpb, numba_stream]( d_grad_out, Nx, Ny, d_sino_grad, n_ang, n_det, _DTYPE(detector_spacing), d_cos, d_sin, cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v ) return grad_sino, None, None, None, None, None, None, None, None
[docs] class FanProjectorFunction(torch.autograd.Function): """ Summary ------- PyTorch autograd function for differentiable 2D fan beam forward projection. Notes ----- Provides a differentiable interface to the CUDA-accelerated Siddon ray-tracing method with interpolation for fan beam geometry, where rays diverge from a point X-ray source to a linear detector array. The forward pass computes sinograms using divergent beam geometry, and the backward pass computes gradients via adjoint backprojection with geometric ``1/U^2`` distance weighting. Examples -------- >>> import torch >>> from diffct.differentiable import FanProjectorFunction >>> >>> image = torch.randn(256, 256, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2 * torch.pi, 360, device='cuda') >>> sinogram = FanProjectorFunction.apply(image, angles, 512, 1.0, 1500.0, 1000.0) >>> loss = sinogram.sum() >>> loss.backward() >>> print(image.grad.shape) # (256, 256) """
[docs] @staticmethod def forward( ctx, image, angles, num_detectors, detector_spacing, sdd, sid, voxel_spacing=1.0, detector_offset=0.0, center_offset_x=0.0, center_offset_y=0.0, ): """Compute the 2D fan beam forward projection of an image using CUDA acceleration. Parameters ---------- image : torch.Tensor 2D input image tensor of shape (H, W), must be on a CUDA device and of type float32. angles : torch.Tensor 1D tensor of projection angles in radians, shape (num_angles,), must be on the same CUDA device as `image`. num_detectors : int Number of detector elements in the sinogram (columns). detector_spacing : float Physical spacing between detector elements. sdd : float Source-to-Detector Distance (SDD). The total distance from the X-ray source to the detector, passing through the isocenter. sid : float Source-to-Isocenter Distance (SID). The distance from the X-ray source to the center of rotation (isocenter). voxel_spacing : float, optional Physical size of one voxel (in same units as detector_spacing, sdd, sid, default: 1.0). Returns ------- sinogram : torch.Tensor 2D tensor of shape (num_angles, num_detectors) containing the fan beam sinogram on the same device as `image`. Notes ----- - All input tensors must be on the same CUDA device. - The operation is fully differentiable and supports autograd. - Fan beam geometry uses divergent rays from a point source to the detector. - Uses the Siddon method with interpolation for accurate ray tracing and bilinear interpolation. Examples -------- >>> image = torch.randn(256, 256, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2 * torch.pi, 360, device='cuda') >>> sinogram = FanProjectorFunction.apply( ... image, angles, 512, 1.0, 1500.0, 1000.0 ... ) """ device = DeviceManager.get_device(image) image = DeviceManager.ensure_device(image, device) angles = DeviceManager.ensure_device(angles, device) image = image.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() Ny, Nx = image.shape n_ang = angles.shape[0] sinogram = torch.zeros((n_ang, num_detectors), dtype=image.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=image.dtype, device=device) d_image = TorchCUDABridge.tensor_to_cuda_array(image) d_sino = TorchCUDABridge.tensor_to_cuda_array(sinogram) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_ang, num_detectors) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _fan_2d_forward_kernel[grid, tpb, numba_stream]( d_image, Nx, Ny, d_sino, n_ang, num_detectors, _DTYPE(detector_spacing), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v ) ctx.save_for_backward(angles) ctx.intermediate = ( num_detectors, detector_spacing, Ny, Nx, sdd, sid, voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) return sinogram
[docs] @staticmethod def backward(ctx, grad_sinogram): angles, = ctx.saved_tensors ( n_det, det_spacing, Ny, Nx, sdd, sid, voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) = ctx.intermediate device = DeviceManager.get_device(grad_sinogram) grad_sinogram = DeviceManager.ensure_device(grad_sinogram, device) angles = DeviceManager.ensure_device(angles, device) grad_sinogram = grad_sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() n_ang = angles.shape[0] grad_img = torch.zeros((Ny, Nx), dtype=grad_sinogram.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=grad_sinogram.dtype, device=device) d_grad_sino = TorchCUDABridge.tensor_to_cuda_array(grad_sinogram) d_img_grad = TorchCUDABridge.tensor_to_cuda_array(grad_img) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_ang, n_det) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _fan_2d_backward_kernel[grid, tpb, numba_stream]( d_grad_sino, n_ang, n_det, d_img_grad, Nx, Ny, _DTYPE(det_spacing), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v, _DTYPE(1.0) ) return grad_img, None, None, None, None, None, None, None, None, None
[docs] class FanBackprojectorFunction(torch.autograd.Function): """ Summary ------- PyTorch autograd function for differentiable 2D fan beam backprojection. Notes ----- Provides a differentiable interface to the CUDA-accelerated Siddon ray-tracing method with interpolation for fan beam backprojection. Implements the adjoint of the fan beam projection operator with geometric ``1/U^2`` distance weighting, distributing sinogram values back into the reconstruction volume along divergent ray paths. The forward pass computes reconstruction from sinogram data, and the backward pass computes gradients via forward projection. Examples -------- >>> import torch >>> from diffct.differentiable import FanBackprojectorFunction >>> >>> sinogram = torch.randn(360, 512, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2 * torch.pi, 360, device='cuda') >>> recon = FanBackprojectorFunction.apply(sinogram, angles, 1.0, 256, 256, 1500.0, 1000.0) >>> loss = recon.sum() >>> loss.backward() >>> print(sinogram.grad.shape) # (360, 512) """
[docs] @staticmethod def forward( ctx, sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spacing=1.0, detector_offset=0.0, center_offset_x=0.0, center_offset_y=0.0, ): """Compute the 2D fan beam backprojection of a sinogram using CUDA acceleration. Parameters ---------- sinogram : torch.Tensor 2D input fan beam sinogram tensor of shape (num_angles, num_detectors), must be on a CUDA device and of type float32. angles : torch.Tensor 1D tensor of projection angles in radians, shape (num_angles,), must be on the same CUDA device as `sinogram`. detector_spacing : float Physical spacing between detector elements. H : int Height of the output reconstruction image. W : int Width of the output reconstruction image. sdd : float Source-to-Detector Distance (SDD). The total distance from the X-ray source to the detector, passing through the isocenter. sid : float Source-to-Isocenter Distance (SID). The distance from the X-ray source to the center of rotation (isocenter). voxel_spacing : float, optional Physical size of one voxel (in same units as detector_spacing, sdd, sid, default: 1.0). Returns ------- reco : torch.Tensor 2D tensor of shape (H, W) containing the reconstructed image on the same device as `sinogram`. Notes ----- - All input tensors must be on the same CUDA device. - The operation is fully differentiable and supports autograd. - Fan beam geometry uses divergent rays from a point source to the detector. - Uses the Siddon method with interpolation for accurate ray tracing and bilinear interpolation. Examples -------- >>> sinogram = torch.randn(360, 512, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2*torch.pi, 360, device='cuda') >>> reco = FanBackprojectorFunction.apply( ... sinogram, angles, 1.0, 256, 256, 1000.0, 500.0 ... ) """ device = DeviceManager.get_device(sinogram) sinogram = DeviceManager.ensure_device(sinogram, device) angles = DeviceManager.ensure_device(angles, device) sinogram = sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() n_ang, n_det = sinogram.shape Ny, Nx = H, W reco = torch.zeros((Ny, Nx), dtype=sinogram.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=sinogram.dtype, device=device) d_sino = TorchCUDABridge.tensor_to_cuda_array(sinogram) d_reco = TorchCUDABridge.tensor_to_cuda_array(reco) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_ang, n_det) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _fan_2d_backward_kernel[grid, tpb, numba_stream]( d_sino, n_ang, n_det, d_reco, Nx, Ny, _DTYPE(detector_spacing), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v, _DTYPE(1.0) ) ctx.save_for_backward(angles) ctx.intermediate = ( H, W, detector_spacing, n_ang, n_det, sdd, sid, voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) return reco
[docs] @staticmethod def backward(ctx, grad_output): angles, = ctx.saved_tensors ( H, W, det_spacing, n_ang, n_det, sdd, sid, voxel_spacing, detector_offset, center_offset_x, center_offset_y, ) = ctx.intermediate device = DeviceManager.get_device(grad_output) grad_output = DeviceManager.ensure_device(grad_output, device) angles = DeviceManager.ensure_device(angles, device) grad_output = grad_output.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() Ny, Nx = grad_output.shape grad_sino = torch.zeros((n_ang, n_det), dtype=grad_output.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=grad_output.dtype, device=device) d_grad_out = TorchCUDABridge.tensor_to_cuda_array(grad_output) d_sino_grad = TorchCUDABridge.tensor_to_cuda_array(grad_sino) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_ang, n_det) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) _fan_2d_forward_kernel[grid, tpb, numba_stream]( d_grad_out, Nx, Ny, d_sino_grad, n_ang, n_det, _DTYPE(det_spacing), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v ) return grad_sino, None, None, None, None, None, None, None, None, None, None
[docs] class ConeProjectorFunction(torch.autograd.Function): """ Summary ------- PyTorch autograd function for differentiable 3D cone beam forward projection. Notes ----- Provides a differentiable interface to the CUDA-accelerated Siddon ray-tracing method with interpolation for 3D cone beam geometry. Rays emanate from a point X-ray source to a 2D detector array capturing volumetric projection data. The forward pass computes 3D projections, and the backward pass computes gradients via adjoint 3D backprojection with geometric ``1/U^2`` distance weighting. Requires significant GPU memory. Examples -------- >>> import torch >>> from diffct.differentiable import ConeProjectorFunction >>> >>> volume = torch.randn(128, 128, 128, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2 * torch.pi, 360, device='cuda') >>> projections = ConeProjectorFunction.apply(volume, angles, 256, 256, 1.0, 1.0, 1500.0, 1000.0) >>> loss = projections.sum() >>> loss.backward() >>> print(volume.grad.shape) # (128, 128, 128) """
[docs] @staticmethod def forward( ctx, volume, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing=1.0, detector_offset_u=0.0, detector_offset_v=0.0, center_offset_x=0.0, center_offset_y=0.0, center_offset_z=0.0, ): """Compute the 3D cone beam forward projection of a volume using CUDA acceleration. Parameters ---------- volume : torch.Tensor 3D input volume tensor of shape (D, H, W), must be on a CUDA device and of type float32. angles : torch.Tensor 1D tensor of projection angles in radians, shape (num_views,), must be on the same CUDA device as `volume`. det_u : int Number of detector elements along the u-axis (width). det_v : int Number of detector elements along the v-axis (height). du : float Physical spacing between detector elements along the u-axis. dv : float Physical spacing between detector elements along the v-axis. sdd : float Source-to-Detector Distance (SDD). The total distance from the X-ray source to the detector, passing through the isocenter. sid : float Source-to-Isocenter Distance (SID). The distance from the X-ray source to the center of rotation (isocenter). voxel_spacing : float, optional Physical size of one voxel (in same units as du, dv, sdd, sid, default: 1.0). Returns ------- sino : torch.Tensor 3D tensor of shape (num_views, det_u, det_v) containing the cone beam projections on the same device as `volume`. Notes ----- - All input tensors must be on the same CUDA device. - The operation is fully differentiable and supports autograd. - Cone beam geometry uses a point source and a 2D detector array. - Uses the Siddon method with interpolation for accurate 3D ray tracing and trilinear interpolation. Examples -------- >>> volume = torch.randn(128, 128, 128, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2*torch.pi, 360, device='cuda') >>> sino = ConeProjectorFunction.apply( ... volume, angles, 256, 256, 1.0, 1.0, 1500.0, 1000.0 ... ) """ device = DeviceManager.get_device(volume) volume = DeviceManager.ensure_device(volume, device) angles = DeviceManager.ensure_device(angles, device) volume = volume.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() D, H, W = volume.shape n_views = angles.shape[0] # Validate memory layout to prevent coordinate system inconsistencies _validate_3d_memory_layout(volume, expected_order='DHW') sino = torch.zeros((n_views, det_u, det_v), dtype=volume.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=volume.dtype, device=device) volume_perm = volume.permute(2, 1, 0).contiguous() d_vol = TorchCUDABridge.tensor_to_cuda_array(volume_perm) d_sino = TorchCUDABridge.tensor_to_cuda_array(sino) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_3d(n_views, det_u, det_v) cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_u_v = _DTYPE(detector_offset_u / voxel_spacing) det_offset_v_v = _DTYPE(detector_offset_v / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) center_offset_z_v = _DTYPE(center_offset_z / voxel_spacing) _cone_3d_forward_kernel[grid, tpb, numba_stream]( d_vol, W, H, D, d_sino, n_views, det_u, det_v, _DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing), det_offset_u_v, det_offset_v_v, center_offset_x_v, center_offset_y_v, center_offset_z_v ) ctx.save_for_backward(angles) ctx.intermediate = ( D, H, W, det_u, det_v, du, dv, sdd, sid, voxel_spacing, detector_offset_u, detector_offset_v, center_offset_x, center_offset_y, center_offset_z, ) return sino
[docs] @staticmethod def backward(ctx, grad_sinogram): angles, = ctx.saved_tensors ( D, H, W, det_u, det_v, du, dv, sdd, sid, voxel_spacing, detector_offset_u, detector_offset_v, center_offset_x, center_offset_y, center_offset_z, ) = ctx.intermediate device = DeviceManager.get_device(grad_sinogram) grad_sinogram = DeviceManager.ensure_device(grad_sinogram, device) angles = DeviceManager.ensure_device(angles, device) grad_sinogram = grad_sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() n_views = angles.shape[0] grad_vol_perm = torch.zeros((W, H, D), dtype=grad_sinogram.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=grad_sinogram.dtype, device=device) d_grad_sino = TorchCUDABridge.tensor_to_cuda_array(grad_sinogram) d_vol_grad = TorchCUDABridge.tensor_to_cuda_array(grad_vol_perm) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_3d(n_views, det_u, det_v) cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_u_v = _DTYPE(detector_offset_u / voxel_spacing) det_offset_v_v = _DTYPE(detector_offset_v / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) center_offset_z_v = _DTYPE(center_offset_z / voxel_spacing) _cone_3d_backward_kernel[grid, tpb, numba_stream]( d_grad_sino, n_views, det_u, det_v, d_vol_grad, W, H, D, _DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing), det_offset_u_v, det_offset_v_v, center_offset_x_v, center_offset_y_v, center_offset_z_v, _DTYPE(1.0) ) grad_vol = grad_vol_perm.permute(2, 1, 0).contiguous() return grad_vol, None, None, None, None, None, None, None, None, None, None, None, None, None
[docs] class ConeBackprojectorFunction(torch.autograd.Function): """ Summary ------- PyTorch autograd function for differentiable 3D cone beam backprojection. Notes ----- Provides a differentiable interface to the CUDA-accelerated Siddon ray-tracing method with interpolation for 3D cone beam backprojection. The forward pass computes a 3D reconstruction from cone beam projection data using weighted backprojection (geometric ``1/U^2`` term) as the adjoint operation. The backward pass computes gradients via 3D cone beam forward projection. Requires CUDA-capable hardware and consistent device placements. This operation may be memory- and computationally-intensive due to 3D geometry. Consider using gradient checkpointing, smaller volumes, or distributed computing for large-scale applications, and ensure sufficient GPU memory is available. Examples -------- >>> import torch >>> from diffct.differentiable import ConeBackprojectorFunction >>> >>> projections = torch.randn(360, 256, 256, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2 * torch.pi, 360, device='cuda') >>> D, H, W = 128, 128, 128 >>> du, dv = 1.0, 1.0 >>> sdd, sid = 1500.0, 1000.0 >>> backprojector = ConeBackprojectorFunction.apply >>> volume = backprojector(projections, angles, D, H, W, du, dv, sdd, sid) >>> loss = volume.sum() >>> loss.backward() >>> print(f"Projection gradient shape: {projections.grad.shape}") # (360, 256, 256) """
[docs] @staticmethod def forward( ctx, sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0, detector_offset_u=0.0, detector_offset_v=0.0, center_offset_x=0.0, center_offset_y=0.0, center_offset_z=0.0, ): """Compute the 3D cone beam backprojection of a projection sinogram using CUDA acceleration. Parameters ---------- sinogram : torch.Tensor 3D input cone beam projection tensor of shape (num_views, det_u, det_v), must be on a CUDA device and of type float32. angles : torch.Tensor 1D tensor of projection angles in radians, shape (num_views,), must be on the same CUDA device as `sinogram`. D : int Depth (z-dimension) of the output reconstruction volume. H : int Height (y-dimension) of the output reconstruction volume. W : int Width (x-dimension) of the output reconstruction volume. du : float Physical spacing between detector elements along the u-axis. dv : float Physical spacing between detector elements along the v-axis. sdd : float Source-to-Detector Distance (SDD). The total distance from the X-ray source to the detector, passing through the isocenter. sid : float Source-to-Isocenter Distance (SID). The distance from the X-ray source to the center of rotation (isocenter). voxel_spacing : float, optional Physical size of one voxel (in same units as du, dv, sdd, sid, default: 1.0). Returns ------- vol : torch.Tensor 3D tensor of shape (D, H, W) containing the reconstructed volume on the same device as `sinogram`. Notes ----- - All input tensors must be on the same CUDA device. - The operation is fully differentiable and supports autograd. - Cone beam geometry uses a point source and a 2D detector array. - Uses the Siddon method with interpolation for accurate 3D ray tracing and trilinear interpolation. Examples -------- >>> projections = torch.randn(360, 256, 256, device='cuda', requires_grad=True) >>> angles = torch.linspace(0, 2*torch.pi, 360, device='cuda') >>> vol = ConeBackprojectorFunction.apply( ... projections, angles, 128, 128, 128, 1.0, 1.0, 1500.0, 1000.0 ... ) """ device = DeviceManager.get_device(sinogram) sinogram = DeviceManager.ensure_device(sinogram, device) angles = DeviceManager.ensure_device(angles, device) sinogram = sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() n_views, n_u, n_v = sinogram.shape # Validate memory layout to prevent coordinate system inconsistencies _validate_3d_memory_layout(sinogram, expected_order='VHW') vol_perm = torch.zeros((W, H, D), dtype=sinogram.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=sinogram.dtype, device=device) d_sino = TorchCUDABridge.tensor_to_cuda_array(sinogram) d_reco = TorchCUDABridge.tensor_to_cuda_array(vol_perm) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_3d(n_views, n_u, n_v) cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_u_v = _DTYPE(detector_offset_u / voxel_spacing) det_offset_v_v = _DTYPE(detector_offset_v / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) center_offset_z_v = _DTYPE(center_offset_z / voxel_spacing) _cone_3d_backward_kernel[grid, tpb, numba_stream]( d_sino, n_views, n_u, n_v, d_reco, W, H, D, _DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing), det_offset_u_v, det_offset_v_v, center_offset_x_v, center_offset_y_v, center_offset_z_v, _DTYPE(1.0) ) ctx.save_for_backward(angles) ctx.intermediate = ( D, H, W, n_u, n_v, du, dv, sdd, sid, voxel_spacing, detector_offset_u, detector_offset_v, center_offset_x, center_offset_y, center_offset_z, ) vol = vol_perm.permute(2, 1, 0).contiguous() return vol
[docs] @staticmethod def backward(ctx, grad_output): angles, = ctx.saved_tensors ( D, H, W, n_u, n_v, du, dv, sdd, sid, voxel_spacing, detector_offset_u, detector_offset_v, center_offset_x, center_offset_y, center_offset_z, ) = ctx.intermediate device = DeviceManager.get_device(grad_output) grad_output = DeviceManager.ensure_device(grad_output, device) angles = DeviceManager.ensure_device(angles, device) grad_output = grad_output.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32).contiguous() n_views = angles.shape[0] grad_sino = torch.zeros((n_views, n_u, n_v), dtype=grad_output.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=grad_output.dtype, device=device) grad_output_perm = grad_output.permute(2, 1, 0).contiguous() d_grad_out = TorchCUDABridge.tensor_to_cuda_array(grad_output_perm) d_sino_grad = TorchCUDABridge.tensor_to_cuda_array(grad_sino) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_3d(n_views, n_u, n_v) cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) det_offset_u_v = _DTYPE(detector_offset_u / voxel_spacing) det_offset_v_v = _DTYPE(detector_offset_v / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) center_offset_z_v = _DTYPE(center_offset_z / voxel_spacing) _cone_3d_forward_kernel[grid, tpb, numba_stream]( d_grad_out, W, H, D, d_sino_grad, n_views, n_u, n_v, _DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing), det_offset_u_v, det_offset_v_v, center_offset_x_v, center_offset_y_v, center_offset_z_v ) return grad_sino, None, None, None, None, None, None, None, None, None, None, None, None, None, None
def fan_weighted_backproject( sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spacing=1.0, detector_offset=0.0, center_offset_x=0.0, center_offset_y=0.0, ): """Fan-beam weighted backprojection for analytical FBP pipelines. This uses the same Siddon traversal as `FanBackprojectorFunction` but enables geometric distance weighting in the accumulation step. """ if not sinogram.is_cuda: raise ValueError("sinogram must be on CUDA device") device = sinogram.device sinogram = sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32, device=device).contiguous() n_ang, n_det = sinogram.shape Ny, Nx = H, W reco = torch.zeros((Ny, Nx), dtype=sinogram.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=sinogram.dtype, device=device) d_sino = TorchCUDABridge.tensor_to_cuda_array(sinogram) d_reco = TorchCUDABridge.tensor_to_cuda_array(reco) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_2d(n_ang, n_det) cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5) det_offset_v = _DTYPE(detector_offset / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) _fan_2d_backward_kernel[grid, tpb, numba_stream]( d_sino, n_ang, n_det, d_reco, Nx, Ny, _DTYPE(detector_spacing), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing), det_offset_v, center_offset_x_v, center_offset_y_v, _DTYPE(1.0) ) return reco def cone_weighted_backproject( sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0, detector_offset_u=0.0, detector_offset_v=0.0, center_offset_x=0.0, center_offset_y=0.0, center_offset_z=0.0, ): """Cone-beam weighted backprojection for analytical FDK pipelines.""" if not sinogram.is_cuda: raise ValueError("sinogram must be on CUDA device") device = sinogram.device sinogram = sinogram.to(dtype=torch.float32).contiguous() angles = angles.to(dtype=torch.float32, device=device).contiguous() n_views, n_u, n_v = sinogram.shape _validate_3d_memory_layout(sinogram, expected_order='VHW') vol_perm = torch.zeros((W, H, D), dtype=sinogram.dtype, device=device) d_cos, d_sin = _trig_tables(angles, dtype=sinogram.dtype, device=device) d_sino = TorchCUDABridge.tensor_to_cuda_array(sinogram) d_reco = TorchCUDABridge.tensor_to_cuda_array(vol_perm) d_cos_arr = TorchCUDABridge.tensor_to_cuda_array(d_cos) d_sin_arr = TorchCUDABridge.tensor_to_cuda_array(d_sin) grid, tpb = _grid_3d(n_views, n_u, n_v) cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5) det_offset_u_v = _DTYPE(detector_offset_u / voxel_spacing) det_offset_v_v = _DTYPE(detector_offset_v / voxel_spacing) center_offset_x_v = _DTYPE(center_offset_x / voxel_spacing) center_offset_y_v = _DTYPE(center_offset_y / voxel_spacing) center_offset_z_v = _DTYPE(center_offset_z / voxel_spacing) pt_stream = torch.cuda.current_stream() numba_stream = _get_numba_external_stream_for(pt_stream) _cone_3d_backward_kernel[grid, tpb, numba_stream]( d_sino, n_views, n_u, n_v, d_reco, W, H, D, _DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr, _DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing), det_offset_u_v, det_offset_v_v, center_offset_x_v, center_offset_y_v, center_offset_z_v, _DTYPE(1.0) ) return vol_perm.permute(2, 1, 0).contiguous()