# coding=utf-8 # Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao, # Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team. # Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import decimal import numpy as np import torch from torch import nn from torch.autograd import Function from ...utils import logging logger = logging.get_logger(__name__) class QuantEmbedding(nn.Module): """ Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`. Args: weight_bit (`int`, *optional*, defaults to `8`): Bitwidth for the quantized weight. momentum (`float`, *optional*, defaults to `0.95`): Momentum for updating the activation quantization range. quant_mode (`bool`, *optional*, defaults to `False`): Whether or not the layer is quantized. """ def __init__( self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, weight_bit=8, momentum=0.95, quant_mode=False, ): super().__init__() self.num_ = num_embeddings self.dim = embedding_dim self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.sparse = sparse self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim])) self.register_buffer("weight_scaling_factor", torch.zeros(1)) self.register_buffer("weight_integer", torch.zeros_like(self.weight)) self.weight_bit = weight_bit self.momentum = momentum self.quant_mode = quant_mode self.percentile_mode = False self.weight_function = SymmetricQuantFunction.apply def forward(self, x, positions=None, incremental_state=None): if not self.quant_mode: return ( nn.functional.embedding( x, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, ), None, ) w = self.weight w_transform = w.data.detach() w_min = w_transform.min().expand(1) w_max = w_transform.max().expand(1) self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False) self.weight_integer = self.weight_function( self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor ) emb_int = nn.functional.embedding( x, self.weight_integer, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, ) return emb_int * self.weight_scaling_factor, self.weight_scaling_factor class QuantAct(nn.Module): """ Quantizes the given activation. Args: activation_bit (`int`): Bitwidth for the quantized activation. act_range_momentum (`float`, *optional*, defaults to `0.95`): Momentum for updating the activation quantization range. per_channel (`bool`, *optional*, defaults to `False`): Whether to or not use channel-wise quantization. channel_len (`int`, *optional*): Specify the channel length when set the *per_channel* True. quant_mode (`bool`, *optional*, defaults to `False`): Whether or not the layer is quantized. """ def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False): super().__init__() self.activation_bit = activation_bit self.act_range_momentum = act_range_momentum self.quant_mode = quant_mode self.per_channel = per_channel self.percentile = False self.act_function = SymmetricQuantFunction.apply if not self.per_channel: self.register_buffer("x_min", torch.zeros(1)) self.register_buffer("x_max", torch.zeros(1)) self.register_buffer("act_scaling_factor", torch.zeros(1)) self.x_min -= 1e-5 self.x_max += 1e-5 else: raise NotImplementedError("per-channel mode is not currently supported for activation.") def __repr__(self): return ( f"{self.__class__.__name__}(activation_bit={self.activation_bit}, " f"quant_mode: {self.quant_mode}, Act_min: {self.x_min.item():.2f}, " f"Act_max: {self.x_max.item():.2f})" ) def forward( self, x, pre_act_scaling_factor=None, identity=None, identity_scaling_factor=None, specified_min=None, specified_max=None, ): x_act = x if identity is None else identity + x # collect running stats if training if self.training: assert not self.percentile, "percentile mode is not currently supported for activation." assert not self.per_channel, "per-channel mode is not currently supported for activation." x_min = x_act.data.min() x_max = x_act.data.max() assert ( x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0 ), "NaN detected when computing min/max of the activation" # Initialization if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5: self.x_min = self.x_min + x_min self.x_max = self.x_max + x_max # exponential moving average (EMA) # use momentum to prevent the quantized values change greatly every iteration elif self.act_range_momentum == -1: self.x_min = torch.min(self.x_min, x_min) self.x_max = torch.max(self.x_max, x_max) else: self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum) self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum) if not self.quant_mode: return x_act, None x_min = self.x_min if specified_min is None else specified_min x_max = self.x_max if specified_max is None else specified_max self.act_scaling_factor = symmetric_linear_quantization_params( self.activation_bit, x_min, x_max, per_channel=self.per_channel ) if pre_act_scaling_factor is None: # this is for the input quantization quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor) else: quant_act_int = FixedPointMul.apply( x, pre_act_scaling_factor, self.activation_bit, self.act_scaling_factor, identity, identity_scaling_factor, ) correct_output_scale = self.act_scaling_factor.view(-1) return quant_act_int * correct_output_scale, self.act_scaling_factor class QuantLinear(nn.Module): """ Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`. Args: weight_bit (`int`, *optional*, defaults to `8`): Bitwidth for the quantized weight. bias_bit (`int`, *optional*, defaults to `32`): Bitwidth for the quantized bias. per_channel (`bool`, *optional*, defaults to `False`): Whether or not to use channel-wise quantization. quant_mode (`bool`, *optional*, defaults to `False`): Whether or not the layer is quantized. """ def __init__( self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False ): super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.zeros([out_features, in_features])) self.register_buffer("weight_integer", torch.zeros_like(self.weight)) self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features)) if bias: self.bias = nn.Parameter(torch.zeros(out_features)) self.register_buffer("bias_integer", torch.zeros_like(self.bias)) self.weight_bit = weight_bit self.quant_mode = quant_mode self.per_channel = per_channel self.bias_bit = bias_bit self.quant_mode = quant_mode self.percentile_mode = False self.weight_function = SymmetricQuantFunction.apply def __repr__(self): s = super().__repr__() s = f"({s} weight_bit={self.weight_bit}, quant_mode={self.quant_mode})" return s def forward(self, x, prev_act_scaling_factor=None): if not self.quant_mode: return nn.functional.linear(x, weight=self.weight, bias=self.bias), None # assert that prev_act_scaling_factor is a scalar tensor assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), ( "Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. " "Please add a QuantAct layer with `per_channel = True` before this QuantAct layer" ) w = self.weight w_transform = w.data.detach() if self.per_channel: w_min, _ = torch.min(w_transform, dim=1, out=None) w_max, _ = torch.max(w_transform, dim=1, out=None) else: w_min = w_transform.min().expand(1) w_max = w_transform.max().expand(1) self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel) self.weight_integer = self.weight_function( self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor ) bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor if self.bias is not None: self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor) prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1) x_int = x / prev_act_scaling_factor return ( nn.functional.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor, bias_scaling_factor, ) class IntGELU(nn.Module): """ Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`. Args: quant_mode (`bool`, *optional*, defaults to `False`): Whether or not the layer is quantized. force_dequant (`str`, *optional*, defaults to `"none"`): Force dequantize the layer if either "gelu" or "nonlinear" is given. """ def __init__(self, quant_mode=True, force_dequant="none"): super().__init__() self.quant_mode = quant_mode if force_dequant in ["nonlinear", "gelu"]: logger.info("Force dequantize gelu") self.quant_mode = False if not self.quant_mode: self.activation_fn = nn.GELU() self.k = 1.4142 self.const = 14 # dummy integer constant self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c self.coeff[2] /= self.coeff[0] def int_erf(self, x_int, scaling_factor): b_int = torch.floor(self.coeff[1] / scaling_factor) c_int = torch.floor(self.coeff[2] / scaling_factor**2) sign = torch.sign(x_int) abs_int = torch.min(torch.abs(x_int), -b_int) y_int = sign * ((abs_int + b_int) ** 2 + c_int) scaling_factor = scaling_factor**2 * self.coeff[0] # avoid overflow y_int = floor_ste.apply(y_int / 2**self.const) scaling_factor = scaling_factor * 2**self.const return y_int, scaling_factor def forward(self, x, scaling_factor=None): if not self.quant_mode: return self.activation_fn(x), None x_int = x / scaling_factor sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k) shift_int = 1.0 // sigmoid_scaling_factor x_int = x_int * (sigmoid_int + shift_int) scaling_factor = scaling_factor * sigmoid_scaling_factor / 2 return x_int * scaling_factor, scaling_factor class IntSoftmax(nn.Module): """ Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`. Args: output_bit (`int`): Bitwidth for the layer output activation. quant_mode (`bool`, *optional*, defaults to `False`): Whether or not the layer is quantized. force_dequant (`str`, *optional*, defaults to `"none"`): Force dequantize the layer if either "softmax" or "nonlinear" is given. """ def __init__(self, output_bit, quant_mode=False, force_dequant="none"): super().__init__() self.output_bit = output_bit self.max_bit = 32 self.quant_mode = quant_mode if force_dequant in ["nonlinear", "softmax"]: logger.info("Force dequantize softmax") self.quant_mode = False self.act = QuantAct(16, quant_mode=self.quant_mode) self.x0 = -0.6931 # -ln2 self.const = 30 # dummy integer constant self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c self.coef[1] /= self.coef[0] self.coef[2] /= self.coef[0] def int_polynomial(self, x_int, scaling_factor): with torch.no_grad(): b_int = torch.floor(self.coef[1] / scaling_factor) c_int = torch.floor(self.coef[2] / scaling_factor**2) z = (x_int + b_int) * x_int + c_int scaling_factor = self.coef[0] * scaling_factor**2 return z, scaling_factor def int_exp(self, x_int, scaling_factor): with torch.no_grad(): x0_int = torch.floor(self.x0 / scaling_factor) x_int = torch.max(x_int, self.const * x0_int) q = floor_ste.apply(x_int / x0_int) r = x_int - x0_int * q exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor) exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0) scaling_factor = exp_scaling_factor / 2**self.const return exp_int, scaling_factor def forward(self, x, scaling_factor): if not self.quant_mode: return nn.functional.softmax(x, dim=-1), None x_int = x / scaling_factor x_int_max, _ = x_int.max(dim=-1, keepdim=True) x_int = x_int - x_int_max exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor) # Avoid overflow exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor) exp_int = exp / exp_scaling_factor exp_int_sum = exp_int.sum(dim=-1, keepdim=True) factor = floor_ste.apply(2**self.max_bit / exp_int_sum) exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit)) scaling_factor = 1 / 2**self.output_bit return exp_int * scaling_factor, scaling_factor class IntLayerNorm(nn.Module): """ Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`. Args: output_bit (`int`, *optional*, defaults to `8`): Bitwidth for the layer output activation. quant_mode (`bool`, *optional*, defaults to `False`): Whether or not the layer is quantized. force_dequant (`str`, *optional*, defaults to `"none"`): Force dequantize the layer if either "layernorm" or "nonlinear" is given. """ def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"): super().__init__() self.normalized_shape = normalized_shape self.eps = eps self.weight = nn.Parameter(torch.zeros(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.quant_mode = quant_mode if force_dequant in ["nonlinear", "layernorm"]: logger.info("Force dequantize layernorm") self.quant_mode = False self.register_buffer("shift", torch.zeros(1)) self.output_bit = output_bit self.max_bit = 32 self.dim_sqrt = None self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode) def set_shift(self, y_int): with torch.no_grad(): y_sq_int = y_int**2 var_int = torch.sum(y_sq_int, axis=2, keepdim=True) shift = (torch.log2(torch.sqrt(var_int / 2**self.max_bit)).ceil()).max() shift_old = self.shift self.shift = torch.max(self.shift, shift) logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}") def overflow_fallback(self, y_int): """ This fallback function is called when overflow is detected during training time, and adjusts the `self.shift` to avoid overflow in the subsequent runs. """ self.set_shift(y_int) # adjusts `self.shift` y_int_shifted = floor_ste.apply(y_int / 2**self.shift) y_sq_int = y_int_shifted**2 var_int = torch.sum(y_sq_int, axis=2, keepdim=True) return var_int def forward(self, x, scaling_factor=None): if not self.quant_mode: mean = x.mean(axis=2, keepdim=True) y = x - mean var = torch.mean(y**2, axis=2, keepdim=True) x = y / torch.sqrt(self.eps + var) x = x * self.weight + self.bias return x, None # compute sqrt of the feature dimension if it is the first run if self.dim_sqrt is None: n = torch.tensor(x.shape[2], dtype=torch.float) self.dim_sqrt = torch.sqrt(n).to(x.device) # Normalization: computes mean and variance(std) x_int = x / scaling_factor mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True)) y_int = x_int - mean_int y_int_shifted = floor_ste.apply(y_int / 2**self.shift) y_sq_int = y_int_shifted**2 var_int = torch.sum(y_sq_int, axis=2, keepdim=True) # overflow handling in training time if self.training: # if overflow is detected if var_int.max() >= 2**self.max_bit: var_int = self.overflow_fallback(y_int) assert var_int.max() < 2**self.max_bit + 0.1, ( "Error detected in overflow handling: " "`var_int` exceeds `self.max_bit` (the maximum possible bit width)" ) # To be replaced with integer-sqrt kernel that produces the same output std_int = floor_ste.apply(torch.sqrt(var_int)) * 2**self.shift factor = floor_ste.apply(2**31 / std_int) y_int = floor_ste.apply(y_int * factor / 2) scaling_factor = self.dim_sqrt / 2**30 # scaling and shifting bias = self.bias.data.detach() / (self.weight.data.detach()) bias_int = floor_ste.apply(bias / scaling_factor) y_int = y_int + bias_int scaling_factor = scaling_factor * self.weight x = y_int * scaling_factor return x, scaling_factor def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False): """ Calculate the percentile max and min values in a given tensor Args: input (`torch.Tensor`): The target tensor to calculate percentile max and min. lower_percentile (`float`): If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min. upper_percentile (`float`): If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max. output_tensor (`bool`, *optional*, defaults to `False`): If True, this function returns tensors, otherwise it returns values. Returns: `Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input* """ input_length = input.shape[0] lower_index = round(input_length * (1 - lower_percentile * 0.01)) upper_index = round(input_length * upper_percentile * 0.01) upper_bound = torch.kthvalue(input, k=upper_index).values if lower_percentile == 0: lower_bound = upper_bound * 0 # lower_index += 1 else: lower_bound = -torch.kthvalue(-input, k=lower_index).values if not output_tensor: lower_bound = lower_bound.item() upper_bound = upper_bound.item() return lower_bound, upper_bound def linear_quantize(input, scale, zero_point, inplace=False): """ Quantize single-precision input tensor to integers with the given scaling factor and zeropoint. Args: input (`torch.Tensor`): Single-precision input tensor to be quantized. scale (`torch.Tensor`): Scaling factor for quantization. zero_pint (`torch.Tensor`): Shift for quantization. inplace (`bool`, *optional*, defaults to `False`): Whether to compute inplace or not. Returns: `torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*. """ # reshape scale and zeropoint for convolutional weights and activation if len(input.shape) == 4: scale = scale.view(-1, 1, 1, 1) zero_point = zero_point.view(-1, 1, 1, 1) # reshape scale and zeropoint for linear weights elif len(input.shape) == 2: scale = scale.view(-1, 1) zero_point = zero_point.view(-1, 1) else: scale = scale.view(-1) zero_point = zero_point.view(-1) # quantized = float / scale + zero_point if inplace: input.mul_(1.0 / scale).add_(zero_point).round_() return input return torch.round(1.0 / scale * input + zero_point) def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False): """ Compute the scaling factor with the given quantization range for symmetric quantization. Args: saturation_min (`torch.Tensor`): Lower bound for quantization range. saturation_max (`torch.Tensor`): Upper bound for quantization range. per_channel (`bool`, *optional*, defaults to `False`): Whether to or not use channel-wise quantization. Returns: `torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and *saturation_max*. """ # in this part, we do not need any gradient computation, # in order to enforce this, we put torch.no_grad() with torch.no_grad(): n = 2 ** (num_bits - 1) - 1 if per_channel: scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1) scale = torch.clamp(scale, min=1e-8) / n else: scale = max(saturation_min.abs(), saturation_max.abs()) scale = torch.clamp(scale, min=1e-8) / n return scale class SymmetricQuantFunction(Function): """ Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth. """ @staticmethod def forward(ctx, x, k, percentile_mode, scale): """ Args: x (`torch.Tensor`): Floating point tensor to be quantized. k (`int`): Quantization bitwidth. percentile_mode (`bool`): Whether or not to use percentile calibration. scale (`torch.Tensor`): Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction requires pre-calculated scaling factor. Returns: `torch.Tensor`: Symmetric-quantized value of *input*. """ zero_point = torch.tensor(0.0).to(scale.device) n = 2 ** (k - 1) - 1 new_quant_x = linear_quantize(x, scale, zero_point, inplace=False) new_quant_x = torch.clamp(new_quant_x, -n, n - 1) ctx.scale = scale return new_quant_x @staticmethod def backward(ctx, grad_output): scale = ctx.scale if len(grad_output.shape) == 4: scale = scale.view(-1, 1, 1, 1) # reshape scale and zeropoint for linear weights elif len(grad_output.shape) == 2: scale = scale.view(-1, 1) else: scale = scale.view(-1) return grad_output.clone() / scale, None, None, None, None class floor_ste(Function): """ Straight-through Estimator(STE) for torch.floor() """ @staticmethod def forward(ctx, x): return torch.floor(x) @staticmethod def backward(ctx, grad_output): return grad_output.clone() class round_ste(Function): """ Straight-through Estimator(STE) for torch.round() """ @staticmethod def forward(ctx, x): return torch.round(x) @staticmethod def backward(ctx, grad_output): return grad_output.clone() def batch_frexp(inputs, max_bit=31): """ Decompose the scaling factor into mantissa and twos exponent. Args: scaling_factor (`torch.Tensor`): Target scaling factor to decompose. Returns: ``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent """ shape_of_input = inputs.size() # trans the input to be a 1-d tensor inputs = inputs.view(-1) output_m, output_e = np.frexp(inputs.cpu().numpy()) tmp_m = [] for m in output_m: int_m_shifted = int( decimal.Decimal(m * (2**max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP) ) tmp_m.append(int_m_shifted) output_m = np.array(tmp_m) output_e = float(max_bit) - output_e return ( torch.from_numpy(output_m).to(inputs.device).view(shape_of_input), torch.from_numpy(output_e).to(inputs.device).view(shape_of_input), ) class FixedPointMul(Function): """ Function to perform fixed-point arithmetic that can match integer arithmetic on hardware. Args: pre_act (`torch.Tensor`): Input tensor. pre_act_scaling_factor (`torch.Tensor`): Scaling factor of the input tensor *pre_act*. bit_num (`int`): Quantization bitwidth. z_scaling_factor (`torch.Tensor`): Scaling factor of the output tensor. identity (`torch.Tensor`, *optional*): Identity tensor, if exists. identity_scaling_factor (`torch.Tensor`, *optional*): Scaling factor of the identity tensor *identity*, if exists. Returns: `torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and *identity*), whose scale is rescaled to *z_scaling_factor*. """ @staticmethod def forward( ctx, pre_act, pre_act_scaling_factor, bit_num, z_scaling_factor, identity=None, identity_scaling_factor=None, ): if len(pre_act_scaling_factor.shape) == 3: reshape = lambda x: x # noqa: E731 else: reshape = lambda x: x.view(1, 1, -1) # noqa: E731 ctx.identity = identity n = 2 ** (bit_num - 1) - 1 with torch.no_grad(): pre_act_scaling_factor = reshape(pre_act_scaling_factor) if identity is not None: identity_scaling_factor = reshape(identity_scaling_factor) ctx.z_scaling_factor = z_scaling_factor z_int = torch.round(pre_act / pre_act_scaling_factor) _A = pre_act_scaling_factor.type(torch.double) _B = (z_scaling_factor.type(torch.float)).type(torch.double) new_scale = _A / _B new_scale = reshape(new_scale) m, e = batch_frexp(new_scale) output = z_int.type(torch.double) * m.type(torch.double) output = torch.round(output / (2.0**e)) if identity is not None: # needs addition of identity activation wx_int = torch.round(identity / identity_scaling_factor) _A = identity_scaling_factor.type(torch.double) _B = (z_scaling_factor.type(torch.float)).type(torch.double) new_scale = _A / _B new_scale = reshape(new_scale) m1, e1 = batch_frexp(new_scale) output1 = wx_int.type(torch.double) * m1.type(torch.double) output1 = torch.round(output1 / (2.0**e1)) output = output1 + output return torch.clamp(output.type(torch.float), -n - 1, n) @staticmethod def backward(ctx, grad_output): identity_grad = None if ctx.identity is not None: identity_grad = grad_output.clone() / ctx.z_scaling_factor return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None