ai-content-maker/.venv/Lib/site-packages/thinc/layers/with_nvtx_range.py

46 lines
1.3 KiB
Python

from typing import Any, Callable, Optional, Tuple, TypeVar
from ..model import Model
from ..util import use_nvtx_range
_ModelT = TypeVar("_ModelT", bound=Model)
def with_nvtx_range(
layer: _ModelT,
name: Optional[str] = None,
*,
forward_color: int = -1,
backprop_color: int = -1,
) -> _ModelT:
"""Wraps any layer and marks the forward and backprop phases as
NVTX ranges for CUDA profiling.
By default, the name of the layer is used as the name of the range,
followed by the name of the pass.
"""
name = layer.name if name is None else name
orig_forward = layer._func
orig_init = layer.init
def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
with use_nvtx_range(f"{name} forward", forward_color):
layer_Y, layer_callback = orig_forward(model, X, is_train=is_train)
def backprop(dY: Any) -> Any:
with use_nvtx_range(f"{name} backprop", backprop_color):
return layer_callback(dY)
return layer_Y, backprop
def init(_model: Model, X: Any, Y: Any) -> Model:
if orig_init is not None:
return orig_init(layer, X, Y)
else:
return layer
layer.replace_callbacks(forward, init=init)
return layer