from . import ( Nodes, ExprNodes, FusedNode, TreeFragment, Pipeline, ParseTreeTransforms, Naming, UtilNodes, ) from .Errors import error from . import PyrexTypes from .UtilityCode import CythonUtilityCode from .Code import TempitaUtilityCode, UtilityCode from .Visitor import PrintTree, TreeVisitor, VisitorTransform numpy_int_types = [ "NPY_BYTE", "NPY_INT8", "NPY_SHORT", "NPY_INT16", "NPY_INT", "NPY_INT32", "NPY_LONG", "NPY_LONGLONG", "NPY_INT64", ] numpy_uint_types = [tp.replace("NPY_", "NPY_U") for tp in numpy_int_types] # note: half float type is deliberately omitted numpy_numeric_types = ( numpy_int_types + numpy_uint_types + [ "NPY_FLOAT", "NPY_FLOAT32", "NPY_DOUBLE", "NPY_FLOAT64", "NPY_LONGDOUBLE", ] ) def _get_type_constant(pos, type_): if type_.is_complex: # 'is' checks don't seem to work for complex types if type_ == PyrexTypes.c_float_complex_type: return "NPY_CFLOAT" elif type_ == PyrexTypes.c_double_complex_type: return "NPY_CDOUBLE" elif type_ == PyrexTypes.c_longdouble_complex_type: return "NPY_CLONGDOUBLE" elif type_.is_numeric: postfix = type_.empty_declaration_code().upper().replace(" ", "") typename = "NPY_%s" % postfix if typename in numpy_numeric_types: return typename elif type_.is_pyobject: return "NPY_OBJECT" # TODO possible NPY_BOOL to bint but it needs a cast? # TODO NPY_DATETIME, NPY_TIMEDELTA, NPY_STRING, NPY_UNICODE and maybe NPY_VOID might be handleable error(pos, "Type '%s' cannot be used as a ufunc argument" % type_) class _FindCFuncDefNode(TreeVisitor): """ Finds the CFuncDefNode in the tree The assumption is that there's only one CFuncDefNode """ found_node = None def visit_Node(self, node): if self.found_node: return else: self.visitchildren(node) def visit_CFuncDefNode(self, node): self.found_node = node def __call__(self, tree): self.visit(tree) return self.found_node def get_cfunc_from_tree(tree): return _FindCFuncDefNode()(tree) class _ArgumentInfo(object): """ Everything related to defining an input/output argument for a ufunc type - PyrexType type_constant - str such as "NPY_INT8" representing numpy dtype constants """ def __init__(self, type, type_constant): self.type = type self.type_constant = type_constant class UFuncConversion(object): def __init__(self, node): self.node = node self.global_scope = node.local_scope.global_scope() self.in_definitions = self.get_in_type_info() self.out_definitions = self.get_out_type_info() def get_in_type_info(self): definitions = [] for n, arg in enumerate(self.node.args): type_const = _get_type_constant(self.node.pos, arg.type) definitions.append(_ArgumentInfo(arg.type, type_const)) return definitions def get_out_type_info(self): if self.node.return_type.is_ctuple: components = self.node.return_type.components else: components = [self.node.return_type] definitions = [] for n, type in enumerate(components): definitions.append( _ArgumentInfo(type, _get_type_constant(self.node.pos, type)) ) return definitions def generate_cy_utility_code(self): arg_types = [a.type for a in self.in_definitions] out_types = [a.type for a in self.out_definitions] inline_func_decl = self.node.entry.type.declaration_code( self.node.entry.cname, pyrex=True ) self.node.entry.used = True ufunc_cname = self.global_scope.next_id(self.node.entry.name + "_ufunc_def") will_be_called_without_gil = not (any(t.is_pyobject for t in arg_types) or any(t.is_pyobject for t in out_types)) context = dict( func_cname=ufunc_cname, in_types=arg_types, out_types=out_types, inline_func_call=self.node.entry.cname, inline_func_declaration=inline_func_decl, nogil=self.node.entry.type.nogil, will_be_called_without_gil=will_be_called_without_gil, ) code = CythonUtilityCode.load( "UFuncDefinition", "UFuncs.pyx", context=context, outer_module_scope=self.global_scope, ) tree = code.get_tree(entries_only=True) return tree def use_generic_utility_code(self): # use the invariant C utility code self.global_scope.use_utility_code( UtilityCode.load_cached("UFuncsInit", "UFuncs_C.c") ) self.global_scope.use_utility_code( UtilityCode.load_cached("NumpyImportUFunc", "NumpyImportArray.c") ) def convert_to_ufunc(node): if isinstance(node, Nodes.CFuncDefNode): if node.local_scope.parent_scope.is_c_class_scope: error(node.pos, "Methods cannot currently be converted to a ufunc") return node converters = [UFuncConversion(node)] original_node = node elif isinstance(node, FusedNode.FusedCFuncDefNode) and isinstance( node.node, Nodes.CFuncDefNode ): if node.node.local_scope.parent_scope.is_c_class_scope: error(node.pos, "Methods cannot currently be converted to a ufunc") return node converters = [UFuncConversion(n) for n in node.nodes] original_node = node.node else: error(node.pos, "Only C functions can be converted to a ufunc") return node if not converters: return # this path probably shouldn't happen del converters[0].global_scope.entries[original_node.entry.name] # the generic utility code is generic, so there's no reason to do it multiple times converters[0].use_generic_utility_code() return [node] + _generate_stats_from_converters(converters, original_node) def generate_ufunc_initialization(converters, cfunc_nodes, original_node): global_scope = converters[0].global_scope ufunc_funcs_name = global_scope.next_id(Naming.pyrex_prefix + "funcs") ufunc_types_name = global_scope.next_id(Naming.pyrex_prefix + "types") ufunc_data_name = global_scope.next_id(Naming.pyrex_prefix + "data") type_constants = [] narg_in = None narg_out = None for c in converters: in_const = [d.type_constant for d in c.in_definitions] if narg_in is not None: assert narg_in == len(in_const) else: narg_in = len(in_const) type_constants.extend(in_const) out_const = [d.type_constant for d in c.out_definitions] if narg_out is not None: assert narg_out == len(out_const) else: narg_out = len(out_const) type_constants.extend(out_const) func_cnames = [cfnode.entry.cname for cfnode in cfunc_nodes] context = dict( ufunc_funcs_name=ufunc_funcs_name, func_cnames=func_cnames, ufunc_types_name=ufunc_types_name, type_constants=type_constants, ufunc_data_name=ufunc_data_name, ) global_scope.use_utility_code( TempitaUtilityCode.load("UFuncConsts", "UFuncs_C.c", context=context) ) pos = original_node.pos func_name = original_node.entry.name docstr = original_node.doc args_to_func = '%s(), %s, %s(), %s, %s, %s, PyUFunc_None, "%s", %s, 0' % ( ufunc_funcs_name, ufunc_data_name, ufunc_types_name, len(func_cnames), narg_in, narg_out, func_name, docstr.as_c_string_literal() if docstr else "NULL", ) call_node = ExprNodes.PythonCapiCallNode( pos, function_name="PyUFunc_FromFuncAndData", # use a dummy type because it's honestly too fiddly func_type=PyrexTypes.CFuncType( PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("dummy", PyrexTypes.c_void_ptr_type, None)], ), args=[ ExprNodes.ConstNode( pos, type=PyrexTypes.c_void_ptr_type, value=args_to_func ) ], ) lhs_entry = global_scope.declare_var(func_name, PyrexTypes.py_object_type, pos) assgn_node = Nodes.SingleAssignmentNode( pos, lhs=ExprNodes.NameNode( pos, name=func_name, type=PyrexTypes.py_object_type, entry=lhs_entry ), rhs=call_node, ) return assgn_node def _generate_stats_from_converters(converters, node): stats = [] for converter in converters: tree = converter.generate_cy_utility_code() ufunc_node = get_cfunc_from_tree(tree) # merge in any utility code converter.global_scope.utility_code_list.extend(tree.scope.utility_code_list) stats.append(ufunc_node) stats.append(generate_ufunc_initialization(converters, stats, node)) return stats