1353 lines
59 KiB
Python
1353 lines
59 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch Swin Transformer model."""
|
||
|
|
||
|
|
||
|
import collections.abc
|
||
|
import math
|
||
|
import warnings
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_outputs import BackboneOutput
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||
|
from ...utils import (
|
||
|
ModelOutput,
|
||
|
add_code_sample_docstrings,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from ...utils.backbone_utils import BackboneMixin
|
||
|
from .configuration_swin import SwinConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
# General docstring
|
||
|
_CONFIG_FOR_DOC = "SwinConfig"
|
||
|
|
||
|
# Base docstring
|
||
|
_CHECKPOINT_FOR_DOC = "microsoft/swin-tiny-patch4-window7-224"
|
||
|
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
|
||
|
|
||
|
# Image classification docstring
|
||
|
_IMAGE_CLASS_CHECKPOINT = "microsoft/swin-tiny-patch4-window7-224"
|
||
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||
|
|
||
|
|
||
|
from ..deprecated._archive_maps import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
# drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SwinEncoderOutput(ModelOutput):
|
||
|
"""
|
||
|
Swin encoder's outputs, with potential hidden states and attentions.
|
||
|
|
||
|
Args:
|
||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||
|
sequence_length)`.
|
||
|
|
||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||
|
heads.
|
||
|
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, hidden_size, height, width)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||
|
include the spatial dimensions.
|
||
|
"""
|
||
|
|
||
|
last_hidden_state: torch.FloatTensor = None
|
||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SwinModelOutput(ModelOutput):
|
||
|
"""
|
||
|
Swin model's outputs that also contains a pooling of the last hidden states.
|
||
|
|
||
|
Args:
|
||
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||
|
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
|
||
|
Average pooling of the last layer hidden-state.
|
||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||
|
sequence_length)`.
|
||
|
|
||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||
|
heads.
|
||
|
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, hidden_size, height, width)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||
|
include the spatial dimensions.
|
||
|
"""
|
||
|
|
||
|
last_hidden_state: torch.FloatTensor = None
|
||
|
pooler_output: Optional[torch.FloatTensor] = None
|
||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SwinMaskedImageModelingOutput(ModelOutput):
|
||
|
"""
|
||
|
Swin masked image model outputs.
|
||
|
|
||
|
Args:
|
||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
|
||
|
Masked image modeling (MLM) loss.
|
||
|
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Reconstructed pixel values.
|
||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||
|
sequence_length)`.
|
||
|
|
||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||
|
heads.
|
||
|
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, hidden_size, height, width)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||
|
include the spatial dimensions.
|
||
|
"""
|
||
|
|
||
|
loss: Optional[torch.FloatTensor] = None
|
||
|
reconstruction: torch.FloatTensor = None
|
||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
|
||
|
@property
|
||
|
def logits(self):
|
||
|
warnings.warn(
|
||
|
"logits attribute is deprecated and will be removed in version 5 of Transformers."
|
||
|
" Please use the reconstruction attribute to retrieve the final output instead.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
return self.reconstruction
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SwinImageClassifierOutput(ModelOutput):
|
||
|
"""
|
||
|
Swin outputs for image classification.
|
||
|
|
||
|
Args:
|
||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||
|
Classification (or regression if config.num_labels==1) loss.
|
||
|
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
|
||
|
sequence_length)`.
|
||
|
|
||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||
|
heads.
|
||
|
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
|
||
|
shape `(batch_size, hidden_size, height, width)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
|
||
|
include the spatial dimensions.
|
||
|
"""
|
||
|
|
||
|
loss: Optional[torch.FloatTensor] = None
|
||
|
logits: torch.FloatTensor = None
|
||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||
|
|
||
|
|
||
|
def window_partition(input_feature, window_size):
|
||
|
"""
|
||
|
Partitions the given input into windows.
|
||
|
"""
|
||
|
batch_size, height, width, num_channels = input_feature.shape
|
||
|
input_feature = input_feature.view(
|
||
|
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
|
||
|
)
|
||
|
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
||
|
return windows
|
||
|
|
||
|
|
||
|
def window_reverse(windows, window_size, height, width):
|
||
|
"""
|
||
|
Merges windows to produce higher resolution features.
|
||
|
"""
|
||
|
num_channels = windows.shape[-1]
|
||
|
windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
|
||
|
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
|
||
|
return windows
|
||
|
|
||
|
|
||
|
class SwinEmbeddings(nn.Module):
|
||
|
"""
|
||
|
Construct the patch and position embeddings. Optionally, also the mask token.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config, use_mask_token=False):
|
||
|
super().__init__()
|
||
|
|
||
|
self.patch_embeddings = SwinPatchEmbeddings(config)
|
||
|
num_patches = self.patch_embeddings.num_patches
|
||
|
self.patch_grid = self.patch_embeddings.grid_size
|
||
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
|
||
|
|
||
|
if config.use_absolute_embeddings:
|
||
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
|
||
|
else:
|
||
|
self.position_embeddings = None
|
||
|
|
||
|
self.norm = nn.LayerNorm(config.embed_dim)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(
|
||
|
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
|
||
|
) -> Tuple[torch.Tensor]:
|
||
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
|
||
|
embeddings = self.norm(embeddings)
|
||
|
batch_size, seq_len, _ = embeddings.size()
|
||
|
|
||
|
if bool_masked_pos is not None:
|
||
|
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
||
|
# replace the masked visual tokens by mask_tokens
|
||
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||
|
|
||
|
if self.position_embeddings is not None:
|
||
|
embeddings = embeddings + self.position_embeddings
|
||
|
|
||
|
embeddings = self.dropout(embeddings)
|
||
|
|
||
|
return embeddings, output_dimensions
|
||
|
|
||
|
|
||
|
class SwinPatchEmbeddings(nn.Module):
|
||
|
"""
|
||
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||
|
Transformer.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
image_size, patch_size = config.image_size, config.patch_size
|
||
|
num_channels, hidden_size = config.num_channels, config.embed_dim
|
||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||
|
self.image_size = image_size
|
||
|
self.patch_size = patch_size
|
||
|
self.num_channels = num_channels
|
||
|
self.num_patches = num_patches
|
||
|
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
||
|
|
||
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||
|
|
||
|
def maybe_pad(self, pixel_values, height, width):
|
||
|
if width % self.patch_size[1] != 0:
|
||
|
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
|
||
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||
|
if height % self.patch_size[0] != 0:
|
||
|
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
|
||
|
pixel_values = nn.functional.pad(pixel_values, pad_values)
|
||
|
return pixel_values
|
||
|
|
||
|
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
|
||
|
_, num_channels, height, width = pixel_values.shape
|
||
|
if num_channels != self.num_channels:
|
||
|
raise ValueError(
|
||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||
|
)
|
||
|
# pad the input to be divisible by self.patch_size, if needed
|
||
|
pixel_values = self.maybe_pad(pixel_values, height, width)
|
||
|
embeddings = self.projection(pixel_values)
|
||
|
_, _, height, width = embeddings.shape
|
||
|
output_dimensions = (height, width)
|
||
|
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||
|
|
||
|
return embeddings, output_dimensions
|
||
|
|
||
|
|
||
|
class SwinPatchMerging(nn.Module):
|
||
|
"""
|
||
|
Patch Merging Layer.
|
||
|
|
||
|
Args:
|
||
|
input_resolution (`Tuple[int]`):
|
||
|
Resolution of input feature.
|
||
|
dim (`int`):
|
||
|
Number of input channels.
|
||
|
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
|
||
|
Normalization layer class.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
|
||
|
super().__init__()
|
||
|
self.input_resolution = input_resolution
|
||
|
self.dim = dim
|
||
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||
|
self.norm = norm_layer(4 * dim)
|
||
|
|
||
|
def maybe_pad(self, input_feature, height, width):
|
||
|
should_pad = (height % 2 == 1) or (width % 2 == 1)
|
||
|
if should_pad:
|
||
|
pad_values = (0, 0, 0, width % 2, 0, height % 2)
|
||
|
input_feature = nn.functional.pad(input_feature, pad_values)
|
||
|
|
||
|
return input_feature
|
||
|
|
||
|
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
|
||
|
height, width = input_dimensions
|
||
|
# `dim` is height * width
|
||
|
batch_size, dim, num_channels = input_feature.shape
|
||
|
|
||
|
input_feature = input_feature.view(batch_size, height, width, num_channels)
|
||
|
# pad input to be disible by width and height, if needed
|
||
|
input_feature = self.maybe_pad(input_feature, height, width)
|
||
|
# [batch_size, height/2, width/2, num_channels]
|
||
|
input_feature_0 = input_feature[:, 0::2, 0::2, :]
|
||
|
# [batch_size, height/2, width/2, num_channels]
|
||
|
input_feature_1 = input_feature[:, 1::2, 0::2, :]
|
||
|
# [batch_size, height/2, width/2, num_channels]
|
||
|
input_feature_2 = input_feature[:, 0::2, 1::2, :]
|
||
|
# [batch_size, height/2, width/2, num_channels]
|
||
|
input_feature_3 = input_feature[:, 1::2, 1::2, :]
|
||
|
# batch_size height/2 width/2 4*num_channels
|
||
|
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
|
||
|
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
|
||
|
|
||
|
input_feature = self.norm(input_feature)
|
||
|
input_feature = self.reduction(input_feature)
|
||
|
|
||
|
return input_feature
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||
|
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||
|
"""
|
||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||
|
|
||
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||
|
argument.
|
||
|
"""
|
||
|
if drop_prob == 0.0 or not training:
|
||
|
return input
|
||
|
keep_prob = 1 - drop_prob
|
||
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||
|
random_tensor.floor_() # binarize
|
||
|
output = input.div(keep_prob) * random_tensor
|
||
|
return output
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin
|
||
|
class SwinDropPath(nn.Module):
|
||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||
|
|
||
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||
|
super().__init__()
|
||
|
self.drop_prob = drop_prob
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
return drop_path(hidden_states, self.drop_prob, self.training)
|
||
|
|
||
|
def extra_repr(self) -> str:
|
||
|
return "p={}".format(self.drop_prob)
|
||
|
|
||
|
|
||
|
class SwinSelfAttention(nn.Module):
|
||
|
def __init__(self, config, dim, num_heads, window_size):
|
||
|
super().__init__()
|
||
|
if dim % num_heads != 0:
|
||
|
raise ValueError(
|
||
|
f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
|
||
|
)
|
||
|
|
||
|
self.num_attention_heads = num_heads
|
||
|
self.attention_head_size = int(dim / num_heads)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
self.window_size = (
|
||
|
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
|
||
|
)
|
||
|
|
||
|
self.relative_position_bias_table = nn.Parameter(
|
||
|
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
|
||
|
)
|
||
|
|
||
|
# get pair-wise relative position index for each token inside the window
|
||
|
coords_h = torch.arange(self.window_size[0])
|
||
|
coords_w = torch.arange(self.window_size[1])
|
||
|
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||
|
coords_flatten = torch.flatten(coords, 1)
|
||
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||
|
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||
|
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||
|
relative_position_index = relative_coords.sum(-1)
|
||
|
self.register_buffer("relative_position_index", relative_position_index)
|
||
|
|
||
|
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||
|
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||
|
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
|
||
|
def transpose_for_scores(self, x):
|
||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||
|
x = x.view(new_x_shape)
|
||
|
return x.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor]:
|
||
|
batch_size, dim, num_channels = hidden_states.shape
|
||
|
mixed_query_layer = self.query(hidden_states)
|
||
|
|
||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||
|
|
||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||
|
|
||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||
|
|
||
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||
|
relative_position_bias = relative_position_bias.view(
|
||
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
||
|
)
|
||
|
|
||
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||
|
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
# Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
|
||
|
mask_shape = attention_mask.shape[0]
|
||
|
attention_scores = attention_scores.view(
|
||
|
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
|
||
|
)
|
||
|
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
|
||
|
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
# Mask heads if we want to
|
||
|
if head_mask is not None:
|
||
|
attention_probs = attention_probs * head_mask
|
||
|
|
||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer = context_layer.view(new_context_layer_shape)
|
||
|
|
||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class SwinSelfOutput(nn.Module):
|
||
|
def __init__(self, config, dim):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(dim, dim)
|
||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class SwinAttention(nn.Module):
|
||
|
def __init__(self, config, dim, num_heads, window_size):
|
||
|
super().__init__()
|
||
|
self.self = SwinSelfAttention(config, dim, num_heads, window_size)
|
||
|
self.output = SwinSelfOutput(config, dim)
|
||
|
self.pruned_heads = set()
|
||
|
|
||
|
def prune_heads(self, heads):
|
||
|
if len(heads) == 0:
|
||
|
return
|
||
|
heads, index = find_pruneable_heads_and_indices(
|
||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||
|
)
|
||
|
|
||
|
# Prune linear layers
|
||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||
|
|
||
|
# Update hyper params and store pruned heads
|
||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor]:
|
||
|
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
|
||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class SwinIntermediate(nn.Module):
|
||
|
def __init__(self, config, dim):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
|
||
|
if isinstance(config.hidden_act, str):
|
||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.hidden_act
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class SwinOutput(nn.Module):
|
||
|
def __init__(self, config, dim):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class SwinLayer(nn.Module):
|
||
|
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
|
||
|
super().__init__()
|
||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||
|
self.shift_size = shift_size
|
||
|
self.window_size = config.window_size
|
||
|
self.input_resolution = input_resolution
|
||
|
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||
|
self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
|
||
|
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||
|
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
||
|
self.intermediate = SwinIntermediate(config, dim)
|
||
|
self.output = SwinOutput(config, dim)
|
||
|
|
||
|
def set_shift_and_window_size(self, input_resolution):
|
||
|
if min(input_resolution) <= self.window_size:
|
||
|
# if window size is larger than input resolution, we don't partition windows
|
||
|
self.shift_size = 0
|
||
|
self.window_size = min(input_resolution)
|
||
|
|
||
|
def get_attn_mask(self, height, width, dtype):
|
||
|
if self.shift_size > 0:
|
||
|
# calculate attention mask for SW-MSA
|
||
|
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
|
||
|
height_slices = (
|
||
|
slice(0, -self.window_size),
|
||
|
slice(-self.window_size, -self.shift_size),
|
||
|
slice(-self.shift_size, None),
|
||
|
)
|
||
|
width_slices = (
|
||
|
slice(0, -self.window_size),
|
||
|
slice(-self.window_size, -self.shift_size),
|
||
|
slice(-self.shift_size, None),
|
||
|
)
|
||
|
count = 0
|
||
|
for height_slice in height_slices:
|
||
|
for width_slice in width_slices:
|
||
|
img_mask[:, height_slice, width_slice, :] = count
|
||
|
count += 1
|
||
|
|
||
|
mask_windows = window_partition(img_mask, self.window_size)
|
||
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||
|
else:
|
||
|
attn_mask = None
|
||
|
return attn_mask
|
||
|
|
||
|
def maybe_pad(self, hidden_states, height, width):
|
||
|
pad_right = (self.window_size - width % self.window_size) % self.window_size
|
||
|
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
|
||
|
pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
|
||
|
hidden_states = nn.functional.pad(hidden_states, pad_values)
|
||
|
return hidden_states, pad_values
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
input_dimensions: Tuple[int, int],
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
always_partition: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
if not always_partition:
|
||
|
self.set_shift_and_window_size(input_dimensions)
|
||
|
else:
|
||
|
pass
|
||
|
height, width = input_dimensions
|
||
|
batch_size, _, channels = hidden_states.size()
|
||
|
shortcut = hidden_states
|
||
|
|
||
|
hidden_states = self.layernorm_before(hidden_states)
|
||
|
|
||
|
hidden_states = hidden_states.view(batch_size, height, width, channels)
|
||
|
|
||
|
# pad hidden_states to multiples of window size
|
||
|
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
||
|
|
||
|
_, height_pad, width_pad, _ = hidden_states.shape
|
||
|
# cyclic shift
|
||
|
if self.shift_size > 0:
|
||
|
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||
|
else:
|
||
|
shifted_hidden_states = hidden_states
|
||
|
|
||
|
# partition windows
|
||
|
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
|
||
|
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
|
||
|
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
|
||
|
if attn_mask is not None:
|
||
|
attn_mask = attn_mask.to(hidden_states_windows.device)
|
||
|
|
||
|
attention_outputs = self.attention(
|
||
|
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
|
||
|
)
|
||
|
|
||
|
attention_output = attention_outputs[0]
|
||
|
|
||
|
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
|
||
|
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
|
||
|
|
||
|
# reverse cyclic shift
|
||
|
if self.shift_size > 0:
|
||
|
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||
|
else:
|
||
|
attention_windows = shifted_windows
|
||
|
|
||
|
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
||
|
if was_padded:
|
||
|
attention_windows = attention_windows[:, :height, :width, :].contiguous()
|
||
|
|
||
|
attention_windows = attention_windows.view(batch_size, height * width, channels)
|
||
|
|
||
|
hidden_states = shortcut + self.drop_path(attention_windows)
|
||
|
|
||
|
layer_output = self.layernorm_after(hidden_states)
|
||
|
layer_output = self.intermediate(layer_output)
|
||
|
layer_output = hidden_states + self.output(layer_output)
|
||
|
|
||
|
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
|
||
|
return layer_outputs
|
||
|
|
||
|
|
||
|
class SwinStage(nn.Module):
|
||
|
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.dim = dim
|
||
|
self.blocks = nn.ModuleList(
|
||
|
[
|
||
|
SwinLayer(
|
||
|
config=config,
|
||
|
dim=dim,
|
||
|
input_resolution=input_resolution,
|
||
|
num_heads=num_heads,
|
||
|
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
|
||
|
)
|
||
|
for i in range(depth)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
# patch merging layer
|
||
|
if downsample is not None:
|
||
|
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
|
||
|
else:
|
||
|
self.downsample = None
|
||
|
|
||
|
self.pointing = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
input_dimensions: Tuple[int, int],
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
always_partition: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor]:
|
||
|
height, width = input_dimensions
|
||
|
for i, layer_module in enumerate(self.blocks):
|
||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||
|
|
||
|
layer_outputs = layer_module(
|
||
|
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
|
||
|
hidden_states_before_downsampling = hidden_states
|
||
|
if self.downsample is not None:
|
||
|
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||
|
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||
|
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
|
||
|
else:
|
||
|
output_dimensions = (height, width, height, width)
|
||
|
|
||
|
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
|
||
|
|
||
|
if output_attentions:
|
||
|
stage_outputs += layer_outputs[1:]
|
||
|
return stage_outputs
|
||
|
|
||
|
|
||
|
class SwinEncoder(nn.Module):
|
||
|
def __init__(self, config, grid_size):
|
||
|
super().__init__()
|
||
|
self.num_layers = len(config.depths)
|
||
|
self.config = config
|
||
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
||
|
self.layers = nn.ModuleList(
|
||
|
[
|
||
|
SwinStage(
|
||
|
config=config,
|
||
|
dim=int(config.embed_dim * 2**i_layer),
|
||
|
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
|
||
|
depth=config.depths[i_layer],
|
||
|
num_heads=config.num_heads[i_layer],
|
||
|
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
|
||
|
downsample=SwinPatchMerging if (i_layer < self.num_layers - 1) else None,
|
||
|
)
|
||
|
for i_layer in range(self.num_layers)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
self.gradient_checkpointing = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
input_dimensions: Tuple[int, int],
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
output_hidden_states: Optional[bool] = False,
|
||
|
output_hidden_states_before_downsampling: Optional[bool] = False,
|
||
|
always_partition: Optional[bool] = False,
|
||
|
return_dict: Optional[bool] = True,
|
||
|
) -> Union[Tuple, SwinEncoderOutput]:
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_reshaped_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attentions = () if output_attentions else None
|
||
|
|
||
|
if output_hidden_states:
|
||
|
batch_size, _, hidden_size = hidden_states.shape
|
||
|
# rearrange b (h w) c -> b c h w
|
||
|
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
||
|
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||
|
|
||
|
for i, layer_module in enumerate(self.layers):
|
||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||
|
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
layer_module.__call__,
|
||
|
hidden_states,
|
||
|
input_dimensions,
|
||
|
layer_head_mask,
|
||
|
output_attentions,
|
||
|
always_partition,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = layer_module(
|
||
|
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
hidden_states_before_downsampling = layer_outputs[1]
|
||
|
output_dimensions = layer_outputs[2]
|
||
|
|
||
|
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
|
||
|
|
||
|
if output_hidden_states and output_hidden_states_before_downsampling:
|
||
|
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
|
||
|
# rearrange b (h w) c -> b c h w
|
||
|
# here we use the original (not downsampled) height and width
|
||
|
reshaped_hidden_state = hidden_states_before_downsampling.view(
|
||
|
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
|
||
|
)
|
||
|
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||
|
all_hidden_states += (hidden_states_before_downsampling,)
|
||
|
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||
|
elif output_hidden_states and not output_hidden_states_before_downsampling:
|
||
|
batch_size, _, hidden_size = hidden_states.shape
|
||
|
# rearrange b (h w) c -> b c h w
|
||
|
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
|
||
|
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
all_reshaped_hidden_states += (reshaped_hidden_state,)
|
||
|
|
||
|
if output_attentions:
|
||
|
all_self_attentions += layer_outputs[3:]
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||
|
|
||
|
return SwinEncoderOutput(
|
||
|
last_hidden_state=hidden_states,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attentions,
|
||
|
reshaped_hidden_states=all_reshaped_hidden_states,
|
||
|
)
|
||
|
|
||
|
|
||
|
class SwinPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = SwinConfig
|
||
|
base_model_prefix = "swin"
|
||
|
main_input_name = "pixel_values"
|
||
|
supports_gradient_checkpointing = True
|
||
|
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, nn.LayerNorm):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
|
||
|
|
||
|
SWIN_START_DOCSTRING = r"""
|
||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||
|
behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`SwinConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
SWIN_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
||
|
for details.
|
||
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
|
||
|
SWIN_START_DOCSTRING,
|
||
|
"""
|
||
|
add_pooling_layer (`bool`, *optional*, defaults to `True`):
|
||
|
Whether or not to apply pooling layer.
|
||
|
use_mask_token (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not to create and apply mask tokens in the embedding layer.
|
||
|
""",
|
||
|
)
|
||
|
class SwinModel(SwinPreTrainedModel):
|
||
|
def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
|
||
|
super().__init__(config)
|
||
|
self.config = config
|
||
|
self.num_layers = len(config.depths)
|
||
|
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
||
|
|
||
|
self.embeddings = SwinEmbeddings(config, use_mask_token=use_mask_token)
|
||
|
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||
|
|
||
|
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
|
||
|
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.embeddings.patch_embeddings
|
||
|
|
||
|
def _prune_heads(self, heads_to_prune):
|
||
|
"""
|
||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||
|
class PreTrainedModel
|
||
|
"""
|
||
|
for layer, heads in heads_to_prune.items():
|
||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=SwinModelOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="vision",
|
||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, SwinModelOutput]:
|
||
|
r"""
|
||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if pixel_values is None:
|
||
|
raise ValueError("You have to specify pixel_values")
|
||
|
|
||
|
# Prepare head mask if needed
|
||
|
# 1.0 in head_mask indicate we keep the head
|
||
|
# attention_probs has shape bsz x n_heads x N x N
|
||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||
|
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
|
||
|
|
||
|
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||
|
|
||
|
encoder_outputs = self.encoder(
|
||
|
embedding_output,
|
||
|
input_dimensions,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
sequence_output = encoder_outputs[0]
|
||
|
sequence_output = self.layernorm(sequence_output)
|
||
|
|
||
|
pooled_output = None
|
||
|
if self.pooler is not None:
|
||
|
pooled_output = self.pooler(sequence_output.transpose(1, 2))
|
||
|
pooled_output = torch.flatten(pooled_output, 1)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (sequence_output, pooled_output) + encoder_outputs[1:]
|
||
|
|
||
|
return output
|
||
|
|
||
|
return SwinModelOutput(
|
||
|
last_hidden_state=sequence_output,
|
||
|
pooler_output=pooled_output,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""Swin Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).
|
||
|
|
||
|
<Tip>
|
||
|
|
||
|
Note that we provide a script to pre-train this model on custom data in our [examples
|
||
|
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
|
||
|
|
||
|
</Tip>
|
||
|
""",
|
||
|
SWIN_START_DOCSTRING,
|
||
|
)
|
||
|
class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.swin = SwinModel(config, add_pooling_layer=False, use_mask_token=True)
|
||
|
|
||
|
num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
|
||
|
self.decoder = nn.Sequential(
|
||
|
nn.Conv2d(
|
||
|
in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
|
||
|
),
|
||
|
nn.PixelShuffle(config.encoder_stride),
|
||
|
)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, SwinMaskedImageModelingOutput]:
|
||
|
r"""
|
||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling
|
||
|
>>> import torch
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
|
||
|
>>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")
|
||
|
|
||
|
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
|
||
|
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
||
|
>>> # create random boolean mask of shape (batch_size, num_patches)
|
||
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||
|
|
||
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||
|
>>> list(reconstructed_pixel_values.shape)
|
||
|
[1, 3, 192, 192]
|
||
|
```"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
outputs = self.swin(
|
||
|
pixel_values,
|
||
|
bool_masked_pos=bool_masked_pos,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
sequence_output = outputs[0]
|
||
|
# Reshape to (batch_size, num_channels, height, width)
|
||
|
sequence_output = sequence_output.transpose(1, 2)
|
||
|
batch_size, num_channels, sequence_length = sequence_output.shape
|
||
|
height = width = math.floor(sequence_length**0.5)
|
||
|
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
|
||
|
|
||
|
# Reconstruct pixel values
|
||
|
reconstructed_pixel_values = self.decoder(sequence_output)
|
||
|
|
||
|
masked_im_loss = None
|
||
|
if bool_masked_pos is not None:
|
||
|
size = self.config.image_size // self.config.patch_size
|
||
|
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
|
||
|
mask = (
|
||
|
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
|
||
|
.repeat_interleave(self.config.patch_size, 2)
|
||
|
.unsqueeze(1)
|
||
|
.contiguous()
|
||
|
)
|
||
|
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
|
||
|
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (reconstructed_pixel_values,) + outputs[2:]
|
||
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||
|
|
||
|
return SwinMaskedImageModelingOutput(
|
||
|
loss=masked_im_loss,
|
||
|
reconstruction=reconstructed_pixel_values,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||
|
the [CLS] token) e.g. for ImageNet.
|
||
|
""",
|
||
|
SWIN_START_DOCSTRING,
|
||
|
)
|
||
|
class SwinForImageClassification(SwinPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.num_labels = config.num_labels
|
||
|
self.swin = SwinModel(config)
|
||
|
|
||
|
# Classifier head
|
||
|
self.classifier = (
|
||
|
nn.Linear(self.swin.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||
|
)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||
|
output_type=SwinImageClassifierOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, SwinImageClassifierOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
outputs = self.swin(
|
||
|
pixel_values,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
pooled_output = outputs[1]
|
||
|
|
||
|
logits = self.classifier(pooled_output)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
if self.config.problem_type is None:
|
||
|
if self.num_labels == 1:
|
||
|
self.config.problem_type = "regression"
|
||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||
|
self.config.problem_type = "single_label_classification"
|
||
|
else:
|
||
|
self.config.problem_type = "multi_label_classification"
|
||
|
|
||
|
if self.config.problem_type == "regression":
|
||
|
loss_fct = MSELoss()
|
||
|
if self.num_labels == 1:
|
||
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||
|
else:
|
||
|
loss = loss_fct(logits, labels)
|
||
|
elif self.config.problem_type == "single_label_classification":
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||
|
elif self.config.problem_type == "multi_label_classification":
|
||
|
loss_fct = BCEWithLogitsLoss()
|
||
|
loss = loss_fct(logits, labels)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[2:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return SwinImageClassifierOutput(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
reshaped_hidden_states=outputs.reshaped_hidden_states,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
Swin backbone, to be used with frameworks like DETR and MaskFormer.
|
||
|
""",
|
||
|
SWIN_START_DOCSTRING,
|
||
|
)
|
||
|
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
|
||
|
def __init__(self, config: SwinConfig):
|
||
|
super().__init__(config)
|
||
|
super()._init_backbone(config)
|
||
|
|
||
|
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||
|
self.embeddings = SwinEmbeddings(config)
|
||
|
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||
|
|
||
|
# Add layer norms to hidden states of out_features
|
||
|
hidden_states_norms = {}
|
||
|
for stage, num_channels in zip(self._out_features, self.channels):
|
||
|
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
||
|
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.embeddings.patch_embeddings
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.Tensor,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> BackboneOutput:
|
||
|
"""
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoImageProcessor, AutoBackbone
|
||
|
>>> import torch
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
|
||
|
>>> model = AutoBackbone.from_pretrained(
|
||
|
... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||
|
... )
|
||
|
|
||
|
>>> inputs = processor(image, return_tensors="pt")
|
||
|
>>> outputs = model(**inputs)
|
||
|
>>> feature_maps = outputs.feature_maps
|
||
|
>>> list(feature_maps[-1].shape)
|
||
|
[1, 768, 7, 7]
|
||
|
```"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
|
||
|
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||
|
|
||
|
outputs = self.encoder(
|
||
|
embedding_output,
|
||
|
input_dimensions,
|
||
|
head_mask=None,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=True,
|
||
|
output_hidden_states_before_downsampling=True,
|
||
|
always_partition=True,
|
||
|
return_dict=True,
|
||
|
)
|
||
|
|
||
|
hidden_states = outputs.reshaped_hidden_states
|
||
|
|
||
|
feature_maps = ()
|
||
|
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
||
|
if stage in self.out_features:
|
||
|
batch_size, num_channels, height, width = hidden_state.shape
|
||
|
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
||
|
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
|
||
|
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
||
|
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
|
||
|
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
||
|
feature_maps += (hidden_state,)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (feature_maps,)
|
||
|
if output_hidden_states:
|
||
|
output += (outputs.hidden_states,)
|
||
|
return output
|
||
|
|
||
|
return BackboneOutput(
|
||
|
feature_maps=feature_maps,
|
||
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|