ai-content-maker/.venv/Lib/site-packages/numba/np/ufunc/array_exprs.py

429 lines
16 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
import ast
from collections import defaultdict, OrderedDict
import contextlib
import sys
from types import SimpleNamespace
import numpy as np
import operator
from numba.core import types, targetconfig, ir, rewrites, compiler
from numba.core.typing import npydecl
from numba.np.ufunc.dufunc import DUFunc
def _is_ufunc(func):
return isinstance(func, (np.ufunc, DUFunc))
@rewrites.register_rewrite('after-inference')
class RewriteArrayExprs(rewrites.Rewrite):
'''The RewriteArrayExprs class is responsible for finding array
expressions in Numba intermediate representation code, and
rewriting those expressions to a single operation that will expand
into something similar to a ufunc call.
'''
def __init__(self, state, *args, **kws):
super(RewriteArrayExprs, self).__init__(state, *args, **kws)
# Install a lowering hook if we are using this rewrite.
special_ops = state.targetctx.special_ops
if 'arrayexpr' not in special_ops:
special_ops['arrayexpr'] = _lower_array_expr
def match(self, func_ir, block, typemap, calltypes):
"""
Using typing and a basic block, search the basic block for array
expressions.
Return True when one or more matches were found, False otherwise.
"""
# We can trivially reject everything if there are no
# calls in the type results.
if len(calltypes) == 0:
return False
self.crnt_block = block
self.typemap = typemap
# { variable name: IR assignment (of a function call or operator) }
self.array_assigns = OrderedDict()
# { variable name: IR assignment (of a constant) }
self.const_assigns = {}
assignments = block.find_insts(ir.Assign)
for instr in assignments:
target_name = instr.target.name
expr = instr.value
# Does it assign an expression to an array variable?
if (isinstance(expr, ir.Expr) and
isinstance(typemap.get(target_name, None), types.Array)):
self._match_array_expr(instr, expr, target_name)
elif isinstance(expr, ir.Const):
# Track constants since we might need them for an
# array expression.
self.const_assigns[target_name] = expr
return len(self.array_assigns) > 0
def _match_array_expr(self, instr, expr, target_name):
"""
Find whether the given assignment (*instr*) of an expression (*expr*)
to variable *target_name* is an array expression.
"""
# We've matched a subexpression assignment to an
# array variable. Now see if the expression is an
# array expression.
expr_op = expr.op
array_assigns = self.array_assigns
if ((expr_op in ('unary', 'binop')) and (
expr.fn in npydecl.supported_array_operators)):
# It is an array operator that maps to a ufunc.
# check that all args have internal types
if all(self.typemap[var.name].is_internal
for var in expr.list_vars()):
array_assigns[target_name] = instr
elif ((expr_op == 'call') and (expr.func.name in self.typemap)):
# It could be a match for a known ufunc call.
func_type = self.typemap[expr.func.name]
if isinstance(func_type, types.Function):
func_key = func_type.typing_key
if _is_ufunc(func_key):
# If so, check whether an explicit output is passed.
if not self._has_explicit_output(expr, func_key):
# If not, match it as a (sub)expression.
array_assigns[target_name] = instr
def _has_explicit_output(self, expr, func):
"""
Return whether the *expr* call to *func* (a ufunc) features an
explicit output argument.
"""
nargs = len(expr.args) + len(expr.kws)
if expr.vararg is not None:
# XXX *args unsupported here, assume there may be an explicit
# output
return True
return nargs > func.nin
def _get_array_operator(self, ir_expr):
ir_op = ir_expr.op
if ir_op in ('unary', 'binop'):
return ir_expr.fn
elif ir_op == 'call':
return self.typemap[ir_expr.func.name].typing_key
raise NotImplementedError(
"Don't know how to find the operator for '{0}' expressions.".format(
ir_op))
def _get_operands(self, ir_expr):
'''Given a Numba IR expression, return the operands to the expression
in order they appear in the expression.
'''
ir_op = ir_expr.op
if ir_op == 'binop':
return ir_expr.lhs, ir_expr.rhs
elif ir_op == 'unary':
return ir_expr.list_vars()
elif ir_op == 'call':
return ir_expr.args
raise NotImplementedError(
"Don't know how to find the operands for '{0}' expressions.".format(
ir_op))
def _translate_expr(self, ir_expr):
'''Translate the given expression from Numba IR to an array expression
tree.
'''
ir_op = ir_expr.op
if ir_op == 'arrayexpr':
return ir_expr.expr
operands_or_args = [self.const_assigns.get(op_var.name, op_var)
for op_var in self._get_operands(ir_expr)]
return self._get_array_operator(ir_expr), operands_or_args
def _handle_matches(self):
'''Iterate over the matches, trying to find which instructions should
be rewritten, deleted, or moved.
'''
replace_map = {}
dead_vars = set()
used_vars = defaultdict(int)
for instr in self.array_assigns.values():
expr = instr.value
arr_inps = []
arr_expr = self._get_array_operator(expr), arr_inps
new_expr = ir.Expr(op='arrayexpr',
loc=expr.loc,
expr=arr_expr,
ty=self.typemap[instr.target.name])
new_instr = ir.Assign(new_expr, instr.target, instr.loc)
replace_map[instr] = new_instr
self.array_assigns[instr.target.name] = new_instr
for operand in self._get_operands(expr):
operand_name = operand.name
if operand.is_temp and operand_name in self.array_assigns:
child_assign = self.array_assigns[operand_name]
child_expr = child_assign.value
child_operands = child_expr.list_vars()
for operand in child_operands:
used_vars[operand.name] += 1
arr_inps.append(self._translate_expr(child_expr))
if child_assign.target.is_temp:
dead_vars.add(child_assign.target.name)
replace_map[child_assign] = None
elif operand_name in self.const_assigns:
arr_inps.append(self.const_assigns[operand_name])
else:
used_vars[operand.name] += 1
arr_inps.append(operand)
return replace_map, dead_vars, used_vars
def _get_final_replacement(self, replacement_map, instr):
'''Find the final replacement instruction for a given initial
instruction by chasing instructions in a map from instructions
to replacement instructions.
'''
replacement = replacement_map[instr]
while replacement in replacement_map:
replacement = replacement_map[replacement]
return replacement
def apply(self):
'''When we've found array expressions in a basic block, rewrite that
block, returning a new, transformed block.
'''
# Part 1: Figure out what instructions should be rewritten
# based on the matches found.
replace_map, dead_vars, used_vars = self._handle_matches()
# Part 2: Using the information above, rewrite the target
# basic block.
result = self.crnt_block.copy()
result.clear()
delete_map = {}
for instr in self.crnt_block.body:
if isinstance(instr, ir.Assign):
if instr in replace_map:
replacement = self._get_final_replacement(
replace_map, instr)
if replacement:
result.append(replacement)
for var in replacement.value.list_vars():
var_name = var.name
if var_name in delete_map:
result.append(delete_map.pop(var_name))
if used_vars[var_name] > 0:
used_vars[var_name] -= 1
else:
result.append(instr)
elif isinstance(instr, ir.Del):
instr_value = instr.value
if used_vars[instr_value] > 0:
used_vars[instr_value] -= 1
delete_map[instr_value] = instr
elif instr_value not in dead_vars:
result.append(instr)
else:
result.append(instr)
if delete_map:
for instr in delete_map.values():
result.insert_before_terminator(instr)
return result
_unaryops = {
operator.pos: ast.UAdd,
operator.neg: ast.USub,
operator.invert: ast.Invert,
}
_binops = {
operator.add: ast.Add,
operator.sub: ast.Sub,
operator.mul: ast.Mult,
operator.truediv: ast.Div,
operator.mod: ast.Mod,
operator.or_: ast.BitOr,
operator.rshift: ast.RShift,
operator.xor: ast.BitXor,
operator.lshift: ast.LShift,
operator.and_: ast.BitAnd,
operator.pow: ast.Pow,
operator.floordiv: ast.FloorDiv,
}
_cmpops = {
operator.eq: ast.Eq,
operator.ne: ast.NotEq,
operator.lt: ast.Lt,
operator.le: ast.LtE,
operator.gt: ast.Gt,
operator.ge: ast.GtE,
}
def _arr_expr_to_ast(expr):
'''Build a Python expression AST from an array expression built by
RewriteArrayExprs.
'''
if isinstance(expr, tuple):
op, arr_expr_args = expr
ast_args = []
env = {}
for arg in arr_expr_args:
ast_arg, child_env = _arr_expr_to_ast(arg)
ast_args.append(ast_arg)
env.update(child_env)
if op in npydecl.supported_array_operators:
if len(ast_args) == 2:
if op in _binops:
return ast.BinOp(
ast_args[0], _binops[op](), ast_args[1]), env
if op in _cmpops:
return ast.Compare(
ast_args[0], [_cmpops[op]()], [ast_args[1]]), env
else:
assert op in _unaryops
return ast.UnaryOp(_unaryops[op](), ast_args[0]), env
elif _is_ufunc(op):
fn_name = "__ufunc_or_dufunc_{0}".format(
hex(hash(op)).replace("-", "_"))
fn_ast_name = ast.Name(fn_name, ast.Load())
env[fn_name] = op # Stash the ufunc or DUFunc in the environment
ast_call = ast.Call(fn_ast_name, ast_args, [])
return ast_call, env
elif isinstance(expr, ir.Var):
return ast.Name(expr.name, ast.Load(),
lineno=expr.loc.line,
col_offset=expr.loc.col if expr.loc.col else 0), {}
elif isinstance(expr, ir.Const):
return ast.Constant(expr.value), {}
raise NotImplementedError(
"Don't know how to translate array expression '%r'" % (expr,))
@contextlib.contextmanager
def _legalize_parameter_names(var_list):
"""
Legalize names in the variable list for use as a Python function's
parameter names.
"""
var_map = OrderedDict()
for var in var_list:
old_name = var.name
new_name = var.scope.redefine(old_name, loc=var.loc).name
new_name = new_name.replace("$", "_").replace(".", "_")
# Caller should ensure the names are unique
if new_name in var_map:
raise AssertionError(f"{new_name!r} not unique")
var_map[new_name] = var, old_name
var.name = new_name
param_names = list(var_map)
try:
yield param_names
finally:
# Make sure the old names are restored, to avoid confusing
# other parts of Numba (see issue #1466)
for var, old_name in var_map.values():
var.name = old_name
class _EraseInvalidLineRanges(ast.NodeTransformer):
def generic_visit(self, node: ast.AST) -> ast.AST:
node = super().generic_visit(node)
if hasattr(node, "lineno"):
if getattr(node, "end_lineno", None) is not None:
if node.lineno > node.end_lineno:
del node.lineno
del node.end_lineno
return node
def _fix_invalid_lineno_ranges(astree: ast.AST):
"""Inplace fixes invalid lineno ranges.
"""
# Make sure lineno and end_lineno are present
ast.fix_missing_locations(astree)
# Delete invalid lineno ranges
_EraseInvalidLineRanges().visit(astree)
# Make sure lineno and end_lineno are present
ast.fix_missing_locations(astree)
def _lower_array_expr(lowerer, expr):
'''Lower an array expression built by RewriteArrayExprs.
'''
expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
expr_filename = expr.loc.filename
expr_var_list = expr.list_vars()
# The expression may use a given variable several times, but we
# should only create one parameter for it.
expr_var_unique = sorted(set(expr_var_list), key=lambda var: var.name)
# Arguments are the names external to the new closure
expr_args = [var.name for var in expr_var_unique]
# 1. Create an AST tree from the array expression.
with _legalize_parameter_names(expr_var_unique) as expr_params:
ast_args = [ast.arg(param_name, None)
for param_name in expr_params]
# Parse a stub function to ensure the AST is populated with
# reasonable defaults for the Python version.
ast_module = ast.parse('def {0}(): return'.format(expr_name),
expr_filename, 'exec')
assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
ast_fn = ast_module.body[0]
ast_fn.args.args = ast_args
ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
_fix_invalid_lineno_ranges(ast_module)
# 2. Compile the AST module and extract the Python function.
code_obj = compile(ast_module, expr_filename, 'exec')
exec(code_obj, namespace)
impl = namespace[expr_name]
# 3. Now compile a ufunc using the Python function as kernel.
context = lowerer.context
builder = lowerer.builder
outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_args))
inner_sig_args = []
for argty in outer_sig.args:
if isinstance(argty, types.Optional):
argty = argty.type
if isinstance(argty, types.Array):
inner_sig_args.append(argty.dtype)
else:
inner_sig_args.append(argty)
inner_sig = outer_sig.return_type.dtype(*inner_sig_args)
flags = targetconfig.ConfigStack().top_or_none()
flags = compiler.Flags() if flags is None else flags.copy() # make sure it's a clone or a fresh instance
# Follow the Numpy error model. Note this also allows e.g. vectorizing
# division (issue #1223).
flags.error_model = 'numpy'
cres = context.compile_subroutine(builder, impl, inner_sig, flags=flags,
caching=False)
# Create kernel subclass calling our native function
from numba.np import npyimpl
class ExprKernel(npyimpl._Kernel):
def generate(self, *args):
arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
cast_args = [self.cast(val, inty, outty)
for val, inty, outty in arg_zip]
result = self.context.call_internal(
builder, cres.fndesc, inner_sig, cast_args)
return self.cast(result, inner_sig.return_type,
self.outer_sig.return_type)
# create a fake ufunc object which is enough to trick numpy_ufunc_kernel
ufunc = SimpleNamespace(nin=len(expr_args), nout=1, __name__=expr_name)
ufunc.nargs = ufunc.nin + ufunc.nout
args = [lowerer.loadvar(name) for name in expr_args]
return npyimpl.numpy_ufunc_kernel(
context, builder, outer_sig, args, ufunc, ExprKernel)