892 lines
37 KiB
Python
892 lines
37 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2021 Facebook AI Research (FAIR), Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch DeiT model."""
|
||
|
|
||
|
|
||
|
import collections.abc
|
||
|
import math
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Optional, Set, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_outputs import (
|
||
|
BaseModelOutput,
|
||
|
BaseModelOutputWithPooling,
|
||
|
ImageClassifierOutput,
|
||
|
MaskedImageModelingOutput,
|
||
|
)
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||
|
from ...utils import (
|
||
|
ModelOutput,
|
||
|
add_code_sample_docstrings,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from .configuration_deit import DeiTConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
# General docstring
|
||
|
_CONFIG_FOR_DOC = "DeiTConfig"
|
||
|
|
||
|
# Base docstring
|
||
|
_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
|
||
|
_EXPECTED_OUTPUT_SHAPE = [1, 198, 768]
|
||
|
|
||
|
# Image classification docstring
|
||
|
_IMAGE_CLASS_CHECKPOINT = "facebook/deit-base-distilled-patch16-224"
|
||
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||
|
|
||
|
|
||
|
from ..deprecated._archive_maps import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
class DeiTEmbeddings(nn.Module):
|
||
|
"""
|
||
|
Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||
|
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||
|
self.patch_embeddings = DeiTPatchEmbeddings(config)
|
||
|
num_patches = self.patch_embeddings.num_patches
|
||
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||
|
embeddings = self.patch_embeddings(pixel_values)
|
||
|
batch_size, seq_length, _ = embeddings.size()
|
||
|
|
||
|
if bool_masked_pos is not None:
|
||
|
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
||
|
# replace the masked visual tokens by mask_tokens
|
||
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||
|
|
||
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||
|
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
|
||
|
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
|
||
|
embeddings = embeddings + self.position_embeddings
|
||
|
embeddings = self.dropout(embeddings)
|
||
|
return embeddings
|
||
|
|
||
|
|
||
|
class DeiTPatchEmbeddings(nn.Module):
|
||
|
"""
|
||
|
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
||
|
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
||
|
Transformer.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
image_size, patch_size = config.image_size, config.patch_size
|
||
|
num_channels, hidden_size = config.num_channels, config.hidden_size
|
||
|
|
||
|
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
||
|
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
||
|
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||
|
self.image_size = image_size
|
||
|
self.patch_size = patch_size
|
||
|
self.num_channels = num_channels
|
||
|
self.num_patches = num_patches
|
||
|
|
||
|
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||
|
|
||
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||
|
batch_size, num_channels, height, width = pixel_values.shape
|
||
|
if num_channels != self.num_channels:
|
||
|
raise ValueError(
|
||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||
|
)
|
||
|
if height != self.image_size[0] or width != self.image_size[1]:
|
||
|
raise ValueError(
|
||
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||
|
)
|
||
|
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||
|
return x
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
|
||
|
class DeiTSelfAttention(nn.Module):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__()
|
||
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||
|
raise ValueError(
|
||
|
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||
|
f"heads {config.num_attention_heads}."
|
||
|
)
|
||
|
|
||
|
self.num_attention_heads = config.num_attention_heads
|
||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
|
||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
|
||
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||
|
x = x.view(new_x_shape)
|
||
|
return x.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(
|
||
|
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||
|
mixed_query_layer = self.query(hidden_states)
|
||
|
|
||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||
|
|
||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||
|
|
||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
# Mask heads if we want to
|
||
|
if head_mask is not None:
|
||
|
attention_probs = attention_probs * head_mask
|
||
|
|
||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||
|
|
||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer = context_layer.view(new_context_layer_shape)
|
||
|
|
||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
|
||
|
class DeiTSelfOutput(nn.Module):
|
||
|
"""
|
||
|
The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
|
||
|
layernorm applied before each block.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
|
||
|
class DeiTAttention(nn.Module):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__()
|
||
|
self.attention = DeiTSelfAttention(config)
|
||
|
self.output = DeiTSelfOutput(config)
|
||
|
self.pruned_heads = set()
|
||
|
|
||
|
def prune_heads(self, heads: Set[int]) -> None:
|
||
|
if len(heads) == 0:
|
||
|
return
|
||
|
heads, index = find_pruneable_heads_and_indices(
|
||
|
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
||
|
)
|
||
|
|
||
|
# Prune linear layers
|
||
|
self.attention.query = prune_linear_layer(self.attention.query, index)
|
||
|
self.attention.key = prune_linear_layer(self.attention.key, index)
|
||
|
self.attention.value = prune_linear_layer(self.attention.value, index)
|
||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||
|
|
||
|
# Update hyper params and store pruned heads
|
||
|
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
||
|
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: bool = False,
|
||
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||
|
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
||
|
|
||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||
|
|
||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
|
||
|
class DeiTIntermediate(nn.Module):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||
|
if isinstance(config.hidden_act, str):
|
||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.hidden_act
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
|
||
|
class DeiTOutput(nn.Module):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
||
|
hidden_states = hidden_states + input_tensor
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT
|
||
|
class DeiTLayer(nn.Module):
|
||
|
"""This corresponds to the Block class in the timm implementation."""
|
||
|
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__()
|
||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||
|
self.seq_len_dim = 1
|
||
|
self.attention = DeiTAttention(config)
|
||
|
self.intermediate = DeiTIntermediate(config)
|
||
|
self.output = DeiTOutput(config)
|
||
|
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: bool = False,
|
||
|
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||
|
self_attention_outputs = self.attention(
|
||
|
self.layernorm_before(hidden_states), # in DeiT, layernorm is applied before self-attention
|
||
|
head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
attention_output = self_attention_outputs[0]
|
||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||
|
|
||
|
# first residual connection
|
||
|
hidden_states = attention_output + hidden_states
|
||
|
|
||
|
# in DeiT, layernorm is also applied after self-attention
|
||
|
layer_output = self.layernorm_after(hidden_states)
|
||
|
layer_output = self.intermediate(layer_output)
|
||
|
|
||
|
# second residual connection is done here
|
||
|
layer_output = self.output(layer_output, hidden_states)
|
||
|
|
||
|
outputs = (layer_output,) + outputs
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
|
||
|
class DeiTEncoder(nn.Module):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
|
||
|
self.gradient_checkpointing = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: bool = False,
|
||
|
output_hidden_states: bool = False,
|
||
|
return_dict: bool = True,
|
||
|
) -> Union[tuple, BaseModelOutput]:
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attentions = () if output_attentions else None
|
||
|
|
||
|
for i, layer_module in enumerate(self.layer):
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||
|
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
layer_module.__call__,
|
||
|
hidden_states,
|
||
|
layer_head_mask,
|
||
|
output_attentions,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
|
||
|
if output_attentions:
|
||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=hidden_states,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class DeiTPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = DeiTConfig
|
||
|
base_model_prefix = "deit"
|
||
|
main_input_name = "pixel_values"
|
||
|
supports_gradient_checkpointing = True
|
||
|
_no_split_modules = ["DeiTLayer"]
|
||
|
|
||
|
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||
|
"""Initialize the weights"""
|
||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
|
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||
|
# `trunc_normal_cpu` not implemented in `half` issues
|
||
|
module.weight.data = nn.init.trunc_normal_(
|
||
|
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||
|
).to(module.weight.dtype)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, nn.LayerNorm):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
|
||
|
|
||
|
DEIT_START_DOCSTRING = r"""
|
||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||
|
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||
|
behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`DeiTConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
DEIT_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
||
|
[`DeiTImageProcessor.__call__`] for details.
|
||
|
|
||
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.",
|
||
|
DEIT_START_DOCSTRING,
|
||
|
)
|
||
|
class DeiTModel(DeiTPreTrainedModel):
|
||
|
def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
|
||
|
super().__init__(config)
|
||
|
self.config = config
|
||
|
|
||
|
self.embeddings = DeiTEmbeddings(config, use_mask_token=use_mask_token)
|
||
|
self.encoder = DeiTEncoder(config)
|
||
|
|
||
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.pooler = DeiTPooler(config) if add_pooling_layer else None
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self) -> DeiTPatchEmbeddings:
|
||
|
return self.embeddings.patch_embeddings
|
||
|
|
||
|
def _prune_heads(self, heads_to_prune):
|
||
|
"""
|
||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||
|
class PreTrainedModel
|
||
|
"""
|
||
|
for layer, heads in heads_to_prune.items():
|
||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=BaseModelOutputWithPooling,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="vision",
|
||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.Tensor] = None,
|
||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||
|
r"""
|
||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if pixel_values is None:
|
||
|
raise ValueError("You have to specify pixel_values")
|
||
|
|
||
|
# Prepare head mask if needed
|
||
|
# 1.0 in head_mask indicate we keep the head
|
||
|
# attention_probs has shape bsz x n_heads x N x N
|
||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||
|
|
||
|
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||
|
if pixel_values.dtype != expected_dtype:
|
||
|
pixel_values = pixel_values.to(expected_dtype)
|
||
|
|
||
|
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||
|
|
||
|
encoder_outputs = self.encoder(
|
||
|
embedding_output,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
sequence_output = encoder_outputs[0]
|
||
|
sequence_output = self.layernorm(sequence_output)
|
||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||
|
|
||
|
if not return_dict:
|
||
|
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
||
|
return head_outputs + encoder_outputs[1:]
|
||
|
|
||
|
return BaseModelOutputWithPooling(
|
||
|
last_hidden_state=sequence_output,
|
||
|
pooler_output=pooled_output,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
|
||
|
class DeiTPooler(nn.Module):
|
||
|
def __init__(self, config: DeiTConfig):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.activation = nn.Tanh()
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
# We "pool" the model by simply taking the hidden state corresponding
|
||
|
# to the first token.
|
||
|
first_token_tensor = hidden_states[:, 0]
|
||
|
pooled_output = self.dense(first_token_tensor)
|
||
|
pooled_output = self.activation(pooled_output)
|
||
|
return pooled_output
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).
|
||
|
|
||
|
<Tip>
|
||
|
|
||
|
Note that we provide a script to pre-train this model on custom data in our [examples
|
||
|
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
|
||
|
|
||
|
</Tip>
|
||
|
""",
|
||
|
DEIT_START_DOCSTRING,
|
||
|
)
|
||
|
class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
|
||
|
|
||
|
self.decoder = nn.Sequential(
|
||
|
nn.Conv2d(
|
||
|
in_channels=config.hidden_size,
|
||
|
out_channels=config.encoder_stride**2 * config.num_channels,
|
||
|
kernel_size=1,
|
||
|
),
|
||
|
nn.PixelShuffle(config.encoder_stride),
|
||
|
)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.Tensor] = None,
|
||
|
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[tuple, MaskedImageModelingOutput]:
|
||
|
r"""
|
||
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
|
||
|
>>> import torch
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||
|
>>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||
|
|
||
|
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
|
||
|
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
||
|
>>> # create random boolean mask of shape (batch_size, num_patches)
|
||
|
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||
|
|
||
|
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
||
|
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
||
|
>>> list(reconstructed_pixel_values.shape)
|
||
|
[1, 3, 224, 224]
|
||
|
```"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
outputs = self.deit(
|
||
|
pixel_values,
|
||
|
bool_masked_pos=bool_masked_pos,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
sequence_output = outputs[0]
|
||
|
|
||
|
# Reshape to (batch_size, num_channels, height, width)
|
||
|
sequence_output = sequence_output[:, 1:-1]
|
||
|
batch_size, sequence_length, num_channels = sequence_output.shape
|
||
|
height = width = int(sequence_length**0.5)
|
||
|
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
||
|
|
||
|
# Reconstruct pixel values
|
||
|
reconstructed_pixel_values = self.decoder(sequence_output)
|
||
|
|
||
|
masked_im_loss = None
|
||
|
if bool_masked_pos is not None:
|
||
|
size = self.config.image_size // self.config.patch_size
|
||
|
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
|
||
|
mask = (
|
||
|
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
|
||
|
.repeat_interleave(self.config.patch_size, 2)
|
||
|
.unsqueeze(1)
|
||
|
.contiguous()
|
||
|
)
|
||
|
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
|
||
|
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (reconstructed_pixel_values,) + outputs[1:]
|
||
|
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||
|
|
||
|
return MaskedImageModelingOutput(
|
||
|
loss=masked_im_loss,
|
||
|
reconstruction=reconstructed_pixel_values,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||
|
the [CLS] token) e.g. for ImageNet.
|
||
|
""",
|
||
|
DEIT_START_DOCSTRING,
|
||
|
)
|
||
|
class DeiTForImageClassification(DeiTPreTrainedModel):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.num_labels = config.num_labels
|
||
|
self.deit = DeiTModel(config, add_pooling_layer=False)
|
||
|
|
||
|
# Classifier head
|
||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[tuple, ImageClassifierOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoImageProcessor, DeiTForImageClassification
|
||
|
>>> import torch
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
|
||
|
>>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
|
||
|
>>> # so the head will be randomly initialized, hence the predictions will be random
|
||
|
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||
|
>>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||
|
|
||
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
||
|
>>> outputs = model(**inputs)
|
||
|
>>> logits = outputs.logits
|
||
|
>>> # model predicts one of the 1000 ImageNet classes
|
||
|
>>> predicted_class_idx = logits.argmax(-1).item()
|
||
|
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||
|
Predicted class: Polaroid camera, Polaroid Land camera
|
||
|
```"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
outputs = self.deit(
|
||
|
pixel_values,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
sequence_output = outputs[0]
|
||
|
|
||
|
logits = self.classifier(sequence_output[:, 0, :])
|
||
|
# we don't use the distillation token
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
labels = labels.to(logits.device)
|
||
|
if self.config.problem_type is None:
|
||
|
if self.num_labels == 1:
|
||
|
self.config.problem_type = "regression"
|
||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||
|
self.config.problem_type = "single_label_classification"
|
||
|
else:
|
||
|
self.config.problem_type = "multi_label_classification"
|
||
|
|
||
|
if self.config.problem_type == "regression":
|
||
|
loss_fct = MSELoss()
|
||
|
if self.num_labels == 1:
|
||
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||
|
else:
|
||
|
loss = loss_fct(logits, labels)
|
||
|
elif self.config.problem_type == "single_label_classification":
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||
|
elif self.config.problem_type == "multi_label_classification":
|
||
|
loss_fct = BCEWithLogitsLoss()
|
||
|
loss = loss_fct(logits, labels)
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[1:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return ImageClassifierOutput(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
|
||
|
"""
|
||
|
Output type of [`DeiTForImageClassificationWithTeacher`].
|
||
|
|
||
|
Args:
|
||
|
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||
|
Prediction scores as the average of the cls_logits and distillation logits.
|
||
|
cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||
|
Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
|
||
|
class token).
|
||
|
distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||
|
Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
|
||
|
distillation token).
|
||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||
|
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
||
|
plus the initial embedding outputs.
|
||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||
|
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
||
|
the self-attention heads.
|
||
|
"""
|
||
|
|
||
|
logits: torch.FloatTensor = None
|
||
|
cls_logits: torch.FloatTensor = None
|
||
|
distillation_logits: torch.FloatTensor = None
|
||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
|
||
|
the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
|
||
|
|
||
|
.. warning::
|
||
|
|
||
|
This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
|
||
|
supported.
|
||
|
""",
|
||
|
DEIT_START_DOCSTRING,
|
||
|
)
|
||
|
class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
|
||
|
def __init__(self, config: DeiTConfig) -> None:
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.num_labels = config.num_labels
|
||
|
self.deit = DeiTModel(config, add_pooling_layer=False)
|
||
|
|
||
|
# Classifier heads
|
||
|
self.cls_classifier = (
|
||
|
nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||
|
)
|
||
|
self.distillation_classifier = (
|
||
|
nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||
|
)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||
|
output_type=DeiTForImageClassificationWithTeacherOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]:
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
outputs = self.deit(
|
||
|
pixel_values,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
sequence_output = outputs[0]
|
||
|
|
||
|
cls_logits = self.cls_classifier(sequence_output[:, 0, :])
|
||
|
distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
|
||
|
|
||
|
# during inference, return the average of both classifier predictions
|
||
|
logits = (cls_logits + distillation_logits) / 2
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits, cls_logits, distillation_logits) + outputs[1:]
|
||
|
return output
|
||
|
|
||
|
return DeiTForImageClassificationWithTeacherOutput(
|
||
|
logits=logits,
|
||
|
cls_logits=cls_logits,
|
||
|
distillation_logits=distillation_logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|