552 lines
21 KiB
Python
552 lines
21 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch ConvNext model."""
|
||
|
|
||
|
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_outputs import (
|
||
|
BackboneOutput,
|
||
|
BaseModelOutputWithNoAttention,
|
||
|
BaseModelOutputWithPoolingAndNoAttention,
|
||
|
ImageClassifierOutputWithNoAttention,
|
||
|
)
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...utils import (
|
||
|
add_code_sample_docstrings,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from ...utils.backbone_utils import BackboneMixin
|
||
|
from .configuration_convnext import ConvNextConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
# General docstring
|
||
|
_CONFIG_FOR_DOC = "ConvNextConfig"
|
||
|
|
||
|
# Base docstring
|
||
|
_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
|
||
|
_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
|
||
|
|
||
|
# Image classification docstring
|
||
|
_IMAGE_CLASS_CHECKPOINT = "facebook/convnext-tiny-224"
|
||
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||
|
|
||
|
|
||
|
from ..deprecated._archive_maps import CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||
|
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||
|
"""
|
||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||
|
|
||
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||
|
argument.
|
||
|
"""
|
||
|
if drop_prob == 0.0 or not training:
|
||
|
return input
|
||
|
keep_prob = 1 - drop_prob
|
||
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||
|
random_tensor.floor_() # binarize
|
||
|
output = input.div(keep_prob) * random_tensor
|
||
|
return output
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
|
||
|
class ConvNextDropPath(nn.Module):
|
||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||
|
|
||
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||
|
super().__init__()
|
||
|
self.drop_prob = drop_prob
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
return drop_path(hidden_states, self.drop_prob, self.training)
|
||
|
|
||
|
def extra_repr(self) -> str:
|
||
|
return "p={}".format(self.drop_prob)
|
||
|
|
||
|
|
||
|
class ConvNextLayerNorm(nn.Module):
|
||
|
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
||
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
||
|
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
||
|
"""
|
||
|
|
||
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
||
|
super().__init__()
|
||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||
|
self.eps = eps
|
||
|
self.data_format = data_format
|
||
|
if self.data_format not in ["channels_last", "channels_first"]:
|
||
|
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
||
|
self.normalized_shape = (normalized_shape,)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
if self.data_format == "channels_last":
|
||
|
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||
|
elif self.data_format == "channels_first":
|
||
|
input_dtype = x.dtype
|
||
|
x = x.float()
|
||
|
u = x.mean(1, keepdim=True)
|
||
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
||
|
x = (x - u) / torch.sqrt(s + self.eps)
|
||
|
x = x.to(dtype=input_dtype)
|
||
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||
|
return x
|
||
|
|
||
|
|
||
|
class ConvNextEmbeddings(nn.Module):
|
||
|
"""This class is comparable to (and inspired by) the SwinEmbeddings class
|
||
|
found in src/transformers/models/swin/modeling_swin.py.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.patch_embeddings = nn.Conv2d(
|
||
|
config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
|
||
|
)
|
||
|
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
|
||
|
self.num_channels = config.num_channels
|
||
|
|
||
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||
|
num_channels = pixel_values.shape[1]
|
||
|
if num_channels != self.num_channels:
|
||
|
raise ValueError(
|
||
|
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||
|
)
|
||
|
embeddings = self.patch_embeddings(pixel_values)
|
||
|
embeddings = self.layernorm(embeddings)
|
||
|
return embeddings
|
||
|
|
||
|
|
||
|
class ConvNextLayer(nn.Module):
|
||
|
"""This corresponds to the `Block` class in the original implementation.
|
||
|
|
||
|
There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
|
||
|
H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
|
||
|
|
||
|
The authors used (2) as they find it slightly faster in PyTorch.
|
||
|
|
||
|
Args:
|
||
|
config ([`ConvNextConfig`]): Model configuration class.
|
||
|
dim (`int`): Number of input channels.
|
||
|
drop_path (`float`): Stochastic depth rate. Default: 0.0.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config, dim, drop_path=0):
|
||
|
super().__init__()
|
||
|
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
||
|
self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
|
||
|
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
||
|
self.act = ACT2FN[config.hidden_act]
|
||
|
self.pwconv2 = nn.Linear(4 * dim, dim)
|
||
|
self.layer_scale_parameter = (
|
||
|
nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||
|
if config.layer_scale_init_value > 0
|
||
|
else None
|
||
|
)
|
||
|
self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||
|
|
||
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
|
||
|
input = hidden_states
|
||
|
x = self.dwconv(hidden_states)
|
||
|
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||
|
x = self.layernorm(x)
|
||
|
x = self.pwconv1(x)
|
||
|
x = self.act(x)
|
||
|
x = self.pwconv2(x)
|
||
|
if self.layer_scale_parameter is not None:
|
||
|
x = self.layer_scale_parameter * x
|
||
|
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
||
|
|
||
|
x = input + self.drop_path(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class ConvNextStage(nn.Module):
|
||
|
"""ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
|
||
|
|
||
|
Args:
|
||
|
config ([`ConvNextConfig`]): Model configuration class.
|
||
|
in_channels (`int`): Number of input channels.
|
||
|
out_channels (`int`): Number of output channels.
|
||
|
depth (`int`): Number of residual blocks.
|
||
|
drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
|
||
|
super().__init__()
|
||
|
|
||
|
if in_channels != out_channels or stride > 1:
|
||
|
self.downsampling_layer = nn.Sequential(
|
||
|
ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
|
||
|
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
|
||
|
)
|
||
|
else:
|
||
|
self.downsampling_layer = nn.Identity()
|
||
|
drop_path_rates = drop_path_rates or [0.0] * depth
|
||
|
self.layers = nn.Sequential(
|
||
|
*[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
|
||
|
)
|
||
|
|
||
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
|
||
|
hidden_states = self.downsampling_layer(hidden_states)
|
||
|
hidden_states = self.layers(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class ConvNextEncoder(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.stages = nn.ModuleList()
|
||
|
drop_path_rates = [
|
||
|
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
|
||
|
]
|
||
|
prev_chs = config.hidden_sizes[0]
|
||
|
for i in range(config.num_stages):
|
||
|
out_chs = config.hidden_sizes[i]
|
||
|
stage = ConvNextStage(
|
||
|
config,
|
||
|
in_channels=prev_chs,
|
||
|
out_channels=out_chs,
|
||
|
stride=2 if i > 0 else 1,
|
||
|
depth=config.depths[i],
|
||
|
drop_path_rates=drop_path_rates[i],
|
||
|
)
|
||
|
self.stages.append(stage)
|
||
|
prev_chs = out_chs
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.FloatTensor,
|
||
|
output_hidden_states: Optional[bool] = False,
|
||
|
return_dict: Optional[bool] = True,
|
||
|
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
|
||
|
for i, layer_module in enumerate(self.stages):
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
hidden_states = layer_module(hidden_states)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
||
|
|
||
|
return BaseModelOutputWithNoAttention(
|
||
|
last_hidden_state=hidden_states,
|
||
|
hidden_states=all_hidden_states,
|
||
|
)
|
||
|
|
||
|
|
||
|
class ConvNextPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = ConvNextConfig
|
||
|
base_model_prefix = "convnext"
|
||
|
main_input_name = "pixel_values"
|
||
|
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, nn.LayerNorm):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
|
||
|
|
||
|
CONVNEXT_START_DOCSTRING = r"""
|
||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||
|
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||
|
behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
CONVNEXT_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
||
|
[`ConvNextImageProcessor.__call__`] for details.
|
||
|
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare ConvNext model outputting raw features without any specific head on top.",
|
||
|
CONVNEXT_START_DOCSTRING,
|
||
|
)
|
||
|
class ConvNextModel(ConvNextPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.config = config
|
||
|
|
||
|
self.embeddings = ConvNextEmbeddings(config)
|
||
|
self.encoder = ConvNextEncoder(config)
|
||
|
|
||
|
# final layernorm layer
|
||
|
self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=BaseModelOutputWithPoolingAndNoAttention,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="vision",
|
||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.FloatTensor = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if pixel_values is None:
|
||
|
raise ValueError("You have to specify pixel_values")
|
||
|
|
||
|
embedding_output = self.embeddings(pixel_values)
|
||
|
|
||
|
encoder_outputs = self.encoder(
|
||
|
embedding_output,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
last_hidden_state = encoder_outputs[0]
|
||
|
|
||
|
# global average pooling, (N, C, H, W) -> (N, C)
|
||
|
pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
|
||
|
|
||
|
if not return_dict:
|
||
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||
|
|
||
|
return BaseModelOutputWithPoolingAndNoAttention(
|
||
|
last_hidden_state=last_hidden_state,
|
||
|
pooler_output=pooled_output,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
||
|
ImageNet.
|
||
|
""",
|
||
|
CONVNEXT_START_DOCSTRING,
|
||
|
)
|
||
|
class ConvNextForImageClassification(ConvNextPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.num_labels = config.num_labels
|
||
|
self.convnext = ConvNextModel(config)
|
||
|
|
||
|
# Classifier head
|
||
|
self.classifier = (
|
||
|
nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||
|
)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||
|
output_type=ImageClassifierOutputWithNoAttention,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.FloatTensor = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
outputs = self.convnext(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
|
||
|
|
||
|
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
||
|
|
||
|
logits = self.classifier(pooled_output)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
if self.config.problem_type is None:
|
||
|
if self.num_labels == 1:
|
||
|
self.config.problem_type = "regression"
|
||
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||
|
self.config.problem_type = "single_label_classification"
|
||
|
else:
|
||
|
self.config.problem_type = "multi_label_classification"
|
||
|
|
||
|
if self.config.problem_type == "regression":
|
||
|
loss_fct = MSELoss()
|
||
|
if self.num_labels == 1:
|
||
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||
|
else:
|
||
|
loss = loss_fct(logits, labels)
|
||
|
elif self.config.problem_type == "single_label_classification":
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||
|
elif self.config.problem_type == "multi_label_classification":
|
||
|
loss_fct = BCEWithLogitsLoss()
|
||
|
loss = loss_fct(logits, labels)
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[2:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return ImageClassifierOutputWithNoAttention(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
|
||
|
""",
|
||
|
CONVNEXT_START_DOCSTRING,
|
||
|
)
|
||
|
class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
super()._init_backbone(config)
|
||
|
|
||
|
self.embeddings = ConvNextEmbeddings(config)
|
||
|
self.encoder = ConvNextEncoder(config)
|
||
|
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
||
|
|
||
|
# Add layer norms to hidden states of out_features
|
||
|
hidden_states_norms = {}
|
||
|
for stage, num_channels in zip(self._out_features, self.channels):
|
||
|
hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
|
||
|
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||
|
|
||
|
# initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.Tensor,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> BackboneOutput:
|
||
|
"""
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoImageProcessor, AutoBackbone
|
||
|
>>> import torch
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
|
||
|
>>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
|
||
|
|
||
|
>>> inputs = processor(image, return_tensors="pt")
|
||
|
>>> outputs = model(**inputs)
|
||
|
```"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
|
||
|
embedding_output = self.embeddings(pixel_values)
|
||
|
|
||
|
outputs = self.encoder(
|
||
|
embedding_output,
|
||
|
output_hidden_states=True,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
||
|
|
||
|
feature_maps = ()
|
||
|
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
||
|
if stage in self.out_features:
|
||
|
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
||
|
feature_maps += (hidden_state,)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (feature_maps,)
|
||
|
if output_hidden_states:
|
||
|
output += (hidden_states,)
|
||
|
return output
|
||
|
|
||
|
return BackboneOutput(
|
||
|
feature_maps=feature_maps,
|
||
|
hidden_states=hidden_states if output_hidden_states else None,
|
||
|
attentions=None,
|
||
|
)
|