99 lines
2.6 KiB
Python
99 lines
2.6 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import os
|
||
|
import tempfile
|
||
|
import textwrap
|
||
|
from functools import lru_cache
|
||
|
|
||
|
if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
|
||
|
|
||
|
@lru_cache(None)
|
||
|
def _record_missing_op(target):
|
||
|
with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
|
||
|
fd.write(str(target) + "\n")
|
||
|
|
||
|
else:
|
||
|
|
||
|
def _record_missing_op(target): # type: ignore[misc]
|
||
|
pass
|
||
|
|
||
|
|
||
|
class OperatorIssue(RuntimeError):
|
||
|
@staticmethod
|
||
|
def operator_str(target, args, kwargs):
|
||
|
lines = [f"target: {target}"] + [
|
||
|
f"args[{i}]: {arg}" for i, arg in enumerate(args)
|
||
|
]
|
||
|
if kwargs:
|
||
|
lines.append(f"kwargs: {kwargs}")
|
||
|
return textwrap.indent("\n".join(lines), " ")
|
||
|
|
||
|
|
||
|
class MissingOperatorWithoutDecomp(OperatorIssue):
|
||
|
def __init__(self, target, args, kwargs):
|
||
|
_record_missing_op(target)
|
||
|
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
|
||
|
|
||
|
|
||
|
class MissingOperatorWithDecomp(OperatorIssue):
|
||
|
def __init__(self, target, args, kwargs):
|
||
|
_record_missing_op(target)
|
||
|
super().__init__(
|
||
|
f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
|
||
|
+ textwrap.dedent(
|
||
|
f"""
|
||
|
|
||
|
There is a decomposition available for {target} in
|
||
|
torch._decomp.get_decompositions(). Please add this operator to the
|
||
|
`decompositions` list in torch._inductor.decompositions
|
||
|
"""
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
class LoweringException(OperatorIssue):
|
||
|
def __init__(self, exc: Exception, target, args, kwargs):
|
||
|
super().__init__(
|
||
|
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
|
||
|
)
|
||
|
|
||
|
|
||
|
class InvalidCxxCompiler(RuntimeError):
|
||
|
def __init__(self):
|
||
|
from . import config
|
||
|
|
||
|
super().__init__(
|
||
|
f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
|
||
|
)
|
||
|
|
||
|
|
||
|
class CppWrapperCodeGenError(RuntimeError):
|
||
|
def __init__(self, msg: str):
|
||
|
super().__init__(f"C++ wrapper codegen error: {msg}")
|
||
|
|
||
|
|
||
|
class CppCompileError(RuntimeError):
|
||
|
def __init__(self, cmd: list[str], output: str):
|
||
|
if isinstance(output, bytes):
|
||
|
output = output.decode("utf-8")
|
||
|
|
||
|
super().__init__(
|
||
|
textwrap.dedent(
|
||
|
"""
|
||
|
C++ compile error
|
||
|
|
||
|
Command:
|
||
|
{cmd}
|
||
|
|
||
|
Output:
|
||
|
{output}
|
||
|
"""
|
||
|
)
|
||
|
.strip()
|
||
|
.format(cmd=" ".join(cmd), output=output)
|
||
|
)
|
||
|
|
||
|
|
||
|
class CUDACompileError(CppCompileError):
|
||
|
pass
|