forked from mindspore-Ecosystem/mindspore
!9480 fix eval symmetric bug
From: @xiaoyisd Reviewed-by: @liangchenghui,@chujinjin Signed-off-by: @liangchenghui
This commit is contained in:
commit
57da31bfdd
|
@ -41,10 +41,12 @@ if __name__ == '__main__':
|
|||
config_device_target = config_ascend_quant
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||
device_id=device_id, save_graphs=False)
|
||||
symmetric_list = [True, False]
|
||||
elif args_opt.device_target == "GPU":
|
||||
config_device_target = config_gpu_quant
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU",
|
||||
device_id=device_id, save_graphs=False)
|
||||
symmetric_list = [False, False]
|
||||
else:
|
||||
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
|
||||
|
||||
|
@ -53,7 +55,7 @@ if __name__ == '__main__':
|
|||
# convert fusion network to quantization aware network
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
symmetric=symmetric_list)
|
||||
network = quantizer.quantize(network)
|
||||
# define network loss
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
|
Loading…
Reference in New Issue