779 lines
31 KiB
Python
779 lines
31 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 KAIST and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch GLPN model."""
|
||
|
|
||
|
|
||
|
import math
|
||
|
from typing import List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||
|
from ...utils import (
|
||
|
add_code_sample_docstrings,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from .configuration_glpn import GLPNConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
# General docstring
|
||
|
_CONFIG_FOR_DOC = "GLPNConfig"
|
||
|
|
||
|
# Base docstring
|
||
|
_CHECKPOINT_FOR_DOC = "vinvino02/glpn-kitti"
|
||
|
_EXPECTED_OUTPUT_SHAPE = [1, 512, 15, 20]
|
||
|
|
||
|
|
||
|
from ..deprecated._archive_maps import GLPN_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.beit.modeling_beit.drop_path
|
||
|
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
||
|
"""
|
||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||
|
|
||
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
||
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
||
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
||
|
argument.
|
||
|
"""
|
||
|
if drop_prob == 0.0 or not training:
|
||
|
return input
|
||
|
keep_prob = 1 - drop_prob
|
||
|
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||
|
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
||
|
random_tensor.floor_() # binarize
|
||
|
output = input.div(keep_prob) * random_tensor
|
||
|
return output
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerDropPath
|
||
|
class GLPNDropPath(nn.Module):
|
||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||
|
|
||
|
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
||
|
super().__init__()
|
||
|
self.drop_prob = drop_prob
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
return drop_path(hidden_states, self.drop_prob, self.training)
|
||
|
|
||
|
def extra_repr(self) -> str:
|
||
|
return "p={}".format(self.drop_prob)
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
|
||
|
class GLPNOverlapPatchEmbeddings(nn.Module):
|
||
|
"""Construct the overlapping patch embeddings."""
|
||
|
|
||
|
def __init__(self, patch_size, stride, num_channels, hidden_size):
|
||
|
super().__init__()
|
||
|
self.proj = nn.Conv2d(
|
||
|
num_channels,
|
||
|
hidden_size,
|
||
|
kernel_size=patch_size,
|
||
|
stride=stride,
|
||
|
padding=patch_size // 2,
|
||
|
)
|
||
|
|
||
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
||
|
|
||
|
def forward(self, pixel_values):
|
||
|
embeddings = self.proj(pixel_values)
|
||
|
_, _, height, width = embeddings.shape
|
||
|
# (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
|
||
|
# this can be fed to a Transformer layer
|
||
|
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||
|
embeddings = self.layer_norm(embeddings)
|
||
|
return embeddings, height, width
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerEfficientSelfAttention
|
||
|
class GLPNEfficientSelfAttention(nn.Module):
|
||
|
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
|
||
|
paper](https://arxiv.org/abs/2102.12122)."""
|
||
|
|
||
|
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
|
||
|
super().__init__()
|
||
|
self.hidden_size = hidden_size
|
||
|
self.num_attention_heads = num_attention_heads
|
||
|
|
||
|
if self.hidden_size % self.num_attention_heads != 0:
|
||
|
raise ValueError(
|
||
|
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
||
|
f"heads ({self.num_attention_heads})"
|
||
|
)
|
||
|
|
||
|
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
|
||
|
self.query = nn.Linear(self.hidden_size, self.all_head_size)
|
||
|
self.key = nn.Linear(self.hidden_size, self.all_head_size)
|
||
|
self.value = nn.Linear(self.hidden_size, self.all_head_size)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
|
||
|
self.sr_ratio = sequence_reduction_ratio
|
||
|
if sequence_reduction_ratio > 1:
|
||
|
self.sr = nn.Conv2d(
|
||
|
hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
|
||
|
)
|
||
|
self.layer_norm = nn.LayerNorm(hidden_size)
|
||
|
|
||
|
def transpose_for_scores(self, hidden_states):
|
||
|
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||
|
hidden_states = hidden_states.view(new_shape)
|
||
|
return hidden_states.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states,
|
||
|
height,
|
||
|
width,
|
||
|
output_attentions=False,
|
||
|
):
|
||
|
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||
|
|
||
|
if self.sr_ratio > 1:
|
||
|
batch_size, seq_len, num_channels = hidden_states.shape
|
||
|
# Reshape to (batch_size, num_channels, height, width)
|
||
|
hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
||
|
# Apply sequence reduction
|
||
|
hidden_states = self.sr(hidden_states)
|
||
|
# Reshape back to (batch_size, seq_len, num_channels)
|
||
|
hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
|
||
|
hidden_states = self.layer_norm(hidden_states)
|
||
|
|
||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||
|
|
||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||
|
|
||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||
|
|
||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer = context_layer.view(new_context_layer_shape)
|
||
|
|
||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerSelfOutput
|
||
|
class GLPNSelfOutput(nn.Module):
|
||
|
def __init__(self, config, hidden_size):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(hidden_size, hidden_size)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states, input_tensor):
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerAttention with Segformer->GLPN
|
||
|
class GLPNAttention(nn.Module):
|
||
|
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
|
||
|
super().__init__()
|
||
|
self.self = GLPNEfficientSelfAttention(
|
||
|
config=config,
|
||
|
hidden_size=hidden_size,
|
||
|
num_attention_heads=num_attention_heads,
|
||
|
sequence_reduction_ratio=sequence_reduction_ratio,
|
||
|
)
|
||
|
self.output = GLPNSelfOutput(config, hidden_size=hidden_size)
|
||
|
self.pruned_heads = set()
|
||
|
|
||
|
def prune_heads(self, heads):
|
||
|
if len(heads) == 0:
|
||
|
return
|
||
|
heads, index = find_pruneable_heads_and_indices(
|
||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||
|
)
|
||
|
|
||
|
# Prune linear layers
|
||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||
|
|
||
|
# Update hyper params and store pruned heads
|
||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||
|
|
||
|
def forward(self, hidden_states, height, width, output_attentions=False):
|
||
|
self_outputs = self.self(hidden_states, height, width, output_attentions)
|
||
|
|
||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerDWConv
|
||
|
class GLPNDWConv(nn.Module):
|
||
|
def __init__(self, dim=768):
|
||
|
super().__init__()
|
||
|
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||
|
|
||
|
def forward(self, hidden_states, height, width):
|
||
|
batch_size, seq_len, num_channels = hidden_states.shape
|
||
|
hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
|
||
|
hidden_states = self.dwconv(hidden_states)
|
||
|
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerMixFFN with Segformer->GLPN
|
||
|
class GLPNMixFFN(nn.Module):
|
||
|
def __init__(self, config, in_features, hidden_features=None, out_features=None):
|
||
|
super().__init__()
|
||
|
out_features = out_features or in_features
|
||
|
self.dense1 = nn.Linear(in_features, hidden_features)
|
||
|
self.dwconv = GLPNDWConv(hidden_features)
|
||
|
if isinstance(config.hidden_act, str):
|
||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.hidden_act
|
||
|
self.dense2 = nn.Linear(hidden_features, out_features)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states, height, width):
|
||
|
hidden_states = self.dense1(hidden_states)
|
||
|
hidden_states = self.dwconv(hidden_states, height, width)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.dense2(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerLayer with Segformer->GLPN
|
||
|
class GLPNLayer(nn.Module):
|
||
|
"""This corresponds to the Block class in the original implementation."""
|
||
|
|
||
|
def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
|
||
|
super().__init__()
|
||
|
self.layer_norm_1 = nn.LayerNorm(hidden_size)
|
||
|
self.attention = GLPNAttention(
|
||
|
config,
|
||
|
hidden_size=hidden_size,
|
||
|
num_attention_heads=num_attention_heads,
|
||
|
sequence_reduction_ratio=sequence_reduction_ratio,
|
||
|
)
|
||
|
self.drop_path = GLPNDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||
|
self.layer_norm_2 = nn.LayerNorm(hidden_size)
|
||
|
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||
|
self.mlp = GLPNMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
|
||
|
|
||
|
def forward(self, hidden_states, height, width, output_attentions=False):
|
||
|
self_attention_outputs = self.attention(
|
||
|
self.layer_norm_1(hidden_states), # in GLPN, layernorm is applied before self-attention
|
||
|
height,
|
||
|
width,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
|
||
|
attention_output = self_attention_outputs[0]
|
||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||
|
|
||
|
# first residual connection (with stochastic depth)
|
||
|
attention_output = self.drop_path(attention_output)
|
||
|
hidden_states = attention_output + hidden_states
|
||
|
|
||
|
mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
|
||
|
|
||
|
# second residual connection (with stochastic depth)
|
||
|
mlp_output = self.drop_path(mlp_output)
|
||
|
layer_output = mlp_output + hidden_states
|
||
|
|
||
|
outputs = (layer_output,) + outputs
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class GLPNEncoder(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
# stochastic depth decay rule
|
||
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
||
|
|
||
|
# patch embeddings
|
||
|
embeddings = []
|
||
|
for i in range(config.num_encoder_blocks):
|
||
|
embeddings.append(
|
||
|
GLPNOverlapPatchEmbeddings(
|
||
|
patch_size=config.patch_sizes[i],
|
||
|
stride=config.strides[i],
|
||
|
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
|
||
|
hidden_size=config.hidden_sizes[i],
|
||
|
)
|
||
|
)
|
||
|
self.patch_embeddings = nn.ModuleList(embeddings)
|
||
|
|
||
|
# Transformer blocks
|
||
|
blocks = []
|
||
|
cur = 0
|
||
|
for i in range(config.num_encoder_blocks):
|
||
|
# each block consists of layers
|
||
|
layers = []
|
||
|
if i != 0:
|
||
|
cur += config.depths[i - 1]
|
||
|
for j in range(config.depths[i]):
|
||
|
layers.append(
|
||
|
GLPNLayer(
|
||
|
config,
|
||
|
hidden_size=config.hidden_sizes[i],
|
||
|
num_attention_heads=config.num_attention_heads[i],
|
||
|
drop_path=dpr[cur + j],
|
||
|
sequence_reduction_ratio=config.sr_ratios[i],
|
||
|
mlp_ratio=config.mlp_ratios[i],
|
||
|
)
|
||
|
)
|
||
|
blocks.append(nn.ModuleList(layers))
|
||
|
|
||
|
self.block = nn.ModuleList(blocks)
|
||
|
|
||
|
# Layer norms
|
||
|
self.layer_norm = nn.ModuleList(
|
||
|
[nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values,
|
||
|
output_attentions=False,
|
||
|
output_hidden_states=False,
|
||
|
return_dict=True,
|
||
|
):
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attentions = () if output_attentions else None
|
||
|
|
||
|
batch_size = pixel_values.shape[0]
|
||
|
|
||
|
hidden_states = pixel_values
|
||
|
for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
|
||
|
embedding_layer, block_layer, norm_layer = x
|
||
|
# first, obtain patch embeddings
|
||
|
hidden_states, height, width = embedding_layer(hidden_states)
|
||
|
# second, send embeddings through blocks
|
||
|
for i, blk in enumerate(block_layer):
|
||
|
layer_outputs = blk(hidden_states, height, width, output_attentions)
|
||
|
hidden_states = layer_outputs[0]
|
||
|
if output_attentions:
|
||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||
|
# third, apply layer norm
|
||
|
hidden_states = norm_layer(hidden_states)
|
||
|
# fourth, optionally reshape back to (batch_size, num_channels, height, width)
|
||
|
hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=hidden_states,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class GLPNPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = GLPNConfig
|
||
|
base_model_prefix = "glpn"
|
||
|
main_input_name = "pixel_values"
|
||
|
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, nn.Embedding):
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
if module.padding_idx is not None:
|
||
|
module.weight.data[module.padding_idx].zero_()
|
||
|
elif isinstance(module, nn.LayerNorm):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
|
||
|
|
||
|
GLPN_START_DOCSTRING = r"""
|
||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||
|
behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`GLPNConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
GLPN_INPUTS_DOCSTRING = r"""
|
||
|
|
||
|
Args:
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||
|
[`AutoImageProcessor`]. See [`GLPNImageProcessor.__call__`] for details.
|
||
|
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare GLPN encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
|
||
|
GLPN_START_DOCSTRING,
|
||
|
)
|
||
|
class GLPNModel(GLPNPreTrainedModel):
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerModel.__init__ with Segformer->GLPN
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.config = config
|
||
|
|
||
|
# hierarchical Transformer encoder
|
||
|
self.encoder = GLPNEncoder(config)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def _prune_heads(self, heads_to_prune):
|
||
|
"""
|
||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||
|
class PreTrainedModel
|
||
|
"""
|
||
|
for layer, heads in heads_to_prune.items():
|
||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(GLPN_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=BaseModelOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="vision",
|
||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||
|
)
|
||
|
# Copied from transformers.models.segformer.modeling_segformer.SegformerModel.forward
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.FloatTensor,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutput]:
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
encoder_outputs = self.encoder(
|
||
|
pixel_values,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
sequence_output = encoder_outputs[0]
|
||
|
|
||
|
if not return_dict:
|
||
|
return (sequence_output,) + encoder_outputs[1:]
|
||
|
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=sequence_output,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class GLPNSelectiveFeatureFusion(nn.Module):
|
||
|
"""
|
||
|
Selective Feature Fusion module, as explained in the [paper](https://arxiv.org/abs/2201.07436) (section 3.4). This
|
||
|
module adaptively selects and integrates local and global features by attaining an attention map for each feature.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, in_channel=64):
|
||
|
super().__init__()
|
||
|
|
||
|
self.convolutional_layer1 = nn.Sequential(
|
||
|
nn.Conv2d(in_channels=int(in_channel * 2), out_channels=in_channel, kernel_size=3, stride=1, padding=1),
|
||
|
nn.BatchNorm2d(in_channel),
|
||
|
nn.ReLU(),
|
||
|
)
|
||
|
|
||
|
self.convolutional_layer2 = nn.Sequential(
|
||
|
nn.Conv2d(in_channels=in_channel, out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1),
|
||
|
nn.BatchNorm2d(int(in_channel / 2)),
|
||
|
nn.ReLU(),
|
||
|
)
|
||
|
|
||
|
self.convolutional_layer3 = nn.Conv2d(
|
||
|
in_channels=int(in_channel / 2), out_channels=2, kernel_size=3, stride=1, padding=1
|
||
|
)
|
||
|
|
||
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
||
|
def forward(self, local_features, global_features):
|
||
|
# concatenate features along the channel dimension
|
||
|
features = torch.cat((local_features, global_features), dim=1)
|
||
|
# pass through convolutional layers
|
||
|
features = self.convolutional_layer1(features)
|
||
|
features = self.convolutional_layer2(features)
|
||
|
features = self.convolutional_layer3(features)
|
||
|
# apply sigmoid to get two-channel attention map
|
||
|
attn = self.sigmoid(features)
|
||
|
# construct hybrid features by adding element-wise
|
||
|
hybrid_features = local_features * attn[:, 0, :, :].unsqueeze(1) + global_features * attn[
|
||
|
:, 1, :, :
|
||
|
].unsqueeze(1)
|
||
|
|
||
|
return hybrid_features
|
||
|
|
||
|
|
||
|
class GLPNDecoderStage(nn.Module):
|
||
|
def __init__(self, in_channels, out_channels):
|
||
|
super().__init__()
|
||
|
should_skip = in_channels == out_channels
|
||
|
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1) if not should_skip else nn.Identity()
|
||
|
self.fusion = GLPNSelectiveFeatureFusion(out_channels)
|
||
|
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
|
||
|
|
||
|
def forward(self, hidden_state, residual=None):
|
||
|
hidden_state = self.convolution(hidden_state)
|
||
|
if residual is not None:
|
||
|
hidden_state = self.fusion(hidden_state, residual)
|
||
|
hidden_state = self.upsample(hidden_state)
|
||
|
|
||
|
return hidden_state
|
||
|
|
||
|
hidden_state = self.upsample(hidden_state)
|
||
|
return hidden_state
|
||
|
|
||
|
|
||
|
class GLPNDecoder(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
# we use features from end -> start
|
||
|
reserved_hidden_sizes = config.hidden_sizes[::-1]
|
||
|
out_channels = config.decoder_hidden_size
|
||
|
|
||
|
self.stages = nn.ModuleList(
|
||
|
[GLPNDecoderStage(hidden_size, out_channels) for hidden_size in reserved_hidden_sizes]
|
||
|
)
|
||
|
# don't fuse in first stage
|
||
|
self.stages[0].fusion = None
|
||
|
|
||
|
self.final_upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
|
||
|
|
||
|
def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
|
||
|
stage_hidden_states = []
|
||
|
stage_hidden_state = None
|
||
|
for hidden_state, stage in zip(hidden_states[::-1], self.stages):
|
||
|
stage_hidden_state = stage(hidden_state, stage_hidden_state)
|
||
|
stage_hidden_states.append(stage_hidden_state)
|
||
|
|
||
|
stage_hidden_states[-1] = self.final_upsample(stage_hidden_state)
|
||
|
|
||
|
return stage_hidden_states
|
||
|
|
||
|
|
||
|
class SiLogLoss(nn.Module):
|
||
|
r"""
|
||
|
Implements the Scale-invariant log scale loss [Eigen et al., 2014](https://arxiv.org/abs/1406.2283).
|
||
|
|
||
|
$$L=\frac{1}{n} \sum_{i} d_{i}^{2}-\frac{1}{2 n^{2}}\left(\sum_{i} d_{i}^{2}\right)$$ where $d_{i}=\log y_{i}-\log
|
||
|
y_{i}^{*}$.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, lambd=0.5):
|
||
|
super().__init__()
|
||
|
self.lambd = lambd
|
||
|
|
||
|
def forward(self, pred, target):
|
||
|
valid_mask = (target > 0).detach()
|
||
|
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
|
||
|
loss = torch.sqrt(torch.pow(diff_log, 2).mean() - self.lambd * torch.pow(diff_log.mean(), 2))
|
||
|
|
||
|
return loss
|
||
|
|
||
|
|
||
|
class GLPNDepthEstimationHead(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
self.config = config
|
||
|
|
||
|
channels = config.decoder_hidden_size
|
||
|
self.head = nn.Sequential(
|
||
|
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
|
||
|
nn.ReLU(inplace=False),
|
||
|
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1),
|
||
|
)
|
||
|
|
||
|
def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
|
||
|
# use last features of the decoder
|
||
|
hidden_states = hidden_states[self.config.head_in_index]
|
||
|
|
||
|
hidden_states = self.head(hidden_states)
|
||
|
|
||
|
predicted_depth = torch.sigmoid(hidden_states) * self.config.max_depth
|
||
|
predicted_depth = predicted_depth.squeeze(dim=1)
|
||
|
|
||
|
return predicted_depth
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.""",
|
||
|
GLPN_START_DOCSTRING,
|
||
|
)
|
||
|
class GLPNForDepthEstimation(GLPNPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.glpn = GLPNModel(config)
|
||
|
self.decoder = GLPNDecoder(config)
|
||
|
self.head = GLPNDepthEstimationHead(config)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(GLPN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||
|
@replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.FloatTensor,
|
||
|
labels: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
|
||
|
r"""
|
||
|
labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
|
||
|
Ground truth depth estimation maps for computing the loss.
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoImageProcessor, GLPNForDepthEstimation
|
||
|
>>> import torch
|
||
|
>>> import numpy as np
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> image_processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti")
|
||
|
>>> model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti")
|
||
|
|
||
|
>>> # prepare image for the model
|
||
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
||
|
|
||
|
>>> with torch.no_grad():
|
||
|
... outputs = model(**inputs)
|
||
|
... predicted_depth = outputs.predicted_depth
|
||
|
|
||
|
>>> # interpolate to original size
|
||
|
>>> prediction = torch.nn.functional.interpolate(
|
||
|
... predicted_depth.unsqueeze(1),
|
||
|
... size=image.size[::-1],
|
||
|
... mode="bicubic",
|
||
|
... align_corners=False,
|
||
|
... )
|
||
|
|
||
|
>>> # visualize the prediction
|
||
|
>>> output = prediction.squeeze().cpu().numpy()
|
||
|
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
||
|
>>> depth = Image.fromarray(formatted)
|
||
|
```"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
|
||
|
outputs = self.glpn(
|
||
|
pixel_values,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=True, # we need the intermediate hidden states
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
||
|
|
||
|
out = self.decoder(hidden_states)
|
||
|
predicted_depth = self.head(out)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
loss_fct = SiLogLoss()
|
||
|
loss = loss_fct(predicted_depth, labels)
|
||
|
|
||
|
if not return_dict:
|
||
|
if output_hidden_states:
|
||
|
output = (predicted_depth,) + outputs[1:]
|
||
|
else:
|
||
|
output = (predicted_depth,) + outputs[2:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return DepthEstimatorOutput(
|
||
|
loss=loss,
|
||
|
predicted_depth=predicted_depth,
|
||
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|