rm bool arg of script

This commit is contained in:
chenfei 2020-08-29 10:13:44 +08:00
parent 8f69fb415a
commit 40580cc795
5 changed files with 4 additions and 13 deletions

View File

@ -38,8 +38,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args()
if __name__ == "__main__":
@ -67,5 +65,5 @@ if __name__ == "__main__":
raise ValueError("Load param into net fail!")
print("============== Starting Testing ==============")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== {} ==============".format(acc))

View File

@ -36,8 +36,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args()
if __name__ == "__main__":

View File

@ -41,8 +41,6 @@ parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args()
if __name__ == "__main__":
@ -76,5 +74,5 @@ if __name__ == "__main__":
print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
dataset_sink_mode=True)
print("============== End Training ==============")

View File

@ -32,7 +32,6 @@ parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training')
args_opt = parser.parse_args()
if __name__ == '__main__':
@ -51,9 +50,8 @@ if __name__ == '__main__':
# define fusion network
network = mobilenetV2(num_classes=config_device_target.num_classes)
if args_opt.quantization_aware:
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')

View File

@ -50,5 +50,4 @@ python ${BASEPATH}/../eval.py \
--device_target=$1 \
--dataset_path=$2 \
--checkpoint_path=$3 \
--quantization_aware=True \
&> infer.log & # dataset val folder path