from functools import wraps from inspect import unwrap from typing import Callable, List, Optional import logging logger = logging.getLogger(__name__) __all__ = [ "PassManager", "inplace_wrapper", "log_hook", "loop_pass", "this_before_that_pass_constraint", "these_before_those_pass_constraint", ] # for callables which modify object inplace and return something other than # the object on which they act def inplace_wrapper(fn: Callable) -> Callable: """ Convenience wrapper for passes which modify an object inplace. This wrapper makes them return the modified object instead. Args: fn (Callable[Object, Any]) Returns: wrapped_fn (Callable[Object, Object]) """ @wraps(fn) def wrapped_fn(gm): val = fn(gm) return gm return wrapped_fn def log_hook(fn: Callable, level=logging.INFO) -> Callable: """ Logs callable output. This is useful for logging output of passes. Note inplace_wrapper replaces the pass output with the modified object. If we want to log the original output, apply this wrapper before inplace_wrapper. ``` def my_pass(d: Dict) -> bool: changed = False if 'foo' in d: d['foo'] = 'bar' changed = True return changed pm = PassManager( passes=[ inplace_wrapper(log_hook(my_pass)) ] ) ``` Args: fn (Callable[Type1, Type2]) level: logging level (e.g. logging.INFO) Returns: wrapped_fn (Callable[Type1, Type2]) """ @wraps(fn) def wrapped_fn(gm): val = fn(gm) logger.log(level, "Ran pass %s\t Return value: %s", fn, val) return val return wrapped_fn def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): """ Convenience wrapper for passes which need to be applied multiple times. Exactly one of `n_iter`or `predicate` must be specified. Args: base_pass (Callable[Object, Object]): pass to be applied in loop n_iter (int, optional): number of times to loop pass predicate (Callable[Object, bool], optional): """ assert (n_iter is not None) ^ ( predicate is not None ), "Exactly one of `n_iter`or `predicate` must be specified." @wraps(base_pass) def new_pass(source): output = source if n_iter is not None and n_iter > 0: for _ in range(n_iter): output = base_pass(output) elif predicate is not None: while predicate(output): output = base_pass(output) else: raise RuntimeError( f"loop_pass must be given positive int n_iter (given " f"{n_iter}) xor predicate (given {predicate})" ) return output return new_pass # Pass Schedule Constraints: # # Implemented as 'depends on' operators. A constraint is satisfied iff a list # has a valid partial ordering according to this comparison operator. def _validate_pass_schedule_constraint( constraint: Callable[[Callable, Callable], bool], passes: List[Callable] ): for i, a in enumerate(passes): for j, b in enumerate(passes[i + 1 :]): if constraint(a, b): continue raise RuntimeError( f"pass schedule constraint violated. Expected {a} before {b}" f" but found {a} at index {i} and {b} at index{j} in pass" f" list." ) def this_before_that_pass_constraint(this: Callable, that: Callable): """ Defines a partial order ('depends on' function) where `this` must occur before `that`. """ def depends_on(a: Callable, b: Callable): if a == that and b == this: return False return True return depends_on def these_before_those_pass_constraint(these: Callable, those: Callable): """ Defines a partial order ('depends on' function) where `these` must occur before `those`. Where the inputs are 'unwrapped' before comparison. For example, the following pass list and constraint list would be invalid. ``` passes = [ loop_pass(pass_b, 3), loop_pass(pass_a, 5), ] constraints = [ these_before_those_pass_constraint(pass_a, pass_b) ] ``` Args: these (Callable): pass which should occur first those (Callable): pass which should occur later Returns: depends_on (Callable[[Object, Object], bool] """ def depends_on(a: Callable, b: Callable): if unwrap(a) == those and unwrap(b) == these: return False return True return depends_on class PassManager: """ Construct a PassManager. Collects passes and constraints. This defines the pass schedule, manages pass constraints and pass execution. Args: passes (Optional[List[Callable]]): list of passes. A pass is a callable which modifies an object and returns modified object constraint (Optional[List[Callable]]): list of constraints. A constraint is a callable which takes two passes (A, B) and returns True if A depends on B and False otherwise. See implementation of `this_before_that_pass_constraint` for example. """ passes: List[Callable] constraints: List[Callable] _validated: bool = False def __init__( self, passes=None, constraints=None, ): self.passes = passes or [] self.constraints = constraints or [] @classmethod def build_from_passlist(cls, passes): pm = PassManager(passes) # TODO(alexbeloi): add constraint management/validation return pm def add_pass(self, _pass: Callable): self.passes.append(_pass) self._validated = False def add_constraint(self, constraint): self.constraints.append(constraint) self._validated = False def remove_pass(self, _passes: List[str]): if _passes is None: return passes_left = [] for ps in self.passes: if ps.__name__ not in _passes: passes_left.append(ps) self.passes = passes_left self._validated = False def replace_pass(self, _target, _replacement): passes_left = [] for ps in self.passes: if ps.__name__ == _target.__name__: passes_left.append(_replacement) else: passes_left.append(ps) self.passes = passes_left self._validated = False def validate(self): """ Validates that current pass schedule defined by `self.passes` is valid according to all constraints in `self.constraints` """ if self._validated: return for constraint in self.constraints: _validate_pass_schedule_constraint(constraint, self.passes) self._validated = True def __call__(self, source): self.validate() out = source for _pass in self.passes: out = _pass(out) return out