!9480 fix eval symmetric bug

From: @xiaoyisd
Reviewed-by: @liangchenghui,@chujinjin
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2020-12-09 14:34:24 +08:00 committed by Gitee
commit 57da31bfdd
1 changed files with 3 additions and 1 deletions

View File

@ -41,10 +41,12 @@ if __name__ == '__main__':
config_device_target = config_ascend_quant
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
symmetric_list = [True, False]
elif args_opt.device_target == "GPU":
config_device_target = config_gpu_quant
context.set_context(mode=context.GRAPH_MODE, device_target="GPU",
device_id=device_id, save_graphs=False)
symmetric_list = [False, False]
else:
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
@ -53,7 +55,7 @@ if __name__ == '__main__':
# convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
symmetric=symmetric_list)
network = quantizer.quantize(network)
# define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')