120 lines
4.5 KiB
Python
120 lines
4.5 KiB
Python
|
# Copyright (C) 2003-2005 Peter J. Verveer
|
||
|
#
|
||
|
# Redistribution and use in source and binary forms, with or without
|
||
|
# modification, are permitted provided that the following conditions
|
||
|
# are met:
|
||
|
#
|
||
|
# 1. Redistributions of source code must retain the above copyright
|
||
|
# notice, this list of conditions and the following disclaimer.
|
||
|
#
|
||
|
# 2. Redistributions in binary form must reproduce the above
|
||
|
# copyright notice, this list of conditions and the following
|
||
|
# disclaimer in the documentation and/or other materials provided
|
||
|
# with the distribution.
|
||
|
#
|
||
|
# 3. The name of the author may not be used to endorse or promote
|
||
|
# products derived from this software without specific prior
|
||
|
# written permission.
|
||
|
#
|
||
|
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
|
||
|
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||
|
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
|
||
|
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
||
|
# GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
||
|
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||
|
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||
|
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
|
||
|
from collections.abc import Iterable
|
||
|
import operator
|
||
|
import warnings
|
||
|
import numpy
|
||
|
|
||
|
|
||
|
def _extend_mode_to_code(mode):
|
||
|
"""Convert an extension mode to the corresponding integer code.
|
||
|
"""
|
||
|
if mode == 'nearest':
|
||
|
return 0
|
||
|
elif mode == 'wrap':
|
||
|
return 1
|
||
|
elif mode in ['reflect', 'grid-mirror']:
|
||
|
return 2
|
||
|
elif mode == 'mirror':
|
||
|
return 3
|
||
|
elif mode == 'constant':
|
||
|
return 4
|
||
|
elif mode == 'grid-wrap':
|
||
|
return 5
|
||
|
elif mode == 'grid-constant':
|
||
|
return 6
|
||
|
else:
|
||
|
raise RuntimeError('boundary mode not supported')
|
||
|
|
||
|
|
||
|
def _normalize_sequence(input, rank):
|
||
|
"""If input is a scalar, create a sequence of length equal to the
|
||
|
rank by duplicating the input. If input is a sequence,
|
||
|
check if its length is equal to the length of array.
|
||
|
"""
|
||
|
is_str = isinstance(input, str)
|
||
|
if not is_str and isinstance(input, Iterable):
|
||
|
normalized = list(input)
|
||
|
if len(normalized) != rank:
|
||
|
err = "sequence argument must have length equal to input rank"
|
||
|
raise RuntimeError(err)
|
||
|
else:
|
||
|
normalized = [input] * rank
|
||
|
return normalized
|
||
|
|
||
|
|
||
|
def _get_output(output, input, shape=None, complex_output=False):
|
||
|
if shape is None:
|
||
|
shape = input.shape
|
||
|
if output is None:
|
||
|
if not complex_output:
|
||
|
output = numpy.zeros(shape, dtype=input.dtype.name)
|
||
|
else:
|
||
|
complex_type = numpy.promote_types(input.dtype, numpy.complex64)
|
||
|
output = numpy.zeros(shape, dtype=complex_type)
|
||
|
elif isinstance(output, (type, numpy.dtype)):
|
||
|
# Classes (like `np.float32`) and dtypes are interpreted as dtype
|
||
|
if complex_output and numpy.dtype(output).kind != 'c':
|
||
|
warnings.warn("promoting specified output dtype to complex", stacklevel=3)
|
||
|
output = numpy.promote_types(output, numpy.complex64)
|
||
|
output = numpy.zeros(shape, dtype=output)
|
||
|
elif isinstance(output, str):
|
||
|
output = numpy.dtype(output)
|
||
|
if complex_output and output.kind != 'c':
|
||
|
raise RuntimeError("output must have complex dtype")
|
||
|
elif not issubclass(output.type, numpy.number):
|
||
|
raise RuntimeError("output must have numeric dtype")
|
||
|
output = numpy.zeros(shape, dtype=output)
|
||
|
elif output.shape != shape:
|
||
|
raise RuntimeError("output shape not correct")
|
||
|
elif complex_output and output.dtype.kind != 'c':
|
||
|
raise RuntimeError("output must have complex dtype")
|
||
|
return output
|
||
|
|
||
|
|
||
|
def _check_axes(axes, ndim):
|
||
|
if axes is None:
|
||
|
return tuple(range(ndim))
|
||
|
elif numpy.isscalar(axes):
|
||
|
axes = (operator.index(axes),)
|
||
|
elif isinstance(axes, Iterable):
|
||
|
for ax in axes:
|
||
|
axes = tuple(operator.index(ax) for ax in axes)
|
||
|
if ax < -ndim or ax > ndim - 1:
|
||
|
raise ValueError(f"specified axis: {ax} is out of range")
|
||
|
axes = tuple(ax % ndim if ax < 0 else ax for ax in axes)
|
||
|
else:
|
||
|
message = "axes must be an integer, iterable of integers, or None"
|
||
|
raise ValueError(message)
|
||
|
if len(tuple(set(axes))) != len(axes):
|
||
|
raise ValueError("axes must be unique")
|
||
|
return axes
|