336 lines
14 KiB
Python
336 lines
14 KiB
Python
|
|
||
|
import torch
|
||
|
from torch.ao.quantization.qconfig import QConfig
|
||
|
from torch.ao.quantization.quant_type import QuantType
|
||
|
from torch.jit._recursive import wrap_cpp_module
|
||
|
|
||
|
__all__ = [
|
||
|
"script_qconfig",
|
||
|
"script_qconfig_dict",
|
||
|
"fuse_conv_bn_jit",
|
||
|
"prepare_jit",
|
||
|
"prepare_dynamic_jit",
|
||
|
"convert_jit",
|
||
|
"convert_dynamic_jit",
|
||
|
"quantize_jit",
|
||
|
"quantize_dynamic_jit",
|
||
|
]
|
||
|
|
||
|
def _check_is_script_module(model):
|
||
|
if not isinstance(model, torch.jit.ScriptModule):
|
||
|
raise ValueError('input must be a script module, got: ' + str(type(model)))
|
||
|
|
||
|
def _check_forward_method(model):
|
||
|
if not model._c._has_method('forward'):
|
||
|
raise ValueError('input script module does not have forward method')
|
||
|
|
||
|
def script_qconfig(qconfig):
|
||
|
r"""Instantiate the activation and weight observer modules and script
|
||
|
them, these observer module instances will be deepcopied during
|
||
|
prepare_jit step.
|
||
|
"""
|
||
|
return QConfig(
|
||
|
activation=torch.jit.script(qconfig.activation())._c,
|
||
|
weight=torch.jit.script(qconfig.weight())._c)
|
||
|
|
||
|
def script_qconfig_dict(qconfig_dict):
|
||
|
r"""Helper function used by `prepare_jit`.
|
||
|
Apply `script_qconfig` for all entries in `qconfig_dict` that is
|
||
|
not None.
|
||
|
"""
|
||
|
return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
|
||
|
|
||
|
def fuse_conv_bn_jit(model, inplace=False):
|
||
|
r""" Fuse conv - bn module
|
||
|
Works for eval model only.
|
||
|
|
||
|
Args:
|
||
|
model: TorchScript model from scripting or tracing
|
||
|
"""
|
||
|
torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit")
|
||
|
model_c = model._c
|
||
|
model_c = torch._C._jit_pass_fold_convbn(model_c)
|
||
|
if inplace:
|
||
|
model._reconstruct(model_c)
|
||
|
else:
|
||
|
model = wrap_cpp_module(model_c)
|
||
|
return model
|
||
|
|
||
|
def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
|
||
|
_check_is_script_module(model)
|
||
|
_check_forward_method(model)
|
||
|
if not all(isinstance(x, str) for x in qconfig_dict.keys()):
|
||
|
raise ValueError('qconfig_dict should only contain names(str) as keys.')
|
||
|
scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
|
||
|
model = fuse_conv_bn_jit(model, inplace)
|
||
|
model_c = torch._C._jit_pass_insert_observers(model._c,
|
||
|
'forward',
|
||
|
scripted_qconfig_dict,
|
||
|
inplace,
|
||
|
quant_type)
|
||
|
if inplace:
|
||
|
model._reconstruct(model_c)
|
||
|
else:
|
||
|
model = wrap_cpp_module(model_c)
|
||
|
return model
|
||
|
|
||
|
def _prepare_ondevice_jit(model, qconfig_dict, method_name='forward', inplace=False, quant_type=QuantType.STATIC):
|
||
|
_check_is_script_module(model)
|
||
|
if not all(isinstance(x, str) for x in qconfig_dict.keys()):
|
||
|
raise ValueError('qconfig_dict should only contain names(str) as keys.')
|
||
|
scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
|
||
|
method_graph = model._c._get_method(method_name).graph
|
||
|
torch._C._jit_pass_inline(method_graph)
|
||
|
model = fuse_conv_bn_jit(model, inplace)
|
||
|
model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq(model._c,
|
||
|
method_name,
|
||
|
scripted_qconfig_dict,
|
||
|
inplace,
|
||
|
quant_type)
|
||
|
if inplace:
|
||
|
model._reconstruct(model_c)
|
||
|
else:
|
||
|
model = wrap_cpp_module(model_c)
|
||
|
return model
|
||
|
|
||
|
def prepare_jit(model, qconfig_dict, inplace=False):
|
||
|
torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit")
|
||
|
return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
|
||
|
|
||
|
def prepare_dynamic_jit(model, qconfig_dict, inplace=False):
|
||
|
torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit")
|
||
|
return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
|
||
|
|
||
|
|
||
|
def _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
|
||
|
return _prepare_ondevice_jit(model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC)
|
||
|
|
||
|
def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC,
|
||
|
preserved_attrs=None):
|
||
|
_check_is_script_module(model)
|
||
|
model.eval()
|
||
|
model_c = model._c
|
||
|
model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type)
|
||
|
if not debug:
|
||
|
is_xpu = all(p.device.type == 'xpu' for p in model.parameters())
|
||
|
if not is_xpu:
|
||
|
# Moving model parameters to CPU since quantized operators
|
||
|
# are only supported on CPU and XPU right now
|
||
|
model.cpu()
|
||
|
if preserved_attrs is None:
|
||
|
preserved_attrs = []
|
||
|
model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs)
|
||
|
if inplace:
|
||
|
model._reconstruct(model_c)
|
||
|
else:
|
||
|
model = wrap_cpp_module(model_c)
|
||
|
torch._C._jit_pass_constant_propagation(model.graph)
|
||
|
torch._C._jit_pass_dce(model.graph)
|
||
|
return model
|
||
|
|
||
|
|
||
|
def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC):
|
||
|
_check_is_script_module(model)
|
||
|
assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant."
|
||
|
assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward"
|
||
|
observe_method_name = "observe_" + method_name
|
||
|
quantize_method_name = "quantize_" + method_name
|
||
|
model_c = model._c
|
||
|
model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq(
|
||
|
model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC)
|
||
|
model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq(model_c, QuantType.DYNAMIC, quantize_method_name)
|
||
|
if inplace:
|
||
|
model._reconstruct(model_c)
|
||
|
else:
|
||
|
model = wrap_cpp_module(model_c)
|
||
|
return model
|
||
|
|
||
|
def convert_jit(model, inplace=False, debug=False, preserved_attrs=None):
|
||
|
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit")
|
||
|
return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs)
|
||
|
|
||
|
def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None):
|
||
|
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit")
|
||
|
return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs)
|
||
|
|
||
|
|
||
|
def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
|
||
|
return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC)
|
||
|
|
||
|
|
||
|
def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False):
|
||
|
model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace)
|
||
|
model = _convert_ondevice_dynamic_jit(model, method_name, inplace)
|
||
|
return model
|
||
|
|
||
|
def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
|
||
|
# Always do inplace convert because the Tensor is already
|
||
|
# copied in prepare_jit when inplace is False
|
||
|
if quant_type == QuantType.DYNAMIC:
|
||
|
model = prepare_dynamic_jit(model, qconfig_dict, inplace)
|
||
|
model = convert_dynamic_jit(model, True, debug)
|
||
|
else:
|
||
|
assert run_fn, "Must provide calibration function for post training static quantization"
|
||
|
assert run_args, "Must provide calibration dataset for post training static quantization"
|
||
|
model = prepare_jit(model, qconfig_dict, inplace)
|
||
|
run_fn(model, *run_args)
|
||
|
model = convert_jit(model, True, debug)
|
||
|
|
||
|
torch._C._jit_pass_constant_propagation(model.graph)
|
||
|
torch._C._jit_pass_dce(model.graph)
|
||
|
return model
|
||
|
|
||
|
def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
|
||
|
r"""Quantize the input float TorchScript model with
|
||
|
post training static quantization.
|
||
|
|
||
|
First it will prepare the model for calibration, then it calls
|
||
|
`run_fn` which will run the calibration step, after that we will
|
||
|
convert the model to a quantized model.
|
||
|
|
||
|
Args:
|
||
|
`model`: input float TorchScript model
|
||
|
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||
|
qconfig for that module as value, empty key means the qconfig will be applied
|
||
|
to whole model unless it's overwritten by more specific configurations, the
|
||
|
qconfig for each module is either found in the dictionary or fallback to
|
||
|
the qconfig of parent module.
|
||
|
|
||
|
Right now qconfig_dict is the only way to configure how the model is quantized,
|
||
|
and it is done in the granularity of module, that is, we only support one type
|
||
|
of qconfig for each torch.nn.Module, and the qconfig for sub module will
|
||
|
override the qconfig for parent module, empty string means global configuration.
|
||
|
`run_fn`: a calibration function for calibrating the prepared model
|
||
|
`run_args`: positional arguments for `run_fn`
|
||
|
`inplace`: carry out model transformations in-place, the original module is
|
||
|
mutated
|
||
|
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||
|
|
||
|
Return:
|
||
|
Quantized TorchSciprt model.
|
||
|
|
||
|
Example:
|
||
|
```python
|
||
|
import torch
|
||
|
from torch.ao.quantization import get_default_qconfig
|
||
|
from torch.ao.quantization import quantize_jit
|
||
|
|
||
|
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||
|
qconfig = get_default_qconfig('fbgemm')
|
||
|
def calibrate(model, data_loader):
|
||
|
model.eval()
|
||
|
with torch.no_grad():
|
||
|
for image, target in data_loader:
|
||
|
model(image)
|
||
|
|
||
|
quantized_model = quantize_jit(
|
||
|
ts_model,
|
||
|
{'': qconfig},
|
||
|
calibrate,
|
||
|
[data_loader_test])
|
||
|
```
|
||
|
"""
|
||
|
torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit")
|
||
|
return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC)
|
||
|
|
||
|
def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
|
||
|
r"""Quantize the input float TorchScript model with
|
||
|
post training dynamic quantization.
|
||
|
Currently only qint8 quantization of torch.nn.Linear is supported.
|
||
|
|
||
|
Args:
|
||
|
`model`: input float TorchScript model
|
||
|
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||
|
qconfig for that module as value, please see detailed
|
||
|
descriptions in :func:`~torch.ao.quantization.quantize_jit`
|
||
|
`inplace`: carry out model transformations in-place, the original module is
|
||
|
mutated
|
||
|
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||
|
|
||
|
Return:
|
||
|
Quantized TorchSciprt model.
|
||
|
|
||
|
Example:
|
||
|
```python
|
||
|
import torch
|
||
|
from torch.ao.quantization import per_channel_dynamic_qconfig
|
||
|
from torch.ao.quantization import quantize_dynamic_jit
|
||
|
|
||
|
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||
|
qconfig = get_default_qconfig('fbgemm')
|
||
|
def calibrate(model, data_loader):
|
||
|
model.eval()
|
||
|
with torch.no_grad():
|
||
|
for image, target in data_loader:
|
||
|
model(image)
|
||
|
|
||
|
quantized_model = quantize_dynamic_jit(
|
||
|
ts_model,
|
||
|
{'': qconfig},
|
||
|
calibrate,
|
||
|
[data_loader_test])
|
||
|
```
|
||
|
"""
|
||
|
torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit")
|
||
|
return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)
|
||
|
|
||
|
|
||
|
def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
|
||
|
r"""Prepares the input float TorchScript model with
|
||
|
*on-device* post training dynamic quantization.
|
||
|
Currently only qint8 quantization of torch.nn.Linear is supported.
|
||
|
|
||
|
Args:
|
||
|
`model`: input float TorchScript model
|
||
|
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||
|
qconfig for that module as value, please see detailed
|
||
|
`method_name`: Name of the method within the model, to be prepared for quantization
|
||
|
descriptions in :func:`~torch.ao.quantization.quantize_jit`
|
||
|
`inplace`: carry out model transformations in-place, the original module is
|
||
|
mutated
|
||
|
|
||
|
Return:
|
||
|
TorchScript model that is ready for on device quantization.
|
||
|
This means that the returned
|
||
|
model has:
|
||
|
- Method is inlined.
|
||
|
- Model has observer modules inserted in the model.
|
||
|
- Model has packed params inserted in the model. However they are empty as in they dont
|
||
|
contain valid quantized weights.
|
||
|
- observe_<method_name> is added that observe the values to be quantized.
|
||
|
- reset_observers_<method_name> to reset observers.
|
||
|
- quantize_<method_name> is added to the model.
|
||
|
- This method extract scale, zero points.
|
||
|
- Quantizes observed weights.
|
||
|
- Creates packed params from it and update the attribute of the model with the new values
|
||
|
for the packed params.
|
||
|
- Reset the original fp32 weights with empty tensor using SetAttr.
|
||
|
- quantized_<method_name> is added to the model.
|
||
|
- This method uses quantized weights and quantized linear ops instead of fp32 op.
|
||
|
- This method should be used for inference post PTQ.
|
||
|
- Note that all method's signatures should be the same as method_name.
|
||
|
|
||
|
Later on device:
|
||
|
- Run reset_observers_<method_name>
|
||
|
- Run observe_<method_name>
|
||
|
- Run quantize_<method_name>
|
||
|
- Now model can be saved and loaded later.
|
||
|
- Run model with quantized_<method_name>
|
||
|
|
||
|
Example:
|
||
|
```python
|
||
|
import torch
|
||
|
from torch.ao.quantization import per_channel_dynamic_qconfig
|
||
|
from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit
|
||
|
|
||
|
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||
|
qconfig = get_default_qconfig('fbgemm')
|
||
|
quant_ready_model = _quantize_ondevice_dynamic_jit(
|
||
|
ts_model,
|
||
|
{'': qconfig},
|
||
|
'forward',
|
||
|
True)
|
||
|
```
|
||
|
"""
|
||
|
return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace)
|