From 9a12fab404d1a46c05269eaf2e30f17b6bc9b700 Mon Sep 17 00:00:00 2001 From: chenfei Date: Tue, 18 Aug 2020 22:42:18 +0800 Subject: [PATCH] fix bug of freeze_bn default value --- mindspore/train/quant/quant.py | 2 +- model_zoo/official/cv/lenet_quant/eval_quant.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 5f46fa3d5b8..767dc9bdadc 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -476,7 +476,7 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' def convert_quant_network(network, bn_fold=True, - freeze_bn=1e7, + freeze_bn=10000000, quant_delay=(0, 0), num_bits=(8, 8), per_channel=(False, False), diff --git a/model_zoo/official/cv/lenet_quant/eval_quant.py b/model_zoo/official/cv/lenet_quant/eval_quant.py index 96740dc51fa..9f586d9e198 100644 --- a/model_zoo/official/cv/lenet_quant/eval_quant.py +++ b/model_zoo/official/cv/lenet_quant/eval_quant.py @@ -50,7 +50,8 @@ if __name__ == "__main__": # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, + per_channel=[True, False]) # define loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")