541 lines
21 KiB
Python
541 lines
21 KiB
Python
import operator
|
|
import types
|
|
|
|
import torch
|
|
from torch._export import capture_pre_autograd_graph
|
|
from torch.fx import (
|
|
GraphModule,
|
|
Node,
|
|
)
|
|
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
|
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
|
|
from torch.utils._pytree import LeafSpec
|
|
from torch.export.unflatten import _AttrKind, _assign_attr
|
|
|
|
# Makes sure that quantized_decomposed ops are registered
|
|
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
|
|
|
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
|
|
|
|
|
__all__ = [
|
|
"fold_bn_weights_into_conv_node",
|
|
"get_aten_graph_module",
|
|
"remove_tensor_overload_for_qdq_ops",
|
|
]
|
|
|
|
_QUANTIZE_OPS = [
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
|
]
|
|
|
|
|
|
_DEQUANTIZE_OPS = [
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
|
]
|
|
|
|
# Example inputs for conv-bn1d patterns
|
|
_conv1d_bn_example_inputs = (
|
|
torch.randn(1, 1, 3), # x
|
|
torch.randn(1, 1, 1), # conv_weight
|
|
torch.randn(1), # conv_bias
|
|
torch.randn(1), # bn_weight
|
|
torch.randn(1), # bn_bias
|
|
torch.randn(1), # bn_running_mean
|
|
torch.randn(1), # bn_running_var
|
|
)
|
|
|
|
# Example inputs for conv-bn2d patterns
|
|
_conv2d_bn_example_inputs = (
|
|
torch.randn(1, 1, 3, 3), # x
|
|
torch.randn(1, 1, 1, 1), # conv_weight
|
|
torch.randn(1), # conv_bias
|
|
torch.randn(1), # bn_weight
|
|
torch.randn(1), # bn_bias
|
|
torch.randn(1), # bn_running_mean
|
|
torch.randn(1), # bn_running_var
|
|
)
|
|
|
|
def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
|
|
"""
|
|
Assuming dest is one of the ops inserted by quant workflow, this function
|
|
finds if source and dest are connected. Assumption is that only quant workflow
|
|
inserted ops exist between source and dest
|
|
"""
|
|
quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
|
|
quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
|
|
while dest.target in quant_workflow_ops:
|
|
if not isinstance(dest.args[0], torch.fx.Node):
|
|
raise ValueError(f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}")
|
|
dest = dest.args[0]
|
|
return (dest == source)
|
|
|
|
|
|
def _find_q_dq_node_for_user(
|
|
produer: torch.fx.Node, user: torch.fx.Node
|
|
) -> Tuple[Any, Any]:
|
|
"""
|
|
Find q, dq pair corresponding to [producer -> q -> dq -> user]
|
|
Utils works by finding dq arg of user and ensuring it is connected to
|
|
producer
|
|
"""
|
|
dq_node = None
|
|
for n in user.args:
|
|
if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
|
|
if _is_connected(produer, n):
|
|
dq_node = n
|
|
break
|
|
if dq_node is None:
|
|
for n in user.kwargs:
|
|
if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
|
|
if _is_connected(produer, n):
|
|
dq_node = n
|
|
break
|
|
if dq_node is None:
|
|
return (None, None)
|
|
|
|
q_node = None
|
|
if dq_node.args[0].op == "call_function" and dq_node.args[0].target in _QUANTIZE_OPS:
|
|
q_node = dq_node.args[0]
|
|
return (q_node, dq_node)
|
|
|
|
|
|
|
|
def _is_sym_size_node(node: Node):
|
|
return (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.sym_size.default
|
|
or node.target == torch.ops.aten.sym_numel.default
|
|
or node.target == torch.ops.aten.sym_numel
|
|
or node.target == torch.ops.aten.sym_size
|
|
)
|
|
|
|
|
|
def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]:
|
|
node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
|
|
return node_users
|
|
|
|
|
|
def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
|
|
if annotation is None:
|
|
return False
|
|
input_qspec_map = annotation.input_qspec_map
|
|
output_qspec = annotation.output_qspec
|
|
if len(input_qspec_map) == 0 and output_qspec is None:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _get_tensor_constant_from_node(node, m):
|
|
if node is None:
|
|
return None
|
|
assert node.op == "get_attr"
|
|
target_atoms = node.target.split('.')
|
|
attr_itr = m
|
|
for i, atom in enumerate(target_atoms):
|
|
if not hasattr(attr_itr, atom):
|
|
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
|
|
attr_itr = getattr(attr_itr, atom)
|
|
return attr_itr
|
|
|
|
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
|
|
all_args = []
|
|
for i, schema in enumerate(args_schema):
|
|
if schema.name in orig_kwargs:
|
|
all_args.append(orig_kwargs[schema.name])
|
|
elif not schema.kwarg_only and i < len(orig_args):
|
|
all_args.append(orig_args[i])
|
|
else:
|
|
all_args.append(schema.default_value)
|
|
return all_args
|
|
|
|
def _is_supported_batch_norm_for_training(node: Node):
|
|
"""
|
|
Return True if the given node refers to an aten batch norm op QAT supports.
|
|
"""
|
|
supported_ops = [
|
|
torch.ops.aten._native_batch_norm_legit.default,
|
|
# Note: we won't need this op anymore after batch norm consolidation
|
|
# For now, we need to continue to support it because it gives better
|
|
# training numerics than `_native_batch_norm_legit`
|
|
torch.ops.aten.cudnn_batch_norm.default,
|
|
torch.ops.aten.miopen_batch_norm.default,
|
|
]
|
|
return node.target in supported_ops
|
|
|
|
# TODO: rename this to _is_conv_node
|
|
def _is_conv(n: Node):
|
|
"""
|
|
Return whether the node refers to an aten conv op.
|
|
"""
|
|
return n.op == "call_function" and n.target in [
|
|
torch.ops.aten.conv1d.default,
|
|
torch.ops.aten.conv2d.default,
|
|
]
|
|
|
|
# TODO: rename this to _is_conv_transpose_node
|
|
def _is_conv_transpose(n: Node):
|
|
"""
|
|
Return whether the node refers to an aten conv_transpose op.
|
|
"""
|
|
return n.op == "call_function" and n.target in [
|
|
torch.ops.aten.conv_transpose1d,
|
|
torch.ops.aten.conv_transpose2d,
|
|
]
|
|
|
|
def _is_bn_node(n: Node):
|
|
return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
|
|
|
|
def fold_bn_weights_into_conv_node(
|
|
conv_node: Node,
|
|
conv_weight_node: Node,
|
|
conv_bias_node: Optional[Node],
|
|
bn_node: Node,
|
|
m: GraphModule
|
|
) -> None:
|
|
# conv args: input, weight, bias, stride, padding, dilation, ...
|
|
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
|
|
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
|
|
transpose = _is_conv_transpose(conv_node)
|
|
|
|
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
|
|
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
|
|
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
|
|
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
|
|
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
|
|
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
|
|
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
|
|
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
|
|
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
|
|
eps_arg_index = 6
|
|
elif _is_supported_batch_norm_for_training(bn_node):
|
|
eps_arg_index = 7
|
|
else:
|
|
raise ValueError("BN node target is unexpected ", bn_node.target)
|
|
bn_eps = bn_args[eps_arg_index]
|
|
|
|
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
|
|
|
|
# update the weight and bias for conv
|
|
conv_args = list(conv_node.args)
|
|
# filling in the default bias argument
|
|
if len(conv_args) == 2:
|
|
conv_args.append(None)
|
|
|
|
# calling data since the fused_weight and fused_bias are nn.Parameter
|
|
weight_attr_name = conv_weight_node.target
|
|
assert isinstance(weight_attr_name, str)
|
|
_assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
|
|
if conv_bias_node is not None:
|
|
bias_attr_name = conv_bias_node.target
|
|
_assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
|
|
else:
|
|
bias_attr_name = weight_attr_name + "_bias"
|
|
_assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
|
|
with m.graph.inserting_before(conv_node):
|
|
get_bias_node = m.graph.get_attr(bias_attr_name)
|
|
# NOTE: here we assume the bias of conv is not quantized!
|
|
conv_args[2] = get_bias_node
|
|
conv_node.args = tuple(conv_args)
|
|
|
|
# native_batch_norm has 3 outputs, we expect getitem calls on the output
|
|
# and we want to replace the uses of getitem 0 with the output of conv
|
|
#
|
|
# Before:
|
|
# conv -> bn - (first output) -> users1
|
|
# \ - (second output) -> users2
|
|
# \ - (third output) -> users3
|
|
# After:
|
|
# conv -> (first output) -> users1
|
|
# bn -
|
|
# \ - (second output) -> users2
|
|
# \ - (third output) -> users3
|
|
# if users2 and users3 are empty then bn will be removed through dead code elimination
|
|
|
|
for user in bn_node.users:
|
|
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
|
|
continue
|
|
user.replace_all_uses_with(conv_node)
|
|
|
|
# fuse conv bn weights, inplace modification of the graph_module and graph
|
|
def _fuse_conv_bn_(m: GraphModule) -> None:
|
|
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
|
if not has_bn:
|
|
return
|
|
for n in m.graph.nodes:
|
|
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
|
|
continue
|
|
bn_node = n
|
|
n = bn_node.args[0]
|
|
if not _is_conv(n):
|
|
continue
|
|
conv_node = n
|
|
conv_weight_node = conv_node.args[1]
|
|
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
|
|
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
|
|
|
|
m.graph.eliminate_dead_code()
|
|
m.recompile()
|
|
|
|
def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
|
|
# TODO: move this information to fx node itself
|
|
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
|
for n in model.graph.nodes:
|
|
nn_module_stack = n.meta.get("nn_module_stack", None)
|
|
current_scope = ("", type(None))
|
|
if nn_module_stack:
|
|
bt = list(nn_module_stack.values())[-1]
|
|
current_scope = (bt[0].split(".")[-1], bt[1])
|
|
node_name_to_scope[n.name] = current_scope
|
|
return node_name_to_scope
|
|
|
|
def get_aten_graph_module(
|
|
pattern: Callable,
|
|
example_inputs: Tuple[Any, ...],
|
|
is_cuda: bool = False,
|
|
**kwargs,
|
|
) -> GraphModule:
|
|
"""
|
|
Convert the pattern to an FX graph with decomposed aten ops.
|
|
"""
|
|
if is_cuda:
|
|
example_inputs = tuple([x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs])
|
|
aten_pattern = capture_pre_autograd_graph(
|
|
pattern,
|
|
example_inputs,
|
|
kwargs,
|
|
)
|
|
aten_pattern.graph.eliminate_dead_code()
|
|
aten_pattern.recompile()
|
|
return aten_pattern
|
|
|
|
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
|
""" Remove .tensor overload for quantize/dequantize ops so that we can
|
|
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
|
|
"""
|
|
_MAP = {
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
|
|
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
|
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
|
|
torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
|
|
}
|
|
for n in match_pattern.graph.nodes:
|
|
if n.op != "call_function":
|
|
continue
|
|
if n.target in _MAP:
|
|
n.target = _MAP[n.target]
|
|
|
|
def _is_literal(arg):
|
|
if isinstance(arg, (int, float)):
|
|
return True
|
|
if isinstance(arg, (tuple, list)):
|
|
return all(map(_is_literal, arg))
|
|
return False
|
|
|
|
def _replace_literals_with_new_placeholders(
|
|
gm: torch.fx.GraphModule,
|
|
merge_dup: bool = False,
|
|
exclude_literals: Optional[List[Any]] = None
|
|
):
|
|
"""Replace the literals in the graph with placeholder nodes that's created on the fly while we
|
|
traverse the graph, so that the literal arguments in the graph can be matched and replaced
|
|
|
|
To use this, the pattern and replacement graph should have the exact same number of literal args
|
|
and they should be used in the exact same order in the pattern and replacement graph.
|
|
|
|
If the literal arguments are not used in the same order in pattern and replacement graph, please
|
|
use `_replace_literals_with_existing_placeholders` instead
|
|
|
|
Args:
|
|
`gm`: input GraphModule that we'll transform
|
|
`merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
|
|
the graph, whether they should correspond to the same placeholder or not
|
|
`exclude_literals`: a list of literals that will not be replaced with placeholders
|
|
|
|
Example:
|
|
|
|
# 1. Original Graph
|
|
def pattern(self, x):
|
|
return x + 3
|
|
|
|
def replacement(self, x):
|
|
return x - 3
|
|
|
|
example_inputs = (torch.randn(1, 3, 3, 3),)
|
|
pattern_gm = get_aten_graph_module(pattern, example_inputs)
|
|
replacement_gm = get_aten_graph_module(pattern, example_inptus)
|
|
|
|
# 2. Before calling replace literals we'll see the following graph:
|
|
def pattern(self, x):
|
|
return x + 3
|
|
|
|
def replacement(self, x):
|
|
return x - 3
|
|
|
|
pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
|
|
replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
|
|
|
|
# 3. After replacing literals with new placeholder nodes
|
|
|
|
def pattern(self, x, new_ph):
|
|
return x + new_ph
|
|
|
|
def pattern(self, x, new_ph):
|
|
return x - new_ph
|
|
|
|
"""
|
|
last_ph = None
|
|
cnt = 0
|
|
literal_to_ph: Dict[Union[float, bool, int, torch.dtype], Node] = {}
|
|
if exclude_literals is None:
|
|
exclude_literals = []
|
|
|
|
in_spec = gm._in_spec
|
|
args_spec = in_spec.children_specs[0]
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
last_ph = node
|
|
cnt += 1
|
|
continue
|
|
with gm.graph.inserting_after(last_ph):
|
|
new_args = []
|
|
for arg in node.args:
|
|
if _is_literal(arg) and arg not in exclude_literals:
|
|
if merge_dup and arg in literal_to_ph:
|
|
new_args.append(literal_to_ph[arg])
|
|
else:
|
|
ph_node = gm.graph.placeholder("arg" + str(cnt))
|
|
new_args.append(ph_node)
|
|
args_spec.children_specs.append(LeafSpec())
|
|
cnt += 1
|
|
if merge_dup:
|
|
literal_to_ph[arg] = ph_node
|
|
else:
|
|
new_args.append(arg)
|
|
new_args = tuple(new_args)
|
|
|
|
node.args = new_args
|
|
|
|
# Update `num_nodes`, `num_leaves`, `num_children`.
|
|
args_spec.__post_init__()
|
|
in_spec.__post_init__()
|
|
return gm
|
|
|
|
|
|
def _replace_literals_with_existing_placeholders(
|
|
gm: torch.fx.GraphModule,
|
|
exclude_literals: Optional[List[Any]] = None,
|
|
literal_to_ph_idx: Optional[Dict[Union[float, int, bool, torch.dtype], int]] = None
|
|
):
|
|
"""Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
|
|
in the graph can be matched and replaced
|
|
|
|
To use this, all literal args in the graph should be unique and each of them should correspond
|
|
to exactly one placeholder node
|
|
|
|
# 1. Original Graph
|
|
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
|
|
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
|
|
|
|
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
|
|
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
|
|
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 3, 3, 3),
|
|
1.0,
|
|
0,
|
|
-128,
|
|
127,
|
|
)
|
|
pattern_gm = get_aten_graph_module(pattern, example_inputs)
|
|
replacement_gm = get_aten_graph_module(pattern, example_inptus)
|
|
|
|
# 2. Before calling replace literals we'll see the following graph:
|
|
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
|
|
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
|
return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
|
|
|
|
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
|
|
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
|
x_i8 = torch.clamp(x_i8, -128, 127)
|
|
return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
|
|
|
|
# Note that literal args appear in different order in pattern and replacement graph, so
|
|
# we can't use _replace_literals_with_new_placeholders
|
|
|
|
literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
|
|
pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
|
|
replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
|
|
|
|
# 3. After replacing literals with existing placeholder nodes
|
|
|
|
def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
|
|
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
|
return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
|
|
|
|
def replacement(x_i8, scale, zero_point, quant_min, quant_max):
|
|
# scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
|
|
x_i8 = torch.clamp(x_i8, quant_min, quant_max)
|
|
return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
|
|
"""
|
|
if exclude_literals is None:
|
|
exclude_literals = []
|
|
|
|
if literal_to_ph_idx is None:
|
|
literal_to_ph_idx = {}
|
|
|
|
phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
|
|
|
for node in gm.graph.nodes:
|
|
if node.op != "call_function":
|
|
continue
|
|
new_args = []
|
|
for arg in node.args:
|
|
if _is_literal(arg) and arg not in exclude_literals and arg in literal_to_ph_idx:
|
|
ph_idx = literal_to_ph_idx[arg]
|
|
ph_node = phs[ph_idx]
|
|
new_args.append(ph_node)
|
|
else:
|
|
new_args.append(arg)
|
|
new_args = tuple(new_args)
|
|
node.args = new_args
|
|
return gm
|
|
|
|
# TODO: Handle this in export itself and don't wrap the model in another GraphModule
|
|
# in prepare and convert
|
|
def _disallow_eval_train(model: GraphModule):
|
|
"""
|
|
Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
|
|
This is useful for exported models, where these methods don't actually behave as expected.
|
|
"""
|
|
error_message = \
|
|
"""
|
|
Calling train() or eval() is not supported for exported models.
|
|
Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.
|
|
|
|
If you cannot replace the calls to `model.train()` and `model.eval()`, you may override
|
|
the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,
|
|
which does the above automatically for you. Note that this has limited effect on switching
|
|
behavior between train and eval modes, and should be used only for special ops such as dropout
|
|
and batchnorm.
|
|
"""
|
|
|
|
def _train(self, mode: bool = True):
|
|
raise NotImplementedError(error_message)
|
|
|
|
def _eval(self, mode: bool = True):
|
|
raise NotImplementedError(error_message)
|
|
|
|
model.train = types.MethodType(_train, model) # type: ignore[method-assign]
|
|
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
|
|
return model
|