821 lines
29 KiB
Python
821 lines
29 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
|
||
|
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
|
||
|
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import decimal
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.autograd import Function
|
||
|
|
||
|
from ...utils import logging
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
class QuantEmbedding(nn.Module):
|
||
|
"""
|
||
|
Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`.
|
||
|
|
||
|
Args:
|
||
|
weight_bit (`int`, *optional*, defaults to `8`):
|
||
|
Bitwidth for the quantized weight.
|
||
|
momentum (`float`, *optional*, defaults to `0.95`):
|
||
|
Momentum for updating the activation quantization range.
|
||
|
quant_mode (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not the layer is quantized.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_embeddings,
|
||
|
embedding_dim,
|
||
|
padding_idx=None,
|
||
|
max_norm=None,
|
||
|
norm_type=2.0,
|
||
|
scale_grad_by_freq=False,
|
||
|
sparse=False,
|
||
|
_weight=None,
|
||
|
weight_bit=8,
|
||
|
momentum=0.95,
|
||
|
quant_mode=False,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.num_ = num_embeddings
|
||
|
self.dim = embedding_dim
|
||
|
self.padding_idx = padding_idx
|
||
|
self.max_norm = max_norm
|
||
|
self.norm_type = norm_type
|
||
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||
|
self.sparse = sparse
|
||
|
|
||
|
self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim]))
|
||
|
self.register_buffer("weight_scaling_factor", torch.zeros(1))
|
||
|
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
|
||
|
|
||
|
self.weight_bit = weight_bit
|
||
|
self.momentum = momentum
|
||
|
self.quant_mode = quant_mode
|
||
|
self.percentile_mode = False
|
||
|
self.weight_function = SymmetricQuantFunction.apply
|
||
|
|
||
|
def forward(self, x, positions=None, incremental_state=None):
|
||
|
if not self.quant_mode:
|
||
|
return (
|
||
|
nn.functional.embedding(
|
||
|
x,
|
||
|
self.weight,
|
||
|
self.padding_idx,
|
||
|
self.max_norm,
|
||
|
self.norm_type,
|
||
|
self.scale_grad_by_freq,
|
||
|
self.sparse,
|
||
|
),
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
w = self.weight
|
||
|
w_transform = w.data.detach()
|
||
|
w_min = w_transform.min().expand(1)
|
||
|
w_max = w_transform.max().expand(1)
|
||
|
|
||
|
self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False)
|
||
|
self.weight_integer = self.weight_function(
|
||
|
self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
|
||
|
)
|
||
|
|
||
|
emb_int = nn.functional.embedding(
|
||
|
x,
|
||
|
self.weight_integer,
|
||
|
self.padding_idx,
|
||
|
self.max_norm,
|
||
|
self.norm_type,
|
||
|
self.scale_grad_by_freq,
|
||
|
self.sparse,
|
||
|
)
|
||
|
return emb_int * self.weight_scaling_factor, self.weight_scaling_factor
|
||
|
|
||
|
|
||
|
class QuantAct(nn.Module):
|
||
|
"""
|
||
|
Quantizes the given activation.
|
||
|
|
||
|
Args:
|
||
|
activation_bit (`int`):
|
||
|
Bitwidth for the quantized activation.
|
||
|
act_range_momentum (`float`, *optional*, defaults to `0.95`):
|
||
|
Momentum for updating the activation quantization range.
|
||
|
per_channel (`bool`, *optional*, defaults to `False`):
|
||
|
Whether to or not use channel-wise quantization.
|
||
|
channel_len (`int`, *optional*):
|
||
|
Specify the channel length when set the *per_channel* True.
|
||
|
quant_mode (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not the layer is quantized.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False):
|
||
|
super().__init__()
|
||
|
|
||
|
self.activation_bit = activation_bit
|
||
|
self.act_range_momentum = act_range_momentum
|
||
|
self.quant_mode = quant_mode
|
||
|
self.per_channel = per_channel
|
||
|
self.percentile = False
|
||
|
self.act_function = SymmetricQuantFunction.apply
|
||
|
|
||
|
if not self.per_channel:
|
||
|
self.register_buffer("x_min", torch.zeros(1))
|
||
|
self.register_buffer("x_max", torch.zeros(1))
|
||
|
self.register_buffer("act_scaling_factor", torch.zeros(1))
|
||
|
self.x_min -= 1e-5
|
||
|
self.x_max += 1e-5
|
||
|
else:
|
||
|
raise NotImplementedError("per-channel mode is not currently supported for activation.")
|
||
|
|
||
|
def __repr__(self):
|
||
|
return (
|
||
|
f"{self.__class__.__name__}(activation_bit={self.activation_bit}, "
|
||
|
f"quant_mode: {self.quant_mode}, Act_min: {self.x_min.item():.2f}, "
|
||
|
f"Act_max: {self.x_max.item():.2f})"
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x,
|
||
|
pre_act_scaling_factor=None,
|
||
|
identity=None,
|
||
|
identity_scaling_factor=None,
|
||
|
specified_min=None,
|
||
|
specified_max=None,
|
||
|
):
|
||
|
x_act = x if identity is None else identity + x
|
||
|
# collect running stats if training
|
||
|
if self.training:
|
||
|
assert not self.percentile, "percentile mode is not currently supported for activation."
|
||
|
assert not self.per_channel, "per-channel mode is not currently supported for activation."
|
||
|
x_min = x_act.data.min()
|
||
|
x_max = x_act.data.max()
|
||
|
|
||
|
assert (
|
||
|
x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0
|
||
|
), "NaN detected when computing min/max of the activation"
|
||
|
|
||
|
# Initialization
|
||
|
if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5:
|
||
|
self.x_min = self.x_min + x_min
|
||
|
self.x_max = self.x_max + x_max
|
||
|
|
||
|
# exponential moving average (EMA)
|
||
|
# use momentum to prevent the quantized values change greatly every iteration
|
||
|
elif self.act_range_momentum == -1:
|
||
|
self.x_min = torch.min(self.x_min, x_min)
|
||
|
self.x_max = torch.max(self.x_max, x_max)
|
||
|
else:
|
||
|
self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
|
||
|
self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
|
||
|
|
||
|
if not self.quant_mode:
|
||
|
return x_act, None
|
||
|
|
||
|
x_min = self.x_min if specified_min is None else specified_min
|
||
|
x_max = self.x_max if specified_max is None else specified_max
|
||
|
|
||
|
self.act_scaling_factor = symmetric_linear_quantization_params(
|
||
|
self.activation_bit, x_min, x_max, per_channel=self.per_channel
|
||
|
)
|
||
|
|
||
|
if pre_act_scaling_factor is None:
|
||
|
# this is for the input quantization
|
||
|
quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor)
|
||
|
else:
|
||
|
quant_act_int = FixedPointMul.apply(
|
||
|
x,
|
||
|
pre_act_scaling_factor,
|
||
|
self.activation_bit,
|
||
|
self.act_scaling_factor,
|
||
|
identity,
|
||
|
identity_scaling_factor,
|
||
|
)
|
||
|
|
||
|
correct_output_scale = self.act_scaling_factor.view(-1)
|
||
|
|
||
|
return quant_act_int * correct_output_scale, self.act_scaling_factor
|
||
|
|
||
|
|
||
|
class QuantLinear(nn.Module):
|
||
|
"""
|
||
|
Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`.
|
||
|
|
||
|
Args:
|
||
|
weight_bit (`int`, *optional*, defaults to `8`):
|
||
|
Bitwidth for the quantized weight.
|
||
|
bias_bit (`int`, *optional*, defaults to `32`):
|
||
|
Bitwidth for the quantized bias.
|
||
|
per_channel (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not to use channel-wise quantization.
|
||
|
quant_mode (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not the layer is quantized.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.in_features = in_features
|
||
|
self.out_features = out_features
|
||
|
|
||
|
self.weight = nn.Parameter(torch.zeros([out_features, in_features]))
|
||
|
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
|
||
|
self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features))
|
||
|
if bias:
|
||
|
self.bias = nn.Parameter(torch.zeros(out_features))
|
||
|
self.register_buffer("bias_integer", torch.zeros_like(self.bias))
|
||
|
|
||
|
self.weight_bit = weight_bit
|
||
|
self.quant_mode = quant_mode
|
||
|
self.per_channel = per_channel
|
||
|
self.bias_bit = bias_bit
|
||
|
self.quant_mode = quant_mode
|
||
|
self.percentile_mode = False
|
||
|
self.weight_function = SymmetricQuantFunction.apply
|
||
|
|
||
|
def __repr__(self):
|
||
|
s = super().__repr__()
|
||
|
s = f"({s} weight_bit={self.weight_bit}, quant_mode={self.quant_mode})"
|
||
|
return s
|
||
|
|
||
|
def forward(self, x, prev_act_scaling_factor=None):
|
||
|
if not self.quant_mode:
|
||
|
return nn.functional.linear(x, weight=self.weight, bias=self.bias), None
|
||
|
|
||
|
# assert that prev_act_scaling_factor is a scalar tensor
|
||
|
assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
|
||
|
"Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
|
||
|
"Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
|
||
|
)
|
||
|
|
||
|
w = self.weight
|
||
|
w_transform = w.data.detach()
|
||
|
if self.per_channel:
|
||
|
w_min, _ = torch.min(w_transform, dim=1, out=None)
|
||
|
w_max, _ = torch.max(w_transform, dim=1, out=None)
|
||
|
else:
|
||
|
w_min = w_transform.min().expand(1)
|
||
|
w_max = w_transform.max().expand(1)
|
||
|
|
||
|
self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel)
|
||
|
self.weight_integer = self.weight_function(
|
||
|
self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
|
||
|
)
|
||
|
|
||
|
bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
|
||
|
|
||
|
if self.bias is not None:
|
||
|
self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor)
|
||
|
|
||
|
prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
|
||
|
x_int = x / prev_act_scaling_factor
|
||
|
|
||
|
return (
|
||
|
nn.functional.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
|
||
|
bias_scaling_factor,
|
||
|
)
|
||
|
|
||
|
|
||
|
class IntGELU(nn.Module):
|
||
|
"""
|
||
|
Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`.
|
||
|
|
||
|
Args:
|
||
|
quant_mode (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not the layer is quantized.
|
||
|
force_dequant (`str`, *optional*, defaults to `"none"`):
|
||
|
Force dequantize the layer if either "gelu" or "nonlinear" is given.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, quant_mode=True, force_dequant="none"):
|
||
|
super().__init__()
|
||
|
self.quant_mode = quant_mode
|
||
|
|
||
|
if force_dequant in ["nonlinear", "gelu"]:
|
||
|
logger.info("Force dequantize gelu")
|
||
|
self.quant_mode = False
|
||
|
|
||
|
if not self.quant_mode:
|
||
|
self.activation_fn = nn.GELU()
|
||
|
|
||
|
self.k = 1.4142
|
||
|
self.const = 14 # dummy integer constant
|
||
|
self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c
|
||
|
self.coeff[2] /= self.coeff[0]
|
||
|
|
||
|
def int_erf(self, x_int, scaling_factor):
|
||
|
b_int = torch.floor(self.coeff[1] / scaling_factor)
|
||
|
c_int = torch.floor(self.coeff[2] / scaling_factor**2)
|
||
|
sign = torch.sign(x_int)
|
||
|
|
||
|
abs_int = torch.min(torch.abs(x_int), -b_int)
|
||
|
y_int = sign * ((abs_int + b_int) ** 2 + c_int)
|
||
|
scaling_factor = scaling_factor**2 * self.coeff[0]
|
||
|
|
||
|
# avoid overflow
|
||
|
y_int = floor_ste.apply(y_int / 2**self.const)
|
||
|
scaling_factor = scaling_factor * 2**self.const
|
||
|
|
||
|
return y_int, scaling_factor
|
||
|
|
||
|
def forward(self, x, scaling_factor=None):
|
||
|
if not self.quant_mode:
|
||
|
return self.activation_fn(x), None
|
||
|
|
||
|
x_int = x / scaling_factor
|
||
|
sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k)
|
||
|
|
||
|
shift_int = 1.0 // sigmoid_scaling_factor
|
||
|
|
||
|
x_int = x_int * (sigmoid_int + shift_int)
|
||
|
scaling_factor = scaling_factor * sigmoid_scaling_factor / 2
|
||
|
|
||
|
return x_int * scaling_factor, scaling_factor
|
||
|
|
||
|
|
||
|
class IntSoftmax(nn.Module):
|
||
|
"""
|
||
|
Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`.
|
||
|
|
||
|
Args:
|
||
|
output_bit (`int`):
|
||
|
Bitwidth for the layer output activation.
|
||
|
quant_mode (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not the layer is quantized.
|
||
|
force_dequant (`str`, *optional*, defaults to `"none"`):
|
||
|
Force dequantize the layer if either "softmax" or "nonlinear" is given.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, output_bit, quant_mode=False, force_dequant="none"):
|
||
|
super().__init__()
|
||
|
self.output_bit = output_bit
|
||
|
self.max_bit = 32
|
||
|
self.quant_mode = quant_mode
|
||
|
|
||
|
if force_dequant in ["nonlinear", "softmax"]:
|
||
|
logger.info("Force dequantize softmax")
|
||
|
self.quant_mode = False
|
||
|
|
||
|
self.act = QuantAct(16, quant_mode=self.quant_mode)
|
||
|
self.x0 = -0.6931 # -ln2
|
||
|
self.const = 30 # dummy integer constant
|
||
|
self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c
|
||
|
self.coef[1] /= self.coef[0]
|
||
|
self.coef[2] /= self.coef[0]
|
||
|
|
||
|
def int_polynomial(self, x_int, scaling_factor):
|
||
|
with torch.no_grad():
|
||
|
b_int = torch.floor(self.coef[1] / scaling_factor)
|
||
|
c_int = torch.floor(self.coef[2] / scaling_factor**2)
|
||
|
z = (x_int + b_int) * x_int + c_int
|
||
|
scaling_factor = self.coef[0] * scaling_factor**2
|
||
|
return z, scaling_factor
|
||
|
|
||
|
def int_exp(self, x_int, scaling_factor):
|
||
|
with torch.no_grad():
|
||
|
x0_int = torch.floor(self.x0 / scaling_factor)
|
||
|
x_int = torch.max(x_int, self.const * x0_int)
|
||
|
|
||
|
q = floor_ste.apply(x_int / x0_int)
|
||
|
r = x_int - x0_int * q
|
||
|
exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
|
||
|
exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
|
||
|
scaling_factor = exp_scaling_factor / 2**self.const
|
||
|
return exp_int, scaling_factor
|
||
|
|
||
|
def forward(self, x, scaling_factor):
|
||
|
if not self.quant_mode:
|
||
|
return nn.functional.softmax(x, dim=-1), None
|
||
|
|
||
|
x_int = x / scaling_factor
|
||
|
|
||
|
x_int_max, _ = x_int.max(dim=-1, keepdim=True)
|
||
|
x_int = x_int - x_int_max
|
||
|
exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor)
|
||
|
|
||
|
# Avoid overflow
|
||
|
exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor)
|
||
|
exp_int = exp / exp_scaling_factor
|
||
|
|
||
|
exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
|
||
|
factor = floor_ste.apply(2**self.max_bit / exp_int_sum)
|
||
|
exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
|
||
|
scaling_factor = 1 / 2**self.output_bit
|
||
|
return exp_int * scaling_factor, scaling_factor
|
||
|
|
||
|
|
||
|
class IntLayerNorm(nn.Module):
|
||
|
"""
|
||
|
Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`.
|
||
|
|
||
|
Args:
|
||
|
output_bit (`int`, *optional*, defaults to `8`):
|
||
|
Bitwidth for the layer output activation.
|
||
|
quant_mode (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not the layer is quantized.
|
||
|
force_dequant (`str`, *optional*, defaults to `"none"`):
|
||
|
Force dequantize the layer if either "layernorm" or "nonlinear" is given.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"):
|
||
|
super().__init__()
|
||
|
self.normalized_shape = normalized_shape
|
||
|
self.eps = eps
|
||
|
|
||
|
self.weight = nn.Parameter(torch.zeros(normalized_shape))
|
||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||
|
|
||
|
self.quant_mode = quant_mode
|
||
|
if force_dequant in ["nonlinear", "layernorm"]:
|
||
|
logger.info("Force dequantize layernorm")
|
||
|
self.quant_mode = False
|
||
|
|
||
|
self.register_buffer("shift", torch.zeros(1))
|
||
|
self.output_bit = output_bit
|
||
|
self.max_bit = 32
|
||
|
self.dim_sqrt = None
|
||
|
self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode)
|
||
|
|
||
|
def set_shift(self, y_int):
|
||
|
with torch.no_grad():
|
||
|
y_sq_int = y_int**2
|
||
|
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
|
||
|
shift = (torch.log2(torch.sqrt(var_int / 2**self.max_bit)).ceil()).max()
|
||
|
shift_old = self.shift
|
||
|
self.shift = torch.max(self.shift, shift)
|
||
|
logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}")
|
||
|
|
||
|
def overflow_fallback(self, y_int):
|
||
|
"""
|
||
|
This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
|
||
|
to avoid overflow in the subsequent runs.
|
||
|
"""
|
||
|
self.set_shift(y_int) # adjusts `self.shift`
|
||
|
y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
|
||
|
y_sq_int = y_int_shifted**2
|
||
|
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
|
||
|
return var_int
|
||
|
|
||
|
def forward(self, x, scaling_factor=None):
|
||
|
if not self.quant_mode:
|
||
|
mean = x.mean(axis=2, keepdim=True)
|
||
|
y = x - mean
|
||
|
var = torch.mean(y**2, axis=2, keepdim=True)
|
||
|
x = y / torch.sqrt(self.eps + var)
|
||
|
x = x * self.weight + self.bias
|
||
|
return x, None
|
||
|
|
||
|
# compute sqrt of the feature dimension if it is the first run
|
||
|
if self.dim_sqrt is None:
|
||
|
n = torch.tensor(x.shape[2], dtype=torch.float)
|
||
|
self.dim_sqrt = torch.sqrt(n).to(x.device)
|
||
|
|
||
|
# Normalization: computes mean and variance(std)
|
||
|
x_int = x / scaling_factor
|
||
|
mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
|
||
|
y_int = x_int - mean_int
|
||
|
y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
|
||
|
y_sq_int = y_int_shifted**2
|
||
|
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
|
||
|
|
||
|
# overflow handling in training time
|
||
|
if self.training:
|
||
|
# if overflow is detected
|
||
|
if var_int.max() >= 2**self.max_bit:
|
||
|
var_int = self.overflow_fallback(y_int)
|
||
|
assert var_int.max() < 2**self.max_bit + 0.1, (
|
||
|
"Error detected in overflow handling: "
|
||
|
"`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
|
||
|
)
|
||
|
|
||
|
# To be replaced with integer-sqrt kernel that produces the same output
|
||
|
std_int = floor_ste.apply(torch.sqrt(var_int)) * 2**self.shift
|
||
|
factor = floor_ste.apply(2**31 / std_int)
|
||
|
y_int = floor_ste.apply(y_int * factor / 2)
|
||
|
scaling_factor = self.dim_sqrt / 2**30
|
||
|
|
||
|
# scaling and shifting
|
||
|
bias = self.bias.data.detach() / (self.weight.data.detach())
|
||
|
bias_int = floor_ste.apply(bias / scaling_factor)
|
||
|
|
||
|
y_int = y_int + bias_int
|
||
|
scaling_factor = scaling_factor * self.weight
|
||
|
x = y_int * scaling_factor
|
||
|
|
||
|
return x, scaling_factor
|
||
|
|
||
|
|
||
|
def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False):
|
||
|
"""
|
||
|
Calculate the percentile max and min values in a given tensor
|
||
|
|
||
|
Args:
|
||
|
input (`torch.Tensor`):
|
||
|
The target tensor to calculate percentile max and min.
|
||
|
lower_percentile (`float`):
|
||
|
If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
|
||
|
upper_percentile (`float`):
|
||
|
If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
|
||
|
output_tensor (`bool`, *optional*, defaults to `False`):
|
||
|
If True, this function returns tensors, otherwise it returns values.
|
||
|
|
||
|
Returns:
|
||
|
`Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input*
|
||
|
"""
|
||
|
input_length = input.shape[0]
|
||
|
|
||
|
lower_index = round(input_length * (1 - lower_percentile * 0.01))
|
||
|
upper_index = round(input_length * upper_percentile * 0.01)
|
||
|
|
||
|
upper_bound = torch.kthvalue(input, k=upper_index).values
|
||
|
|
||
|
if lower_percentile == 0:
|
||
|
lower_bound = upper_bound * 0
|
||
|
# lower_index += 1
|
||
|
else:
|
||
|
lower_bound = -torch.kthvalue(-input, k=lower_index).values
|
||
|
|
||
|
if not output_tensor:
|
||
|
lower_bound = lower_bound.item()
|
||
|
upper_bound = upper_bound.item()
|
||
|
return lower_bound, upper_bound
|
||
|
|
||
|
|
||
|
def linear_quantize(input, scale, zero_point, inplace=False):
|
||
|
"""
|
||
|
Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
|
||
|
|
||
|
Args:
|
||
|
input (`torch.Tensor`):
|
||
|
Single-precision input tensor to be quantized.
|
||
|
scale (`torch.Tensor`):
|
||
|
Scaling factor for quantization.
|
||
|
zero_pint (`torch.Tensor`):
|
||
|
Shift for quantization.
|
||
|
inplace (`bool`, *optional*, defaults to `False`):
|
||
|
Whether to compute inplace or not.
|
||
|
|
||
|
Returns:
|
||
|
`torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*.
|
||
|
"""
|
||
|
# reshape scale and zeropoint for convolutional weights and activation
|
||
|
if len(input.shape) == 4:
|
||
|
scale = scale.view(-1, 1, 1, 1)
|
||
|
zero_point = zero_point.view(-1, 1, 1, 1)
|
||
|
# reshape scale and zeropoint for linear weights
|
||
|
elif len(input.shape) == 2:
|
||
|
scale = scale.view(-1, 1)
|
||
|
zero_point = zero_point.view(-1, 1)
|
||
|
else:
|
||
|
scale = scale.view(-1)
|
||
|
zero_point = zero_point.view(-1)
|
||
|
# quantized = float / scale + zero_point
|
||
|
if inplace:
|
||
|
input.mul_(1.0 / scale).add_(zero_point).round_()
|
||
|
return input
|
||
|
return torch.round(1.0 / scale * input + zero_point)
|
||
|
|
||
|
|
||
|
def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False):
|
||
|
"""
|
||
|
Compute the scaling factor with the given quantization range for symmetric quantization.
|
||
|
|
||
|
Args:
|
||
|
saturation_min (`torch.Tensor`):
|
||
|
Lower bound for quantization range.
|
||
|
saturation_max (`torch.Tensor`):
|
||
|
Upper bound for quantization range.
|
||
|
per_channel (`bool`, *optional*, defaults to `False`):
|
||
|
Whether to or not use channel-wise quantization.
|
||
|
|
||
|
Returns:
|
||
|
`torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and
|
||
|
*saturation_max*.
|
||
|
"""
|
||
|
# in this part, we do not need any gradient computation,
|
||
|
# in order to enforce this, we put torch.no_grad()
|
||
|
with torch.no_grad():
|
||
|
n = 2 ** (num_bits - 1) - 1
|
||
|
|
||
|
if per_channel:
|
||
|
scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1)
|
||
|
scale = torch.clamp(scale, min=1e-8) / n
|
||
|
|
||
|
else:
|
||
|
scale = max(saturation_min.abs(), saturation_max.abs())
|
||
|
scale = torch.clamp(scale, min=1e-8) / n
|
||
|
|
||
|
return scale
|
||
|
|
||
|
|
||
|
class SymmetricQuantFunction(Function):
|
||
|
"""
|
||
|
Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, x, k, percentile_mode, scale):
|
||
|
"""
|
||
|
Args:
|
||
|
x (`torch.Tensor`):
|
||
|
Floating point tensor to be quantized.
|
||
|
k (`int`):
|
||
|
Quantization bitwidth.
|
||
|
percentile_mode (`bool`):
|
||
|
Whether or not to use percentile calibration.
|
||
|
scale (`torch.Tensor`):
|
||
|
Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction
|
||
|
requires pre-calculated scaling factor.
|
||
|
|
||
|
Returns:
|
||
|
`torch.Tensor`: Symmetric-quantized value of *input*.
|
||
|
"""
|
||
|
zero_point = torch.tensor(0.0).to(scale.device)
|
||
|
|
||
|
n = 2 ** (k - 1) - 1
|
||
|
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
|
||
|
new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
|
||
|
|
||
|
ctx.scale = scale
|
||
|
return new_quant_x
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
scale = ctx.scale
|
||
|
if len(grad_output.shape) == 4:
|
||
|
scale = scale.view(-1, 1, 1, 1)
|
||
|
# reshape scale and zeropoint for linear weights
|
||
|
elif len(grad_output.shape) == 2:
|
||
|
scale = scale.view(-1, 1)
|
||
|
else:
|
||
|
scale = scale.view(-1)
|
||
|
|
||
|
return grad_output.clone() / scale, None, None, None, None
|
||
|
|
||
|
|
||
|
class floor_ste(Function):
|
||
|
"""
|
||
|
Straight-through Estimator(STE) for torch.floor()
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
return torch.floor(x)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
return grad_output.clone()
|
||
|
|
||
|
|
||
|
class round_ste(Function):
|
||
|
"""
|
||
|
Straight-through Estimator(STE) for torch.round()
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
return torch.round(x)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
return grad_output.clone()
|
||
|
|
||
|
|
||
|
def batch_frexp(inputs, max_bit=31):
|
||
|
"""
|
||
|
Decompose the scaling factor into mantissa and twos exponent.
|
||
|
|
||
|
Args:
|
||
|
scaling_factor (`torch.Tensor`):
|
||
|
Target scaling factor to decompose.
|
||
|
|
||
|
Returns:
|
||
|
``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
|
||
|
"""
|
||
|
|
||
|
shape_of_input = inputs.size()
|
||
|
|
||
|
# trans the input to be a 1-d tensor
|
||
|
inputs = inputs.view(-1)
|
||
|
|
||
|
output_m, output_e = np.frexp(inputs.cpu().numpy())
|
||
|
tmp_m = []
|
||
|
for m in output_m:
|
||
|
int_m_shifted = int(
|
||
|
decimal.Decimal(m * (2**max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP)
|
||
|
)
|
||
|
tmp_m.append(int_m_shifted)
|
||
|
output_m = np.array(tmp_m)
|
||
|
|
||
|
output_e = float(max_bit) - output_e
|
||
|
|
||
|
return (
|
||
|
torch.from_numpy(output_m).to(inputs.device).view(shape_of_input),
|
||
|
torch.from_numpy(output_e).to(inputs.device).view(shape_of_input),
|
||
|
)
|
||
|
|
||
|
|
||
|
class FixedPointMul(Function):
|
||
|
"""
|
||
|
Function to perform fixed-point arithmetic that can match integer arithmetic on hardware.
|
||
|
|
||
|
Args:
|
||
|
pre_act (`torch.Tensor`):
|
||
|
Input tensor.
|
||
|
pre_act_scaling_factor (`torch.Tensor`):
|
||
|
Scaling factor of the input tensor *pre_act*.
|
||
|
bit_num (`int`):
|
||
|
Quantization bitwidth.
|
||
|
z_scaling_factor (`torch.Tensor`):
|
||
|
Scaling factor of the output tensor.
|
||
|
identity (`torch.Tensor`, *optional*):
|
||
|
Identity tensor, if exists.
|
||
|
identity_scaling_factor (`torch.Tensor`, *optional*):
|
||
|
Scaling factor of the identity tensor *identity*, if exists.
|
||
|
|
||
|
Returns:
|
||
|
`torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and
|
||
|
*identity*), whose scale is rescaled to *z_scaling_factor*.
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(
|
||
|
ctx,
|
||
|
pre_act,
|
||
|
pre_act_scaling_factor,
|
||
|
bit_num,
|
||
|
z_scaling_factor,
|
||
|
identity=None,
|
||
|
identity_scaling_factor=None,
|
||
|
):
|
||
|
if len(pre_act_scaling_factor.shape) == 3:
|
||
|
reshape = lambda x: x # noqa: E731
|
||
|
else:
|
||
|
reshape = lambda x: x.view(1, 1, -1) # noqa: E731
|
||
|
ctx.identity = identity
|
||
|
|
||
|
n = 2 ** (bit_num - 1) - 1
|
||
|
|
||
|
with torch.no_grad():
|
||
|
pre_act_scaling_factor = reshape(pre_act_scaling_factor)
|
||
|
if identity is not None:
|
||
|
identity_scaling_factor = reshape(identity_scaling_factor)
|
||
|
|
||
|
ctx.z_scaling_factor = z_scaling_factor
|
||
|
|
||
|
z_int = torch.round(pre_act / pre_act_scaling_factor)
|
||
|
_A = pre_act_scaling_factor.type(torch.double)
|
||
|
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
|
||
|
new_scale = _A / _B
|
||
|
new_scale = reshape(new_scale)
|
||
|
|
||
|
m, e = batch_frexp(new_scale)
|
||
|
|
||
|
output = z_int.type(torch.double) * m.type(torch.double)
|
||
|
output = torch.round(output / (2.0**e))
|
||
|
|
||
|
if identity is not None:
|
||
|
# needs addition of identity activation
|
||
|
wx_int = torch.round(identity / identity_scaling_factor)
|
||
|
|
||
|
_A = identity_scaling_factor.type(torch.double)
|
||
|
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
|
||
|
new_scale = _A / _B
|
||
|
new_scale = reshape(new_scale)
|
||
|
|
||
|
m1, e1 = batch_frexp(new_scale)
|
||
|
output1 = wx_int.type(torch.double) * m1.type(torch.double)
|
||
|
output1 = torch.round(output1 / (2.0**e1))
|
||
|
|
||
|
output = output1 + output
|
||
|
|
||
|
return torch.clamp(output.type(torch.float), -n - 1, n)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
identity_grad = None
|
||
|
if ctx.identity is not None:
|
||
|
identity_grad = grad_output.clone() / ctx.z_scaling_factor
|
||
|
return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None
|