429 lines
16 KiB
Python
429 lines
16 KiB
Python
|
import ast
|
||
|
from collections import defaultdict, OrderedDict
|
||
|
import contextlib
|
||
|
import sys
|
||
|
from types import SimpleNamespace
|
||
|
|
||
|
import numpy as np
|
||
|
import operator
|
||
|
|
||
|
from numba.core import types, targetconfig, ir, rewrites, compiler
|
||
|
from numba.core.typing import npydecl
|
||
|
from numba.np.ufunc.dufunc import DUFunc
|
||
|
|
||
|
|
||
|
def _is_ufunc(func):
|
||
|
return isinstance(func, (np.ufunc, DUFunc))
|
||
|
|
||
|
|
||
|
@rewrites.register_rewrite('after-inference')
|
||
|
class RewriteArrayExprs(rewrites.Rewrite):
|
||
|
'''The RewriteArrayExprs class is responsible for finding array
|
||
|
expressions in Numba intermediate representation code, and
|
||
|
rewriting those expressions to a single operation that will expand
|
||
|
into something similar to a ufunc call.
|
||
|
'''
|
||
|
def __init__(self, state, *args, **kws):
|
||
|
super(RewriteArrayExprs, self).__init__(state, *args, **kws)
|
||
|
# Install a lowering hook if we are using this rewrite.
|
||
|
special_ops = state.targetctx.special_ops
|
||
|
if 'arrayexpr' not in special_ops:
|
||
|
special_ops['arrayexpr'] = _lower_array_expr
|
||
|
|
||
|
def match(self, func_ir, block, typemap, calltypes):
|
||
|
"""
|
||
|
Using typing and a basic block, search the basic block for array
|
||
|
expressions.
|
||
|
Return True when one or more matches were found, False otherwise.
|
||
|
"""
|
||
|
# We can trivially reject everything if there are no
|
||
|
# calls in the type results.
|
||
|
if len(calltypes) == 0:
|
||
|
return False
|
||
|
|
||
|
self.crnt_block = block
|
||
|
self.typemap = typemap
|
||
|
# { variable name: IR assignment (of a function call or operator) }
|
||
|
self.array_assigns = OrderedDict()
|
||
|
# { variable name: IR assignment (of a constant) }
|
||
|
self.const_assigns = {}
|
||
|
|
||
|
assignments = block.find_insts(ir.Assign)
|
||
|
for instr in assignments:
|
||
|
target_name = instr.target.name
|
||
|
expr = instr.value
|
||
|
# Does it assign an expression to an array variable?
|
||
|
if (isinstance(expr, ir.Expr) and
|
||
|
isinstance(typemap.get(target_name, None), types.Array)):
|
||
|
self._match_array_expr(instr, expr, target_name)
|
||
|
elif isinstance(expr, ir.Const):
|
||
|
# Track constants since we might need them for an
|
||
|
# array expression.
|
||
|
self.const_assigns[target_name] = expr
|
||
|
|
||
|
return len(self.array_assigns) > 0
|
||
|
|
||
|
def _match_array_expr(self, instr, expr, target_name):
|
||
|
"""
|
||
|
Find whether the given assignment (*instr*) of an expression (*expr*)
|
||
|
to variable *target_name* is an array expression.
|
||
|
"""
|
||
|
# We've matched a subexpression assignment to an
|
||
|
# array variable. Now see if the expression is an
|
||
|
# array expression.
|
||
|
expr_op = expr.op
|
||
|
array_assigns = self.array_assigns
|
||
|
|
||
|
if ((expr_op in ('unary', 'binop')) and (
|
||
|
expr.fn in npydecl.supported_array_operators)):
|
||
|
# It is an array operator that maps to a ufunc.
|
||
|
# check that all args have internal types
|
||
|
if all(self.typemap[var.name].is_internal
|
||
|
for var in expr.list_vars()):
|
||
|
array_assigns[target_name] = instr
|
||
|
|
||
|
elif ((expr_op == 'call') and (expr.func.name in self.typemap)):
|
||
|
# It could be a match for a known ufunc call.
|
||
|
func_type = self.typemap[expr.func.name]
|
||
|
if isinstance(func_type, types.Function):
|
||
|
func_key = func_type.typing_key
|
||
|
if _is_ufunc(func_key):
|
||
|
# If so, check whether an explicit output is passed.
|
||
|
if not self._has_explicit_output(expr, func_key):
|
||
|
# If not, match it as a (sub)expression.
|
||
|
array_assigns[target_name] = instr
|
||
|
|
||
|
def _has_explicit_output(self, expr, func):
|
||
|
"""
|
||
|
Return whether the *expr* call to *func* (a ufunc) features an
|
||
|
explicit output argument.
|
||
|
"""
|
||
|
nargs = len(expr.args) + len(expr.kws)
|
||
|
if expr.vararg is not None:
|
||
|
# XXX *args unsupported here, assume there may be an explicit
|
||
|
# output
|
||
|
return True
|
||
|
return nargs > func.nin
|
||
|
|
||
|
def _get_array_operator(self, ir_expr):
|
||
|
ir_op = ir_expr.op
|
||
|
if ir_op in ('unary', 'binop'):
|
||
|
return ir_expr.fn
|
||
|
elif ir_op == 'call':
|
||
|
return self.typemap[ir_expr.func.name].typing_key
|
||
|
raise NotImplementedError(
|
||
|
"Don't know how to find the operator for '{0}' expressions.".format(
|
||
|
ir_op))
|
||
|
|
||
|
def _get_operands(self, ir_expr):
|
||
|
'''Given a Numba IR expression, return the operands to the expression
|
||
|
in order they appear in the expression.
|
||
|
'''
|
||
|
ir_op = ir_expr.op
|
||
|
if ir_op == 'binop':
|
||
|
return ir_expr.lhs, ir_expr.rhs
|
||
|
elif ir_op == 'unary':
|
||
|
return ir_expr.list_vars()
|
||
|
elif ir_op == 'call':
|
||
|
return ir_expr.args
|
||
|
raise NotImplementedError(
|
||
|
"Don't know how to find the operands for '{0}' expressions.".format(
|
||
|
ir_op))
|
||
|
|
||
|
def _translate_expr(self, ir_expr):
|
||
|
'''Translate the given expression from Numba IR to an array expression
|
||
|
tree.
|
||
|
'''
|
||
|
ir_op = ir_expr.op
|
||
|
if ir_op == 'arrayexpr':
|
||
|
return ir_expr.expr
|
||
|
operands_or_args = [self.const_assigns.get(op_var.name, op_var)
|
||
|
for op_var in self._get_operands(ir_expr)]
|
||
|
return self._get_array_operator(ir_expr), operands_or_args
|
||
|
|
||
|
def _handle_matches(self):
|
||
|
'''Iterate over the matches, trying to find which instructions should
|
||
|
be rewritten, deleted, or moved.
|
||
|
'''
|
||
|
replace_map = {}
|
||
|
dead_vars = set()
|
||
|
used_vars = defaultdict(int)
|
||
|
for instr in self.array_assigns.values():
|
||
|
expr = instr.value
|
||
|
arr_inps = []
|
||
|
arr_expr = self._get_array_operator(expr), arr_inps
|
||
|
new_expr = ir.Expr(op='arrayexpr',
|
||
|
loc=expr.loc,
|
||
|
expr=arr_expr,
|
||
|
ty=self.typemap[instr.target.name])
|
||
|
new_instr = ir.Assign(new_expr, instr.target, instr.loc)
|
||
|
replace_map[instr] = new_instr
|
||
|
self.array_assigns[instr.target.name] = new_instr
|
||
|
for operand in self._get_operands(expr):
|
||
|
operand_name = operand.name
|
||
|
if operand.is_temp and operand_name in self.array_assigns:
|
||
|
child_assign = self.array_assigns[operand_name]
|
||
|
child_expr = child_assign.value
|
||
|
child_operands = child_expr.list_vars()
|
||
|
for operand in child_operands:
|
||
|
used_vars[operand.name] += 1
|
||
|
arr_inps.append(self._translate_expr(child_expr))
|
||
|
if child_assign.target.is_temp:
|
||
|
dead_vars.add(child_assign.target.name)
|
||
|
replace_map[child_assign] = None
|
||
|
elif operand_name in self.const_assigns:
|
||
|
arr_inps.append(self.const_assigns[operand_name])
|
||
|
else:
|
||
|
used_vars[operand.name] += 1
|
||
|
arr_inps.append(operand)
|
||
|
return replace_map, dead_vars, used_vars
|
||
|
|
||
|
def _get_final_replacement(self, replacement_map, instr):
|
||
|
'''Find the final replacement instruction for a given initial
|
||
|
instruction by chasing instructions in a map from instructions
|
||
|
to replacement instructions.
|
||
|
'''
|
||
|
replacement = replacement_map[instr]
|
||
|
while replacement in replacement_map:
|
||
|
replacement = replacement_map[replacement]
|
||
|
return replacement
|
||
|
|
||
|
def apply(self):
|
||
|
'''When we've found array expressions in a basic block, rewrite that
|
||
|
block, returning a new, transformed block.
|
||
|
'''
|
||
|
# Part 1: Figure out what instructions should be rewritten
|
||
|
# based on the matches found.
|
||
|
replace_map, dead_vars, used_vars = self._handle_matches()
|
||
|
# Part 2: Using the information above, rewrite the target
|
||
|
# basic block.
|
||
|
result = self.crnt_block.copy()
|
||
|
result.clear()
|
||
|
delete_map = {}
|
||
|
for instr in self.crnt_block.body:
|
||
|
if isinstance(instr, ir.Assign):
|
||
|
if instr in replace_map:
|
||
|
replacement = self._get_final_replacement(
|
||
|
replace_map, instr)
|
||
|
if replacement:
|
||
|
result.append(replacement)
|
||
|
for var in replacement.value.list_vars():
|
||
|
var_name = var.name
|
||
|
if var_name in delete_map:
|
||
|
result.append(delete_map.pop(var_name))
|
||
|
if used_vars[var_name] > 0:
|
||
|
used_vars[var_name] -= 1
|
||
|
|
||
|
else:
|
||
|
result.append(instr)
|
||
|
elif isinstance(instr, ir.Del):
|
||
|
instr_value = instr.value
|
||
|
if used_vars[instr_value] > 0:
|
||
|
used_vars[instr_value] -= 1
|
||
|
delete_map[instr_value] = instr
|
||
|
elif instr_value not in dead_vars:
|
||
|
result.append(instr)
|
||
|
else:
|
||
|
result.append(instr)
|
||
|
if delete_map:
|
||
|
for instr in delete_map.values():
|
||
|
result.insert_before_terminator(instr)
|
||
|
return result
|
||
|
|
||
|
|
||
|
_unaryops = {
|
||
|
operator.pos: ast.UAdd,
|
||
|
operator.neg: ast.USub,
|
||
|
operator.invert: ast.Invert,
|
||
|
}
|
||
|
|
||
|
_binops = {
|
||
|
operator.add: ast.Add,
|
||
|
operator.sub: ast.Sub,
|
||
|
operator.mul: ast.Mult,
|
||
|
operator.truediv: ast.Div,
|
||
|
operator.mod: ast.Mod,
|
||
|
operator.or_: ast.BitOr,
|
||
|
operator.rshift: ast.RShift,
|
||
|
operator.xor: ast.BitXor,
|
||
|
operator.lshift: ast.LShift,
|
||
|
operator.and_: ast.BitAnd,
|
||
|
operator.pow: ast.Pow,
|
||
|
operator.floordiv: ast.FloorDiv,
|
||
|
}
|
||
|
|
||
|
|
||
|
_cmpops = {
|
||
|
operator.eq: ast.Eq,
|
||
|
operator.ne: ast.NotEq,
|
||
|
operator.lt: ast.Lt,
|
||
|
operator.le: ast.LtE,
|
||
|
operator.gt: ast.Gt,
|
||
|
operator.ge: ast.GtE,
|
||
|
}
|
||
|
|
||
|
|
||
|
def _arr_expr_to_ast(expr):
|
||
|
'''Build a Python expression AST from an array expression built by
|
||
|
RewriteArrayExprs.
|
||
|
'''
|
||
|
if isinstance(expr, tuple):
|
||
|
op, arr_expr_args = expr
|
||
|
ast_args = []
|
||
|
env = {}
|
||
|
for arg in arr_expr_args:
|
||
|
ast_arg, child_env = _arr_expr_to_ast(arg)
|
||
|
ast_args.append(ast_arg)
|
||
|
env.update(child_env)
|
||
|
if op in npydecl.supported_array_operators:
|
||
|
if len(ast_args) == 2:
|
||
|
if op in _binops:
|
||
|
return ast.BinOp(
|
||
|
ast_args[0], _binops[op](), ast_args[1]), env
|
||
|
if op in _cmpops:
|
||
|
return ast.Compare(
|
||
|
ast_args[0], [_cmpops[op]()], [ast_args[1]]), env
|
||
|
else:
|
||
|
assert op in _unaryops
|
||
|
return ast.UnaryOp(_unaryops[op](), ast_args[0]), env
|
||
|
elif _is_ufunc(op):
|
||
|
fn_name = "__ufunc_or_dufunc_{0}".format(
|
||
|
hex(hash(op)).replace("-", "_"))
|
||
|
fn_ast_name = ast.Name(fn_name, ast.Load())
|
||
|
env[fn_name] = op # Stash the ufunc or DUFunc in the environment
|
||
|
ast_call = ast.Call(fn_ast_name, ast_args, [])
|
||
|
return ast_call, env
|
||
|
elif isinstance(expr, ir.Var):
|
||
|
return ast.Name(expr.name, ast.Load(),
|
||
|
lineno=expr.loc.line,
|
||
|
col_offset=expr.loc.col if expr.loc.col else 0), {}
|
||
|
elif isinstance(expr, ir.Const):
|
||
|
return ast.Constant(expr.value), {}
|
||
|
raise NotImplementedError(
|
||
|
"Don't know how to translate array expression '%r'" % (expr,))
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _legalize_parameter_names(var_list):
|
||
|
"""
|
||
|
Legalize names in the variable list for use as a Python function's
|
||
|
parameter names.
|
||
|
"""
|
||
|
var_map = OrderedDict()
|
||
|
for var in var_list:
|
||
|
old_name = var.name
|
||
|
new_name = var.scope.redefine(old_name, loc=var.loc).name
|
||
|
new_name = new_name.replace("$", "_").replace(".", "_")
|
||
|
# Caller should ensure the names are unique
|
||
|
if new_name in var_map:
|
||
|
raise AssertionError(f"{new_name!r} not unique")
|
||
|
var_map[new_name] = var, old_name
|
||
|
var.name = new_name
|
||
|
param_names = list(var_map)
|
||
|
try:
|
||
|
yield param_names
|
||
|
finally:
|
||
|
# Make sure the old names are restored, to avoid confusing
|
||
|
# other parts of Numba (see issue #1466)
|
||
|
for var, old_name in var_map.values():
|
||
|
var.name = old_name
|
||
|
|
||
|
|
||
|
class _EraseInvalidLineRanges(ast.NodeTransformer):
|
||
|
def generic_visit(self, node: ast.AST) -> ast.AST:
|
||
|
node = super().generic_visit(node)
|
||
|
if hasattr(node, "lineno"):
|
||
|
if getattr(node, "end_lineno", None) is not None:
|
||
|
if node.lineno > node.end_lineno:
|
||
|
del node.lineno
|
||
|
del node.end_lineno
|
||
|
return node
|
||
|
|
||
|
|
||
|
def _fix_invalid_lineno_ranges(astree: ast.AST):
|
||
|
"""Inplace fixes invalid lineno ranges.
|
||
|
"""
|
||
|
# Make sure lineno and end_lineno are present
|
||
|
ast.fix_missing_locations(astree)
|
||
|
# Delete invalid lineno ranges
|
||
|
_EraseInvalidLineRanges().visit(astree)
|
||
|
# Make sure lineno and end_lineno are present
|
||
|
ast.fix_missing_locations(astree)
|
||
|
|
||
|
|
||
|
def _lower_array_expr(lowerer, expr):
|
||
|
'''Lower an array expression built by RewriteArrayExprs.
|
||
|
'''
|
||
|
expr_name = "__numba_array_expr_%s" % (hex(hash(expr)).replace("-", "_"))
|
||
|
expr_filename = expr.loc.filename
|
||
|
expr_var_list = expr.list_vars()
|
||
|
# The expression may use a given variable several times, but we
|
||
|
# should only create one parameter for it.
|
||
|
expr_var_unique = sorted(set(expr_var_list), key=lambda var: var.name)
|
||
|
|
||
|
# Arguments are the names external to the new closure
|
||
|
expr_args = [var.name for var in expr_var_unique]
|
||
|
|
||
|
# 1. Create an AST tree from the array expression.
|
||
|
with _legalize_parameter_names(expr_var_unique) as expr_params:
|
||
|
ast_args = [ast.arg(param_name, None)
|
||
|
for param_name in expr_params]
|
||
|
# Parse a stub function to ensure the AST is populated with
|
||
|
# reasonable defaults for the Python version.
|
||
|
ast_module = ast.parse('def {0}(): return'.format(expr_name),
|
||
|
expr_filename, 'exec')
|
||
|
assert hasattr(ast_module, 'body') and len(ast_module.body) == 1
|
||
|
ast_fn = ast_module.body[0]
|
||
|
ast_fn.args.args = ast_args
|
||
|
ast_fn.body[0].value, namespace = _arr_expr_to_ast(expr.expr)
|
||
|
_fix_invalid_lineno_ranges(ast_module)
|
||
|
|
||
|
# 2. Compile the AST module and extract the Python function.
|
||
|
code_obj = compile(ast_module, expr_filename, 'exec')
|
||
|
exec(code_obj, namespace)
|
||
|
impl = namespace[expr_name]
|
||
|
|
||
|
# 3. Now compile a ufunc using the Python function as kernel.
|
||
|
|
||
|
context = lowerer.context
|
||
|
builder = lowerer.builder
|
||
|
outer_sig = expr.ty(*(lowerer.typeof(name) for name in expr_args))
|
||
|
inner_sig_args = []
|
||
|
for argty in outer_sig.args:
|
||
|
if isinstance(argty, types.Optional):
|
||
|
argty = argty.type
|
||
|
if isinstance(argty, types.Array):
|
||
|
inner_sig_args.append(argty.dtype)
|
||
|
else:
|
||
|
inner_sig_args.append(argty)
|
||
|
inner_sig = outer_sig.return_type.dtype(*inner_sig_args)
|
||
|
|
||
|
flags = targetconfig.ConfigStack().top_or_none()
|
||
|
flags = compiler.Flags() if flags is None else flags.copy() # make sure it's a clone or a fresh instance
|
||
|
# Follow the Numpy error model. Note this also allows e.g. vectorizing
|
||
|
# division (issue #1223).
|
||
|
flags.error_model = 'numpy'
|
||
|
cres = context.compile_subroutine(builder, impl, inner_sig, flags=flags,
|
||
|
caching=False)
|
||
|
|
||
|
# Create kernel subclass calling our native function
|
||
|
from numba.np import npyimpl
|
||
|
|
||
|
class ExprKernel(npyimpl._Kernel):
|
||
|
def generate(self, *args):
|
||
|
arg_zip = zip(args, self.outer_sig.args, inner_sig.args)
|
||
|
cast_args = [self.cast(val, inty, outty)
|
||
|
for val, inty, outty in arg_zip]
|
||
|
result = self.context.call_internal(
|
||
|
builder, cres.fndesc, inner_sig, cast_args)
|
||
|
return self.cast(result, inner_sig.return_type,
|
||
|
self.outer_sig.return_type)
|
||
|
|
||
|
# create a fake ufunc object which is enough to trick numpy_ufunc_kernel
|
||
|
ufunc = SimpleNamespace(nin=len(expr_args), nout=1, __name__=expr_name)
|
||
|
ufunc.nargs = ufunc.nin + ufunc.nout
|
||
|
|
||
|
args = [lowerer.loadvar(name) for name in expr_args]
|
||
|
return npyimpl.numpy_ufunc_kernel(
|
||
|
context, builder, outer_sig, args, ufunc, ExprKernel)
|