From 471b03775fb8dc45a109a0ea746e6b781681c21f Mon Sep 17 00:00:00 2001 From: chenzhuo Date: Tue, 7 Sep 2021 09:32:43 +0800 Subject: [PATCH] fix acc loss of yolov3_resnet18 post training quantization --- .../ascend310_quant_infer/post_quant.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/cv/yolov3_resnet18/ascend310_quant_infer/post_quant.py b/model_zoo/official/cv/yolov3_resnet18/ascend310_quant_infer/post_quant.py index 45f0ad637e2..ca00080a41a 100644 --- a/model_zoo/official/cv/yolov3_resnet18/ascend310_quant_infer/post_quant.py +++ b/model_zoo/official/cv/yolov3_resnet18/ascend310_quant_infer/post_quant.py @@ -21,6 +21,7 @@ from amct_mindspore.quantize_tool import create_quant_config from amct_mindspore.quantize_tool import quantize_model from amct_mindspore.quantize_tool import save_model import mindspore as ms +import mindspore.ops as ops from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -47,8 +48,17 @@ def quant_yolov3_resnet(network, dataset, input_data): calibration_network.set_train(False) # step4: perform the evaluation of network to do activation calibration + concat = ops.Concat() + index = 0 + image_data = [] for data in dataset.create_dict_iterator(num_epochs=1): - _ = calibration_network(data["image"], data["image_shape"]) + index += 1 + if index == 1: + image_data = data["image"] + else: + image_data = concat((image_data, data["image"])) + if index == dataset.get_dataset_size(): + _ = calibration_network(image_data, data["image_shape"]) # step5: export the air file save_model("results/yolov3_resnet_quant", calibration_network, *input_data) @@ -87,7 +97,7 @@ def export_yolov3_resnet(): else: print("image_dir or anno_path not exits") datasets = create_yolo_dataset(mindrecord_file, is_training=False) - ds = datasets.take(1) + ds = datasets.take(16) quant_yolov3_resnet(eval_net, ds, inputs)