forked from mindspore-Ecosystem/mindspore
!22988 [bug][Ascend] fix accuracy loss of yolov3_resnet18 post training quantization
Merge pull request !22988 from chenzhuo/master
This commit is contained in:
commit
aedf54f585
|
@ -21,6 +21,7 @@ from amct_mindspore.quantize_tool import create_quant_config
|
|||
from amct_mindspore.quantize_tool import quantize_model
|
||||
from amct_mindspore.quantize_tool import save_model
|
||||
import mindspore as ms
|
||||
import mindspore.ops as ops
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
@ -47,8 +48,17 @@ def quant_yolov3_resnet(network, dataset, input_data):
|
|||
calibration_network.set_train(False)
|
||||
|
||||
# step4: perform the evaluation of network to do activation calibration
|
||||
concat = ops.Concat()
|
||||
index = 0
|
||||
image_data = []
|
||||
for data in dataset.create_dict_iterator(num_epochs=1):
|
||||
_ = calibration_network(data["image"], data["image_shape"])
|
||||
index += 1
|
||||
if index == 1:
|
||||
image_data = data["image"]
|
||||
else:
|
||||
image_data = concat((image_data, data["image"]))
|
||||
if index == dataset.get_dataset_size():
|
||||
_ = calibration_network(image_data, data["image_shape"])
|
||||
|
||||
# step5: export the air file
|
||||
save_model("results/yolov3_resnet_quant", calibration_network, *input_data)
|
||||
|
@ -87,7 +97,7 @@ def export_yolov3_resnet():
|
|||
else:
|
||||
print("image_dir or anno_path not exits")
|
||||
datasets = create_yolo_dataset(mindrecord_file, is_training=False)
|
||||
ds = datasets.take(1)
|
||||
ds = datasets.take(16)
|
||||
quant_yolov3_resnet(eval_net, ds, inputs)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue