import threading from typing import Any, Dict import torch._C._lazy class DeviceContext: _CONTEXTS: Dict[str, Any] = dict() _CONTEXTS_LOCK = threading.Lock() def __init__(self, device): self.device = device def get_device_context(device=None): if device is None: device = torch._C._lazy._get_default_device_type() else: device = str(device) with DeviceContext._CONTEXTS_LOCK: devctx = DeviceContext._CONTEXTS.get(device, None) if devctx is None: devctx = DeviceContext(device) DeviceContext._CONTEXTS[device] = devctx return devctx