From 2083307de4aeab575c57a7d807489be0425bfe0f Mon Sep 17 00:00:00 2001 From: chenfei Date: Thu, 13 Aug 2020 17:02:21 +0800 Subject: [PATCH] fix bug of mobilenetv2 quant eval --- model_zoo/official/cv/mobilenetv2_quant/eval.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/cv/mobilenetv2_quant/eval.py b/model_zoo/official/cv/mobilenetv2_quant/eval.py index 0976abbe99b..cfa873a98d3 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/eval.py +++ b/model_zoo/official/cv/mobilenetv2_quant/eval.py @@ -25,7 +25,8 @@ from mindspore.train.quant import quant from src.mobilenetV2 import mobilenetV2 from src.dataset import create_dataset -from src.config import config_ascend +from src.config import config_ascend_quant +from src.config import config_gpu_quant parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') @@ -36,12 +37,15 @@ args_opt = parser.parse_args() if __name__ == '__main__': config_device_target = None + device_id = int(os.getenv('DEVICE_ID')) if args_opt.device_target == "Ascend": - config_device_target = config_ascend - device_id = int(os.getenv('DEVICE_ID')) + config_device_target = config_ascend_quant context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) - + elif args_opt.device_target == "GPU": + config_device_target = config_gpu_quant + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", + device_id=device_id, save_graphs=False) else: raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))