""" Backends in `einops` are organized to meet the following requirements - backends are not imported unless those are actually needed, because - backends may not be installed - importing all available backends will drive to significant memory footprint - backends may be present but installed with errors (but never used), importing may drive to crashes - backend should be either symbolic or imperative - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined - if backend can't provide symbols for shape dimensions, UnknownSize objects are used """ import sys __author__ = "Alex Rogozhnikov" _loaded_backends: dict = {} _type2backend: dict = {} _debug_importing = False def get_backend(tensor) -> "AbstractBackend": """ Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor. If needed, imports package and creates backend """ _type = type(tensor) _result = _type2backend.get(_type, None) if _result is not None: return _result for framework_name, backend in list(_loaded_backends.items()): if backend.is_appropriate_type(tensor): _type2backend[_type] = backend return backend # Find backend subclasses recursively backend_subclasses = [] backends = AbstractBackend.__subclasses__() while backends: backend = backends.pop() backends += backend.__subclasses__() backend_subclasses.append(backend) for BackendSubclass in backend_subclasses: if _debug_importing: print("Testing for subclass of ", BackendSubclass) if BackendSubclass.framework_name not in _loaded_backends: # check that module was already imported. Otherwise it can't be imported if BackendSubclass.framework_name in sys.modules: if _debug_importing: print("Imported backend for ", BackendSubclass.framework_name) backend = BackendSubclass() _loaded_backends[backend.framework_name] = backend if backend.is_appropriate_type(tensor): _type2backend[_type] = backend return backend raise RuntimeError("Tensor type unknown to einops {}".format(type(tensor))) class AbstractBackend: """Base backend class, major part of methods are only for debugging purposes.""" framework_name: str def is_appropriate_type(self, tensor): """helper method should recognize tensors it can handle""" raise NotImplementedError() def from_numpy(self, x): raise NotImplementedError("framework doesn't support imperative execution") def to_numpy(self, x): raise NotImplementedError("framework doesn't support imperative execution") def create_symbol(self, shape): raise NotImplementedError("framework doesn't support symbolic computations") def eval_symbol(self, symbol, input_dict): raise NotImplementedError("framework doesn't support symbolic computations") def arange(self, start, stop): # supplementary method used only in testing, so should implement CPU version raise NotImplementedError("framework doesn't implement arange") def shape(self, x): """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)""" return x.shape def reshape(self, x, shape): return x.reshape(shape) def transpose(self, x, axes): return x.transpose(axes) def reduce(self, x, operation, axes): return getattr(x, operation)(axis=axes) def stack_on_zeroth_dimension(self, tensors: list): raise NotImplementedError() def add_axis(self, x, new_position): raise NotImplementedError() def add_axes(self, x, n_axes, pos2len): repeats = [1] * n_axes for axis_position, axis_length in pos2len.items(): x = self.add_axis(x, axis_position) repeats[axis_position] = axis_length return self.tile(x, tuple(repeats)) def tile(self, x, repeats): """repeats - same lengths as x.shape""" raise NotImplementedError() def concat(self, tensors, axis: int): """concatenates tensors along axis. Assume identical across tensors: devices, dtypes and shapes except selected axis.""" raise NotImplementedError() def is_float_type(self, x): # some backends (torch) can't compute average for non-floating types. # Decided to drop average for all backends if type is not floating raise NotImplementedError() def layers(self): raise NotImplementedError("backend does not provide layers") def __repr__(self): return "".format(self.framework_name) def einsum(self, pattern, *x): raise NotImplementedError("backend does not support einsum") class UnknownSize: """pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements""" def __floordiv__(self, other): return self def __eq__(self, other): return True # we don't know actual size def __mul__(self, other): return self def __rmul__(self, other): return self def __hash__(self): return hash(None) class NumpyBackend(AbstractBackend): framework_name = "numpy" def __init__(self): import numpy self.np = numpy def is_appropriate_type(self, tensor): return isinstance(tensor, self.np.ndarray) def from_numpy(self, x): return x def to_numpy(self, x): return x def arange(self, start, stop): return self.np.arange(start, stop) def stack_on_zeroth_dimension(self, tensors: list): return self.np.stack(tensors) def tile(self, x, repeats): return self.np.tile(x, repeats) def concat(self, tensors, axis: int): return self.np.concatenate(tensors, axis=axis) def is_float_type(self, x): return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") def add_axis(self, x, new_position): return self.np.expand_dims(x, new_position) def einsum(self, pattern, *x): return self.np.einsum(pattern, *x) class JaxBackend(NumpyBackend): framework_name = "jax" def __init__(self): super(JaxBackend, self).__init__() self.onp = self.np import jax.numpy self.np = jax.numpy def from_numpy(self, x): return self.np.asarray(x) def to_numpy(self, x): return self.onp.asarray(x) class TorchBackend(AbstractBackend): framework_name = "torch" def __init__(self): import torch self.torch = torch # importing would register operations in torch._dynamo for torch.compile from . import _torch_specific # noqa def is_appropriate_type(self, tensor): return isinstance(tensor, self.torch.Tensor) def from_numpy(self, x): variable = self.torch.from_numpy(x) if self.is_float_type(variable): # attach grad only to floating types variable.requires_grad = True return variable def to_numpy(self, x): return x.detach().cpu().numpy() def arange(self, start, stop): return self.torch.arange(start, stop, dtype=self.torch.int64) def reduce(self, x, operation, reduced_axes): if operation == "min": return x.amin(dim=reduced_axes) elif operation == "max": return x.amax(dim=reduced_axes) elif operation == "sum": return x.sum(dim=reduced_axes) elif operation == "mean": return x.mean(dim=reduced_axes) elif operation in ("any", "all", "prod"): # pytorch supports reducing only one operation at a time for i in list(sorted(reduced_axes))[::-1]: x = getattr(x, operation)(dim=i) return x else: raise NotImplementedError("Unknown reduction ", operation) def transpose(self, x, axes): return x.permute(axes) def stack_on_zeroth_dimension(self, tensors: list): return self.torch.stack(tensors) def add_axes(self, x, n_axes, pos2len): repeats = [-1] * n_axes for axis_position, axis_length in pos2len.items(): x = self.add_axis(x, axis_position) repeats[axis_position] = axis_length return x.expand(repeats) def tile(self, x, repeats): return x.repeat(repeats) def concat(self, tensors, axis: int): return self.torch.cat(tensors, dim=axis) def add_axis(self, x, new_position): return self.torch.unsqueeze(x, new_position) def is_float_type(self, x): return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16] def layers(self): from .layers import torch return torch def einsum(self, pattern, *x): return self.torch.einsum(pattern, *x) class CupyBackend(AbstractBackend): framework_name = "cupy" def __init__(self): import cupy self.cupy = cupy def is_appropriate_type(self, tensor): return isinstance(tensor, self.cupy.ndarray) def from_numpy(self, x): return self.cupy.asarray(x) def to_numpy(self, x): return self.cupy.asnumpy(x) def arange(self, start, stop): return self.cupy.arange(start, stop) def stack_on_zeroth_dimension(self, tensors: list): return self.cupy.stack(tensors) def tile(self, x, repeats): return self.cupy.tile(x, repeats) def concat(self, tensors, axis: int): return self.cupy.concatenate(tensors, axis=axis) def add_axis(self, x, new_position): return self.cupy.expand_dims(x, new_position) def is_float_type(self, x): return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") def einsum(self, pattern, *x): return self.cupy.einsum(pattern, *x) class ChainerBackend(AbstractBackend): framework_name = "chainer" def __init__(self): import chainer import numpy self.numpy = numpy self.chainer = chainer def is_appropriate_type(self, tensor): return isinstance(tensor, self.chainer.Variable) def from_numpy(self, x): return self.chainer.Variable(x.astype("float32")) def to_numpy(self, x): if isinstance(x, self.chainer.Variable): x = x.data return x def arange(self, start, stop): return self.numpy.arange(start, stop) def reduce(self, x, operation, axes): return getattr(self.chainer.functions, operation)(x, axis=axes) def stack_on_zeroth_dimension(self, tensors: list): return self.chainer.functions.stack(tensors) def tile(self, x, repeats): return self.chainer.functions.tile(x, repeats) def concat(self, tensors, axis: int): return self.chainer.functions.concat(tensors, axis=axis) def add_axis(self, x, new_position): return self.chainer.functions.expand_dims(x, new_position) def is_float_type(self, x): return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") def layers(self): from .layers import chainer return chainer def einsum(self, pattern, *x): return self.chainer.functions.einsum(pattern, *x) class HashableTuple: """Overcomes non-hashability of symbolic elements""" def __init__(self, elements: tuple): self.elements = elements def __iter__(self): for x in self.elements: yield x def __len__(self): return len(self.elements) def __getitem__(self, item): return self.elements[item] # default equality and hash is used (True only with itself, hash taken of id) class TensorflowBackend(AbstractBackend): framework_name = "tensorflow" def __init__(self): import tensorflow self.tf = tensorflow def is_appropriate_type(self, tensor): return isinstance(tensor, (self.tf.Tensor, self.tf.Variable)) def from_numpy(self, x): assert self.tf.executing_eagerly() return self.tf.convert_to_tensor(x) def to_numpy(self, x): assert self.tf.executing_eagerly() return x.numpy() def arange(self, start, stop): return self.tf.range(start, stop) def shape(self, x): if self.tf.executing_eagerly(): return tuple(UnknownSize() if d is None else int(d) for d in x.shape) else: static_shape = x.shape.as_list() tf_shape = self.tf.shape(x) # use the static shape where known, otherwise use the TF shape components shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)]) try: hash(shape) return shape except BaseException: # unhashable symbols in shape. Wrap tuple to be hashable. return HashableTuple(shape) def reduce(self, x, operation, axes): return getattr(self.tf, "reduce_" + operation)(x, axis=axes) def reshape(self, x, shape): return self.tf.reshape(x, shape) def transpose(self, x, axes): return self.tf.transpose(x, axes) def stack_on_zeroth_dimension(self, tensors: list): return self.tf.stack(tensors) def tile(self, x, repeats): return self.tf.tile(x, repeats) def concat(self, tensors, axis: int): return self.tf.concat(tensors, axis=axis) def add_axis(self, x, new_position): return self.tf.expand_dims(x, new_position) def is_float_type(self, x): return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") def layers(self): from .layers import tensorflow return tensorflow def einsum(self, pattern, *x): return self.tf.einsum(pattern, *x) class TFKerasBackend(AbstractBackend): framework_name = "tensorflow.keras" def __init__(self): import tensorflow as tf self.tf = tf self.keras = tf.keras self.K = tf.keras.backend def is_appropriate_type(self, tensor): return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor) def create_symbol(self, shape): return self.keras.Input(batch_shape=shape) def eval_symbol(self, symbol, input_dict): model = self.keras.models.Model([var for (var, _) in input_dict], symbol) return model.predict_on_batch([val for (_, val) in input_dict]) def arange(self, start, stop): return self.K.arange(start, stop) def shape(self, x): shape = self.K.shape(x) # tf tensor return HashableTuple(tuple(shape)) def reduce(self, x, operation, axes): return getattr(self.K, operation)(x, axis=axes) def reshape(self, x, shape): return self.K.reshape(x, shape) def transpose(self, x, axes): return self.K.permute_dimensions(x, axes) def stack_on_zeroth_dimension(self, tensors: list): return self.K.stack(tensors) def tile(self, x, repeats): return self.K.tile(x, repeats) def concat(self, tensors, axis: int): return self.K.concatenate(tensors, axis=axis) def add_axis(self, x, new_position): return self.K.expand_dims(x, new_position) def is_float_type(self, x): return "float" in self.K.dtype(x) def layers(self): from .layers import keras return keras class OneFlowBackend(AbstractBackend): framework_name = "oneflow" def __init__(self): import oneflow as flow self.flow = flow def is_appropriate_type(self, tensor): return isinstance(tensor, self.flow.Tensor) def from_numpy(self, x): variable = self.flow.from_numpy(x) if self.is_float_type(variable): # attach grad only to floating types variable.requires_grad = True return variable def to_numpy(self, x): return x.detach().cpu().numpy() def arange(self, start, stop): return self.flow.arange(start, stop, dtype=self.flow.int64) def reduce(self, x, operation, reduced_axes): for axis in sorted(reduced_axes, reverse=True): if operation == "min": x, _ = x.min(dim=axis) elif operation == "max": x, _ = x.max(dim=axis) elif operation in ["sum", "mean", "prod", "any", "all"]: x = getattr(x, operation)(dim=axis) else: raise NotImplementedError("Unknown reduction ", operation) return x def transpose(self, x, axes): return x.permute(axes) def stack_on_zeroth_dimension(self, tensors: list): return self.flow.stack(tensors) def add_axes(self, x, n_axes, pos2len): repeats = [-1] * n_axes for axis_position, axis_length in pos2len.items(): x = self.add_axis(x, axis_position) repeats[axis_position] = axis_length return x.expand(*repeats) def tile(self, x, repeats): return x.repeat(repeats) def concat(self, tensors, axis: int): return self.flow.concat(tensors, dim=axis) def add_axis(self, x, new_position): return self.flow.unsqueeze(x, new_position) def is_float_type(self, x): return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64] def layers(self): from .layers import oneflow return oneflow def einsum(self, pattern, *x): return self.flow.einsum(pattern, *x) class PaddleBackend(AbstractBackend): framework_name = "paddle" def __init__(self): import paddle self.paddle = paddle def is_appropriate_type(self, tensor): return isinstance(tensor, (self.paddle.Tensor, self.paddle.static.Variable)) def from_numpy(self, x): tensor = self.paddle.to_tensor(x) tensor.stop_gradient = False return tensor def to_numpy(self, x): return x.detach().numpy() def arange(self, start, stop): return self.paddle.arange(start, stop, dtype=self.paddle.int64) def reduce(self, x, operation, axes): if len(axes) == x.ndim: # currently paddle returns 1d tensor instead of 0d return super().reduce(x, operation, axes).squeeze(0) else: return super().reduce(x, operation, axes) def transpose(self, x, axes): return x.transpose(axes) def add_axes(self, x, n_axes, pos2len): repeats = [-1] * n_axes for axis_position, axis_length in pos2len.items(): x = self.add_axis(x, axis_position) repeats[axis_position] = axis_length return x.expand(repeats) def stack_on_zeroth_dimension(self, tensors: list): return self.paddle.stack(tensors) def reshape(self, x, shape): return x.reshape(shape) def tile(self, x, repeats): return x.tile(repeats) def concat(self, tensors, axis: int): return self.paddle.concat(tensors, axis=axis) def add_axis(self, x, new_position): return x.unsqueeze(new_position) def is_float_type(self, x): return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64] def layers(self): from .layers import paddle return paddle def einsum(self, pattern, *x): return self.paddle.einsum(pattern, *x) def shape(self, x): return tuple(x.shape) class TinygradBackend(AbstractBackend): framework_name = "tinygrad" def __init__(self): import tinygrad self.tinygrad = tinygrad def is_appropriate_type(self, tensor): return isinstance(tensor, self.tinygrad.Tensor) def from_numpy(self, x): return self.tinygrad.Tensor(x) def to_numpy(self, x): return x.numpy() def arange(self, start, stop): return self.tinygrad.Tensor.arange(start, stop) def shape(self, x): return x.shape def reshape(self, x, shape): return x.reshape(shape) def transpose(self, x, axes): return x.permute(axes) def reduce(self, x, operation, axes): for axis in sorted(axes, reverse=True): x = getattr(x, operation)(axis=axis) return x def stack_on_zeroth_dimension(self, tensors: list): return self.tinygrad.Tensor.stack(tensors) def add_axis(self, x, new_position): return x.unsqueeze(new_position) def tile(self, x, repeats): return x.repeat(repeats) def concat(self, tensors, axis: int): return tensors[0].cat(tensors[1:], axis) if len(tensors) > 1 else tensors[0] def is_float_type(self, x): return self.tinygrad.dtypes.is_float(x.dtype) def einsum(self, pattern, *x): return self.tinygrad.Tensor.einsum(pattern, *x)