976 lines
46 KiB
Python
976 lines
46 KiB
Python
import copy
|
|
from typing import Optional, Any, Union, Callable
|
|
|
|
import torch
|
|
import warnings
|
|
from torch import Tensor
|
|
from .. import functional as F
|
|
from .module import Module
|
|
from .activation import MultiheadAttention
|
|
from .container import ModuleList
|
|
from ..init import xavier_uniform_
|
|
from .dropout import Dropout
|
|
from .linear import Linear
|
|
from .normalization import LayerNorm
|
|
|
|
__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
|
|
|
|
def _generate_square_subsequent_mask(
|
|
sz: int,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> Tensor:
|
|
r"""Generate a square causal mask for the sequence.
|
|
|
|
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
|
"""
|
|
if device is None:
|
|
device = torch.device('cpu')
|
|
if dtype is None:
|
|
dtype = torch.float32
|
|
return torch.triu(
|
|
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
|
|
diagonal=1,
|
|
)
|
|
|
|
|
|
def _get_seq_len(
|
|
src: Tensor,
|
|
batch_first: bool
|
|
) -> Optional[int]:
|
|
|
|
if src.is_nested:
|
|
return None
|
|
else:
|
|
src_size = src.size()
|
|
if len(src_size) == 2:
|
|
# unbatched: S, E
|
|
return src_size[0]
|
|
else:
|
|
# batched: B, S, E if batch_first else S, B, E
|
|
seq_len_pos = 1 if batch_first else 0
|
|
return src_size[seq_len_pos]
|
|
|
|
|
|
class Transformer(Module):
|
|
r"""A transformer model.
|
|
|
|
User is able to modify the attributes as needed. The architecture
|
|
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
|
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
|
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
|
Processing Systems, pages 6000-6010.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
|
nhead: the number of heads in the multiheadattention models (default=8).
|
|
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
|
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of encoder/decoder intermediate layer, can be a string
|
|
("relu" or "gelu") or a unary callable. Default: relu
|
|
custom_encoder: custom encoder (default=None).
|
|
custom_decoder: custom decoder (default=None).
|
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
|
norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
|
|
other attention and feedforward operations, otherwise after. Default: ``False`` (after).
|
|
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
|
bias. Default: ``True``.
|
|
|
|
Examples::
|
|
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
|
>>> src = torch.rand((10, 32, 512))
|
|
>>> tgt = torch.rand((20, 32, 512))
|
|
>>> out = transformer_model(src, tgt)
|
|
|
|
Note: A full example to apply nn.Transformer module for the word language model is available in
|
|
https://github.com/pytorch/examples/tree/master/word_language_model
|
|
"""
|
|
|
|
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
|
|
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
|
custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
|
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
|
bias: bool = True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
|
|
|
if custom_encoder is not None:
|
|
self.encoder = custom_encoder
|
|
else:
|
|
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
|
|
activation, layer_norm_eps, batch_first, norm_first,
|
|
bias, **factory_kwargs)
|
|
encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
|
|
|
if custom_decoder is not None:
|
|
self.decoder = custom_decoder
|
|
else:
|
|
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
|
|
activation, layer_norm_eps, batch_first, norm_first,
|
|
bias, **factory_kwargs)
|
|
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
|
|
|
self._reset_parameters()
|
|
|
|
self.d_model = d_model
|
|
self.nhead = nhead
|
|
|
|
self.batch_first = batch_first
|
|
|
|
def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
|
|
src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None,
|
|
memory_is_causal: bool = False) -> Tensor:
|
|
r"""Take in and process masked source/target sequences.
|
|
|
|
.. note::
|
|
|
|
If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
|
|
not allowed to participate in the attention,
|
|
which is the opposite of the definition for :attr:`attn_mask`
|
|
in :func:`torch.nn.functional.scaled_dot_product_attention`.
|
|
|
|
Args:
|
|
src: the sequence to the encoder (required).
|
|
tgt: the sequence to the decoder (required).
|
|
src_mask: the additive mask for the src sequence (optional).
|
|
tgt_mask: the additive mask for the tgt sequence (optional).
|
|
memory_mask: the additive mask for the encoder output (optional).
|
|
src_key_padding_mask: the Tensor mask for src keys per batch (optional).
|
|
tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
|
|
memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
|
|
src_is_causal: If specified, applies a causal mask as ``src_mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``src_is_causal`` provides a hint that ``src_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
memory_is_causal: If specified, applies a causal mask as
|
|
``memory_mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``memory_is_causal`` provides a hint that
|
|
``memory_mask`` is the causal mask. Providing incorrect
|
|
hints can result in incorrect execution, including
|
|
forward and backward compatibility.
|
|
|
|
Shape:
|
|
- src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
|
|
`(N, S, E)` if `batch_first=True`.
|
|
- tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
|
`(N, T, E)` if `batch_first=True`.
|
|
- src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
|
|
- tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
|
|
- memory_mask: :math:`(T, S)`.
|
|
- src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
|
- tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
|
|
- memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
|
|
|
Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
|
|
positions. If a BoolTensor is provided, positions with ``True``
|
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
is provided, it will be added to the attention weight.
|
|
[src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
|
|
the attention. If a BoolTensor is provided, the positions with the
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
|
|
- output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
|
`(N, T, E)` if `batch_first=True`.
|
|
|
|
Note: Due to the multi-head attention architecture in the transformer model,
|
|
the output sequence length of a transformer is same as the input sequence
|
|
(i.e. target) length of the decoder.
|
|
|
|
where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
|
|
batch size, :math:`E` is the feature number
|
|
|
|
Examples:
|
|
>>> # xdoctest: +SKIP
|
|
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
|
"""
|
|
is_batched = src.dim() == 3
|
|
if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
|
|
raise RuntimeError("the batch number of src and tgt must be equal")
|
|
elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
|
|
raise RuntimeError("the batch number of src and tgt must be equal")
|
|
|
|
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
|
|
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
|
|
|
|
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
|
|
is_causal=src_is_causal)
|
|
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
|
|
return output
|
|
|
|
@staticmethod
|
|
def generate_square_subsequent_mask(
|
|
sz: int,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> Tensor:
|
|
r"""Generate a square causal mask for the sequence.
|
|
|
|
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
|
"""
|
|
return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
|
|
|
|
def _reset_parameters(self):
|
|
r"""Initiate parameters in the transformer model."""
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
xavier_uniform_(p)
|
|
|
|
|
|
class TransformerEncoder(Module):
|
|
r"""TransformerEncoder is a stack of N encoder layers.
|
|
|
|
Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
|
|
|
Args:
|
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
|
norm: the layer normalization component (optional).
|
|
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
|
(and convert back on output). This will improve the overall performance of
|
|
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
|
|
|
Examples::
|
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> out = transformer_encoder(src)
|
|
"""
|
|
|
|
__constants__ = ['norm']
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_layer: "TransformerEncoderLayer",
|
|
num_layers: int,
|
|
norm: Optional[Module] = None,
|
|
enable_nested_tensor: bool = True,
|
|
mask_check: bool = True
|
|
) -> None:
|
|
super().__init__()
|
|
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
|
self.layers = _get_clones(encoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
# this attribute saves the value providedat object construction
|
|
self.enable_nested_tensor = enable_nested_tensor
|
|
# this attribute controls whether nested tensors are used
|
|
self.use_nested_tensor = enable_nested_tensor
|
|
self.mask_check = mask_check
|
|
|
|
enc_layer = "encoder_layer"
|
|
why_not_sparsity_fast_path = ''
|
|
if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
|
|
why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
|
|
elif encoder_layer.norm_first :
|
|
why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
|
|
elif not encoder_layer.self_attn.batch_first:
|
|
why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
|
|
"(use batch_first for better inference performance)")
|
|
elif not encoder_layer.self_attn._qkv_same_embed_dim:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
|
|
elif encoder_layer.self_attn.in_proj_bias is None:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
|
|
elif not encoder_layer.activation_relu_or_gelu:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
|
|
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
|
|
why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
|
|
elif encoder_layer.self_attn.num_heads % 2 == 1:
|
|
why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
|
|
|
|
if enable_nested_tensor and why_not_sparsity_fast_path:
|
|
warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
|
|
self.use_nested_tensor = False
|
|
|
|
|
|
def forward(
|
|
self,
|
|
src: Tensor,
|
|
mask: Optional[Tensor] = None,
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
is_causal: Optional[bool] = None) -> Tensor:
|
|
r"""Pass the input through the encoder layers in turn.
|
|
|
|
Args:
|
|
src: the sequence to the encoder (required).
|
|
mask: the mask for the src sequence (optional).
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
is_causal: If specified, applies a causal mask as ``mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``is_causal`` provides a hint that ``mask`` is the
|
|
causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
src_key_padding_mask = F._canonical_mask(
|
|
mask=src_key_padding_mask,
|
|
mask_name="src_key_padding_mask",
|
|
other_type=F._none_or_dtype(mask),
|
|
other_name="mask",
|
|
target_type=src.dtype
|
|
)
|
|
|
|
mask = F._canonical_mask(
|
|
mask=mask,
|
|
mask_name="mask",
|
|
other_type=None,
|
|
other_name="",
|
|
target_type=src.dtype,
|
|
check_other=False,
|
|
)
|
|
|
|
output = src
|
|
convert_to_nested = False
|
|
first_layer = self.layers[0]
|
|
src_key_padding_mask_for_layers = src_key_padding_mask
|
|
why_not_sparsity_fast_path = ''
|
|
str_first_layer = "self.layers[0]"
|
|
batch_first = first_layer.self_attn.batch_first
|
|
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
|
|
|
if not is_fastpath_enabled:
|
|
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
|
elif not hasattr(self, "use_nested_tensor"):
|
|
why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
|
|
elif not self.use_nested_tensor:
|
|
why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
|
|
elif first_layer.training:
|
|
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
|
|
elif not src.dim() == 3:
|
|
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
|
elif src_key_padding_mask is None:
|
|
why_not_sparsity_fast_path = "src_key_padding_mask was None"
|
|
elif (((not hasattr(self, "mask_check")) or self.mask_check)
|
|
and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
|
|
why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
|
|
elif output.is_nested:
|
|
why_not_sparsity_fast_path = "NestedTensor input is not supported"
|
|
elif mask is not None:
|
|
why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
|
|
elif torch.is_autocast_enabled():
|
|
why_not_sparsity_fast_path = "autocast is enabled"
|
|
|
|
if not why_not_sparsity_fast_path:
|
|
tensor_args = (
|
|
src,
|
|
first_layer.self_attn.in_proj_weight,
|
|
first_layer.self_attn.in_proj_bias,
|
|
first_layer.self_attn.out_proj.weight,
|
|
first_layer.self_attn.out_proj.bias,
|
|
first_layer.norm1.weight,
|
|
first_layer.norm1.bias,
|
|
first_layer.norm2.weight,
|
|
first_layer.norm2.bias,
|
|
first_layer.linear1.weight,
|
|
first_layer.linear1.bias,
|
|
first_layer.linear2.weight,
|
|
first_layer.linear2.bias,
|
|
)
|
|
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
|
if torch.overrides.has_torch_function(tensor_args):
|
|
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
|
elif src.device.type not in _supported_device_type:
|
|
why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
|
|
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
|
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
|
"input/output projection weights or biases requires_grad")
|
|
|
|
if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
|
|
convert_to_nested = True
|
|
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
|
|
src_key_padding_mask_for_layers = None
|
|
|
|
seq_len = _get_seq_len(src, batch_first)
|
|
is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
|
|
|
|
for mod in self.layers:
|
|
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
|
|
|
|
if convert_to_nested:
|
|
output = output.to_padded_tensor(0., src.size())
|
|
|
|
if self.norm is not None:
|
|
output = self.norm(output)
|
|
|
|
return output
|
|
|
|
|
|
class TransformerDecoder(Module):
|
|
r"""TransformerDecoder is a stack of N decoder layers.
|
|
|
|
Args:
|
|
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
|
num_layers: the number of sub-decoder-layers in the decoder (required).
|
|
norm: the layer normalization component (optional).
|
|
|
|
Examples::
|
|
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
|
>>> memory = torch.rand(10, 32, 512)
|
|
>>> tgt = torch.rand(20, 32, 512)
|
|
>>> out = transformer_decoder(tgt, memory)
|
|
"""
|
|
|
|
__constants__ = ['norm']
|
|
|
|
def __init__(
|
|
self,
|
|
decoder_layer: "TransformerDecoderLayer",
|
|
num_layers: int,
|
|
norm: Optional[Module] = None
|
|
) -> None:
|
|
super().__init__()
|
|
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
|
self.layers = _get_clones(decoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
|
|
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
|
|
memory_is_causal: bool = False) -> Tensor:
|
|
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
|
|
|
Args:
|
|
tgt: the sequence to the decoder (required).
|
|
memory: the sequence from the last layer of the encoder (required).
|
|
tgt_mask: the mask for the tgt sequence (optional).
|
|
memory_mask: the mask for the memory sequence (optional).
|
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
|
Default: ``None``; try to detect a causal mask.
|
|
Warning:
|
|
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
memory_is_causal: If specified, applies a causal mask as
|
|
``memory mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``memory_is_causal`` provides a hint that
|
|
``memory_mask`` is the causal mask. Providing incorrect
|
|
hints can result in incorrect execution, including
|
|
forward and backward compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
output = tgt
|
|
|
|
seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
|
|
tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
|
|
|
|
for mod in self.layers:
|
|
output = mod(output, memory, tgt_mask=tgt_mask,
|
|
memory_mask=memory_mask,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
tgt_is_causal=tgt_is_causal,
|
|
memory_is_causal=memory_is_causal)
|
|
|
|
if self.norm is not None:
|
|
output = self.norm(output)
|
|
|
|
return output
|
|
|
|
class TransformerEncoderLayer(Module):
|
|
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
|
|
|
This standard encoder layer is based on the paper "Attention Is All You Need".
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
in a different way during application.
|
|
|
|
TransformerEncoderLayer can handle either traditional torch.tensor inputs,
|
|
or Nested Tensor inputs. Derived classes are expected to similarly accept
|
|
both input formats. (Not all combinations of inputs are currently
|
|
supported by TransformerEncoderLayer while Nested Tensor is in prototype
|
|
state.)
|
|
|
|
If you are implementing a custom layer, you may derive it either from
|
|
the Module or TransformerEncoderLayer class. If your custom layer
|
|
supports both torch.Tensors and Nested Tensors inputs, make its
|
|
implementation a derived class of TransformerEncoderLayer. If your custom
|
|
Layer supports only torch.Tensor inputs, derive its implementation from
|
|
Module.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the input (required).
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of the intermediate layer, can be a string
|
|
("relu" or "gelu") or a unary callable. Default: relu
|
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
|
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
|
operations, respectively. Otherwise it's done after. Default: ``False`` (after).
|
|
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
|
bias. Default: ``True``.
|
|
|
|
Examples::
|
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> out = encoder_layer(src)
|
|
|
|
Alternatively, when ``batch_first`` is ``True``:
|
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
>>> src = torch.rand(32, 10, 512)
|
|
>>> out = encoder_layer(src)
|
|
|
|
Fast path:
|
|
forward() will use a special optimized implementation described in
|
|
`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
|
|
conditions are met:
|
|
|
|
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
|
|
argument ``requires_grad``
|
|
- training is disabled (using ``.eval()``)
|
|
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
|
|
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
|
|
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
|
|
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
|
|
nor ``src_key_padding_mask`` is passed
|
|
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
|
|
unless the caller has manually modified one without modifying the other)
|
|
|
|
If the optimized implementation is in use, a
|
|
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
|
|
passed for ``src`` to represent padding more efficiently than using a padding
|
|
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
|
|
returned, and an additional speedup proportional to the fraction of the input that
|
|
is padding can be expected.
|
|
|
|
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
|
https://arxiv.org/abs/2205.14135
|
|
|
|
"""
|
|
|
|
__constants__ = ['norm_first']
|
|
|
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
|
bias: bool = True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
|
|
bias=bias, batch_first=batch_first,
|
|
**factory_kwargs)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
|
self.dropout = Dropout(dropout)
|
|
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
|
|
|
self.norm_first = norm_first
|
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.dropout1 = Dropout(dropout)
|
|
self.dropout2 = Dropout(dropout)
|
|
|
|
# Legacy string support for activation function.
|
|
if isinstance(activation, str):
|
|
activation = _get_activation_fn(activation)
|
|
|
|
# We can't test self.activation in forward() in TorchScript,
|
|
# so stash some information about it instead.
|
|
if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
|
self.activation_relu_or_gelu = 1
|
|
elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
|
self.activation_relu_or_gelu = 2
|
|
else:
|
|
self.activation_relu_or_gelu = 0
|
|
self.activation = activation
|
|
|
|
def __setstate__(self, state):
|
|
super().__setstate__(state)
|
|
if not hasattr(self, 'activation'):
|
|
self.activation = F.relu
|
|
|
|
|
|
def forward(
|
|
self,
|
|
src: Tensor,
|
|
src_mask: Optional[Tensor] = None,
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
is_causal: bool = False) -> Tensor:
|
|
r"""Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
src: the sequence to the encoder layer (required).
|
|
src_mask: the mask for the src sequence (optional).
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
is_causal: If specified, applies a causal mask as ``src mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``is_causal`` provides a hint that ``src_mask`` is the
|
|
causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
src_key_padding_mask = F._canonical_mask(
|
|
mask=src_key_padding_mask,
|
|
mask_name="src_key_padding_mask",
|
|
other_type=F._none_or_dtype(src_mask),
|
|
other_name="src_mask",
|
|
target_type=src.dtype
|
|
)
|
|
|
|
src_mask = F._canonical_mask(
|
|
mask=src_mask,
|
|
mask_name="src_mask",
|
|
other_type=None,
|
|
other_name="",
|
|
target_type=src.dtype,
|
|
check_other=False,
|
|
)
|
|
|
|
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
|
|
|
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
|
why_not_sparsity_fast_path = ''
|
|
if not is_fastpath_enabled:
|
|
why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
|
elif not src.dim() == 3:
|
|
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
|
elif self.training:
|
|
why_not_sparsity_fast_path = "training is enabled"
|
|
elif not self.self_attn.batch_first:
|
|
why_not_sparsity_fast_path = "self_attn.batch_first was not True"
|
|
elif self.self_attn.in_proj_bias is None:
|
|
why_not_sparsity_fast_path = "self_attn was passed bias=False"
|
|
elif not self.self_attn._qkv_same_embed_dim:
|
|
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
|
|
elif not self.activation_relu_or_gelu:
|
|
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
|
|
elif not (self.norm1.eps == self.norm2.eps):
|
|
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
|
|
elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
|
|
why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
|
|
elif self.self_attn.num_heads % 2 == 1:
|
|
why_not_sparsity_fast_path = "num_head is odd"
|
|
elif torch.is_autocast_enabled():
|
|
why_not_sparsity_fast_path = "autocast is enabled"
|
|
if not why_not_sparsity_fast_path:
|
|
tensor_args = (
|
|
src,
|
|
self.self_attn.in_proj_weight,
|
|
self.self_attn.in_proj_bias,
|
|
self.self_attn.out_proj.weight,
|
|
self.self_attn.out_proj.bias,
|
|
self.norm1.weight,
|
|
self.norm1.bias,
|
|
self.norm2.weight,
|
|
self.norm2.bias,
|
|
self.linear1.weight,
|
|
self.linear1.bias,
|
|
self.linear2.weight,
|
|
self.linear2.bias,
|
|
)
|
|
|
|
# We have to use list comprehensions below because TorchScript does not support
|
|
# generator expressions.
|
|
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
|
if torch.overrides.has_torch_function(tensor_args):
|
|
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
|
elif not all((x.device.type in _supported_device_type) for x in tensor_args):
|
|
why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
|
|
f"{_supported_device_type}")
|
|
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
|
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
|
"input/output projection weights or biases requires_grad")
|
|
|
|
if not why_not_sparsity_fast_path:
|
|
merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
|
|
return torch._transformer_encoder_layer_fwd(
|
|
src,
|
|
self.self_attn.embed_dim,
|
|
self.self_attn.num_heads,
|
|
self.self_attn.in_proj_weight,
|
|
self.self_attn.in_proj_bias,
|
|
self.self_attn.out_proj.weight,
|
|
self.self_attn.out_proj.bias,
|
|
self.activation_relu_or_gelu == 2,
|
|
self.norm_first,
|
|
self.norm1.eps,
|
|
self.norm1.weight,
|
|
self.norm1.bias,
|
|
self.norm2.weight,
|
|
self.norm2.bias,
|
|
self.linear1.weight,
|
|
self.linear1.bias,
|
|
self.linear2.weight,
|
|
self.linear2.bias,
|
|
merged_mask,
|
|
mask_type,
|
|
)
|
|
|
|
|
|
x = src
|
|
if self.norm_first:
|
|
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
|
|
x = x + self._ff_block(self.norm2(x))
|
|
else:
|
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
|
|
x = self.norm2(x + self._ff_block(x))
|
|
|
|
return x
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x: Tensor,
|
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
|
x = self.self_attn(x, x, x,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False, is_causal=is_causal)[0]
|
|
return self.dropout1(x)
|
|
|
|
# feed forward block
|
|
def _ff_block(self, x: Tensor) -> Tensor:
|
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
return self.dropout2(x)
|
|
|
|
|
|
class TransformerDecoderLayer(Module):
|
|
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
|
|
|
This standard decoder layer is based on the paper "Attention Is All You Need".
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
in a different way during application.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the input (required).
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of the intermediate layer, can be a string
|
|
("relu" or "gelu") or a unary callable. Default: relu
|
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
|
norm_first: if ``True``, layer norm is done prior to self attention, multihead
|
|
attention and feedforward operations, respectively. Otherwise it's done after.
|
|
Default: ``False`` (after).
|
|
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
|
bias. Default: ``True``.
|
|
|
|
Examples::
|
|
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
>>> memory = torch.rand(10, 32, 512)
|
|
>>> tgt = torch.rand(20, 32, 512)
|
|
>>> out = decoder_layer(tgt, memory)
|
|
|
|
Alternatively, when ``batch_first`` is ``True``:
|
|
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
>>> memory = torch.rand(32, 10, 512)
|
|
>>> tgt = torch.rand(32, 20, 512)
|
|
>>> out = decoder_layer(tgt, memory)
|
|
"""
|
|
|
|
__constants__ = ['norm_first']
|
|
|
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
|
bias: bool = True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__()
|
|
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
|
bias=bias, **factory_kwargs)
|
|
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
|
bias=bias, **factory_kwargs)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
|
self.dropout = Dropout(dropout)
|
|
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
|
|
|
self.norm_first = norm_first
|
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
|
self.dropout1 = Dropout(dropout)
|
|
self.dropout2 = Dropout(dropout)
|
|
self.dropout3 = Dropout(dropout)
|
|
|
|
# Legacy string support for activation function.
|
|
if isinstance(activation, str):
|
|
self.activation = _get_activation_fn(activation)
|
|
else:
|
|
self.activation = activation
|
|
|
|
def __setstate__(self, state):
|
|
if 'activation' not in state:
|
|
state['activation'] = F.relu
|
|
super().__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
tgt: Tensor,
|
|
memory: Tensor,
|
|
tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
tgt_is_causal: bool = False,
|
|
memory_is_causal: bool = False,
|
|
) -> Tensor:
|
|
r"""Pass the inputs (and mask) through the decoder layer.
|
|
|
|
Args:
|
|
tgt: the sequence to the decoder layer (required).
|
|
memory: the sequence from the last layer of the encoder (required).
|
|
tgt_mask: the mask for the tgt sequence (optional).
|
|
memory_mask: the mask for the memory sequence (optional).
|
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
|
the causal mask. Providing incorrect hints can result in
|
|
incorrect execution, including forward and backward
|
|
compatibility.
|
|
memory_is_causal: If specified, applies a causal mask as
|
|
``memory mask``.
|
|
Default: ``False``.
|
|
Warning:
|
|
``memory_is_causal`` provides a hint that
|
|
``memory_mask`` is the causal mask. Providing incorrect
|
|
hints can result in incorrect execution, including
|
|
forward and backward compatibility.
|
|
|
|
Shape:
|
|
see the docs in :class:`~torch.nn.Transformer`.
|
|
"""
|
|
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
|
|
|
x = tgt
|
|
if self.norm_first:
|
|
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
|
|
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
|
|
x = x + self._ff_block(self.norm3(x))
|
|
else:
|
|
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
|
|
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
|
|
x = self.norm3(x + self._ff_block(x))
|
|
|
|
return x
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x: Tensor,
|
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
|
x = self.self_attn(x, x, x,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
need_weights=False)[0]
|
|
return self.dropout1(x)
|
|
|
|
# multihead attention block
|
|
def _mha_block(self, x: Tensor, mem: Tensor,
|
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
|
x = self.multihead_attn(x, mem, mem,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
need_weights=False)[0]
|
|
return self.dropout2(x)
|
|
|
|
# feed forward block
|
|
def _ff_block(self, x: Tensor) -> Tensor:
|
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
return self.dropout3(x)
|
|
|
|
|
|
def _get_clones(module, N):
|
|
# FIXME: copy.deepcopy() is not defined on nn.module
|
|
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
|
|
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
|
if activation == "relu":
|
|
return F.relu
|
|
elif activation == "gelu":
|
|
return F.gelu
|
|
|
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}")
|
|
|
|
|
|
def _detect_is_causal_mask(
|
|
mask: Optional[Tensor],
|
|
is_causal: Optional[bool] = None,
|
|
size: Optional[int] = None,
|
|
) -> bool:
|
|
"""Return whether the given attention mask is causal.
|
|
|
|
Warning:
|
|
If ``is_causal`` is not ``None``, its value will be returned as is. If a
|
|
user supplies an incorrect ``is_causal`` hint,
|
|
|
|
``is_causal=False`` when the mask is in fact a causal attention.mask
|
|
may lead to reduced performance relative to what would be achievable
|
|
with ``is_causal=True``;
|
|
``is_causal=True`` when the mask is in fact not a causal attention.mask
|
|
may lead to incorrect and unpredictable execution - in some scenarios,
|
|
a causal mask may be applied based on the hint, in other execution
|
|
scenarios the specified mask may be used. The choice may not appear
|
|
to be deterministic, in that a number of factors like alignment,
|
|
hardware SKU, etc influence the decision whether to use a mask or
|
|
rely on the hint.
|
|
``size`` if not None, check whether the mask is a causal mask of the provided size
|
|
Otherwise, checks for any causal mask.
|
|
"""
|
|
# Prevent type refinement
|
|
make_causal = (is_causal is True)
|
|
|
|
if is_causal is None and mask is not None:
|
|
sz = size if size is not None else mask.size(-2)
|
|
causal_comparison = _generate_square_subsequent_mask(
|
|
sz, device=mask.device, dtype=mask.dtype)
|
|
|
|
# Do not use `torch.equal` so we handle batched masks by
|
|
# broadcasting the comparison.
|
|
if mask.size() == causal_comparison.size():
|
|
make_causal = bool((mask == causal_comparison).all())
|
|
else:
|
|
make_causal = False
|
|
|
|
return make_causal
|