736 lines
23 KiB
Python
736 lines
23 KiB
Python
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""
|
||
|
Generic utilities
|
||
|
"""
|
||
|
|
||
|
import inspect
|
||
|
import tempfile
|
||
|
from collections import OrderedDict, UserDict
|
||
|
from collections.abc import MutableMapping
|
||
|
from contextlib import ExitStack, contextmanager
|
||
|
from dataclasses import fields, is_dataclass
|
||
|
from enum import Enum
|
||
|
from functools import partial
|
||
|
from typing import Any, ContextManager, Iterable, List, Tuple
|
||
|
|
||
|
import numpy as np
|
||
|
from packaging import version
|
||
|
|
||
|
from .import_utils import (
|
||
|
get_torch_version,
|
||
|
is_flax_available,
|
||
|
is_mlx_available,
|
||
|
is_tf_available,
|
||
|
is_torch_available,
|
||
|
is_torch_fx_proxy,
|
||
|
)
|
||
|
|
||
|
|
||
|
if is_flax_available():
|
||
|
import jax.numpy as jnp
|
||
|
|
||
|
|
||
|
class cached_property(property):
|
||
|
"""
|
||
|
Descriptor that mimics @property but caches output in member variable.
|
||
|
|
||
|
From tensorflow_datasets
|
||
|
|
||
|
Built-in in functools from Python 3.8.
|
||
|
"""
|
||
|
|
||
|
def __get__(self, obj, objtype=None):
|
||
|
# See docs.python.org/3/howto/descriptor.html#properties
|
||
|
if obj is None:
|
||
|
return self
|
||
|
if self.fget is None:
|
||
|
raise AttributeError("unreadable attribute")
|
||
|
attr = "__cached_" + self.fget.__name__
|
||
|
cached = getattr(obj, attr, None)
|
||
|
if cached is None:
|
||
|
cached = self.fget(obj)
|
||
|
setattr(obj, attr, cached)
|
||
|
return cached
|
||
|
|
||
|
|
||
|
# vendored from distutils.util
|
||
|
def strtobool(val):
|
||
|
"""Convert a string representation of truth to true (1) or false (0).
|
||
|
|
||
|
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
|
||
|
Raises ValueError if 'val' is anything else.
|
||
|
"""
|
||
|
val = val.lower()
|
||
|
if val in {"y", "yes", "t", "true", "on", "1"}:
|
||
|
return 1
|
||
|
if val in {"n", "no", "f", "false", "off", "0"}:
|
||
|
return 0
|
||
|
raise ValueError(f"invalid truth value {val!r}")
|
||
|
|
||
|
|
||
|
def infer_framework_from_repr(x):
|
||
|
"""
|
||
|
Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
|
||
|
frameworks in a smart order, without the need to import the frameworks).
|
||
|
"""
|
||
|
representation = str(type(x))
|
||
|
if representation.startswith("<class 'torch."):
|
||
|
return "pt"
|
||
|
elif representation.startswith("<class 'tensorflow."):
|
||
|
return "tf"
|
||
|
elif representation.startswith("<class 'jax"):
|
||
|
return "jax"
|
||
|
elif representation.startswith("<class 'numpy."):
|
||
|
return "np"
|
||
|
elif representation.startswith("<class 'mlx."):
|
||
|
return "mlx"
|
||
|
|
||
|
|
||
|
def _get_frameworks_and_test_func(x):
|
||
|
"""
|
||
|
Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework
|
||
|
we can guess from the repr first, then Numpy, then the others.
|
||
|
"""
|
||
|
framework_to_test = {
|
||
|
"pt": is_torch_tensor,
|
||
|
"tf": is_tf_tensor,
|
||
|
"jax": is_jax_tensor,
|
||
|
"np": is_numpy_array,
|
||
|
"mlx": is_mlx_array,
|
||
|
}
|
||
|
preferred_framework = infer_framework_from_repr(x)
|
||
|
# We will test this one first, then numpy, then the others.
|
||
|
frameworks = [] if preferred_framework is None else [preferred_framework]
|
||
|
if preferred_framework != "np":
|
||
|
frameworks.append("np")
|
||
|
frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
|
||
|
return {f: framework_to_test[f] for f in frameworks}
|
||
|
|
||
|
|
||
|
def is_tensor(x):
|
||
|
"""
|
||
|
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` or `mlx.array`
|
||
|
in the order defined by `infer_framework_from_repr`
|
||
|
"""
|
||
|
# This gives us a smart order to test the frameworks with the corresponding tests.
|
||
|
framework_to_test_func = _get_frameworks_and_test_func(x)
|
||
|
for test_func in framework_to_test_func.values():
|
||
|
if test_func(x):
|
||
|
return True
|
||
|
|
||
|
# Tracers
|
||
|
if is_torch_fx_proxy(x):
|
||
|
return True
|
||
|
|
||
|
if is_flax_available():
|
||
|
from jax.core import Tracer
|
||
|
|
||
|
if isinstance(x, Tracer):
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _is_numpy(x):
|
||
|
return isinstance(x, np.ndarray)
|
||
|
|
||
|
|
||
|
def is_numpy_array(x):
|
||
|
"""
|
||
|
Tests if `x` is a numpy array or not.
|
||
|
"""
|
||
|
return _is_numpy(x)
|
||
|
|
||
|
|
||
|
def _is_torch(x):
|
||
|
import torch
|
||
|
|
||
|
return isinstance(x, torch.Tensor)
|
||
|
|
||
|
|
||
|
def is_torch_tensor(x):
|
||
|
"""
|
||
|
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
|
||
|
"""
|
||
|
return False if not is_torch_available() else _is_torch(x)
|
||
|
|
||
|
|
||
|
def _is_torch_device(x):
|
||
|
import torch
|
||
|
|
||
|
return isinstance(x, torch.device)
|
||
|
|
||
|
|
||
|
def is_torch_device(x):
|
||
|
"""
|
||
|
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
|
||
|
"""
|
||
|
return False if not is_torch_available() else _is_torch_device(x)
|
||
|
|
||
|
|
||
|
def _is_torch_dtype(x):
|
||
|
import torch
|
||
|
|
||
|
if isinstance(x, str):
|
||
|
if hasattr(torch, x):
|
||
|
x = getattr(torch, x)
|
||
|
else:
|
||
|
return False
|
||
|
return isinstance(x, torch.dtype)
|
||
|
|
||
|
|
||
|
def is_torch_dtype(x):
|
||
|
"""
|
||
|
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
|
||
|
"""
|
||
|
return False if not is_torch_available() else _is_torch_dtype(x)
|
||
|
|
||
|
|
||
|
def _is_tensorflow(x):
|
||
|
import tensorflow as tf
|
||
|
|
||
|
return isinstance(x, tf.Tensor)
|
||
|
|
||
|
|
||
|
def is_tf_tensor(x):
|
||
|
"""
|
||
|
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
|
||
|
"""
|
||
|
return False if not is_tf_available() else _is_tensorflow(x)
|
||
|
|
||
|
|
||
|
def _is_tf_symbolic_tensor(x):
|
||
|
import tensorflow as tf
|
||
|
|
||
|
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14
|
||
|
if hasattr(tf, "is_symbolic_tensor"):
|
||
|
return tf.is_symbolic_tensor(x)
|
||
|
return type(x) == tf.Tensor
|
||
|
|
||
|
|
||
|
def is_tf_symbolic_tensor(x):
|
||
|
"""
|
||
|
Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not
|
||
|
installed.
|
||
|
"""
|
||
|
return False if not is_tf_available() else _is_tf_symbolic_tensor(x)
|
||
|
|
||
|
|
||
|
def _is_jax(x):
|
||
|
import jax.numpy as jnp # noqa: F811
|
||
|
|
||
|
return isinstance(x, jnp.ndarray)
|
||
|
|
||
|
|
||
|
def is_jax_tensor(x):
|
||
|
"""
|
||
|
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
|
||
|
"""
|
||
|
return False if not is_flax_available() else _is_jax(x)
|
||
|
|
||
|
|
||
|
def _is_mlx(x):
|
||
|
import mlx.core as mx
|
||
|
|
||
|
return isinstance(x, mx.array)
|
||
|
|
||
|
|
||
|
def is_mlx_array(x):
|
||
|
"""
|
||
|
Tests if `x` is a mlx array or not. Safe to call even when mlx is not installed.
|
||
|
"""
|
||
|
return False if not is_mlx_available() else _is_mlx(x)
|
||
|
|
||
|
|
||
|
def to_py_obj(obj):
|
||
|
"""
|
||
|
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
|
||
|
"""
|
||
|
|
||
|
framework_to_py_obj = {
|
||
|
"pt": lambda obj: obj.detach().cpu().tolist(),
|
||
|
"tf": lambda obj: obj.numpy().tolist(),
|
||
|
"jax": lambda obj: np.asarray(obj).tolist(),
|
||
|
"np": lambda obj: obj.tolist(),
|
||
|
}
|
||
|
|
||
|
if isinstance(obj, (dict, UserDict)):
|
||
|
return {k: to_py_obj(v) for k, v in obj.items()}
|
||
|
elif isinstance(obj, (list, tuple)):
|
||
|
return [to_py_obj(o) for o in obj]
|
||
|
|
||
|
# This gives us a smart order to test the frameworks with the corresponding tests.
|
||
|
framework_to_test_func = _get_frameworks_and_test_func(obj)
|
||
|
for framework, test_func in framework_to_test_func.items():
|
||
|
if test_func(obj):
|
||
|
return framework_to_py_obj[framework](obj)
|
||
|
|
||
|
# tolist also works on 0d np arrays
|
||
|
if isinstance(obj, np.number):
|
||
|
return obj.tolist()
|
||
|
else:
|
||
|
return obj
|
||
|
|
||
|
|
||
|
def to_numpy(obj):
|
||
|
"""
|
||
|
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
|
||
|
"""
|
||
|
|
||
|
framework_to_numpy = {
|
||
|
"pt": lambda obj: obj.detach().cpu().numpy(),
|
||
|
"tf": lambda obj: obj.numpy(),
|
||
|
"jax": lambda obj: np.asarray(obj),
|
||
|
"np": lambda obj: obj,
|
||
|
}
|
||
|
|
||
|
if isinstance(obj, (dict, UserDict)):
|
||
|
return {k: to_numpy(v) for k, v in obj.items()}
|
||
|
elif isinstance(obj, (list, tuple)):
|
||
|
return np.array(obj)
|
||
|
|
||
|
# This gives us a smart order to test the frameworks with the corresponding tests.
|
||
|
framework_to_test_func = _get_frameworks_and_test_func(obj)
|
||
|
for framework, test_func in framework_to_test_func.items():
|
||
|
if test_func(obj):
|
||
|
return framework_to_numpy[framework](obj)
|
||
|
|
||
|
return obj
|
||
|
|
||
|
|
||
|
class ModelOutput(OrderedDict):
|
||
|
"""
|
||
|
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
|
||
|
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
|
||
|
python dictionary.
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
|
||
|
before.
|
||
|
|
||
|
</Tip>
|
||
|
"""
|
||
|
|
||
|
def __init_subclass__(cls) -> None:
|
||
|
"""Register subclasses as pytree nodes.
|
||
|
|
||
|
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
|
||
|
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||
|
"""
|
||
|
if is_torch_available():
|
||
|
if version.parse(get_torch_version()) >= version.parse("2.2"):
|
||
|
_torch_pytree.register_pytree_node(
|
||
|
cls,
|
||
|
_model_output_flatten,
|
||
|
partial(_model_output_unflatten, output_type=cls),
|
||
|
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
|
||
|
)
|
||
|
else:
|
||
|
_torch_pytree._register_pytree_node(
|
||
|
cls,
|
||
|
_model_output_flatten,
|
||
|
partial(_model_output_unflatten, output_type=cls),
|
||
|
)
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
|
||
|
# Subclasses of ModelOutput must use the @dataclass decorator
|
||
|
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
|
||
|
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
|
||
|
# Just need to check that the current class is not ModelOutput
|
||
|
is_modeloutput_subclass = self.__class__ != ModelOutput
|
||
|
|
||
|
if is_modeloutput_subclass and not is_dataclass(self):
|
||
|
raise TypeError(
|
||
|
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
|
||
|
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
|
||
|
)
|
||
|
|
||
|
def __post_init__(self):
|
||
|
"""Check the ModelOutput dataclass.
|
||
|
|
||
|
Only occurs if @dataclass decorator has been used.
|
||
|
"""
|
||
|
class_fields = fields(self)
|
||
|
|
||
|
# Safety and consistency checks
|
||
|
if not len(class_fields):
|
||
|
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||
|
if not all(field.default is None for field in class_fields[1:]):
|
||
|
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
|
||
|
|
||
|
first_field = getattr(self, class_fields[0].name)
|
||
|
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||
|
|
||
|
if other_fields_are_none and not is_tensor(first_field):
|
||
|
if isinstance(first_field, dict):
|
||
|
iterator = first_field.items()
|
||
|
first_field_iterator = True
|
||
|
else:
|
||
|
try:
|
||
|
iterator = iter(first_field)
|
||
|
first_field_iterator = True
|
||
|
except TypeError:
|
||
|
first_field_iterator = False
|
||
|
|
||
|
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
||
|
# set the associated fields
|
||
|
if first_field_iterator:
|
||
|
for idx, element in enumerate(iterator):
|
||
|
if (
|
||
|
not isinstance(element, (list, tuple))
|
||
|
or not len(element) == 2
|
||
|
or not isinstance(element[0], str)
|
||
|
):
|
||
|
if idx == 0:
|
||
|
# If we do not have an iterator of key/values, set it as attribute
|
||
|
self[class_fields[0].name] = first_field
|
||
|
else:
|
||
|
# If we have a mixed iterator, raise an error
|
||
|
raise ValueError(
|
||
|
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
|
||
|
)
|
||
|
break
|
||
|
setattr(self, element[0], element[1])
|
||
|
if element[1] is not None:
|
||
|
self[element[0]] = element[1]
|
||
|
elif first_field is not None:
|
||
|
self[class_fields[0].name] = first_field
|
||
|
else:
|
||
|
for field in class_fields:
|
||
|
v = getattr(self, field.name)
|
||
|
if v is not None:
|
||
|
self[field.name] = v
|
||
|
|
||
|
def __delitem__(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def setdefault(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def pop(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def update(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def __getitem__(self, k):
|
||
|
if isinstance(k, str):
|
||
|
inner_dict = dict(self.items())
|
||
|
return inner_dict[k]
|
||
|
else:
|
||
|
return self.to_tuple()[k]
|
||
|
|
||
|
def __setattr__(self, name, value):
|
||
|
if name in self.keys() and value is not None:
|
||
|
# Don't call self.__setitem__ to avoid recursion errors
|
||
|
super().__setitem__(name, value)
|
||
|
super().__setattr__(name, value)
|
||
|
|
||
|
def __setitem__(self, key, value):
|
||
|
# Will raise a KeyException if needed
|
||
|
super().__setitem__(key, value)
|
||
|
# Don't call self.__setattr__ to avoid recursion errors
|
||
|
super().__setattr__(key, value)
|
||
|
|
||
|
def __reduce__(self):
|
||
|
if not is_dataclass(self):
|
||
|
return super().__reduce__()
|
||
|
callable, _args, *remaining = super().__reduce__()
|
||
|
args = tuple(getattr(self, field.name) for field in fields(self))
|
||
|
return callable, args, *remaining
|
||
|
|
||
|
def to_tuple(self) -> Tuple[Any]:
|
||
|
"""
|
||
|
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||
|
"""
|
||
|
return tuple(self[k] for k in self.keys())
|
||
|
|
||
|
|
||
|
if is_torch_available():
|
||
|
import torch.utils._pytree as _torch_pytree
|
||
|
|
||
|
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
||
|
return list(output.values()), list(output.keys())
|
||
|
|
||
|
def _model_output_unflatten(
|
||
|
values: Iterable[Any],
|
||
|
context: "_torch_pytree.Context",
|
||
|
output_type=None,
|
||
|
) -> ModelOutput:
|
||
|
return output_type(**dict(zip(context, values)))
|
||
|
|
||
|
if version.parse(get_torch_version()) >= version.parse("2.2"):
|
||
|
_torch_pytree.register_pytree_node(
|
||
|
ModelOutput,
|
||
|
_model_output_flatten,
|
||
|
partial(_model_output_unflatten, output_type=ModelOutput),
|
||
|
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
|
||
|
)
|
||
|
else:
|
||
|
_torch_pytree._register_pytree_node(
|
||
|
ModelOutput,
|
||
|
_model_output_flatten,
|
||
|
partial(_model_output_unflatten, output_type=ModelOutput),
|
||
|
)
|
||
|
|
||
|
|
||
|
class ExplicitEnum(str, Enum):
|
||
|
"""
|
||
|
Enum with more explicit error message for missing values.
|
||
|
"""
|
||
|
|
||
|
@classmethod
|
||
|
def _missing_(cls, value):
|
||
|
raise ValueError(
|
||
|
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
|
||
|
)
|
||
|
|
||
|
|
||
|
class PaddingStrategy(ExplicitEnum):
|
||
|
"""
|
||
|
Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
|
||
|
IDE.
|
||
|
"""
|
||
|
|
||
|
LONGEST = "longest"
|
||
|
MAX_LENGTH = "max_length"
|
||
|
DO_NOT_PAD = "do_not_pad"
|
||
|
|
||
|
|
||
|
class TensorType(ExplicitEnum):
|
||
|
"""
|
||
|
Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
|
||
|
tab-completion in an IDE.
|
||
|
"""
|
||
|
|
||
|
PYTORCH = "pt"
|
||
|
TENSORFLOW = "tf"
|
||
|
NUMPY = "np"
|
||
|
JAX = "jax"
|
||
|
MLX = "mlx"
|
||
|
|
||
|
|
||
|
class ContextManagers:
|
||
|
"""
|
||
|
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
|
||
|
in the `fastcore` library.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, context_managers: List[ContextManager]):
|
||
|
self.context_managers = context_managers
|
||
|
self.stack = ExitStack()
|
||
|
|
||
|
def __enter__(self):
|
||
|
for context_manager in self.context_managers:
|
||
|
self.stack.enter_context(context_manager)
|
||
|
|
||
|
def __exit__(self, *args, **kwargs):
|
||
|
self.stack.__exit__(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def can_return_loss(model_class):
|
||
|
"""
|
||
|
Check if a given model can return loss.
|
||
|
|
||
|
Args:
|
||
|
model_class (`type`): The class of the model.
|
||
|
"""
|
||
|
framework = infer_framework(model_class)
|
||
|
if framework == "tf":
|
||
|
signature = inspect.signature(model_class.call) # TensorFlow models
|
||
|
elif framework == "pt":
|
||
|
signature = inspect.signature(model_class.forward) # PyTorch models
|
||
|
else:
|
||
|
signature = inspect.signature(model_class.__call__) # Flax models
|
||
|
|
||
|
for p in signature.parameters:
|
||
|
if p == "return_loss" and signature.parameters[p].default is True:
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
def find_labels(model_class):
|
||
|
"""
|
||
|
Find the labels used by a given model.
|
||
|
|
||
|
Args:
|
||
|
model_class (`type`): The class of the model.
|
||
|
"""
|
||
|
model_name = model_class.__name__
|
||
|
framework = infer_framework(model_class)
|
||
|
if framework == "tf":
|
||
|
signature = inspect.signature(model_class.call) # TensorFlow models
|
||
|
elif framework == "pt":
|
||
|
signature = inspect.signature(model_class.forward) # PyTorch models
|
||
|
else:
|
||
|
signature = inspect.signature(model_class.__call__) # Flax models
|
||
|
|
||
|
if "QuestionAnswering" in model_name:
|
||
|
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
|
||
|
else:
|
||
|
return [p for p in signature.parameters if "label" in p]
|
||
|
|
||
|
|
||
|
def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
|
||
|
"""Flatten a nested dict into a single level dict."""
|
||
|
|
||
|
def _flatten_dict(d, parent_key="", delimiter="."):
|
||
|
for k, v in d.items():
|
||
|
key = str(parent_key) + delimiter + str(k) if parent_key else k
|
||
|
if v and isinstance(v, MutableMapping):
|
||
|
yield from flatten_dict(v, key, delimiter=delimiter).items()
|
||
|
else:
|
||
|
yield key, v
|
||
|
|
||
|
return dict(_flatten_dict(d, parent_key, delimiter))
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
|
||
|
if use_temp_dir:
|
||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||
|
yield tmp_dir
|
||
|
else:
|
||
|
yield working_dir
|
||
|
|
||
|
|
||
|
def transpose(array, axes=None):
|
||
|
"""
|
||
|
Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||
|
arrays.
|
||
|
"""
|
||
|
if is_numpy_array(array):
|
||
|
return np.transpose(array, axes=axes)
|
||
|
elif is_torch_tensor(array):
|
||
|
return array.T if axes is None else array.permute(*axes)
|
||
|
elif is_tf_tensor(array):
|
||
|
import tensorflow as tf
|
||
|
|
||
|
return tf.transpose(array, perm=axes)
|
||
|
elif is_jax_tensor(array):
|
||
|
return jnp.transpose(array, axes=axes)
|
||
|
else:
|
||
|
raise ValueError(f"Type not supported for transpose: {type(array)}.")
|
||
|
|
||
|
|
||
|
def reshape(array, newshape):
|
||
|
"""
|
||
|
Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||
|
arrays.
|
||
|
"""
|
||
|
if is_numpy_array(array):
|
||
|
return np.reshape(array, newshape)
|
||
|
elif is_torch_tensor(array):
|
||
|
return array.reshape(*newshape)
|
||
|
elif is_tf_tensor(array):
|
||
|
import tensorflow as tf
|
||
|
|
||
|
return tf.reshape(array, newshape)
|
||
|
elif is_jax_tensor(array):
|
||
|
return jnp.reshape(array, newshape)
|
||
|
else:
|
||
|
raise ValueError(f"Type not supported for reshape: {type(array)}.")
|
||
|
|
||
|
|
||
|
def squeeze(array, axis=None):
|
||
|
"""
|
||
|
Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||
|
arrays.
|
||
|
"""
|
||
|
if is_numpy_array(array):
|
||
|
return np.squeeze(array, axis=axis)
|
||
|
elif is_torch_tensor(array):
|
||
|
return array.squeeze() if axis is None else array.squeeze(dim=axis)
|
||
|
elif is_tf_tensor(array):
|
||
|
import tensorflow as tf
|
||
|
|
||
|
return tf.squeeze(array, axis=axis)
|
||
|
elif is_jax_tensor(array):
|
||
|
return jnp.squeeze(array, axis=axis)
|
||
|
else:
|
||
|
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
|
||
|
|
||
|
|
||
|
def expand_dims(array, axis):
|
||
|
"""
|
||
|
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
|
||
|
arrays.
|
||
|
"""
|
||
|
if is_numpy_array(array):
|
||
|
return np.expand_dims(array, axis)
|
||
|
elif is_torch_tensor(array):
|
||
|
return array.unsqueeze(dim=axis)
|
||
|
elif is_tf_tensor(array):
|
||
|
import tensorflow as tf
|
||
|
|
||
|
return tf.expand_dims(array, axis=axis)
|
||
|
elif is_jax_tensor(array):
|
||
|
return jnp.expand_dims(array, axis=axis)
|
||
|
else:
|
||
|
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
|
||
|
|
||
|
|
||
|
def tensor_size(array):
|
||
|
"""
|
||
|
Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
|
||
|
"""
|
||
|
if is_numpy_array(array):
|
||
|
return np.size(array)
|
||
|
elif is_torch_tensor(array):
|
||
|
return array.numel()
|
||
|
elif is_tf_tensor(array):
|
||
|
import tensorflow as tf
|
||
|
|
||
|
return tf.size(array)
|
||
|
elif is_jax_tensor(array):
|
||
|
return array.size
|
||
|
else:
|
||
|
raise ValueError(f"Type not supported for tensor_size: {type(array)}.")
|
||
|
|
||
|
|
||
|
def add_model_info_to_auto_map(auto_map, repo_id):
|
||
|
"""
|
||
|
Adds the information of the repo_id to a given auto map.
|
||
|
"""
|
||
|
for key, value in auto_map.items():
|
||
|
if isinstance(value, (tuple, list)):
|
||
|
auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value]
|
||
|
elif value is not None and "--" not in value:
|
||
|
auto_map[key] = f"{repo_id}--{value}"
|
||
|
|
||
|
return auto_map
|
||
|
|
||
|
|
||
|
def infer_framework(model_class):
|
||
|
"""
|
||
|
Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
|
||
|
classes are imported or available.
|
||
|
"""
|
||
|
for base_class in inspect.getmro(model_class):
|
||
|
module = base_class.__module__
|
||
|
name = base_class.__name__
|
||
|
if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel":
|
||
|
return "tf"
|
||
|
elif module.startswith("torch") or name == "PreTrainedModel":
|
||
|
return "pt"
|
||
|
elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel":
|
||
|
return "flax"
|
||
|
else:
|
||
|
raise TypeError(f"Could not infer framework from class {model_class}.")
|