258 lines
7.0 KiB
Python
258 lines
7.0 KiB
Python
from functools import wraps
|
|
from inspect import unwrap
|
|
from typing import Callable, List, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
__all__ = [
|
|
"PassManager",
|
|
"inplace_wrapper",
|
|
"log_hook",
|
|
"loop_pass",
|
|
"this_before_that_pass_constraint",
|
|
"these_before_those_pass_constraint",
|
|
]
|
|
|
|
# for callables which modify object inplace and return something other than
|
|
# the object on which they act
|
|
def inplace_wrapper(fn: Callable) -> Callable:
|
|
"""
|
|
Convenience wrapper for passes which modify an object inplace. This
|
|
wrapper makes them return the modified object instead.
|
|
|
|
Args:
|
|
fn (Callable[Object, Any])
|
|
|
|
Returns:
|
|
wrapped_fn (Callable[Object, Object])
|
|
"""
|
|
|
|
@wraps(fn)
|
|
def wrapped_fn(gm):
|
|
val = fn(gm)
|
|
return gm
|
|
|
|
return wrapped_fn
|
|
|
|
def log_hook(fn: Callable, level=logging.INFO) -> Callable:
|
|
"""
|
|
Logs callable output.
|
|
|
|
This is useful for logging output of passes. Note inplace_wrapper replaces
|
|
the pass output with the modified object. If we want to log the original
|
|
output, apply this wrapper before inplace_wrapper.
|
|
|
|
|
|
```
|
|
def my_pass(d: Dict) -> bool:
|
|
changed = False
|
|
if 'foo' in d:
|
|
d['foo'] = 'bar'
|
|
changed = True
|
|
return changed
|
|
|
|
pm = PassManager(
|
|
passes=[
|
|
inplace_wrapper(log_hook(my_pass))
|
|
]
|
|
)
|
|
```
|
|
|
|
Args:
|
|
fn (Callable[Type1, Type2])
|
|
level: logging level (e.g. logging.INFO)
|
|
|
|
Returns:
|
|
wrapped_fn (Callable[Type1, Type2])
|
|
"""
|
|
@wraps(fn)
|
|
def wrapped_fn(gm):
|
|
val = fn(gm)
|
|
logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
|
|
return val
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
|
|
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
|
|
"""
|
|
Convenience wrapper for passes which need to be applied multiple times.
|
|
|
|
Exactly one of `n_iter`or `predicate` must be specified.
|
|
|
|
Args:
|
|
base_pass (Callable[Object, Object]): pass to be applied in loop
|
|
n_iter (int, optional): number of times to loop pass
|
|
predicate (Callable[Object, bool], optional):
|
|
|
|
"""
|
|
assert (n_iter is not None) ^ (
|
|
predicate is not None
|
|
), "Exactly one of `n_iter`or `predicate` must be specified."
|
|
|
|
@wraps(base_pass)
|
|
def new_pass(source):
|
|
output = source
|
|
if n_iter is not None and n_iter > 0:
|
|
for _ in range(n_iter):
|
|
output = base_pass(output)
|
|
elif predicate is not None:
|
|
while predicate(output):
|
|
output = base_pass(output)
|
|
else:
|
|
raise RuntimeError(
|
|
f"loop_pass must be given positive int n_iter (given "
|
|
f"{n_iter}) xor predicate (given {predicate})"
|
|
)
|
|
return output
|
|
|
|
return new_pass
|
|
|
|
|
|
# Pass Schedule Constraints:
|
|
#
|
|
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
|
|
# has a valid partial ordering according to this comparison operator.
|
|
def _validate_pass_schedule_constraint(
|
|
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
|
|
):
|
|
for i, a in enumerate(passes):
|
|
for j, b in enumerate(passes[i + 1 :]):
|
|
if constraint(a, b):
|
|
continue
|
|
raise RuntimeError(
|
|
f"pass schedule constraint violated. Expected {a} before {b}"
|
|
f" but found {a} at index {i} and {b} at index{j} in pass"
|
|
f" list."
|
|
)
|
|
|
|
|
|
def this_before_that_pass_constraint(this: Callable, that: Callable):
|
|
"""
|
|
Defines a partial order ('depends on' function) where `this` must occur
|
|
before `that`.
|
|
"""
|
|
|
|
def depends_on(a: Callable, b: Callable):
|
|
if a == that and b == this:
|
|
return False
|
|
return True
|
|
|
|
return depends_on
|
|
|
|
|
|
def these_before_those_pass_constraint(these: Callable, those: Callable):
|
|
"""
|
|
Defines a partial order ('depends on' function) where `these` must occur
|
|
before `those`. Where the inputs are 'unwrapped' before comparison.
|
|
|
|
For example, the following pass list and constraint list would be invalid.
|
|
```
|
|
passes = [
|
|
loop_pass(pass_b, 3),
|
|
loop_pass(pass_a, 5),
|
|
]
|
|
|
|
constraints = [
|
|
these_before_those_pass_constraint(pass_a, pass_b)
|
|
]
|
|
```
|
|
|
|
Args:
|
|
these (Callable): pass which should occur first
|
|
those (Callable): pass which should occur later
|
|
|
|
Returns:
|
|
depends_on (Callable[[Object, Object], bool]
|
|
"""
|
|
|
|
def depends_on(a: Callable, b: Callable):
|
|
if unwrap(a) == those and unwrap(b) == these:
|
|
return False
|
|
return True
|
|
|
|
return depends_on
|
|
|
|
|
|
class PassManager:
|
|
"""
|
|
Construct a PassManager.
|
|
|
|
Collects passes and constraints. This defines the pass schedule, manages
|
|
pass constraints and pass execution.
|
|
|
|
Args:
|
|
passes (Optional[List[Callable]]): list of passes. A pass is a
|
|
callable which modifies an object and returns modified object
|
|
constraint (Optional[List[Callable]]): list of constraints. A
|
|
constraint is a callable which takes two passes (A, B) and returns
|
|
True if A depends on B and False otherwise. See implementation of
|
|
`this_before_that_pass_constraint` for example.
|
|
"""
|
|
|
|
passes: List[Callable]
|
|
constraints: List[Callable]
|
|
_validated: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
passes=None,
|
|
constraints=None,
|
|
):
|
|
self.passes = passes or []
|
|
self.constraints = constraints or []
|
|
|
|
@classmethod
|
|
def build_from_passlist(cls, passes):
|
|
pm = PassManager(passes)
|
|
# TODO(alexbeloi): add constraint management/validation
|
|
return pm
|
|
|
|
def add_pass(self, _pass: Callable):
|
|
self.passes.append(_pass)
|
|
self._validated = False
|
|
|
|
def add_constraint(self, constraint):
|
|
self.constraints.append(constraint)
|
|
self._validated = False
|
|
|
|
def remove_pass(self, _passes: List[str]):
|
|
if _passes is None:
|
|
return
|
|
passes_left = []
|
|
for ps in self.passes:
|
|
if ps.__name__ not in _passes:
|
|
passes_left.append(ps)
|
|
self.passes = passes_left
|
|
self._validated = False
|
|
|
|
def replace_pass(self, _target, _replacement):
|
|
passes_left = []
|
|
for ps in self.passes:
|
|
if ps.__name__ == _target.__name__:
|
|
passes_left.append(_replacement)
|
|
else:
|
|
passes_left.append(ps)
|
|
self.passes = passes_left
|
|
self._validated = False
|
|
|
|
def validate(self):
|
|
"""
|
|
Validates that current pass schedule defined by `self.passes` is valid
|
|
according to all constraints in `self.constraints`
|
|
"""
|
|
if self._validated:
|
|
return
|
|
for constraint in self.constraints:
|
|
_validate_pass_schedule_constraint(constraint, self.passes)
|
|
self._validated = True
|
|
|
|
def __call__(self, source):
|
|
self.validate()
|
|
out = source
|
|
for _pass in self.passes:
|
|
out = _pass(out)
|
|
return out
|