26 lines
650 B
Python
26 lines
650 B
Python
|
from typing import Union
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
class _InsertPoint:
|
||
|
def __init__(
|
||
|
self,
|
||
|
insert_point_graph: torch._C.Graph,
|
||
|
insert_point: Union[torch._C.Node, torch._C.Block],
|
||
|
):
|
||
|
self.insert_point = insert_point
|
||
|
self.g = insert_point_graph
|
||
|
self.guard = None
|
||
|
|
||
|
def __enter__(self):
|
||
|
self.prev_insert_point = self.g.insertPoint()
|
||
|
self.g.setInsertPoint(self.insert_point)
|
||
|
|
||
|
def __exit__(self, *args):
|
||
|
self.g.setInsertPoint(self.prev_insert_point)
|
||
|
|
||
|
|
||
|
def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
|
||
|
return _InsertPoint(self, insert_point)
|