1686 lines
74 KiB
Python
1686 lines
74 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 Google AI and The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch OWL-ViT model."""
|
||
|
|
||
|
import warnings
|
||
|
from dataclasses import dataclass
|
||
|
from functools import lru_cache
|
||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import Tensor, nn
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...utils import (
|
||
|
ModelOutput,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
is_vision_available,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
|
||
|
|
||
|
|
||
|
if is_vision_available():
|
||
|
from transformers.image_transforms import center_to_corners_format
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
_CHECKPOINT_FOR_DOC = "google/owlvit-base-patch32"
|
||
|
|
||
|
# See all OwlViT models at https://huggingface.co/models?filter=owlvit
|
||
|
|
||
|
from ..deprecated._archive_maps import OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit
|
||
|
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
||
|
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
|
||
|
def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||
|
caption_loss = contrastive_loss(similarity)
|
||
|
image_loss = contrastive_loss(similarity.t())
|
||
|
return (caption_loss + image_loss) / 2.0
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OwlViTOutput(ModelOutput):
|
||
|
"""
|
||
|
Args:
|
||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
||
|
Contrastive loss for image-text similarity.
|
||
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
||
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
||
|
similarity scores.
|
||
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
||
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
||
|
similarity scores.
|
||
|
text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
|
||
|
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
|
||
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||
|
The image embeddings obtained by applying the projection layer to the pooled output of
|
||
|
[`OwlViTVisionModel`].
|
||
|
text_model_output (Tuple[`BaseModelOutputWithPooling`]):
|
||
|
The output of the [`OwlViTTextModel`].
|
||
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||
|
The output of the [`OwlViTVisionModel`].
|
||
|
"""
|
||
|
|
||
|
loss: Optional[torch.FloatTensor] = None
|
||
|
logits_per_image: torch.FloatTensor = None
|
||
|
logits_per_text: torch.FloatTensor = None
|
||
|
text_embeds: torch.FloatTensor = None
|
||
|
image_embeds: torch.FloatTensor = None
|
||
|
text_model_output: BaseModelOutputWithPooling = None
|
||
|
vision_model_output: BaseModelOutputWithPooling = None
|
||
|
|
||
|
def to_tuple(self) -> Tuple[Any]:
|
||
|
return tuple(
|
||
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||
|
for k in self.keys()
|
||
|
)
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.detr.modeling_detr._upcast
|
||
|
def _upcast(t: Tensor) -> Tensor:
|
||
|
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
||
|
if t.is_floating_point():
|
||
|
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
||
|
else:
|
||
|
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.detr.modeling_detr.box_area
|
||
|
def box_area(boxes: Tensor) -> Tensor:
|
||
|
"""
|
||
|
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||
|
|
||
|
Args:
|
||
|
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||
|
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||
|
< x2` and `0 <= y1 < y2`.
|
||
|
|
||
|
Returns:
|
||
|
`torch.FloatTensor`: a tensor containing the area for each box.
|
||
|
"""
|
||
|
boxes = _upcast(boxes)
|
||
|
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.detr.modeling_detr.box_iou
|
||
|
def box_iou(boxes1, boxes2):
|
||
|
area1 = box_area(boxes1)
|
||
|
area2 = box_area(boxes2)
|
||
|
|
||
|
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||
|
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||
|
|
||
|
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||
|
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||
|
|
||
|
union = area1[:, None] + area2 - inter
|
||
|
|
||
|
iou = inter / union
|
||
|
return iou, union
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
||
|
def generalized_box_iou(boxes1, boxes2):
|
||
|
"""
|
||
|
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||
|
|
||
|
Returns:
|
||
|
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||
|
"""
|
||
|
# degenerate boxes gives inf / nan results
|
||
|
# so do an early check
|
||
|
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
||
|
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
||
|
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
||
|
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
||
|
iou, union = box_iou(boxes1, boxes2)
|
||
|
|
||
|
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||
|
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||
|
|
||
|
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
||
|
area = width_height[:, :, 0] * width_height[:, :, 1]
|
||
|
|
||
|
return iou - (area - union) / area
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OwlViTObjectDetectionOutput(ModelOutput):
|
||
|
"""
|
||
|
Output type of [`OwlViTForObjectDetection`].
|
||
|
|
||
|
Args:
|
||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
|
||
|
Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
|
||
|
bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
|
||
|
scale-invariant IoU loss.
|
||
|
loss_dict (`Dict`, *optional*):
|
||
|
A dictionary containing the individual losses. Useful for logging.
|
||
|
logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
|
||
|
Classification logits (including no-object) for all queries.
|
||
|
pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||
|
values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
|
||
|
possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to retrieve the
|
||
|
unnormalized bounding boxes.
|
||
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
|
||
|
The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`].
|
||
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
||
|
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
|
||
|
image embeddings for each patch.
|
||
|
class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
|
||
|
Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
|
||
|
number of patches is (image_size / patch_size)**2.
|
||
|
text_model_output (Tuple[`BaseModelOutputWithPooling`]):
|
||
|
The output of the [`OwlViTTextModel`].
|
||
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||
|
The output of the [`OwlViTVisionModel`].
|
||
|
"""
|
||
|
|
||
|
loss: Optional[torch.FloatTensor] = None
|
||
|
loss_dict: Optional[Dict] = None
|
||
|
logits: torch.FloatTensor = None
|
||
|
pred_boxes: torch.FloatTensor = None
|
||
|
text_embeds: torch.FloatTensor = None
|
||
|
image_embeds: torch.FloatTensor = None
|
||
|
class_embeds: torch.FloatTensor = None
|
||
|
text_model_output: BaseModelOutputWithPooling = None
|
||
|
vision_model_output: BaseModelOutputWithPooling = None
|
||
|
|
||
|
def to_tuple(self) -> Tuple[Any]:
|
||
|
return tuple(
|
||
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||
|
for k in self.keys()
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class OwlViTImageGuidedObjectDetectionOutput(ModelOutput):
|
||
|
"""
|
||
|
Output type of [`OwlViTForObjectDetection.image_guided_detection`].
|
||
|
|
||
|
Args:
|
||
|
logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
|
||
|
Classification logits (including no-object) for all queries.
|
||
|
target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||
|
values are normalized in [0, 1], relative to the size of each individual target image in the batch
|
||
|
(disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to
|
||
|
retrieve the unnormalized bounding boxes.
|
||
|
query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
|
||
|
Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
|
||
|
values are normalized in [0, 1], relative to the size of each individual query image in the batch
|
||
|
(disregarding possible padding). You can use [`~OwlViTImageProcessor.post_process_object_detection`] to
|
||
|
retrieve the unnormalized bounding boxes.
|
||
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
||
|
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
|
||
|
image embeddings for each patch.
|
||
|
query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
|
||
|
Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes
|
||
|
image embeddings for each patch.
|
||
|
class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
|
||
|
Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total
|
||
|
number of patches is (image_size / patch_size)**2.
|
||
|
text_model_output (Tuple[`BaseModelOutputWithPooling`]):
|
||
|
The output of the [`OwlViTTextModel`].
|
||
|
vision_model_output (`BaseModelOutputWithPooling`):
|
||
|
The output of the [`OwlViTVisionModel`].
|
||
|
"""
|
||
|
|
||
|
logits: torch.FloatTensor = None
|
||
|
image_embeds: torch.FloatTensor = None
|
||
|
query_image_embeds: torch.FloatTensor = None
|
||
|
target_pred_boxes: torch.FloatTensor = None
|
||
|
query_pred_boxes: torch.FloatTensor = None
|
||
|
class_embeds: torch.FloatTensor = None
|
||
|
text_model_output: BaseModelOutputWithPooling = None
|
||
|
vision_model_output: BaseModelOutputWithPooling = None
|
||
|
|
||
|
def to_tuple(self) -> Tuple[Any]:
|
||
|
return tuple(
|
||
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
||
|
for k in self.keys()
|
||
|
)
|
||
|
|
||
|
|
||
|
class OwlViTVisionEmbeddings(nn.Module):
|
||
|
def __init__(self, config: OwlViTVisionConfig):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.embed_dim = config.hidden_size
|
||
|
self.class_embedding = nn.Parameter(torch.randn(config.hidden_size))
|
||
|
|
||
|
self.patch_embedding = nn.Conv2d(
|
||
|
in_channels=config.num_channels,
|
||
|
out_channels=self.embed_dim,
|
||
|
kernel_size=config.patch_size,
|
||
|
stride=config.patch_size,
|
||
|
bias=False,
|
||
|
)
|
||
|
|
||
|
self.num_patches = (config.image_size // config.patch_size) ** 2
|
||
|
self.num_positions = self.num_patches + 1
|
||
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||
|
|
||
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||
|
batch_size = pixel_values.shape[0]
|
||
|
patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width]
|
||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||
|
|
||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||
|
|
||
|
return embeddings
|
||
|
|
||
|
|
||
|
class OwlViTTextEmbeddings(nn.Module):
|
||
|
def __init__(self, config: OwlViTTextConfig):
|
||
|
super().__init__()
|
||
|
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||
|
self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||
|
|
||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||
|
self.register_buffer(
|
||
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.LongTensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
) -> torch.Tensor:
|
||
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||
|
|
||
|
if position_ids is None:
|
||
|
position_ids = self.position_ids[:, :seq_length]
|
||
|
|
||
|
if inputs_embeds is None:
|
||
|
inputs_embeds = self.token_embedding(input_ids)
|
||
|
|
||
|
position_embeddings = self.position_embedding(position_ids)
|
||
|
embeddings = inputs_embeds + position_embeddings
|
||
|
|
||
|
return embeddings
|
||
|
|
||
|
|
||
|
class OwlViTAttention(nn.Module):
|
||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.embed_dim = config.hidden_size
|
||
|
self.num_heads = config.num_attention_heads
|
||
|
self.head_dim = self.embed_dim // self.num_heads
|
||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||
|
raise ValueError(
|
||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||
|
f" {self.num_heads})."
|
||
|
)
|
||
|
self.scale = self.head_dim**-0.5
|
||
|
self.dropout = config.attention_dropout
|
||
|
|
||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||
|
|
||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
|
"""Input shape: Batch x Time x Channel"""
|
||
|
|
||
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
||
|
|
||
|
# get query proj
|
||
|
query_states = self.q_proj(hidden_states) * self.scale
|
||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||
|
|
||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||
|
key_states = key_states.view(*proj_shape)
|
||
|
value_states = value_states.view(*proj_shape)
|
||
|
|
||
|
src_len = key_states.size(1)
|
||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||
|
|
||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||
|
raise ValueError(
|
||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||
|
f" {attn_weights.size()}"
|
||
|
)
|
||
|
|
||
|
# apply the causal_attention_mask first
|
||
|
if causal_attention_mask is not None:
|
||
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||
|
raise ValueError(
|
||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||
|
f" {causal_attention_mask.size()}"
|
||
|
)
|
||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||
|
raise ValueError(
|
||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||
|
)
|
||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||
|
|
||
|
if output_attentions:
|
||
|
# this operation is a bit akward, but it's required to
|
||
|
# make sure that attn_weights keeps its gradient.
|
||
|
# In order to do so, attn_weights have to reshaped
|
||
|
# twice and have to be reused in the following
|
||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
else:
|
||
|
attn_weights_reshaped = None
|
||
|
|
||
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||
|
|
||
|
# For int8 compatibility, sometimes the `attn_probs` are in `fp32`
|
||
|
attn_probs = attn_probs.to(value_states.dtype)
|
||
|
|
||
|
attn_output = torch.bmm(attn_probs, value_states)
|
||
|
|
||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||
|
raise ValueError(
|
||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||
|
f" {attn_output.size()}"
|
||
|
)
|
||
|
|
||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||
|
attn_output = attn_output.transpose(1, 2)
|
||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||
|
|
||
|
attn_output = self.out_proj(attn_output)
|
||
|
|
||
|
return attn_output, attn_weights_reshaped
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT
|
||
|
class OwlViTMLP(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.fc1(hidden_states)
|
||
|
hidden_states = self.activation_fn(hidden_states)
|
||
|
hidden_states = self.fc2(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT
|
||
|
class OwlViTEncoderLayer(nn.Module):
|
||
|
def __init__(self, config: OwlViTConfig):
|
||
|
super().__init__()
|
||
|
self.embed_dim = config.hidden_size
|
||
|
self.self_attn = OwlViTAttention(config)
|
||
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||
|
self.mlp = OwlViTMLP(config)
|
||
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
causal_attention_mask: torch.Tensor,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.FloatTensor]:
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||
|
`(config.encoder_attention_heads,)`.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
"""
|
||
|
residual = hidden_states
|
||
|
|
||
|
hidden_states = self.layer_norm1(hidden_states)
|
||
|
hidden_states, attn_weights = self.self_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
causal_attention_mask=causal_attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
hidden_states = residual + hidden_states
|
||
|
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.layer_norm2(hidden_states)
|
||
|
hidden_states = self.mlp(hidden_states)
|
||
|
hidden_states = residual + hidden_states
|
||
|
|
||
|
outputs = (hidden_states,)
|
||
|
|
||
|
if output_attentions:
|
||
|
outputs += (attn_weights,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class OwlViTPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = OwlViTConfig
|
||
|
base_model_prefix = "owlvit"
|
||
|
supports_gradient_checkpointing = True
|
||
|
_no_split_modules = ["OwlViTEncoderLayer"]
|
||
|
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
factor = self.config.initializer_factor
|
||
|
if isinstance(module, OwlViTTextEmbeddings):
|
||
|
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||
|
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||
|
elif isinstance(module, OwlViTVisionEmbeddings):
|
||
|
factor = self.config.initializer_factor
|
||
|
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||
|
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
||
|
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
||
|
elif isinstance(module, OwlViTAttention):
|
||
|
factor = self.config.initializer_factor
|
||
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||
|
out_proj_std = (module.embed_dim**-0.5) * factor
|
||
|
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
||
|
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
||
|
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
||
|
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
||
|
elif isinstance(module, OwlViTMLP):
|
||
|
factor = self.config.initializer_factor
|
||
|
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||
|
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
||
|
nn.init.normal_(module.fc1.weight, std=fc_std)
|
||
|
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||
|
elif isinstance(module, OwlViTModel):
|
||
|
nn.init.normal_(
|
||
|
module.text_projection.weight,
|
||
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
||
|
)
|
||
|
nn.init.normal_(
|
||
|
module.visual_projection.weight,
|
||
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
||
|
)
|
||
|
if isinstance(module, nn.LayerNorm):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
|
||
|
|
||
|
OWLVIT_START_DOCSTRING = r"""
|
||
|
|
||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||
|
etc.)
|
||
|
|
||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||
|
and behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`OwlViTConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
OWLVIT_TEXT_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
|
||
|
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
|
||
|
IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
OWLVIT_VISION_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
OWLVIT_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
|
||
|
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
|
||
|
IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values.
|
||
|
return_loss (`bool`, *optional*):
|
||
|
Whether or not to return the contrastive loss.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values.
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*):
|
||
|
Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
|
||
|
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
|
||
|
IDs?](../glossary#input-ids).
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the last hidden state. See `text_model_last_hidden_state` and
|
||
|
`vision_model_last_hidden_state` under returned tensors for more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
OWLVIT_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values.
|
||
|
query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||
|
Pixel values of query image(s) to be detected. Pass in one query image per target image.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class OwlViTEncoder(nn.Module):
|
||
|
"""
|
||
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||
|
[`OwlViTEncoderLayer`].
|
||
|
|
||
|
Args:
|
||
|
config: OwlViTConfig
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: OwlViTConfig):
|
||
|
super().__init__()
|
||
|
self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||
|
self.gradient_checkpointing = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
inputs_embeds,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutput]:
|
||
|
r"""
|
||
|
Args:
|
||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`).
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||
|
for more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
encoder_states = () if output_hidden_states else None
|
||
|
all_attentions = () if output_attentions else None
|
||
|
|
||
|
hidden_states = inputs_embeds
|
||
|
for encoder_layer in self.layers:
|
||
|
if output_hidden_states:
|
||
|
encoder_states = encoder_states + (hidden_states,)
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
encoder_layer.__call__,
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
causal_attention_mask,
|
||
|
output_attentions,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = encoder_layer(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
causal_attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
|
||
|
if output_attentions:
|
||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
encoder_states = encoder_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||
|
)
|
||
|
|
||
|
|
||
|
class OwlViTTextTransformer(nn.Module):
|
||
|
def __init__(self, config: OwlViTTextConfig):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
embed_dim = config.hidden_size
|
||
|
self.embeddings = OwlViTTextEmbeddings(config)
|
||
|
self.encoder = OwlViTEncoder(config)
|
||
|
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||
|
r"""
|
||
|
Returns:
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
input_shape = input_ids.size()
|
||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
||
|
|
||
|
# num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries
|
||
|
# OWLVIT's text model uses causal mask, prepare it here.
|
||
|
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||
|
causal_attention_mask = _create_4d_causal_attention_mask(
|
||
|
input_shape, hidden_states.dtype, device=hidden_states.device
|
||
|
)
|
||
|
# expand attention_mask
|
||
|
if attention_mask is not None:
|
||
|
# [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
|
||
|
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||
|
|
||
|
encoder_outputs = self.encoder(
|
||
|
inputs_embeds=hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
causal_attention_mask=causal_attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
last_hidden_state = encoder_outputs[0]
|
||
|
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||
|
|
||
|
# take features from the end of tokens embedding (end of token is the highest number in each sequence)
|
||
|
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||
|
pooled_output = last_hidden_state[
|
||
|
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||
|
input_ids.to(torch.int).argmax(dim=-1).to(last_hidden_state.device),
|
||
|
]
|
||
|
|
||
|
if not return_dict:
|
||
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||
|
|
||
|
return BaseModelOutputWithPooling(
|
||
|
last_hidden_state=last_hidden_state,
|
||
|
pooler_output=pooled_output,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class OwlViTTextModel(OwlViTPreTrainedModel):
|
||
|
config_class = OwlViTTextConfig
|
||
|
|
||
|
def __init__(self, config: OwlViTTextConfig):
|
||
|
super().__init__(config)
|
||
|
self.text_model = OwlViTTextTransformer(config)
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self) -> nn.Module:
|
||
|
return self.text_model.embeddings.token_embedding
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
self.text_model.embeddings.token_embedding = value
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||
|
r"""
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from transformers import AutoProcessor, OwlViTTextModel
|
||
|
|
||
|
>>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> inputs = processor(
|
||
|
... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
|
||
|
... )
|
||
|
>>> outputs = model(**inputs)
|
||
|
>>> last_hidden_state = outputs.last_hidden_state
|
||
|
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||
|
```"""
|
||
|
|
||
|
# Get embeddings for all text queries in all batch samples
|
||
|
return self.text_model(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
|
||
|
class OwlViTVisionTransformer(nn.Module):
|
||
|
def __init__(self, config: OwlViTVisionConfig):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
self.embeddings = OwlViTVisionEmbeddings(config)
|
||
|
self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.encoder = OwlViTEncoder(config)
|
||
|
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.FloatTensor,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||
|
r"""
|
||
|
Returns:
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
# Cast the input to the expected `dtype`
|
||
|
expected_input_dtype = self.embeddings.patch_embedding.weight.dtype
|
||
|
pixel_values = pixel_values.to(expected_input_dtype)
|
||
|
|
||
|
hidden_states = self.embeddings(pixel_values)
|
||
|
hidden_states = self.pre_layernorm(hidden_states)
|
||
|
|
||
|
encoder_outputs = self.encoder(
|
||
|
inputs_embeds=hidden_states,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
last_hidden_state = encoder_outputs[0]
|
||
|
pooled_output = last_hidden_state[:, 0, :]
|
||
|
|
||
|
pooled_output = self.post_layernorm(pooled_output)
|
||
|
|
||
|
if not return_dict:
|
||
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||
|
|
||
|
return BaseModelOutputWithPooling(
|
||
|
last_hidden_state=last_hidden_state,
|
||
|
pooler_output=pooled_output,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class OwlViTVisionModel(OwlViTPreTrainedModel):
|
||
|
config_class = OwlViTVisionConfig
|
||
|
main_input_name = "pixel_values"
|
||
|
|
||
|
def __init__(self, config: OwlViTVisionConfig):
|
||
|
super().__init__(config)
|
||
|
self.vision_model = OwlViTVisionTransformer(config)
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self) -> nn.Module:
|
||
|
return self.vision_model.embeddings.patch_embedding
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig)
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||
|
r"""
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
>>> from transformers import AutoProcessor, OwlViTVisionModel
|
||
|
|
||
|
>>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
|
||
|
>>> inputs = processor(images=image, return_tensors="pt")
|
||
|
|
||
|
>>> outputs = model(**inputs)
|
||
|
>>> last_hidden_state = outputs.last_hidden_state
|
||
|
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||
|
```"""
|
||
|
return self.vision_model(
|
||
|
pixel_values=pixel_values,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(OWLVIT_START_DOCSTRING)
|
||
|
class OwlViTModel(OwlViTPreTrainedModel):
|
||
|
config_class = OwlViTConfig
|
||
|
|
||
|
def __init__(self, config: OwlViTConfig):
|
||
|
super().__init__(config)
|
||
|
|
||
|
if not isinstance(config.text_config, OwlViTTextConfig):
|
||
|
raise ValueError(
|
||
|
"config.text_config is expected to be of type OwlViTTextConfig but is of type"
|
||
|
f" {type(config.text_config)}."
|
||
|
)
|
||
|
|
||
|
if not isinstance(config.vision_config, OwlViTVisionConfig):
|
||
|
raise ValueError(
|
||
|
"config.vision_config is expected to be of type OwlViTVisionConfig but is of type"
|
||
|
f" {type(config.vision_config)}."
|
||
|
)
|
||
|
|
||
|
text_config = config.text_config
|
||
|
vision_config = config.vision_config
|
||
|
|
||
|
self.projection_dim = config.projection_dim
|
||
|
self.text_embed_dim = text_config.hidden_size
|
||
|
self.vision_embed_dim = vision_config.hidden_size
|
||
|
|
||
|
self.text_model = OwlViTTextTransformer(text_config)
|
||
|
self.vision_model = OwlViTVisionTransformer(vision_config)
|
||
|
|
||
|
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
||
|
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||
|
self.logit_scale = nn.Parameter(torch.tensor(config.logit_scale_init_value))
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING)
|
||
|
def get_text_features(
|
||
|
self,
|
||
|
input_ids: Optional[torch.Tensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> torch.FloatTensor:
|
||
|
r"""
|
||
|
Returns:
|
||
|
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
||
|
applying the projection layer to the pooled output of [`OwlViTTextModel`].
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from transformers import AutoProcessor, OwlViTModel
|
||
|
|
||
|
>>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> inputs = processor(
|
||
|
... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
|
||
|
... )
|
||
|
>>> text_features = model.get_text_features(**inputs)
|
||
|
```"""
|
||
|
# Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components.
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
# Get embeddings for all text queries in all batch samples
|
||
|
text_output = self.text_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=return_dict)
|
||
|
pooled_output = text_output[1]
|
||
|
text_features = self.text_projection(pooled_output)
|
||
|
|
||
|
return text_features
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING)
|
||
|
def get_image_features(
|
||
|
self,
|
||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> torch.FloatTensor:
|
||
|
r"""
|
||
|
Returns:
|
||
|
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
||
|
applying the projection layer to the pooled output of [`OwlViTVisionModel`].
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
>>> from transformers import AutoProcessor, OwlViTModel
|
||
|
|
||
|
>>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
>>> inputs = processor(images=image, return_tensors="pt")
|
||
|
>>> image_features = model.get_image_features(**inputs)
|
||
|
```"""
|
||
|
# Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components.
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
vision_outputs = self.vision_model(
|
||
|
pixel_values=pixel_values,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
pooled_output = vision_outputs[1]
|
||
|
image_features = self.visual_projection(pooled_output)
|
||
|
|
||
|
return image_features
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=OwlViTOutput, config_class=OwlViTConfig)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.LongTensor] = None,
|
||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
return_loss: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_base_image_embeds: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, OwlViTOutput]:
|
||
|
r"""
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> from PIL import Image
|
||
|
>>> import requests
|
||
|
>>> from transformers import AutoProcessor, OwlViTModel
|
||
|
|
||
|
>>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
>>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
|
||
|
>>> outputs = model(**inputs)
|
||
|
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||
|
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||
|
```"""
|
||
|
# Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components.
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
vision_outputs = self.vision_model(
|
||
|
pixel_values=pixel_values,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
# Get embeddings for all text queries in all batch samples
|
||
|
text_outputs = self.text_model(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
text_embeds = text_outputs[1]
|
||
|
text_embeds = self.text_projection(text_embeds)
|
||
|
image_embeds = vision_outputs[1]
|
||
|
image_embeds = self.visual_projection(image_embeds)
|
||
|
|
||
|
# normalized features
|
||
|
image_embeds = image_embeds / torch.linalg.norm(image_embeds, ord=2, dim=-1, keepdim=True)
|
||
|
text_embeds_norm = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
||
|
|
||
|
# cosine similarity as logits and set it on the correct device
|
||
|
logit_scale = self.logit_scale.exp().to(image_embeds.device)
|
||
|
|
||
|
logits_per_text = torch.matmul(text_embeds_norm, image_embeds.t()) * logit_scale
|
||
|
logits_per_image = logits_per_text.t()
|
||
|
|
||
|
loss = None
|
||
|
if return_loss:
|
||
|
loss = owlvit_loss(logits_per_text)
|
||
|
|
||
|
if return_base_image_embeds:
|
||
|
warnings.warn(
|
||
|
"`return_base_image_embeds` is deprecated and will be removed in v4.27 of Transformers, one can"
|
||
|
" obtain the base (unprojected) image embeddings from outputs.vision_model_output.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
last_hidden_state = vision_outputs[0]
|
||
|
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
|
||
|
else:
|
||
|
text_embeds = text_embeds_norm
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return OwlViTOutput(
|
||
|
loss=loss,
|
||
|
logits_per_image=logits_per_image,
|
||
|
logits_per_text=logits_per_text,
|
||
|
text_embeds=text_embeds,
|
||
|
image_embeds=image_embeds,
|
||
|
text_model_output=text_outputs,
|
||
|
vision_model_output=vision_outputs,
|
||
|
)
|
||
|
|
||
|
|
||
|
class OwlViTBoxPredictionHead(nn.Module):
|
||
|
def __init__(self, config: OwlViTConfig, out_dim: int = 4):
|
||
|
super().__init__()
|
||
|
|
||
|
width = config.vision_config.hidden_size
|
||
|
self.dense0 = nn.Linear(width, width)
|
||
|
self.dense1 = nn.Linear(width, width)
|
||
|
self.gelu = nn.GELU()
|
||
|
self.dense2 = nn.Linear(width, out_dim)
|
||
|
|
||
|
def forward(self, image_features: torch.Tensor) -> torch.FloatTensor:
|
||
|
output = self.dense0(image_features)
|
||
|
output = self.gelu(output)
|
||
|
output = self.dense1(output)
|
||
|
output = self.gelu(output)
|
||
|
output = self.dense2(output)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class OwlViTClassPredictionHead(nn.Module):
|
||
|
def __init__(self, config: OwlViTConfig):
|
||
|
super().__init__()
|
||
|
|
||
|
out_dim = config.text_config.hidden_size
|
||
|
self.query_dim = config.vision_config.hidden_size
|
||
|
|
||
|
self.dense0 = nn.Linear(self.query_dim, out_dim)
|
||
|
self.logit_shift = nn.Linear(self.query_dim, 1)
|
||
|
self.logit_scale = nn.Linear(self.query_dim, 1)
|
||
|
self.elu = nn.ELU()
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
image_embeds: torch.FloatTensor,
|
||
|
query_embeds: Optional[torch.FloatTensor],
|
||
|
query_mask: Optional[torch.Tensor],
|
||
|
) -> Tuple[torch.FloatTensor]:
|
||
|
image_class_embeds = self.dense0(image_embeds)
|
||
|
if query_embeds is None:
|
||
|
device = image_class_embeds.device
|
||
|
batch_size, num_patches = image_class_embeds.shape[:2]
|
||
|
pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
|
||
|
return (pred_logits, image_class_embeds)
|
||
|
|
||
|
# Normalize image and text features
|
||
|
image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
|
||
|
query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
|
||
|
|
||
|
# Get class predictions
|
||
|
pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
|
||
|
|
||
|
# Apply a learnable shift and scale to logits
|
||
|
logit_shift = self.logit_shift(image_embeds)
|
||
|
logit_scale = self.logit_scale(image_embeds)
|
||
|
logit_scale = self.elu(logit_scale) + 1
|
||
|
pred_logits = (pred_logits + logit_shift) * logit_scale
|
||
|
|
||
|
if query_mask is not None:
|
||
|
if query_mask.ndim > 1:
|
||
|
query_mask = torch.unsqueeze(query_mask, dim=-2)
|
||
|
|
||
|
pred_logits = pred_logits.to(torch.float64)
|
||
|
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
|
||
|
pred_logits = pred_logits.to(torch.float32)
|
||
|
|
||
|
return (pred_logits, image_class_embeds)
|
||
|
|
||
|
|
||
|
class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
||
|
config_class = OwlViTConfig
|
||
|
|
||
|
def __init__(self, config: OwlViTConfig):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.owlvit = OwlViTModel(config)
|
||
|
self.class_head = OwlViTClassPredictionHead(config)
|
||
|
self.box_head = OwlViTBoxPredictionHead(config)
|
||
|
|
||
|
self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
||
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
||
|
self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
|
||
|
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)
|
||
|
|
||
|
@staticmethod
|
||
|
def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
|
||
|
# Create grid coordinates using torch
|
||
|
x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
|
||
|
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
|
||
|
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")
|
||
|
|
||
|
# Stack the coordinates and divide by num_patches
|
||
|
box_coordinates = torch.stack((xx, yy), dim=-1)
|
||
|
box_coordinates /= num_patches
|
||
|
|
||
|
# Flatten (h, w, 2) -> (h*w, 2)
|
||
|
box_coordinates = box_coordinates.view(-1, 2)
|
||
|
|
||
|
return box_coordinates
|
||
|
|
||
|
@lru_cache(maxsize=2)
|
||
|
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
|
||
|
if feature_map is not None:
|
||
|
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
|
||
|
# The box center is biased to its position on the feature grid
|
||
|
box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
|
||
|
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
|
||
|
|
||
|
# Unnormalize xy
|
||
|
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
|
||
|
|
||
|
# The box size is biased to the patch size
|
||
|
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
|
||
|
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
|
||
|
|
||
|
# Compute box bias
|
||
|
box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
|
||
|
return box_bias
|
||
|
|
||
|
def box_predictor(
|
||
|
self,
|
||
|
image_feats: torch.FloatTensor,
|
||
|
feature_map: torch.FloatTensor,
|
||
|
) -> torch.FloatTensor:
|
||
|
"""
|
||
|
Args:
|
||
|
image_feats:
|
||
|
Features extracted from the image, returned by the `image_text_embedder` method.
|
||
|
feature_map:
|
||
|
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
|
||
|
Returns:
|
||
|
pred_boxes:
|
||
|
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
|
||
|
"""
|
||
|
# Bounding box detection head [batch_size, num_boxes, 4].
|
||
|
pred_boxes = self.box_head(image_feats)
|
||
|
|
||
|
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
|
||
|
box_bias = self.box_bias.to(feature_map.device)
|
||
|
pred_boxes += box_bias
|
||
|
pred_boxes = self.sigmoid(pred_boxes)
|
||
|
return pred_boxes
|
||
|
|
||
|
def class_predictor(
|
||
|
self,
|
||
|
image_feats: torch.FloatTensor,
|
||
|
query_embeds: Optional[torch.FloatTensor] = None,
|
||
|
query_mask: Optional[torch.Tensor] = None,
|
||
|
) -> Tuple[torch.FloatTensor]:
|
||
|
"""
|
||
|
Args:
|
||
|
image_feats:
|
||
|
Features extracted from the `image_text_embedder`.
|
||
|
query_embeds:
|
||
|
Text query embeddings.
|
||
|
query_mask:
|
||
|
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
|
||
|
"""
|
||
|
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
|
||
|
|
||
|
return (pred_logits, image_class_embeds)
|
||
|
|
||
|
def image_text_embedder(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
pixel_values: torch.FloatTensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
) -> Tuple[torch.FloatTensor]:
|
||
|
# Encode text and image
|
||
|
outputs = self.owlvit(
|
||
|
pixel_values=pixel_values,
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=True,
|
||
|
)
|
||
|
|
||
|
# Get image embeddings
|
||
|
last_hidden_state = outputs.vision_model_output[0]
|
||
|
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
|
||
|
|
||
|
# Resize class token
|
||
|
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
|
||
|
|
||
|
# Merge image embedding with class tokens
|
||
|
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||
|
image_embeds = self.layer_norm(image_embeds)
|
||
|
|
||
|
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||
|
new_size = (
|
||
|
image_embeds.shape[0],
|
||
|
self.sqrt_num_patches,
|
||
|
self.sqrt_num_patches,
|
||
|
image_embeds.shape[-1],
|
||
|
)
|
||
|
image_embeds = image_embeds.reshape(new_size)
|
||
|
text_embeds = outputs[-4]
|
||
|
|
||
|
return (text_embeds, image_embeds, outputs)
|
||
|
|
||
|
def image_embedder(
|
||
|
self,
|
||
|
pixel_values: torch.FloatTensor,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
) -> Tuple[torch.FloatTensor]:
|
||
|
# Get OwlViTModel vision embeddings (same as CLIP)
|
||
|
vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values, return_dict=True)
|
||
|
|
||
|
# Apply post_layernorm to last_hidden_state, return non-projected output
|
||
|
last_hidden_state = vision_outputs[0]
|
||
|
image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)
|
||
|
|
||
|
# Resize class token
|
||
|
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], image_embeds[:, :-1].shape)
|
||
|
|
||
|
# Merge image embedding with class tokens
|
||
|
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
||
|
image_embeds = self.layer_norm(image_embeds)
|
||
|
|
||
|
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
||
|
new_size = (
|
||
|
image_embeds.shape[0],
|
||
|
self.sqrt_num_patches,
|
||
|
self.sqrt_num_patches,
|
||
|
image_embeds.shape[-1],
|
||
|
)
|
||
|
image_embeds = image_embeds.reshape(new_size)
|
||
|
|
||
|
return (image_embeds, vision_outputs)
|
||
|
|
||
|
def embed_image_query(
|
||
|
self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor
|
||
|
) -> torch.FloatTensor:
|
||
|
_, class_embeds = self.class_predictor(query_image_features)
|
||
|
pred_boxes = self.box_predictor(query_image_features, query_feature_map)
|
||
|
pred_boxes_as_corners = center_to_corners_format(pred_boxes)
|
||
|
|
||
|
# Loop over query images
|
||
|
best_class_embeds = []
|
||
|
best_box_indices = []
|
||
|
pred_boxes_device = pred_boxes_as_corners.device
|
||
|
|
||
|
for i in range(query_image_features.shape[0]):
|
||
|
each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device)
|
||
|
each_query_pred_boxes = pred_boxes_as_corners[i]
|
||
|
ious, _ = box_iou(each_query_box, each_query_pred_boxes)
|
||
|
|
||
|
# If there are no overlapping boxes, fall back to generalized IoU
|
||
|
if torch.all(ious[0] == 0.0):
|
||
|
ious = generalized_box_iou(each_query_box, each_query_pred_boxes)
|
||
|
|
||
|
# Use an adaptive threshold to include all boxes within 80% of the best IoU
|
||
|
iou_threshold = torch.max(ious) * 0.8
|
||
|
|
||
|
selected_inds = (ious[0] >= iou_threshold).nonzero()
|
||
|
if selected_inds.numel():
|
||
|
selected_embeddings = class_embeds[i][selected_inds.squeeze(1)]
|
||
|
mean_embeds = torch.mean(class_embeds[i], axis=0)
|
||
|
mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
|
||
|
best_box_ind = selected_inds[torch.argmin(mean_sim)]
|
||
|
best_class_embeds.append(class_embeds[i][best_box_ind])
|
||
|
best_box_indices.append(best_box_ind)
|
||
|
|
||
|
if best_class_embeds:
|
||
|
query_embeds = torch.stack(best_class_embeds)
|
||
|
box_indices = torch.stack(best_box_indices)
|
||
|
else:
|
||
|
query_embeds, box_indices = None, None
|
||
|
|
||
|
return query_embeds, box_indices, pred_boxes
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_IMAGE_GUIDED_OBJECT_DETECTION_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=OwlViTImageGuidedObjectDetectionOutput, config_class=OwlViTConfig)
|
||
|
def image_guided_detection(
|
||
|
self,
|
||
|
pixel_values: torch.FloatTensor,
|
||
|
query_pixel_values: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> OwlViTImageGuidedObjectDetectionOutput:
|
||
|
r"""
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> import requests
|
||
|
>>> from PIL import Image
|
||
|
>>> import torch
|
||
|
>>> from transformers import AutoProcessor, OwlViTForObjectDetection
|
||
|
|
||
|
>>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch16")
|
||
|
>>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
>>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
|
||
|
>>> query_image = Image.open(requests.get(query_url, stream=True).raw)
|
||
|
>>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")
|
||
|
>>> with torch.no_grad():
|
||
|
... outputs = model.image_guided_detection(**inputs)
|
||
|
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
|
||
|
>>> target_sizes = torch.Tensor([image.size[::-1]])
|
||
|
>>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
|
||
|
>>> results = processor.post_process_image_guided_detection(
|
||
|
... outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes
|
||
|
... )
|
||
|
>>> i = 0 # Retrieve predictions for the first image
|
||
|
>>> boxes, scores = results[i]["boxes"], results[i]["scores"]
|
||
|
>>> for box, score in zip(boxes, scores):
|
||
|
... box = [round(i, 2) for i in box.tolist()]
|
||
|
... print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
|
||
|
Detected similar object with confidence 0.856 at location [10.94, 50.4, 315.8, 471.39]
|
||
|
Detected similar object with confidence 1.0 at location [334.84, 25.33, 636.16, 374.71]
|
||
|
```"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||
|
|
||
|
# Compute feature maps for the input and query images
|
||
|
query_feature_map = self.image_embedder(pixel_values=query_pixel_values)[0]
|
||
|
feature_map, vision_outputs = self.image_embedder(
|
||
|
pixel_values=pixel_values,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
)
|
||
|
|
||
|
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
||
|
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||
|
|
||
|
batch_size, num_patches, num_patches, hidden_dim = query_feature_map.shape
|
||
|
query_image_feats = torch.reshape(query_feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||
|
# Get top class embedding and best box index for each query image in batch
|
||
|
query_embeds, best_box_indices, query_pred_boxes = self.embed_image_query(query_image_feats, query_feature_map)
|
||
|
|
||
|
# Predict object classes [batch_size, num_patches, num_queries+1]
|
||
|
(pred_logits, class_embeds) = self.class_predictor(image_feats=image_feats, query_embeds=query_embeds)
|
||
|
|
||
|
# Predict object boxes
|
||
|
target_pred_boxes = self.box_predictor(image_feats, feature_map)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (
|
||
|
feature_map,
|
||
|
query_feature_map,
|
||
|
target_pred_boxes,
|
||
|
query_pred_boxes,
|
||
|
pred_logits,
|
||
|
class_embeds,
|
||
|
vision_outputs.to_tuple(),
|
||
|
)
|
||
|
output = tuple(x for x in output if x is not None)
|
||
|
return output
|
||
|
|
||
|
return OwlViTImageGuidedObjectDetectionOutput(
|
||
|
image_embeds=feature_map,
|
||
|
query_image_embeds=query_feature_map,
|
||
|
target_pred_boxes=target_pred_boxes,
|
||
|
query_pred_boxes=query_pred_boxes,
|
||
|
logits=pred_logits,
|
||
|
class_embeds=class_embeds,
|
||
|
text_model_output=None,
|
||
|
vision_model_output=vision_outputs,
|
||
|
)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
pixel_values: torch.FloatTensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> OwlViTObjectDetectionOutput:
|
||
|
r"""
|
||
|
Returns:
|
||
|
|
||
|
Examples:
|
||
|
```python
|
||
|
>>> import requests
|
||
|
>>> from PIL import Image
|
||
|
>>> import torch
|
||
|
>>> from transformers import AutoProcessor, OwlViTForObjectDetection
|
||
|
|
||
|
>>> processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
|
||
|
>>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
||
|
|
||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||
|
>>> texts = [["a photo of a cat", "a photo of a dog"]]
|
||
|
>>> inputs = processor(text=texts, images=image, return_tensors="pt")
|
||
|
>>> outputs = model(**inputs)
|
||
|
|
||
|
>>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
|
||
|
>>> target_sizes = torch.Tensor([image.size[::-1]])
|
||
|
>>> # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores
|
||
|
>>> results = processor.post_process_object_detection(
|
||
|
... outputs=outputs, threshold=0.1, target_sizes=target_sizes
|
||
|
... )
|
||
|
|
||
|
>>> i = 0 # Retrieve predictions for the first image for the corresponding text queries
|
||
|
>>> text = texts[i]
|
||
|
>>> boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
|
||
|
|
||
|
>>> for box, score, label in zip(boxes, scores, labels):
|
||
|
... box = [round(i, 2) for i in box.tolist()]
|
||
|
... print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
||
|
Detected a photo of a cat with confidence 0.707 at location [324.97, 20.44, 640.58, 373.29]
|
||
|
Detected a photo of a cat with confidence 0.717 at location [1.46, 55.26, 315.55, 472.17]
|
||
|
```"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||
|
|
||
|
# Embed images and text queries
|
||
|
query_embeds, feature_map, outputs = self.image_text_embedder(
|
||
|
input_ids=input_ids,
|
||
|
pixel_values=pixel_values,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
)
|
||
|
|
||
|
# Text and vision model outputs
|
||
|
text_outputs = outputs.text_model_output
|
||
|
vision_outputs = outputs.vision_model_output
|
||
|
|
||
|
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
||
|
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
||
|
|
||
|
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
|
||
|
max_text_queries = input_ids.shape[0] // batch_size
|
||
|
query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
|
||
|
|
||
|
# If first token is 0, then this is a padded query [batch_size, num_queries].
|
||
|
input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
|
||
|
query_mask = input_ids[..., 0] > 0
|
||
|
|
||
|
# Predict object classes [batch_size, num_patches, num_queries+1]
|
||
|
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
|
||
|
|
||
|
# Predict object boxes
|
||
|
pred_boxes = self.box_predictor(image_feats, feature_map)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (
|
||
|
pred_logits,
|
||
|
pred_boxes,
|
||
|
query_embeds,
|
||
|
feature_map,
|
||
|
class_embeds,
|
||
|
text_outputs.to_tuple(),
|
||
|
vision_outputs.to_tuple(),
|
||
|
)
|
||
|
output = tuple(x for x in output if x is not None)
|
||
|
return output
|
||
|
|
||
|
return OwlViTObjectDetectionOutput(
|
||
|
image_embeds=feature_map,
|
||
|
text_embeds=query_embeds,
|
||
|
pred_boxes=pred_boxes,
|
||
|
logits=pred_logits,
|
||
|
class_embeds=class_embeds,
|
||
|
text_model_output=text_outputs,
|
||
|
vision_model_output=vision_outputs,
|
||
|
)
|