From 40580cc7951a1c18ac67d4fcb432385c0030dd96 Mon Sep 17 00:00:00 2001 From: chenfei Date: Sat, 29 Aug 2020 10:13:44 +0800 Subject: [PATCH] rm bool arg of script --- model_zoo/official/cv/lenet_quant/eval_quant.py | 4 +--- model_zoo/official/cv/lenet_quant/export.py | 2 -- model_zoo/official/cv/lenet_quant/train_quant.py | 4 +--- model_zoo/official/cv/mobilenetv2_quant/eval.py | 6 ++---- .../cv/mobilenetv2_quant/scripts/run_infer_quant.sh | 1 - 5 files changed, 4 insertions(+), 13 deletions(-) diff --git a/model_zoo/official/cv/lenet_quant/eval_quant.py b/model_zoo/official/cv/lenet_quant/eval_quant.py index f545a8a23a8..3aca04b7d31 100644 --- a/model_zoo/official/cv/lenet_quant/eval_quant.py +++ b/model_zoo/official/cv/lenet_quant/eval_quant.py @@ -38,8 +38,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide path where the trained ckpt file') -parser.add_argument('--dataset_sink_mode', type=bool, default=True, - help='dataset_sink_mode is False or True') args = parser.parse_args() if __name__ == "__main__": @@ -67,5 +65,5 @@ if __name__ == "__main__": raise ValueError("Load param into net fail!") print("============== Starting Testing ==============") - acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) + acc = model.eval(ds_eval, dataset_sink_mode=True) print("============== {} ==============".format(acc)) diff --git a/model_zoo/official/cv/lenet_quant/export.py b/model_zoo/official/cv/lenet_quant/export.py index b3fd007fed1..4fad84c9eb6 100644 --- a/model_zoo/official/cv/lenet_quant/export.py +++ b/model_zoo/official/cv/lenet_quant/export.py @@ -36,8 +36,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide path where the trained ckpt file') -parser.add_argument('--dataset_sink_mode', type=bool, default=True, - help='dataset_sink_mode is False or True') args = parser.parse_args() if __name__ == "__main__": diff --git a/model_zoo/official/cv/lenet_quant/train_quant.py b/model_zoo/official/cv/lenet_quant/train_quant.py index 51d37cc1bfa..2e9654d2bed 100644 --- a/model_zoo/official/cv/lenet_quant/train_quant.py +++ b/model_zoo/official/cv/lenet_quant/train_quant.py @@ -41,8 +41,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data", help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide path where the trained ckpt file') -parser.add_argument('--dataset_sink_mode', type=bool, default=True, - help='dataset_sink_mode is False or True') args = parser.parse_args() if __name__ == "__main__": @@ -76,5 +74,5 @@ if __name__ == "__main__": print("============== Starting Training ==============") model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], - dataset_sink_mode=args.dataset_sink_mode) + dataset_sink_mode=True) print("============== End Training ==============") diff --git a/model_zoo/official/cv/mobilenetv2_quant/eval.py b/model_zoo/official/cv/mobilenetv2_quant/eval.py index e6b0875c75f..427b3abdbf9 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/eval.py +++ b/model_zoo/official/cv/mobilenetv2_quant/eval.py @@ -32,7 +32,6 @@ parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--device_target', type=str, default=None, help='Run device target') -parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training') args_opt = parser.parse_args() if __name__ == '__main__': @@ -51,9 +50,8 @@ if __name__ == '__main__': # define fusion network network = mobilenetV2(num_classes=config_device_target.num_classes) - if args_opt.quantization_aware: - # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # define network loss loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') diff --git a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer_quant.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer_quant.sh index f8f3c106199..308723af2ae 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer_quant.sh +++ b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_infer_quant.sh @@ -50,5 +50,4 @@ python ${BASEPATH}/../eval.py \ --device_target=$1 \ --dataset_path=$2 \ --checkpoint_path=$3 \ - --quantization_aware=True \ &> infer.log & # dataset val folder path