ai-content-maker/.venv/Lib/site-packages/torch/fx/passes/pass_manager.py

258 lines
7.0 KiB
Python

from functools import wraps
from inspect import unwrap
from typing import Callable, List, Optional
import logging
logger = logging.getLogger(__name__)
__all__ = [
"PassManager",
"inplace_wrapper",
"log_hook",
"loop_pass",
"this_before_that_pass_constraint",
"these_before_those_pass_constraint",
]
# for callables which modify object inplace and return something other than
# the object on which they act
def inplace_wrapper(fn: Callable) -> Callable:
"""
Convenience wrapper for passes which modify an object inplace. This
wrapper makes them return the modified object instead.
Args:
fn (Callable[Object, Any])
Returns:
wrapped_fn (Callable[Object, Object])
"""
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
return gm
return wrapped_fn
def log_hook(fn: Callable, level=logging.INFO) -> Callable:
"""
Logs callable output.
This is useful for logging output of passes. Note inplace_wrapper replaces
the pass output with the modified object. If we want to log the original
output, apply this wrapper before inplace_wrapper.
```
def my_pass(d: Dict) -> bool:
changed = False
if 'foo' in d:
d['foo'] = 'bar'
changed = True
return changed
pm = PassManager(
passes=[
inplace_wrapper(log_hook(my_pass))
]
)
```
Args:
fn (Callable[Type1, Type2])
level: logging level (e.g. logging.INFO)
Returns:
wrapped_fn (Callable[Type1, Type2])
"""
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
return val
return wrapped_fn
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
"""
Convenience wrapper for passes which need to be applied multiple times.
Exactly one of `n_iter`or `predicate` must be specified.
Args:
base_pass (Callable[Object, Object]): pass to be applied in loop
n_iter (int, optional): number of times to loop pass
predicate (Callable[Object, bool], optional):
"""
assert (n_iter is not None) ^ (
predicate is not None
), "Exactly one of `n_iter`or `predicate` must be specified."
@wraps(base_pass)
def new_pass(source):
output = source
if n_iter is not None and n_iter > 0:
for _ in range(n_iter):
output = base_pass(output)
elif predicate is not None:
while predicate(output):
output = base_pass(output)
else:
raise RuntimeError(
f"loop_pass must be given positive int n_iter (given "
f"{n_iter}) xor predicate (given {predicate})"
)
return output
return new_pass
# Pass Schedule Constraints:
#
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
# has a valid partial ordering according to this comparison operator.
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
):
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
if constraint(a, b):
continue
raise RuntimeError(
f"pass schedule constraint violated. Expected {a} before {b}"
f" but found {a} at index {i} and {b} at index{j} in pass"
f" list."
)
def this_before_that_pass_constraint(this: Callable, that: Callable):
"""
Defines a partial order ('depends on' function) where `this` must occur
before `that`.
"""
def depends_on(a: Callable, b: Callable):
if a == that and b == this:
return False
return True
return depends_on
def these_before_those_pass_constraint(these: Callable, those: Callable):
"""
Defines a partial order ('depends on' function) where `these` must occur
before `those`. Where the inputs are 'unwrapped' before comparison.
For example, the following pass list and constraint list would be invalid.
```
passes = [
loop_pass(pass_b, 3),
loop_pass(pass_a, 5),
]
constraints = [
these_before_those_pass_constraint(pass_a, pass_b)
]
```
Args:
these (Callable): pass which should occur first
those (Callable): pass which should occur later
Returns:
depends_on (Callable[[Object, Object], bool]
"""
def depends_on(a: Callable, b: Callable):
if unwrap(a) == those and unwrap(b) == these:
return False
return True
return depends_on
class PassManager:
"""
Construct a PassManager.
Collects passes and constraints. This defines the pass schedule, manages
pass constraints and pass execution.
Args:
passes (Optional[List[Callable]]): list of passes. A pass is a
callable which modifies an object and returns modified object
constraint (Optional[List[Callable]]): list of constraints. A
constraint is a callable which takes two passes (A, B) and returns
True if A depends on B and False otherwise. See implementation of
`this_before_that_pass_constraint` for example.
"""
passes: List[Callable]
constraints: List[Callable]
_validated: bool = False
def __init__(
self,
passes=None,
constraints=None,
):
self.passes = passes or []
self.constraints = constraints or []
@classmethod
def build_from_passlist(cls, passes):
pm = PassManager(passes)
# TODO(alexbeloi): add constraint management/validation
return pm
def add_pass(self, _pass: Callable):
self.passes.append(_pass)
self._validated = False
def add_constraint(self, constraint):
self.constraints.append(constraint)
self._validated = False
def remove_pass(self, _passes: List[str]):
if _passes is None:
return
passes_left = []
for ps in self.passes:
if ps.__name__ not in _passes:
passes_left.append(ps)
self.passes = passes_left
self._validated = False
def replace_pass(self, _target, _replacement):
passes_left = []
for ps in self.passes:
if ps.__name__ == _target.__name__:
passes_left.append(_replacement)
else:
passes_left.append(ps)
self.passes = passes_left
self._validated = False
def validate(self):
"""
Validates that current pass schedule defined by `self.passes` is valid
according to all constraints in `self.constraints`
"""
if self._validated:
return
for constraint in self.constraints:
_validate_pass_schedule_constraint(constraint, self.passes)
self._validated = True
def __call__(self, source):
self.validate()
out = source
for _pass in self.passes:
out = _pass(out)
return out