1322 lines
49 KiB
Python
1322 lines
49 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import builtins
|
||
|
import collections
|
||
|
import functools
|
||
|
import inspect
|
||
|
import math
|
||
|
import operator
|
||
|
import os
|
||
|
import random
|
||
|
import warnings
|
||
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
||
|
from torch.fx._compatibility import compatibility
|
||
|
from torch.fx.proxy import ParameterProxy
|
||
|
|
||
|
from .. import PretrainedConfig, PreTrainedModel, logging
|
||
|
from ..models.auto import get_values
|
||
|
from ..models.auto.modeling_auto import (
|
||
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||
|
MODEL_FOR_BACKBONE_MAPPING_NAMES,
|
||
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||
|
MODEL_FOR_CTC_MAPPING_NAMES,
|
||
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
|
||
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||
|
MODEL_FOR_IMAGE_MAPPING_NAMES,
|
||
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
||
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
|
||
|
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||
|
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||
|
MODEL_MAPPING_NAMES,
|
||
|
)
|
||
|
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||
|
from ..utils import (
|
||
|
ENV_VARS_TRUE_VALUES,
|
||
|
TORCH_FX_REQUIRED_VERSION,
|
||
|
get_torch_version,
|
||
|
is_peft_available,
|
||
|
is_torch_fx_available,
|
||
|
)
|
||
|
|
||
|
|
||
|
if is_peft_available():
|
||
|
from peft import PeftModel
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
|
||
|
|
||
|
|
||
|
def _generate_supported_model_class_names(
|
||
|
model_name: Type[PretrainedConfig],
|
||
|
supported_tasks: Optional[Union[str, List[str]]] = None,
|
||
|
) -> List[str]:
|
||
|
task_mapping = {
|
||
|
"default": MODEL_MAPPING_NAMES,
|
||
|
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||
|
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
|
||
|
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||
|
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||
|
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
||
|
"speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
||
|
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
|
||
|
"document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
|
||
|
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
||
|
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
||
|
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||
|
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
|
||
|
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||
|
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||
|
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
|
||
|
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||
|
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
|
||
|
"backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
|
||
|
"image-feature-extraction": MODEL_FOR_IMAGE_MAPPING_NAMES,
|
||
|
}
|
||
|
|
||
|
if supported_tasks is None:
|
||
|
supported_tasks = task_mapping.keys()
|
||
|
if isinstance(supported_tasks, str):
|
||
|
supported_tasks = [supported_tasks]
|
||
|
|
||
|
model_class_names = []
|
||
|
for task in supported_tasks:
|
||
|
class_name = task_mapping[task].get(model_name, None)
|
||
|
if class_name:
|
||
|
model_class_names.append(class_name)
|
||
|
|
||
|
return model_class_names
|
||
|
|
||
|
|
||
|
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||
|
"altclip",
|
||
|
"albert",
|
||
|
"bart",
|
||
|
"bert",
|
||
|
"blenderbot",
|
||
|
"blenderbot-small",
|
||
|
"bloom",
|
||
|
"clip",
|
||
|
"convnext",
|
||
|
"deberta",
|
||
|
"deberta-v2",
|
||
|
"dinov2",
|
||
|
"distilbert",
|
||
|
"donut-swin",
|
||
|
"electra",
|
||
|
"gpt2",
|
||
|
"gpt_neo",
|
||
|
"gptj",
|
||
|
"hubert",
|
||
|
"layoutlm",
|
||
|
"llama",
|
||
|
"cohere",
|
||
|
"lxmert",
|
||
|
"m2m_100",
|
||
|
"marian",
|
||
|
"mbart",
|
||
|
"megatron-bert",
|
||
|
"mistral",
|
||
|
"mixtral",
|
||
|
"mobilebert",
|
||
|
"mt5",
|
||
|
"nezha",
|
||
|
"opt",
|
||
|
"pegasus",
|
||
|
"plbart",
|
||
|
"qwen2",
|
||
|
"qwen2_moe",
|
||
|
"resnet",
|
||
|
"roberta",
|
||
|
"segformer",
|
||
|
"speech_to_text",
|
||
|
"speech_to_text_2",
|
||
|
"swin",
|
||
|
"t5",
|
||
|
"trocr",
|
||
|
"vit",
|
||
|
"xglm",
|
||
|
"wav2vec2",
|
||
|
# "xlnet",
|
||
|
]
|
||
|
|
||
|
_FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]
|
||
|
|
||
|
_REGULAR_SUPPORTED_MODELS = []
|
||
|
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
|
||
|
if isinstance(item, dict):
|
||
|
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
|
||
|
else:
|
||
|
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
|
||
|
|
||
|
_SPECIAL_SUPPORTED_MODELS = [
|
||
|
"CLIPTextModel",
|
||
|
"CLIPTextModelWithProjection",
|
||
|
"CLIPVisionModel",
|
||
|
"CLIPVisionModelWithProjection",
|
||
|
"AltCLIPTextModel",
|
||
|
"AltCLIPVisionModel",
|
||
|
"GitVisionModel",
|
||
|
"GPT2DoubleHeadsModel",
|
||
|
"Speech2Text2Decoder",
|
||
|
"TrOCRDecoder",
|
||
|
"PeftModelForCausalLM",
|
||
|
"PeftModelForSeq2SeqLM",
|
||
|
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||
|
# XLNetForQuestionAnswering,
|
||
|
]
|
||
|
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
|
||
|
|
||
|
|
||
|
def torch_nn_embedding(self, input):
|
||
|
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
|
||
|
|
||
|
|
||
|
def torch_nn_functional_embedding(
|
||
|
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
|
||
|
):
|
||
|
return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)
|
||
|
|
||
|
|
||
|
def torch_nn_layernorm(self, input):
|
||
|
return input
|
||
|
|
||
|
|
||
|
def torch_nn_groupnorm(self, input):
|
||
|
return input
|
||
|
|
||
|
|
||
|
def torch_nn_linear(self, input):
|
||
|
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||
|
|
||
|
|
||
|
def torch_relu(x):
|
||
|
return x
|
||
|
|
||
|
|
||
|
def torch_nn_relu(self, x):
|
||
|
return x
|
||
|
|
||
|
|
||
|
def torch_nn_functional_relu(x, inplace=False):
|
||
|
if not inplace:
|
||
|
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
|
||
|
return x
|
||
|
|
||
|
|
||
|
def torch_where(condition, x, y):
|
||
|
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||
|
# so hack it by using addition
|
||
|
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||
|
|
||
|
|
||
|
def torch_abs(input, *, out=None):
|
||
|
if out is not None:
|
||
|
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||
|
return input
|
||
|
|
||
|
|
||
|
def torch_arange(*args, **kwargs):
|
||
|
n = len(args)
|
||
|
step = 1
|
||
|
if n == 1:
|
||
|
start = 0
|
||
|
end = args[0]
|
||
|
elif n == 2:
|
||
|
start, end = args
|
||
|
else:
|
||
|
start, end, step = args
|
||
|
if isinstance(start, float):
|
||
|
start = int(start)
|
||
|
if isinstance(end, float):
|
||
|
start = int(end)
|
||
|
if isinstance(step, float):
|
||
|
step = int(step)
|
||
|
step = kwargs.get("step", step)
|
||
|
dtype = kwargs.get("dtype")
|
||
|
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||
|
|
||
|
|
||
|
def torch_full(*args, **kwargs):
|
||
|
args = list(args)
|
||
|
# We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
|
||
|
if len(args) > 1:
|
||
|
args[1] = 1
|
||
|
else:
|
||
|
kwargs["fill_value"] = 1
|
||
|
kwargs_without_device = dict(kwargs)
|
||
|
kwargs_without_device.pop("device", None)
|
||
|
return torch.full(*args, **kwargs_without_device, device="meta")
|
||
|
|
||
|
|
||
|
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||
|
if dim is None and axis is None:
|
||
|
dim = 0
|
||
|
if dim is None and axis is not None:
|
||
|
dim = axis
|
||
|
if dim < 0:
|
||
|
dim = tensors[0].dim() + dim
|
||
|
shapes = [t.shape for t in tensors]
|
||
|
shape = list(shapes[0])
|
||
|
concatenated_dim = sum(shape[dim] for shape in shapes)
|
||
|
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
|
||
|
return torch.empty(final_shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_stack(tensors, dim=None, axis=None, *, out=None):
|
||
|
if dim is None and axis is None:
|
||
|
dim = 0
|
||
|
if dim is None and axis is not None:
|
||
|
dim = axis
|
||
|
if dim < 0:
|
||
|
dim = tensors[0].dim() + 1 + dim
|
||
|
shape = list(tensors[0].shape)
|
||
|
shape.insert(dim, len(tensors))
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_add(input, other, *, alpha=1, out=None):
|
||
|
if not isinstance(input, torch.Tensor):
|
||
|
return torch.empty_like(other, device="meta")
|
||
|
if not isinstance(other, torch.Tensor):
|
||
|
return torch.empty_like(input, device="meta")
|
||
|
max_length = max(input.dim(), other.dim())
|
||
|
input_shape = list(input.shape) + [1] * (max_length - input.dim())
|
||
|
other_shape = list(other.shape) + [1] * (max_length - other.dim())
|
||
|
shape = []
|
||
|
for i in range(max_length):
|
||
|
shape.append(max(input_shape[i], other_shape[i]))
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_mul(input, other, *, out=None):
|
||
|
return torch_add(input, other, out=out)
|
||
|
|
||
|
|
||
|
def torch_tensor_mul(self, other):
|
||
|
return torch_mul(self, other)
|
||
|
|
||
|
|
||
|
def torch_matmul(input, other, *, out=None):
|
||
|
d1 = input.dim()
|
||
|
d2 = other.dim()
|
||
|
shape = None
|
||
|
if d1 == 1 and d2 == 1:
|
||
|
shape = None
|
||
|
elif d1 == 2 and d2 == 2:
|
||
|
shape = (input.size(0), other.size(1))
|
||
|
elif d1 == 1 and d2 == 2:
|
||
|
shape = (other.size(1),)
|
||
|
elif d1 == 2 and d1 == 1:
|
||
|
shape = (input.size(0),)
|
||
|
else:
|
||
|
max_length = max(input.dim(), other.dim())
|
||
|
shape1 = list(input.shape)
|
||
|
shape2 = list(other.shape)
|
||
|
if d1 == 1:
|
||
|
shape1 = [1] + shape1
|
||
|
if d2 == 1:
|
||
|
shape2.append(1)
|
||
|
shape1 = [-1] * (max_length - d1) + list(input.shape)
|
||
|
shape2 = [-1] * (max_length - d2) + list(other.shape)
|
||
|
shape = []
|
||
|
for i in range(max_length):
|
||
|
shape.append(max(shape1[i], shape2[i]))
|
||
|
shape[-2] = shape1[-2]
|
||
|
shape[-1] = shape2[-1]
|
||
|
if d1 == 1:
|
||
|
shape.pop(-2)
|
||
|
if d2 == 1:
|
||
|
shape.pop(-1)
|
||
|
if shape is None:
|
||
|
return torch.tensor(0.0, device="meta")
|
||
|
return torch.empty(*shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_bmm(input, mat2, *, out=None):
|
||
|
if out is not None:
|
||
|
raise ValueError("Don't support in-place bmm for MetaTensor analysis")
|
||
|
batch_size, n, m = input.shape
|
||
|
_, _, p = mat2.shape
|
||
|
return torch.empty(batch_size, n, p, device="meta")
|
||
|
|
||
|
|
||
|
def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
|
||
|
if out is not None:
|
||
|
raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
|
||
|
return torch_bmm(batch1, batch2)
|
||
|
|
||
|
|
||
|
def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
|
||
|
return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)
|
||
|
|
||
|
|
||
|
def torch_einsum(equation, *operands):
|
||
|
# TODO: infer shape without performing the computation, this might be quite hard.
|
||
|
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
|
||
|
return torch.einsum(equation, *concrete_operands).to("meta")
|
||
|
|
||
|
|
||
|
def torch_tensor_repeat(self, *sizes):
|
||
|
shape = list(self.shape)
|
||
|
for i, x in enumerate(sizes):
|
||
|
shape[i] *= x
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_repeat_interleave(*args, dim=None, output_size=None):
|
||
|
num_args = len(args)
|
||
|
if num_args == 1:
|
||
|
shape = [output_size if output_size is not None else args[0].sum()]
|
||
|
else:
|
||
|
shape = list(args[0].shape)
|
||
|
if dim is None:
|
||
|
if num_args > 2:
|
||
|
dim = args[2]
|
||
|
else:
|
||
|
shape = [sum(shape)]
|
||
|
dim = 0
|
||
|
repeats = args[1]
|
||
|
if isinstance(repeats, int) or torch.numel(repeats) == 1:
|
||
|
shape[dim] *= int(repeats)
|
||
|
else:
|
||
|
shape[dim] = output_size if output_size is not None else repeats.sum()
|
||
|
return torch.empty(*shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_index_select(input, dim, index, *, out=None):
|
||
|
shape = list(input.shape)
|
||
|
shape[dim] = len(index)
|
||
|
return torch.empty(*shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_tensor_index_select(self, dim, index):
|
||
|
return torch_index_select(self, dim, index)
|
||
|
|
||
|
|
||
|
def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
|
||
|
shape = list(input.shape)
|
||
|
shape[dim] = index.shape[dim]
|
||
|
return torch.empty(*shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_tensor_gather(self, dim, index):
|
||
|
return torch_gather(self, dim, index)
|
||
|
|
||
|
|
||
|
def torch_roll(input, shifts, dims=None):
|
||
|
return input
|
||
|
|
||
|
|
||
|
def torch_flip(input, dims):
|
||
|
return input
|
||
|
|
||
|
|
||
|
def torch_tensor_flip(self, dims):
|
||
|
return self
|
||
|
|
||
|
|
||
|
def torch_nn_conv1d(self, input):
|
||
|
l_in = input.shape[-1]
|
||
|
shape = None
|
||
|
padding = self.padding
|
||
|
if padding == "valid":
|
||
|
padding = (0, 0)
|
||
|
if padding == "same":
|
||
|
shape = list(input.shape)
|
||
|
if shape is None:
|
||
|
shape = list(input.shape)
|
||
|
l_out = math.floor(
|
||
|
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
|
||
|
)
|
||
|
shape[-1] = l_out
|
||
|
shape[-2] = self.out_channels
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_nn_conv2d(self, input):
|
||
|
h_in, w_in = input.shape[-2:]
|
||
|
shape = None
|
||
|
padding = self.padding
|
||
|
if padding == "valid":
|
||
|
padding = (0, 0)
|
||
|
if padding == "same":
|
||
|
shape = list(input.shape)
|
||
|
if shape is None:
|
||
|
shape = list(input.shape)
|
||
|
h_out = math.floor(
|
||
|
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
|
||
|
)
|
||
|
w_out = math.floor(
|
||
|
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
|
||
|
)
|
||
|
shape[-2:] = [h_out, w_out]
|
||
|
shape[-3] = self.out_channels
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_squeeze(input, dim=None):
|
||
|
shape = list(input.shape)
|
||
|
if dim is not None:
|
||
|
if dim < 0:
|
||
|
dim = input.dim() + dim
|
||
|
if shape[dim] == 1:
|
||
|
shape.pop(dim)
|
||
|
else:
|
||
|
new_shape = []
|
||
|
for dim_value in shape:
|
||
|
if dim_value == 1:
|
||
|
continue
|
||
|
new_shape.append(dim_value)
|
||
|
shape = new_shape
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_tensor_squeeze(self, dim=None):
|
||
|
return torch_squeeze(self, dim)
|
||
|
|
||
|
|
||
|
def torch_unsqueeze(input, dim):
|
||
|
shape = list(input.shape)
|
||
|
if dim < 0:
|
||
|
dim = input.dim() + 1 + dim
|
||
|
shape.insert(dim, 1)
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_tensor_unsqueeze(self, dim):
|
||
|
return torch_unsqueeze(self, dim)
|
||
|
|
||
|
|
||
|
def torch_unique_consecutive(input, **kwargs):
|
||
|
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
|
||
|
if isinstance(output, torch.Tensor):
|
||
|
return output.to("meta")
|
||
|
else:
|
||
|
return tuple(map(output, lambda x: x.to("meta")))
|
||
|
|
||
|
|
||
|
def torch_nn_functional_one_hot(tensor, num_classes=-1):
|
||
|
if num_classes < 0:
|
||
|
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
|
||
|
shape = list(tensor.shape) + [num_classes]
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_nn_functional_scaled_dot_product_attention(
|
||
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
|
||
|
):
|
||
|
target_length = query.shape[-2]
|
||
|
head_dim = value.shape[-1]
|
||
|
return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")
|
||
|
|
||
|
|
||
|
def torch_nn_mseloss(self, input, target):
|
||
|
if self.reduction == "none":
|
||
|
shape = target.shape
|
||
|
else:
|
||
|
shape = (1,)
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_nn_crossentropyloss(self, input, target):
|
||
|
if self.reduction == "none":
|
||
|
shape = target.shape
|
||
|
else:
|
||
|
shape = (1,)
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def torch_nn_bcewithlogitsloss(self, input, target):
|
||
|
if self.reduction == "none":
|
||
|
shape = target.shape
|
||
|
else:
|
||
|
shape = (1,)
|
||
|
return torch.empty(shape, device="meta")
|
||
|
|
||
|
|
||
|
def operator_getitem(a, b):
|
||
|
def to_concrete(t):
|
||
|
if isinstance(t, torch.Tensor):
|
||
|
concrete = torch.ones_like(t, device="cpu")
|
||
|
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||
|
concrete = concrete.to(torch.int64)
|
||
|
return concrete
|
||
|
return t
|
||
|
|
||
|
if isinstance(a, torch.Tensor):
|
||
|
# TODO: infer shape without performing the computation.
|
||
|
if isinstance(b, tuple):
|
||
|
b = tuple(map(to_concrete, b))
|
||
|
else:
|
||
|
b = to_concrete(b)
|
||
|
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||
|
return operator.getitem(a, b)
|
||
|
|
||
|
|
||
|
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||
|
torch.nn.Embedding: torch_nn_embedding,
|
||
|
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
||
|
torch.nn.LayerNorm: torch_nn_layernorm,
|
||
|
torch.nn.GroupNorm: torch_nn_groupnorm,
|
||
|
torch.nn.Linear: torch_nn_linear,
|
||
|
torch.relu: torch_relu,
|
||
|
torch.nn.functional.relu: torch_nn_functional_relu,
|
||
|
torch.nn.ReLU: torch_nn_relu,
|
||
|
torch.where: torch_where,
|
||
|
torch.abs: torch_abs,
|
||
|
torch.arange: torch_arange,
|
||
|
torch.full: torch_full,
|
||
|
torch.cat: torch_cat,
|
||
|
torch.stack: torch_stack,
|
||
|
torch.add: torch_add,
|
||
|
torch.mul: torch_mul,
|
||
|
torch.Tensor.mul: torch_tensor_mul,
|
||
|
torch.matmul: torch_matmul,
|
||
|
torch.bmm: torch_bmm,
|
||
|
torch.baddbmm: torch_baddbmm,
|
||
|
torch.Tensor.baddbmm: torch_tensor_baddbmm,
|
||
|
torch.einsum: torch_einsum,
|
||
|
torch.Tensor.repeat: torch_tensor_repeat,
|
||
|
torch.repeat_interleave: torch_repeat_interleave,
|
||
|
torch.roll: torch_roll,
|
||
|
torch.flip: torch_flip,
|
||
|
torch.Tensor.flip: torch_tensor_flip,
|
||
|
torch.index_select: torch_index_select,
|
||
|
torch.Tensor.index_select: torch_tensor_index_select,
|
||
|
torch.gather: torch_gather,
|
||
|
torch.Tensor.gather: torch_tensor_gather,
|
||
|
torch.nn.Conv1d: torch_nn_conv1d,
|
||
|
torch.nn.Conv2d: torch_nn_conv2d,
|
||
|
torch.squeeze: torch_squeeze,
|
||
|
torch.Tensor.squeeze: torch_tensor_squeeze,
|
||
|
torch.unsqueeze: torch_unsqueeze,
|
||
|
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||
|
torch.unique_consecutive: torch_unique_consecutive,
|
||
|
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
|
||
|
torch.nn.MSELoss: torch_nn_mseloss,
|
||
|
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||
|
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||
|
operator.getitem: operator_getitem,
|
||
|
}
|
||
|
|
||
|
if is_torch_greater_or_equal_than_2_0:
|
||
|
_MANUAL_META_OVERRIDES[
|
||
|
torch.nn.functional.scaled_dot_product_attention
|
||
|
] = torch_nn_functional_scaled_dot_product_attention
|
||
|
|
||
|
|
||
|
class HFProxy(Proxy):
|
||
|
"""
|
||
|
Proxy that uses metadata to handle data-dependent control-flow.
|
||
|
"""
|
||
|
|
||
|
def install_metadata(self, metadata):
|
||
|
self._metadata = metadata
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return self.tracer.create_proxy("call_method", "size", (self,), {})
|
||
|
|
||
|
@property
|
||
|
def device(self):
|
||
|
# Hack so we can track when devices are used. During meta-tensor propagation,
|
||
|
# replace these values with a constant 'meta'
|
||
|
return MetaDeviceAttribute(self, "device")
|
||
|
|
||
|
def __len__(self):
|
||
|
if hasattr(self, "_metadata") and self._metadata is not None:
|
||
|
return len(self._metadata)
|
||
|
return super().__len__()
|
||
|
|
||
|
def __bool__(self):
|
||
|
if hasattr(self, "_metadata") and self._metadata is not None:
|
||
|
return self._metadata
|
||
|
return super().__bool__()
|
||
|
|
||
|
def __getattr__(self, k):
|
||
|
if k == "_metadata":
|
||
|
return self.__getattribute__(k)
|
||
|
# note: not added to the graph yet, if this is a method call
|
||
|
# we peephole optimize to the method invocation
|
||
|
return HFAttribute(self, k)
|
||
|
|
||
|
def __setitem__(self, indices, values):
|
||
|
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||
|
|
||
|
def __contains__(self, key):
|
||
|
if hasattr(self, "_metadata") and self._metadata is not None:
|
||
|
return key in self._metadata
|
||
|
return super().__contains__(key)
|
||
|
|
||
|
|
||
|
class HFAttribute(HFProxy):
|
||
|
def __init__(self, root, attr: str):
|
||
|
self.root = root
|
||
|
self.attr = attr
|
||
|
self.tracer = root.tracer
|
||
|
self._node = None
|
||
|
|
||
|
if hasattr(self.root, "_metadata"):
|
||
|
self.install_metadata(getattr(self.root._metadata, attr))
|
||
|
|
||
|
@property
|
||
|
def node(self):
|
||
|
# the node for attributes is added lazily, since most will just be method calls
|
||
|
# which do not rely on the getitem call
|
||
|
if self._node is None:
|
||
|
self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
|
||
|
return self._node
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||
|
|
||
|
|
||
|
class MetaDeviceAttribute(HFAttribute):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _proxies_to_metas(v):
|
||
|
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
|
||
|
if isinstance(v, MetaDeviceAttribute):
|
||
|
return "meta"
|
||
|
if isinstance(v, torch.fx.Proxy):
|
||
|
if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
|
||
|
raise RuntimeError(f"No metadata was found for {v}")
|
||
|
return v._metadata
|
||
|
return v
|
||
|
|
||
|
|
||
|
def _gen_constructor_wrapper(target):
|
||
|
@functools.wraps(target)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
proxy = None
|
||
|
|
||
|
def check_has_proxy(v):
|
||
|
if isinstance(v, Proxy):
|
||
|
nonlocal proxy
|
||
|
proxy = v
|
||
|
|
||
|
torch.fx.node.map_aggregate(args, check_has_proxy)
|
||
|
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
||
|
|
||
|
if proxy is not None:
|
||
|
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||
|
else:
|
||
|
return target(*args, **kwargs)
|
||
|
|
||
|
return wrapper, target
|
||
|
|
||
|
|
||
|
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||
|
if forbidden_values is None:
|
||
|
forbidden_values = []
|
||
|
value = random.randint(low, high)
|
||
|
while value in forbidden_values:
|
||
|
value = random.randint(low, high)
|
||
|
return value
|
||
|
|
||
|
|
||
|
class HFTracer(Tracer):
|
||
|
"""
|
||
|
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
|
||
|
regular PyTorch torch.fx.Proxy.
|
||
|
"""
|
||
|
|
||
|
# Feature flag for proxying accesses to buffer values
|
||
|
proxy_buffer_attributes: bool = True
|
||
|
allow_insert_stateless_mods: bool = True
|
||
|
_TORCH_METHODS_TO_PATCH = [
|
||
|
"arange",
|
||
|
"zeros",
|
||
|
"ones",
|
||
|
"full",
|
||
|
"full_like",
|
||
|
"eye",
|
||
|
"empty",
|
||
|
"tensor",
|
||
|
"clamp",
|
||
|
"finfo",
|
||
|
"tril",
|
||
|
]
|
||
|
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
||
|
|
||
|
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||
|
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
|
||
|
|
||
|
if not is_torch_fx_available():
|
||
|
raise ImportError(
|
||
|
f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
|
||
|
f"{TORCH_FX_REQUIRED_VERSION} is supported."
|
||
|
)
|
||
|
|
||
|
def _generate_dummy_input(
|
||
|
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
|
||
|
) -> Dict[str, torch.Tensor]:
|
||
|
"""Generates dummy input for model inference recording."""
|
||
|
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||
|
# from pickle, or from the "__class__" attribute in the general case.
|
||
|
model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
|
||
|
device = model.device
|
||
|
inputs_dict = {}
|
||
|
|
||
|
# when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
|
||
|
# rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
|
||
|
# After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
|
||
|
kv_cache_length = 5
|
||
|
|
||
|
if input_name in ["labels", "start_positions", "end_positions"]:
|
||
|
batch_size = shape[0]
|
||
|
if model_class_name in [
|
||
|
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
|
||
|
]:
|
||
|
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||
|
elif model_class_name in [
|
||
|
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
|
||
|
"XLNetForQuestionAnswering",
|
||
|
]:
|
||
|
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||
|
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||
|
elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
|
||
|
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
|
||
|
raise ValueError(
|
||
|
"Could not retrieve the problem type for the sequence classification task, please set "
|
||
|
'model.config.problem_type to one of the following values: "regression", '
|
||
|
'"single_label_classification", or "multi_label_classification".'
|
||
|
)
|
||
|
|
||
|
if model.config.problem_type == "regression":
|
||
|
labels_shape = (batch_size, model.config.num_labels)
|
||
|
labels_dtype = torch.float32
|
||
|
elif model.config.problem_type == "single_label_classification":
|
||
|
labels_shape = (batch_size,)
|
||
|
labels_dtype = torch.long
|
||
|
elif model.config.problem_type == "multi_label_classification":
|
||
|
labels_shape = (batch_size, model.config.num_labels)
|
||
|
labels_dtype = torch.float32
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
|
||
|
f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
|
||
|
)
|
||
|
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
|
||
|
|
||
|
elif model_class_name in [
|
||
|
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
|
||
|
*get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
|
||
|
"GPT2DoubleHeadsModel",
|
||
|
"PeftModelForCausalLM",
|
||
|
"PeftModelForSeq2SeqLM",
|
||
|
]:
|
||
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||
|
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
|
||
|
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
|
||
|
else:
|
||
|
raise NotImplementedError(
|
||
|
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
|
||
|
)
|
||
|
elif "pixel_values" in input_name:
|
||
|
batch_size = shape[0]
|
||
|
image_size = getattr(model.config, "image_size", None)
|
||
|
if image_size is None:
|
||
|
if hasattr(model.config, "vision_config"):
|
||
|
image_size = model.config.vision_config.image_size
|
||
|
elif hasattr(model.config, "encoder"):
|
||
|
image_size = model.config.encoder.image_size
|
||
|
else:
|
||
|
image_size = (_generate_random_int(), _generate_random_int())
|
||
|
|
||
|
# If no num_channels is in the config, use some arbitrary value.
|
||
|
num_channels = getattr(model.config, "num_channels", 3)
|
||
|
if not isinstance(image_size, collections.abc.Iterable):
|
||
|
image_size = (image_size, image_size)
|
||
|
height, width = image_size
|
||
|
inputs_dict[input_name] = torch.zeros(
|
||
|
batch_size, num_channels, height, width, dtype=torch.float32, device=device
|
||
|
)
|
||
|
elif "bbox" in input_name:
|
||
|
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
|
||
|
elif "input_features" in input_name:
|
||
|
inputs_dict[input_name] = torch.zeros(
|
||
|
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
|
||
|
)
|
||
|
elif "visual_feats" in input_name:
|
||
|
inputs_dict[input_name] = torch.zeros(
|
||
|
shape
|
||
|
+ [
|
||
|
model.config.visual_feat_dim,
|
||
|
],
|
||
|
dtype=torch.float,
|
||
|
device=device,
|
||
|
)
|
||
|
elif "visual_pos" in input_name:
|
||
|
inputs_dict[input_name] = torch.zeros(
|
||
|
shape
|
||
|
+ [
|
||
|
model.config.visual_pos_dim,
|
||
|
],
|
||
|
dtype=torch.float,
|
||
|
device=device,
|
||
|
)
|
||
|
elif "inputs" in input_name:
|
||
|
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
|
||
|
elif "input_values" in input_name:
|
||
|
batch_size, _ = shape
|
||
|
# Generating big sequence length for audio inputs.
|
||
|
seq_length = _generate_random_int(low=10000, high=20000)
|
||
|
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
|
||
|
elif "mask" in input_name:
|
||
|
if "past_key_values" in input_names:
|
||
|
mask_shape = [shape[0], shape[1] + kv_cache_length]
|
||
|
else:
|
||
|
mask_shape = shape
|
||
|
|
||
|
inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device)
|
||
|
elif "ids" in input_name:
|
||
|
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||
|
elif "past_key_values" in input_name:
|
||
|
if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
||
|
raise NotImplementedError(
|
||
|
f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added."
|
||
|
)
|
||
|
num_heads = model.config.num_attention_heads
|
||
|
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||
|
|
||
|
cache_shape = (shape[0], num_heads, kv_cache_length, head_dim)
|
||
|
pkv = tuple(
|
||
|
(
|
||
|
torch.rand(cache_shape, dtype=torch.float, device=device),
|
||
|
torch.rand(cache_shape, dtype=torch.float, device=device),
|
||
|
)
|
||
|
for i in range(model.config.num_hidden_layers)
|
||
|
)
|
||
|
inputs_dict[input_name] = pkv
|
||
|
else:
|
||
|
shape_with_hidden_size = shape + [model.config.hidden_size]
|
||
|
inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
|
||
|
|
||
|
return inputs_dict
|
||
|
|
||
|
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
|
||
|
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||
|
|
||
|
if kind == "placeholder" and target in self.meta_args:
|
||
|
rv.install_metadata(self.meta_args[target])
|
||
|
return rv
|
||
|
|
||
|
if target in self.orig_fns:
|
||
|
# NOTE: tensor constructors in PyTorch define the `device` argument as
|
||
|
# *kwargs-only*. That is why this works. If you add methods to
|
||
|
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
|
||
|
# this will break and you will likely see issues where we cannot infer
|
||
|
# the size of the output.
|
||
|
if "device" in kwargs:
|
||
|
kwargs["device"] = "meta"
|
||
|
|
||
|
try:
|
||
|
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
|
||
|
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
|
||
|
|
||
|
if kind == "call_function":
|
||
|
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
|
||
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||
|
if isinstance(meta_out, torch.Tensor):
|
||
|
meta_out = meta_out.to(device="meta")
|
||
|
elif kind == "call_method":
|
||
|
method = getattr(args_metas[0].__class__, target)
|
||
|
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
|
||
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||
|
elif kind == "call_module":
|
||
|
if not hasattr(self, "orig_forward"):
|
||
|
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
||
|
self._disable_module_getattr = True
|
||
|
try:
|
||
|
mod = self.root.get_submodule(target)
|
||
|
mod_type = type(mod)
|
||
|
if mod_type in _MANUAL_META_OVERRIDES:
|
||
|
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
|
||
|
else:
|
||
|
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||
|
finally:
|
||
|
self._disable_module_getattr = False
|
||
|
elif kind == "get_attr":
|
||
|
self._disable_module_getattr = True
|
||
|
try:
|
||
|
attr_itr = self.root
|
||
|
atoms = target.split(".")
|
||
|
for atom in atoms:
|
||
|
attr_itr = getattr(attr_itr, atom)
|
||
|
if isinstance(attr_itr, torch.Tensor):
|
||
|
meta_out = attr_itr.to(device="meta")
|
||
|
else:
|
||
|
meta_out = attr_itr
|
||
|
finally:
|
||
|
self._disable_module_getattr = False
|
||
|
else:
|
||
|
return rv
|
||
|
|
||
|
if not isinstance(rv, Proxy):
|
||
|
raise ValueError("Don't support composite output yet")
|
||
|
rv.install_metadata(meta_out)
|
||
|
except Exception as e:
|
||
|
if _IS_IN_DEBUG_MODE:
|
||
|
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
|
||
|
|
||
|
return rv
|
||
|
|
||
|
# Replaced by .getattr from PyTorch 1.13
|
||
|
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||
|
if getattr(self, "_disable_module_getattr", False):
|
||
|
return attr_val
|
||
|
else:
|
||
|
|
||
|
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||
|
for n, p in collection_to_search:
|
||
|
if attr_val is p:
|
||
|
if n not in parameter_proxy_cache:
|
||
|
kwargs = {}
|
||
|
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||
|
kwargs["proxy_factory_fn"] = (
|
||
|
None
|
||
|
if not self.param_shapes_constant
|
||
|
else lambda node: ParameterProxy(self, node, n, attr_val)
|
||
|
)
|
||
|
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||
|
parameter_proxy_cache[n] = val_proxy
|
||
|
return parameter_proxy_cache[n]
|
||
|
return None
|
||
|
|
||
|
if isinstance(attr_val, torch.nn.Parameter):
|
||
|
maybe_parameter_proxy = maybe_get_proxy_for_attr(
|
||
|
attr_val, self.root.named_parameters(), parameter_proxy_cache
|
||
|
)
|
||
|
if maybe_parameter_proxy is not None:
|
||
|
return maybe_parameter_proxy
|
||
|
|
||
|
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||
|
maybe_buffer_proxy = maybe_get_proxy_for_attr(
|
||
|
attr_val, self.root.named_buffers(), parameter_proxy_cache
|
||
|
)
|
||
|
if maybe_buffer_proxy is not None:
|
||
|
return maybe_buffer_proxy
|
||
|
|
||
|
return attr_val
|
||
|
|
||
|
# Needed for PyTorch 1.13+
|
||
|
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
|
||
|
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||
|
|
||
|
def call_module(self, m, forward, args, kwargs):
|
||
|
self.orig_forward = forward
|
||
|
return super().call_module(m, forward, args, kwargs)
|
||
|
|
||
|
def proxy(self, node):
|
||
|
return HFProxy(node, self)
|
||
|
|
||
|
def trace(
|
||
|
self,
|
||
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||
|
concrete_args: Optional[Dict[str, Any]] = None,
|
||
|
dummy_inputs: Optional[Dict[str, Any]] = None,
|
||
|
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
|
||
|
) -> Graph:
|
||
|
"""
|
||
|
Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
|
||
|
`torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
|
||
|
the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
|
||
|
`torch.nn.Module` instance to use as the root and add embedded constants to.
|
||
|
|
||
|
Args:
|
||
|
root (`torch.nn.Module` or `Callable`):
|
||
|
Either a `torch.nn.Module`` or a function to be traced through. If root is not a
|
||
|
[`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
|
||
|
concrete_args (`Dict[str, Any], *optional*):
|
||
|
Concrete arguments that should not be treated as Proxies
|
||
|
dummy_inputs (`Dict[str, Any]`, *optional*):
|
||
|
The dummy inputs needed to handle data-dependent control-flow if `root` is not a
|
||
|
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a
|
||
|
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
|
||
|
complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
|
||
|
If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
|
||
|
`dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
|
||
|
|
||
|
Returns:
|
||
|
`torch.fx.Graph`:
|
||
|
A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
|
||
|
|
||
|
"""
|
||
|
sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
|
||
|
|
||
|
if concrete_args is None:
|
||
|
concrete_args = {}
|
||
|
|
||
|
if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
|
||
|
for param in sig.parameters.values():
|
||
|
if param.name in dummy_inputs:
|
||
|
continue
|
||
|
if param.default is inspect.Parameter.empty:
|
||
|
raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
|
||
|
concrete_args.update(
|
||
|
{
|
||
|
p.name: p.default
|
||
|
for p in sig.parameters.values()
|
||
|
if (p.name not in dummy_inputs and p.name not in concrete_args)
|
||
|
}
|
||
|
)
|
||
|
|
||
|
input_names = sig.parameters.keys() - concrete_args.keys()
|
||
|
|
||
|
# Creating a random input shape to generate dummy inputs.
|
||
|
batch_size = _generate_random_int()
|
||
|
sequence_length = _generate_random_int()
|
||
|
shape = [batch_size, sequence_length]
|
||
|
|
||
|
if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
|
||
|
num_choices = _generate_random_int(low=2, high=5)
|
||
|
shape.insert(1, num_choices)
|
||
|
|
||
|
inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
|
||
|
for input_name in input_names:
|
||
|
if input_name in inputs:
|
||
|
continue
|
||
|
# We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
|
||
|
# be able to use HFTracer._generate_dummy_input.
|
||
|
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
|
||
|
("_deserialize_graph_module", "_CodeOnlyModule")
|
||
|
):
|
||
|
inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names))
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
f"Could not generate input named {input_name} for because root is not a"
|
||
|
" transformers.PreTrainedModel."
|
||
|
)
|
||
|
|
||
|
concrete_metas = {
|
||
|
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
|
||
|
for input_name, input_ in inputs.items()
|
||
|
}
|
||
|
for param in sig.parameters.values():
|
||
|
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
|
||
|
concrete_metas[f"**{param.name}"] = {}
|
||
|
self.meta_args = concrete_metas
|
||
|
self.patched_torch_methods = {
|
||
|
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||
|
}
|
||
|
self.orig_fns = set()
|
||
|
|
||
|
for name, (wrapper, orig) in self.patched_torch_methods.items():
|
||
|
setattr(torch, name, wrapper)
|
||
|
self.orig_fns.add(orig)
|
||
|
|
||
|
try:
|
||
|
self.graph = super().trace(root, concrete_args=concrete_args)
|
||
|
finally:
|
||
|
for name, (_, orig) in self.patched_torch_methods.items():
|
||
|
setattr(torch, name, orig)
|
||
|
|
||
|
# This is necessary because concrete args are added as input to the traced module since
|
||
|
# https://github.com/pytorch/pytorch/pull/55888.
|
||
|
for node in self.graph.nodes:
|
||
|
if node.op == "placeholder":
|
||
|
# Removing default values for inputs as the forward pass will fail with them.
|
||
|
if node.target in input_names:
|
||
|
node.args = ()
|
||
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
||
|
node.type = torch.Tensor
|
||
|
# It is a concrete arg so it is not used and should be removed.
|
||
|
else:
|
||
|
to_visit = [node]
|
||
|
to_delete = collections.OrderedDict()
|
||
|
while to_visit:
|
||
|
n = to_visit.pop(0)
|
||
|
to_delete[n] = None
|
||
|
to_visit += list(n.users.keys())
|
||
|
|
||
|
for user in reversed(to_delete.keys()):
|
||
|
self.graph.erase_node(user)
|
||
|
|
||
|
# TODO: solves GraphModule creation.
|
||
|
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||
|
if node.op == "output":
|
||
|
node.type = None
|
||
|
|
||
|
return self.graph
|
||
|
|
||
|
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
|
||
|
"""
|
||
|
Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
|
||
|
because its attributes are input-dependent.
|
||
|
"""
|
||
|
return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
|
||
|
|
||
|
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
|
||
|
"""
|
||
|
Helper method which tries to insert a module that was not declared as submodule.
|
||
|
"""
|
||
|
# If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
|
||
|
# It is not possible to insert such modules, those should be traced through.
|
||
|
if self._stateless_mod_instanciation_depends_on_proxies(mod):
|
||
|
return ""
|
||
|
idx = 0
|
||
|
mod_name = mod.__class__.__name__.lower()
|
||
|
path = f"{mod_name}_{idx}"
|
||
|
already_inserted = False
|
||
|
while hasattr(self.root, path):
|
||
|
if getattr(self.root, path) is mod:
|
||
|
already_inserted = True
|
||
|
break
|
||
|
path = f"{mod_name}_{idx}"
|
||
|
idx += 1
|
||
|
|
||
|
# No need to add multiple instances of the same module.
|
||
|
if not already_inserted:
|
||
|
self.root.add_module(path, mod)
|
||
|
return path
|
||
|
|
||
|
def path_of_module(self, mod: nn.Module) -> str:
|
||
|
"""
|
||
|
Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
|
||
|
a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
|
||
|
string "foo.bar".
|
||
|
|
||
|
Args:
|
||
|
mod (str): The `Module` to retrieve the qualified name for.
|
||
|
"""
|
||
|
try:
|
||
|
return super().path_of_module(mod)
|
||
|
except NameError as e:
|
||
|
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
|
||
|
path = self._insert_module_as_submodule(mod)
|
||
|
return path
|
||
|
raise e
|
||
|
|
||
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
||
|
return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
|
||
|
m, module_qualified_name
|
||
|
)
|
||
|
|
||
|
@compatibility(is_backward_compatible=True)
|
||
|
def keys(self, obj: "Proxy") -> Any:
|
||
|
"""Called when a proxy object is has the keys() method called.
|
||
|
This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
|
||
|
your custom tracer.
|
||
|
"""
|
||
|
attribute = HFAttribute(obj, "keys")()
|
||
|
if obj.node.target == "**kwargs":
|
||
|
return attribute._metadata
|
||
|
return attribute
|
||
|
|
||
|
|
||
|
def get_concrete_args(model: nn.Module, input_names: List[str]):
|
||
|
sig = inspect.signature(model.forward)
|
||
|
|
||
|
if not (set(input_names) <= set(sig.parameters.keys())):
|
||
|
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
|
||
|
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
|
||
|
raise ValueError(
|
||
|
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
|
||
|
f" {formatted_allowed_input_names}"
|
||
|
)
|
||
|
|
||
|
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||
|
|
||
|
|
||
|
def is_model_supported(model: PreTrainedModel):
|
||
|
return model.__class__.__name__ in _SUPPORTED_MODELS
|
||
|
|
||
|
|
||
|
def check_if_model_is_supported(model: PreTrainedModel):
|
||
|
if not is_model_supported(model):
|
||
|
supported_model_names = ", ".join(_SUPPORTED_MODELS)
|
||
|
raise NotImplementedError(
|
||
|
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
|
||
|
)
|
||
|
|
||
|
|
||
|
def symbolic_trace(
|
||
|
model: PreTrainedModel,
|
||
|
input_names: Optional[List[str]] = None,
|
||
|
disable_check: bool = False,
|
||
|
tracer_cls: Type[HFTracer] = HFTracer,
|
||
|
) -> GraphModule:
|
||
|
"""
|
||
|
Performs symbolic tracing on the model.
|
||
|
|
||
|
Args:
|
||
|
model ([`PretrainedModel`]):
|
||
|
The model to trace.
|
||
|
input_names (`List[str]`, *optional*):
|
||
|
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
|
||
|
disable_check (`bool`, *optional*, defaults to `False`):
|
||
|
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
|
||
|
tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
|
||
|
The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
|
||
|
|
||
|
Returns:
|
||
|
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
from transformers.utils.fx import symbolic_trace
|
||
|
|
||
|
traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
|
||
|
```
|
||
|
"""
|
||
|
if input_names is None:
|
||
|
input_names = model.dummy_inputs.keys()
|
||
|
|
||
|
input_names = list(input_names)
|
||
|
concrete_args = get_concrete_args(model, input_names)
|
||
|
|
||
|
if not disable_check:
|
||
|
check_if_model_is_supported(model)
|
||
|
|
||
|
# Tracing.
|
||
|
tracer = tracer_cls()
|
||
|
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||
|
traced = torch.fx.GraphModule(model, traced_graph)
|
||
|
|
||
|
traced.config = model.config
|
||
|
# The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
|
||
|
# _generate_dummy_input, where the model class is needed.
|
||
|
traced.class_for_deserialization = model.__class__
|
||
|
traced.device = model.device
|
||
|
|
||
|
return traced
|