ai-content-maker/.venv/Lib/site-packages/scipy/fft/_basic_backend.py

177 lines
6.4 KiB
Python

from scipy._lib._array_api import (
array_namespace, is_numpy, xp_unsupported_param_msg, is_complex
)
from . import _pocketfft
import numpy as np
def _validate_fft_args(workers, plan, norm):
if workers is not None:
raise ValueError(xp_unsupported_param_msg("workers"))
if plan is not None:
raise ValueError(xp_unsupported_param_msg("plan"))
if norm is None:
norm = 'backward'
return norm
# pocketfft is used whenever SCIPY_ARRAY_API is not set,
# or x is a NumPy array or array-like.
# When SCIPY_ARRAY_API is set, we try to use xp.fft for CuPy arrays,
# PyTorch arrays and other array API standard supporting objects.
# If xp.fft does not exist, we attempt to convert to np and back to use pocketfft.
def _execute_1D(func_str, pocketfft_func, x, n, axis, norm, overwrite_x, workers, plan):
xp = array_namespace(x)
if is_numpy(xp):
return pocketfft_func(x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
norm = _validate_fft_args(workers, plan, norm)
if hasattr(xp, 'fft'):
xp_func = getattr(xp.fft, func_str)
return xp_func(x, n=n, axis=axis, norm=norm)
x = np.asarray(x)
y = pocketfft_func(x, n=n, axis=axis, norm=norm)
return xp.asarray(y)
def _execute_nD(func_str, pocketfft_func, x, s, axes, norm, overwrite_x, workers, plan):
xp = array_namespace(x)
if is_numpy(xp):
return pocketfft_func(x, s=s, axes=axes, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
norm = _validate_fft_args(workers, plan, norm)
if hasattr(xp, 'fft'):
xp_func = getattr(xp.fft, func_str)
return xp_func(x, s=s, axes=axes, norm=norm)
x = np.asarray(x)
y = pocketfft_func(x, s=s, axes=axes, norm=norm)
return xp.asarray(y)
def fft(x, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_1D('fft', _pocketfft.fft, x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, *,
plan=None):
return _execute_1D('ifft', _pocketfft.ifft, x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def rfft(x, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_1D('rfft', _pocketfft.rfft, x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def irfft(x, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_1D('irfft', _pocketfft.irfft, x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def hfft(x, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_1D('hfft', _pocketfft.hfft, x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def ihfft(x, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_1D('ihfft', _pocketfft.ihfft, x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def fftn(x, s=None, axes=None, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_nD('fftn', _pocketfft.fftn, x, s=s, axes=axes, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def ifftn(x, s=None, axes=None, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_nD('ifftn', _pocketfft.ifftn, x, s=s, axes=axes, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def fft2(x, s=None, axes=(-2, -1), norm=None,
overwrite_x=False, workers=None, *, plan=None):
return fftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
def ifft2(x, s=None, axes=(-2, -1), norm=None,
overwrite_x=False, workers=None, *, plan=None):
return ifftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
def rfftn(x, s=None, axes=None, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_nD('rfftn', _pocketfft.rfftn, x, s=s, axes=axes, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def rfft2(x, s=None, axes=(-2, -1), norm=None,
overwrite_x=False, workers=None, *, plan=None):
return rfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
def irfftn(x, s=None, axes=None, norm=None,
overwrite_x=False, workers=None, *, plan=None):
return _execute_nD('irfftn', _pocketfft.irfftn, x, s=s, axes=axes, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
def irfft2(x, s=None, axes=(-2, -1), norm=None,
overwrite_x=False, workers=None, *, plan=None):
return irfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
def _swap_direction(norm):
if norm in (None, 'backward'):
norm = 'forward'
elif norm == 'forward':
norm = 'backward'
elif norm != 'ortho':
raise ValueError('Invalid norm value %s; should be "backward", '
'"ortho", or "forward".' % norm)
return norm
def hfftn(x, s=None, axes=None, norm=None,
overwrite_x=False, workers=None, *, plan=None):
xp = array_namespace(x)
if is_numpy(xp):
return _pocketfft.hfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
if is_complex(x, xp):
x = xp.conj(x)
return irfftn(x, s, axes, _swap_direction(norm),
overwrite_x, workers, plan=plan)
def hfft2(x, s=None, axes=(-2, -1), norm=None,
overwrite_x=False, workers=None, *, plan=None):
return hfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
def ihfftn(x, s=None, axes=None, norm=None,
overwrite_x=False, workers=None, *, plan=None):
xp = array_namespace(x)
if is_numpy(xp):
return _pocketfft.ihfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)
return xp.conj(rfftn(x, s, axes, _swap_direction(norm),
overwrite_x, workers, plan=plan))
def ihfft2(x, s=None, axes=(-2, -1), norm=None,
overwrite_x=False, workers=None, *, plan=None):
return ihfftn(x, s, axes, norm, overwrite_x, workers, plan=plan)