51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
|
# Copyright 2019 Kakao Brain
|
||
|
#
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the BSD license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
"""Provides phony for arbitrary dependency in a autograd graph."""
|
||
|
from typing import Dict, List, Tuple
|
||
|
|
||
|
import torch
|
||
|
from torch import Tensor
|
||
|
|
||
|
from .stream import default_stream, use_stream
|
||
|
|
||
|
__all__: List[str] = ["get_phony"]
|
||
|
|
||
|
|
||
|
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
|
||
|
|
||
|
|
||
|
def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
|
||
|
"""Get a phony. Phony is tensor without space.
|
||
|
|
||
|
It is useful to make arbitrary dependency in a autograd graph because it doesn't require any
|
||
|
gradient accumulation.
|
||
|
|
||
|
.. note::
|
||
|
|
||
|
Phonies for each device are cached. If an autograd function gets a phony
|
||
|
internally, the phony must be detached to be returned. Otherwise, the
|
||
|
autograd engine will mutate the cached phony in-place::
|
||
|
|
||
|
class Phonify(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, input):
|
||
|
phony = get_phony(input.device, requires_grad=False)
|
||
|
return phony.detach() # detach() is necessary.
|
||
|
|
||
|
"""
|
||
|
key = (device, requires_grad)
|
||
|
|
||
|
try:
|
||
|
phony = _phonies[key]
|
||
|
except KeyError:
|
||
|
with use_stream(default_stream(device)):
|
||
|
phony = torch.empty(0, device=device, requires_grad=requires_grad)
|
||
|
|
||
|
_phonies[key] = phony
|
||
|
|
||
|
return phony
|