ai-content-maker/.venv/Lib/site-packages/Cython/Compiler/UFuncs.py

287 lines
9.0 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from . import (
Nodes,
ExprNodes,
FusedNode,
TreeFragment,
Pipeline,
ParseTreeTransforms,
Naming,
UtilNodes,
)
from .Errors import error
from . import PyrexTypes
from .UtilityCode import CythonUtilityCode
from .Code import TempitaUtilityCode, UtilityCode
from .Visitor import PrintTree, TreeVisitor, VisitorTransform
numpy_int_types = [
"NPY_BYTE",
"NPY_INT8",
"NPY_SHORT",
"NPY_INT16",
"NPY_INT",
"NPY_INT32",
"NPY_LONG",
"NPY_LONGLONG",
"NPY_INT64",
]
numpy_uint_types = [tp.replace("NPY_", "NPY_U") for tp in numpy_int_types]
# note: half float type is deliberately omitted
numpy_numeric_types = (
numpy_int_types
+ numpy_uint_types
+ [
"NPY_FLOAT",
"NPY_FLOAT32",
"NPY_DOUBLE",
"NPY_FLOAT64",
"NPY_LONGDOUBLE",
]
)
def _get_type_constant(pos, type_):
if type_.is_complex:
# 'is' checks don't seem to work for complex types
if type_ == PyrexTypes.c_float_complex_type:
return "NPY_CFLOAT"
elif type_ == PyrexTypes.c_double_complex_type:
return "NPY_CDOUBLE"
elif type_ == PyrexTypes.c_longdouble_complex_type:
return "NPY_CLONGDOUBLE"
elif type_.is_numeric:
postfix = type_.empty_declaration_code().upper().replace(" ", "")
typename = "NPY_%s" % postfix
if typename in numpy_numeric_types:
return typename
elif type_.is_pyobject:
return "NPY_OBJECT"
# TODO possible NPY_BOOL to bint but it needs a cast?
# TODO NPY_DATETIME, NPY_TIMEDELTA, NPY_STRING, NPY_UNICODE and maybe NPY_VOID might be handleable
error(pos, "Type '%s' cannot be used as a ufunc argument" % type_)
class _FindCFuncDefNode(TreeVisitor):
"""
Finds the CFuncDefNode in the tree
The assumption is that there's only one CFuncDefNode
"""
found_node = None
def visit_Node(self, node):
if self.found_node:
return
else:
self.visitchildren(node)
def visit_CFuncDefNode(self, node):
self.found_node = node
def __call__(self, tree):
self.visit(tree)
return self.found_node
def get_cfunc_from_tree(tree):
return _FindCFuncDefNode()(tree)
class _ArgumentInfo(object):
"""
Everything related to defining an input/output argument for a ufunc
type - PyrexType
type_constant - str such as "NPY_INT8" representing numpy dtype constants
"""
def __init__(self, type, type_constant):
self.type = type
self.type_constant = type_constant
class UFuncConversion(object):
def __init__(self, node):
self.node = node
self.global_scope = node.local_scope.global_scope()
self.in_definitions = self.get_in_type_info()
self.out_definitions = self.get_out_type_info()
def get_in_type_info(self):
definitions = []
for n, arg in enumerate(self.node.args):
type_const = _get_type_constant(self.node.pos, arg.type)
definitions.append(_ArgumentInfo(arg.type, type_const))
return definitions
def get_out_type_info(self):
if self.node.return_type.is_ctuple:
components = self.node.return_type.components
else:
components = [self.node.return_type]
definitions = []
for n, type in enumerate(components):
definitions.append(
_ArgumentInfo(type, _get_type_constant(self.node.pos, type))
)
return definitions
def generate_cy_utility_code(self):
arg_types = [a.type for a in self.in_definitions]
out_types = [a.type for a in self.out_definitions]
inline_func_decl = self.node.entry.type.declaration_code(
self.node.entry.cname, pyrex=True
)
self.node.entry.used = True
ufunc_cname = self.global_scope.next_id(self.node.entry.name + "_ufunc_def")
will_be_called_without_gil = not (any(t.is_pyobject for t in arg_types) or
any(t.is_pyobject for t in out_types))
context = dict(
func_cname=ufunc_cname,
in_types=arg_types,
out_types=out_types,
inline_func_call=self.node.entry.cname,
inline_func_declaration=inline_func_decl,
nogil=self.node.entry.type.nogil,
will_be_called_without_gil=will_be_called_without_gil,
)
code = CythonUtilityCode.load(
"UFuncDefinition",
"UFuncs.pyx",
context=context,
outer_module_scope=self.global_scope,
)
tree = code.get_tree(entries_only=True)
return tree
def use_generic_utility_code(self):
# use the invariant C utility code
self.global_scope.use_utility_code(
UtilityCode.load_cached("UFuncsInit", "UFuncs_C.c")
)
self.global_scope.use_utility_code(
UtilityCode.load_cached("NumpyImportUFunc", "NumpyImportArray.c")
)
def convert_to_ufunc(node):
if isinstance(node, Nodes.CFuncDefNode):
if node.local_scope.parent_scope.is_c_class_scope:
error(node.pos, "Methods cannot currently be converted to a ufunc")
return node
converters = [UFuncConversion(node)]
original_node = node
elif isinstance(node, FusedNode.FusedCFuncDefNode) and isinstance(
node.node, Nodes.CFuncDefNode
):
if node.node.local_scope.parent_scope.is_c_class_scope:
error(node.pos, "Methods cannot currently be converted to a ufunc")
return node
converters = [UFuncConversion(n) for n in node.nodes]
original_node = node.node
else:
error(node.pos, "Only C functions can be converted to a ufunc")
return node
if not converters:
return # this path probably shouldn't happen
del converters[0].global_scope.entries[original_node.entry.name]
# the generic utility code is generic, so there's no reason to do it multiple times
converters[0].use_generic_utility_code()
return [node] + _generate_stats_from_converters(converters, original_node)
def generate_ufunc_initialization(converters, cfunc_nodes, original_node):
global_scope = converters[0].global_scope
ufunc_funcs_name = global_scope.next_id(Naming.pyrex_prefix + "funcs")
ufunc_types_name = global_scope.next_id(Naming.pyrex_prefix + "types")
ufunc_data_name = global_scope.next_id(Naming.pyrex_prefix + "data")
type_constants = []
narg_in = None
narg_out = None
for c in converters:
in_const = [d.type_constant for d in c.in_definitions]
if narg_in is not None:
assert narg_in == len(in_const)
else:
narg_in = len(in_const)
type_constants.extend(in_const)
out_const = [d.type_constant for d in c.out_definitions]
if narg_out is not None:
assert narg_out == len(out_const)
else:
narg_out = len(out_const)
type_constants.extend(out_const)
func_cnames = [cfnode.entry.cname for cfnode in cfunc_nodes]
context = dict(
ufunc_funcs_name=ufunc_funcs_name,
func_cnames=func_cnames,
ufunc_types_name=ufunc_types_name,
type_constants=type_constants,
ufunc_data_name=ufunc_data_name,
)
global_scope.use_utility_code(
TempitaUtilityCode.load("UFuncConsts", "UFuncs_C.c", context=context)
)
pos = original_node.pos
func_name = original_node.entry.name
docstr = original_node.doc
args_to_func = '%s(), %s, %s(), %s, %s, %s, PyUFunc_None, "%s", %s, 0' % (
ufunc_funcs_name,
ufunc_data_name,
ufunc_types_name,
len(func_cnames),
narg_in,
narg_out,
func_name,
docstr.as_c_string_literal() if docstr else "NULL",
)
call_node = ExprNodes.PythonCapiCallNode(
pos,
function_name="PyUFunc_FromFuncAndData",
# use a dummy type because it's honestly too fiddly
func_type=PyrexTypes.CFuncType(
PyrexTypes.py_object_type,
[PyrexTypes.CFuncTypeArg("dummy", PyrexTypes.c_void_ptr_type, None)],
),
args=[
ExprNodes.ConstNode(
pos, type=PyrexTypes.c_void_ptr_type, value=args_to_func
)
],
)
lhs_entry = global_scope.declare_var(func_name, PyrexTypes.py_object_type, pos)
assgn_node = Nodes.SingleAssignmentNode(
pos,
lhs=ExprNodes.NameNode(
pos, name=func_name, type=PyrexTypes.py_object_type, entry=lhs_entry
),
rhs=call_node,
)
return assgn_node
def _generate_stats_from_converters(converters, node):
stats = []
for converter in converters:
tree = converter.generate_cy_utility_code()
ufunc_node = get_cfunc_from_tree(tree)
# merge in any utility code
converter.global_scope.utility_code_list.extend(tree.scope.utility_code_list)
stats.append(ufunc_node)
stats.append(generate_ufunc_initialization(converters, stats, node))
return stats