!22988 [bug][Ascend] fix accuracy loss of yolov3_resnet18 post training quantization

Merge pull request !22988 from chenzhuo/master
This commit is contained in:
i-robot 2021-09-07 11:07:17 +00:00 committed by Gitee
commit aedf54f585
1 changed files with 12 additions and 2 deletions

View File

@ -21,6 +21,7 @@ from amct_mindspore.quantize_tool import create_quant_config
from amct_mindspore.quantize_tool import quantize_model
from amct_mindspore.quantize_tool import save_model
import mindspore as ms
import mindspore.ops as ops
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
@ -47,8 +48,17 @@ def quant_yolov3_resnet(network, dataset, input_data):
calibration_network.set_train(False)
# step4: perform the evaluation of network to do activation calibration
concat = ops.Concat()
index = 0
image_data = []
for data in dataset.create_dict_iterator(num_epochs=1):
_ = calibration_network(data["image"], data["image_shape"])
index += 1
if index == 1:
image_data = data["image"]
else:
image_data = concat((image_data, data["image"]))
if index == dataset.get_dataset_size():
_ = calibration_network(image_data, data["image_shape"])
# step5: export the air file
save_model("results/yolov3_resnet_quant", calibration_network, *input_data)
@ -87,7 +97,7 @@ def export_yolov3_resnet():
else:
print("image_dir or anno_path not exits")
datasets = create_yolo_dataset(mindrecord_file, is_training=False)
ds = datasets.take(1)
ds = datasets.take(16)
quant_yolov3_resnet(eval_net, ds, inputs)