# mypy: ignore-errors import functools import importlib import logging import os import tempfile import torch from .common import device_from_inputs, fake_tensor_unsupported from .registry import register_backend log = logging.getLogger(__name__) @register_backend @fake_tensor_unsupported def tvm(gm, example_inputs, *, scheduler=None, trials=20000): import tvm # type: ignore[import] from tvm import relay # type: ignore[import] from tvm.contrib import graph_executor # type: ignore[import] jit_mod = torch.jit.trace(gm, example_inputs) device = device_from_inputs(example_inputs) shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] example_outputs = gm(*example_inputs) if len(example_outputs) == 0: log.warning("Explicitly fall back to eager due to zero output") return gm.forward mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) if device.type == "cuda": dev = tvm.cuda(device.index) target = tvm.target.cuda() else: dev = tvm.cpu(0) target = tvm.target.Target(llvm_target()) if scheduler is None: scheduler = os.environ.get("TVM_SCHEDULER", None) if scheduler == "auto_scheduler": from tvm import auto_scheduler log_file = tempfile.NamedTemporaryFile() if not os.path.exists(log_file): tasks, task_weights = auto_scheduler.extract_tasks( mod["main"], params, target ) for task in tasks: print(task.compute_dag) else: print("No tasks") if len(tasks) != 0: tuner = auto_scheduler.TaskScheduler(tasks, task_weights) if not os.path.exists(log_file): assert trials > 0 tune_option = auto_scheduler.TuningOptions( num_measure_trials=trials, measure_callbacks=[auto_scheduler.RecordToFile(log_file)], early_stopping=2000, ) try: tuner.tune(tune_option) except Exception: if os.path.exists(log_file): os.unlink(log_file) raise with auto_scheduler.ApplyHistoryBest(log_file): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_auto_scheduler": True} ): lib = relay.build(mod, target=target, params=params) elif scheduler == "meta_schedule": from tvm import meta_schedule as ms with tempfile.TemporaryDirectory() as work_dir: if device.type != "cuda": # meta_schedule needs num-cores to be specified # here we use the maximum core count target = tvm.target.Target( f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}" ) # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. database = ms.relay_integration.tune_relay( mod=mod, target=target, work_dir=work_dir, max_trials_global=20000, num_trials_per_iter=64, params=params, strategy="evolutionary", ) lib = ms.relay_integration.compile_relay( database=database, mod=mod, target=target, params=params, ) elif scheduler == "default" or not scheduler: # no autotuning with tvm.transform.PassContext(opt_level=10): lib = relay.build(mod, target=target, params=params) else: raise NotImplementedError( "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " "There are three available options: default, auto_scheduler and meta_schedule." ) m = graph_executor.GraphModule(lib["default"](dev)) def to_torch_tensor(nd_tensor): """A helper function to transfer a NDArray to torch.tensor.""" if nd_tensor.dtype == "bool": # DLPack does not support boolean so it can't be handled by # torch.utils.dlpack.from_pack. Workaround by going through # numpy, although this brings additional data copy overhead. return torch.from_numpy(nd_tensor.numpy()) return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) def to_tvm_tensor(torch_tensor): """A helper function to transfer a torch.tensor to NDArray.""" if torch_tensor.dtype == torch.bool: # same reason as above, fallback to numpy conversion which # could introduce data copy overhead return tvm.nd.array(torch_tensor.cpu().numpy()) return tvm.nd.from_dlpack(torch_tensor) def exec_tvm(*i_args): args = [a.contiguous() for a in i_args] shape_info, _ = m.get_input_info() active_inputs = {name for name, _ in shape_info.items()} for idx, arg in enumerate(args, 0): if arg.dim() != 0: if arg.requires_grad: arg = arg.detach() inp_name = f"inp_{idx}" if inp_name not in active_inputs: log.warning( "input %s skipped as not found in tvm's runtime library", inp_name, ) continue m.set_input( inp_name, to_tvm_tensor(arg), ) m.run() return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())] return exec_tvm tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule") tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler") def has_tvm(): try: importlib.import_module("tvm") return True except ImportError: return False @functools.lru_cache(None) def llvm_target(): if "avx512" in open("/proc/cpuinfo").read(): return "llvm -mcpu=skylake-avx512" return "llvm -mcpu=core-avx2"