bug fix in quantization aware training auto create graph
This commit is contained in:
parent
d383ade6f9
commit
3446940142
|
@ -855,7 +855,7 @@ class ActQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.act = activation
|
||||
self.act = activation()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.act(x)
|
||||
|
@ -921,7 +921,7 @@ class HSwishQuant(_QuantActivation):
|
|||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
if isinstance(activation, nn.HSwish):
|
||||
self.act = activation
|
||||
self.act = activation()
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSwish`")
|
||||
|
||||
|
@ -990,7 +990,7 @@ class HSigmoidQuant(_QuantActivation):
|
|||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
if isinstance(activation, nn.HSwish):
|
||||
self.act = activation
|
||||
self.act = activation()
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSigmoid`")
|
||||
|
||||
|
|
|
@ -114,7 +114,6 @@ class ConvertToQuantNetwork:
|
|||
def run(self):
|
||||
self.network.update_cell_prefix()
|
||||
network = self._convert_subcells2quant(self.network)
|
||||
network = _AddFakeQuantInput(network)
|
||||
self.network.update_cell_type("quant")
|
||||
return network
|
||||
|
||||
|
@ -275,16 +274,20 @@ class ExportToQuantInferNetwork:
|
|||
Args:
|
||||
network (Cell): MindSpore network API `convert_quant_network`.
|
||||
inputs (Tensor): Input tensors of the `quantization aware training network`.
|
||||
mean (int): Input data mean. Default: 127.5.
|
||||
std_dev (int, float): Input data variance. Default: 127.5.
|
||||
|
||||
Returns:
|
||||
Cell, GEIR backend Infer network.
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self,
|
||||
network,
|
||||
*inputs):
|
||||
def __init__(self, network, mean, std_dev, *inputs):
|
||||
network = validator.check_isinstance('network', network, (nn.Cell,))
|
||||
# quantize for inputs: q = f / scale + zero_point
|
||||
# dequantize for outputs: f = (q - zero_point) * scale
|
||||
self.input_scale = round(mean)
|
||||
self.input_zero_point = 1 / std_dev
|
||||
self.data_type = mstype.int8
|
||||
self.network = copy.deepcopy(network)
|
||||
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
|
||||
|
@ -395,7 +398,7 @@ class ExportToQuantInferNetwork:
|
|||
return network
|
||||
|
||||
|
||||
def export(network, *inputs, file_name, file_format='GEIR'):
|
||||
def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='GEIR'):
|
||||
"""
|
||||
Exports MindSpore quantization predict model to deploy with GEIR.
|
||||
|
||||
|
@ -403,12 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'):
|
|||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `quantization aware training network`.
|
||||
file_name (str): File name of model to export.
|
||||
mean (int): Input data mean. Default: 127.5.
|
||||
std_dev (int, float): Input data variance. Default: 127.5.
|
||||
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
|
||||
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
|
||||
"""
|
||||
supported_device = ["Ascend"]
|
||||
supported_formats = ['GEIR']
|
||||
|
||||
mean = validator.check_type("mean", mean, (int, float))
|
||||
std_dev = validator.check_type("std_dev", std_dev, (int, float))
|
||||
|
||||
if context.get_context('device_target') not in supported_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
|
@ -418,7 +426,7 @@ def export(network, *inputs, file_name, file_format='GEIR'):
|
|||
network.set_train(False)
|
||||
|
||||
if file_format == 'GEIR':
|
||||
exporter = ExportToQuantInferNetwork(network, *inputs)
|
||||
exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
|
||||
deploy_net = exporter.run()
|
||||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
|
||||
|
||||
|
|
Loading…
Reference in New Issue