46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from typing import Any, Callable, Optional, Tuple, TypeVar
|
|
|
|
from ..model import Model
|
|
from ..util import use_nvtx_range
|
|
|
|
_ModelT = TypeVar("_ModelT", bound=Model)
|
|
|
|
|
|
def with_nvtx_range(
|
|
layer: _ModelT,
|
|
name: Optional[str] = None,
|
|
*,
|
|
forward_color: int = -1,
|
|
backprop_color: int = -1,
|
|
) -> _ModelT:
|
|
"""Wraps any layer and marks the forward and backprop phases as
|
|
NVTX ranges for CUDA profiling.
|
|
|
|
By default, the name of the layer is used as the name of the range,
|
|
followed by the name of the pass.
|
|
"""
|
|
name = layer.name if name is None else name
|
|
|
|
orig_forward = layer._func
|
|
orig_init = layer.init
|
|
|
|
def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
|
|
with use_nvtx_range(f"{name} forward", forward_color):
|
|
layer_Y, layer_callback = orig_forward(model, X, is_train=is_train)
|
|
|
|
def backprop(dY: Any) -> Any:
|
|
with use_nvtx_range(f"{name} backprop", backprop_color):
|
|
return layer_callback(dY)
|
|
|
|
return layer_Y, backprop
|
|
|
|
def init(_model: Model, X: Any, Y: Any) -> Model:
|
|
if orig_init is not None:
|
|
return orig_init(layer, X, Y)
|
|
else:
|
|
return layer
|
|
|
|
layer.replace_callbacks(forward, init=init)
|
|
|
|
return layer
|