Source code for ppftpy._ppft3

from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Final

from ._utils import _get_config, _get_pq_pz, _get_rpq_rpz, _verify_dtype_non_complex

if TYPE_CHECKING:
    from types import ModuleType

    import numpy as np
    from numpy.typing import NDArray

__MULTI_MODE_3D: Final = 4


[docs] def ppft3( data: NDArray, /, *, vectorized: bool = False, scipy_fft: bool = False ) -> NDArray[np.complex128 | np.complex256]: """Compute the 3D Pseudo-Polar Fourier Transform. This function computes the 3-dimensional Pseudo-Polar Fourier Transform [Averbuch2003]_. Parameters ---------- data The input data, either a single 3D matrix (shape ``(N, N, N)``) or a batch of 3D matrices (shape ``(V, N, N, N)``). Each matrix must be cubical with even length. The data can be either real or complex. scipy_fft If ``True``, uses SciPy's FFT backend (``scipy.fft``) or a registered backend via SciPy's backend control instead of the native array's FFT implementation (e.g., ``numpy.fft``). This may improve performance but requires SciPy to be installed. vectorized If ``True``, computes the transform using a fully vectorized approach. This is significantly faster but requires more memory. Use with caution for large datasets. Returns ------- The computed 3D Pseudo-Polar Fourier Transform. The output shape depends on the input: - If `data` has shape ``(N, N, N)``, the output has shape ``(3, 3N+1, N+1, N+1)``. - If `data` has shape ``(V, N, N, N)``, the output has shape ``(V, 3, 3N+1, N+1, N+1)``. Raises ------ ValueError If the input shape is not ``(N, N, N)`` or ``(V, N, N, N)`` with even ``N`` and arbitrary ``V``. ModuleNotFoundError If `scipy_fft=True` but SciPy is not installed. See Also -------- rppft3 : The 3D Pseudo-Polar Fourier Transform of real input. ppft2 : The 2D Pseudo-Polar Fourier Transform. References ---------- .. [Averbuch2003] A. Averbuch and Y. Shkolnisky, "3D Fourier based discrete Radon transform," Applied and Computational Harmonic Analysis, vol. 15, no. 1, pp. 33-69, Jul. 2003, issn: 1063-5203. `doi:10.1016/s1063-5203(03)00030-7 <https://doi.org/10.1016/s1063-5203(03)00030-7>`_. Examples -------- >>> from ppftpy import ppft3 >>> import numpy as np >>> arr = np.random.default_rng().random((4, 4, 4)) >>> ppft3(arr).shape (3, 13, 5, 5) Compute PPFT3D for 2 arrays. >>> from ppftpy import ppft3 >>> import numpy as np >>> arr = np.random.default_rng().random((2, 4, 4, 4)) >>> ppft3(arr).shape (2, 3, 13, 5, 5) """ __verify_data_shape(data) ppft_func = __ppft3_vectorized if vectorized else __ppft3_sequential return ppft_func(data, scipy_fft=scipy_fft, real_mode=False)
[docs] def rppft3( data: NDArray, /, *, vectorized: bool = False, scipy_fft: bool = False ) -> NDArray[np.complex128 | np.complex256]: """Compute the 3D Pseudo-Polar Fourier Transform for real input. This function computes the 3-dimensional Pseudo-Polar Fourier Transform [Averbuch2003]_ for real input. The real PPFT3D computes only the non-redundant half of the spectrum. Parameters ---------- data The input data, either a single 3D matrix (shape ``(N, N, N)``) or a batch of 3D matrices (shape ``(V, N, N, N)``). Each matrix must be cubical with even length. The data can be either real or complex. scipy_fft If ``True``, uses SciPy's FFT backend (``scipy.fft``) or a registered backend via SciPy's backend control instead of the native array's FFT implementation (e.g., ``numpy.fft``). This may improve performance but requires SciPy to be installed. vectorized If ``True``, computes the transform using a fully vectorized approach. This is significantly faster but requires more memory. Use with caution for large datasets. Returns ------- The computed 3D Pseudo-Polar Fourier Transform. The output shape depends on the input: - If `data` has shape ``(N, N, N)``, the output has shape ``(3, 3N+1, N+1, N+1)``. - If `data` has shape ``(V, N, N, N)``, the output has shape ``(V, 3, 3N+1, N+1, N+1)``. Raises ------ ValueError If the input shape is not ``(N, N, N)`` (single 3D matrix) or ``(V, N, N, N)`` (multiple 3D matrices) with even ``N`` and arbitrary ``V``. ModuleNotFoundError If `scipy_fft=True` but SciPy is not installed. TypeError If the input data is complex. See Also -------- ppft3 : The 3D Pseudo-Polar Fourier Transform. rppft2 : The 2D Pseudo-Polar Fourier Transform of real input. References ---------- .. [Averbuch2003] A. Averbuch and Y. Shkolnisky, "3D Fourier based discrete Radon transform," Applied and Computational Harmonic Analysis, vol. 15, no. 1, pp. 33-69, Jul. 2003, issn: 1063-5203. `doi:10.1016/s1063-5203(03)00030-7 <https://doi.org/10.1016/s1063-5203(03)00030-7>`_. Examples -------- >>> from ppftpy import rppft3 >>> import numpy as np >>> arr = np.random.default_rng().random((4, 4, 4)) >>> rppft3(arr).shape (3, 7, 5, 5) Compute PPFT3D for 2 arrays. >>> from ppftpy import rppft3 >>> import numpy as np >>> arr = np.random.default_rng().random((2, 4, 4, 4)) >>> rppft3(arr).shape (2, 3, 7, 5, 5) """ __verify_data_shape(data) _verify_dtype_non_complex(data) ppft_func = __ppft3_vectorized if vectorized else __ppft3_sequential return ppft_func(data, scipy_fft=scipy_fft, real_mode=True)
def __ppft3_vectorized( data: NDArray, *, scipy_fft: bool, real_mode: bool ) -> NDArray[np.complex128 | np.complex256]: xp, xp_inner, fft, device, rechunk, compute = _get_config(data, scipy_fft=scipy_fft) n = len(data[-1]) np = n + 1 np3 = np * 3 m = 3 * n + 1 mx = m // 2 + 1 if real_mode else m md = (2 * n + 1) - n // 2 idx = (..., slice(md, md + np), slice(None)) multi = data.ndim == __MULTI_MODE_3D amount = len(data) if multi else 1 sectors = amount * 3 pq_pz_func = _get_rpq_rpz if real_mode else _get_pq_pz pq, pz = pq_pz_func( n, dim=3, xp=xp, xp_inner=xp_inner, scipy_fft=scipy_fft, device=device ) pq = pq[None, :, :, None] pz = pz[None, :, :, None] out = xp.stack([data, xp.moveaxis(data, -2, -3), xp.moveaxis(data, -1, -3)]) out = xp.reshape(out, (-1, n, n, n)) zeros = xp_inner.zeros((sectors, n + 1, n, n), device=device) out = xp.concat([zeros[:, :-1], out, zeros], axis=1) out = fft.ifftshift(out, axes=1) if rechunk: out = out.rechunk({1: -1}) out = ( fft.rfft(out, axis=1) if real_mode else fft.fftshift(fft.fft(out, axis=1), axes=1) ) if compute: out = out.compute() zeros = xp_inner.zeros((sectors, mx, 1, np), device=device) out = xp.concat([xp.moveaxis(out, -2, -1), zeros[..., :-1]], axis=2) * pq if rechunk: out = out.rechunk({2: -1}) out = fft.ifft(fft.fft(out, n=np3, axis=2) * pz, axis=2)[idx] * pq if compute: out = out.compute() out = xp.concat([xp.moveaxis(out, -2, -1), zeros], axis=2) * pq if rechunk: out = out.rechunk({2: -1}) out = fft.ifft(fft.fft(out, n=np3, axis=2) * pz, axis=2)[idx] * pq if multi: out = xp.moveaxis(xp.reshape(out, (3, amount, mx, np, np)), 0, 1) return xp.flip(xp.flip(out, axis=-1), axis=-2) def __ppft3_sequential( data: NDArray, *, scipy_fft: bool, real_mode: bool ) -> NDArray[np.complex128 | np.complex256]: xp, xp_inner, fft, device, rechunk, compute = _get_config(data, scipy_fft=scipy_fft) n = len(data[-1]) np = n + 1 np3 = np * 3 m = 3 * n + 1 mx = m // 2 + 1 if real_mode else m md = (2 * n + 1) - n // 2 idx = slice(md, md + np) multi = data.ndim == __MULTI_MODE_3D amount = len(data) if multi else 1 pq_pz_func = _get_rpq_rpz if real_mode else _get_pq_pz pq, pz = pq_pz_func( n, dim=3, xp=xp, xp_inner=xp_inner, scipy_fft=scipy_fft, device=device ) pq = pq[..., None] pz = pz[..., None] data_stacked = data if multi else (data,) zeros_1 = xp_inner.zeros((n + 1, n, n), device=device) ops = (xp.moveaxis(x, i, 0) for x, i in itertools.product(data_stacked, range(3))) ops = (xp.concat([zeros_1[:-1], x, zeros_1]) for x in ops) ops = (fft.ifftshift(x, axes=0) for x in ops) if rechunk: ops = (x.rechunk({0: -1}) for x in ops) ops = ( fft.rfft(x, axis=0) if real_mode else fft.fftshift(fft.fft(x, axis=0), axes=0) for x in ops ) if compute: ops = (x.compute() for x in ops) zeros_2 = xp_inner.zeros((1, np), device=device) params = zeros_2, np3, idx, xp, fft out = xp.stack( [ __pp_sector(x, params, pq, pz, xp=xp, compute=compute, rechunk=rechunk) for x in ops ] ) if multi: out = xp.reshape(out, (amount, 3, mx, np, np)) return xp.flip(xp.flip(out, axis=-1), axis=-2) def __pp_sector( # noqa: PLR0913 x: NDArray, params: tuple[NDArray, int, slice, ModuleType, ModuleType], pq: NDArray, pz: NDArray, *, xp: ModuleType, compute: bool, rechunk: bool, ) -> NDArray[np.complex128 | np.complex256]: return xp.stack( [ __apply_qz( __apply_qz(_x, q, z, *params, compute=compute, rechunk=rechunk), q, z, *params, compute=False, rechunk=rechunk, ) for _x, q, z in zip(x, pq, pz, strict=False) ] ) def __apply_qz( # noqa: PLR0913 x: NDArray, q: NDArray, z: NDArray, zeros: NDArray, np3: int, idx: slice, xp: ModuleType, fft: ModuleType, *, rechunk: bool, compute: bool, ) -> NDArray[np.complex128 | np.complex256]: x = xp.concat([x.T, zeros[:, : len(x)]]) * q if rechunk: x = x.rechunk({0: -1}) x = fft.ifft(fft.fft(x, n=np3, axis=0) * z, axis=0)[idx] * q if compute: x = x.compute() return x def __verify_data_shape(data: NDArray) -> None: if data.ndim not in (3, 4): msg = "Input data must a single NxNxN matrix or an array of NxNxN matrices" raise ValueError(msg) inner = data[0] if data.ndim == __MULTI_MODE_3D else data if len(set(inner.shape)) != 1: msg = "Input data must have sides with same lengths" raise ValueError(msg) if len(inner) % 2 != 0: msg = "Input data must have even sides" raise ValueError(msg)