!17144 Enable acceleration by Graph Kernel for VGG16 only GPU.

From: @xixixian
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-29 09:21:40 +08:00 committed by Gitee
commit 43463b9ff1
2 changed files with 6 additions and 3 deletions

View File

@ -119,8 +119,9 @@ def merge_args(args, cloud_args):
def test(cloud_args=None): def test(cloud_args=None):
"""test""" """test"""
args = parse_args(cloud_args) args = parse_args(cloud_args)
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, _enable_graph_kernel = args.device_target == "GPU"
device_target=args.device_target, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=_enable_graph_kernel,
enable_auto_mixed_precision=True, device_target=args.device_target, save_graphs=False)
if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend": if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) context.set_context(device_id=int(os.getenv('DEVICE_ID')))

View File

@ -126,7 +126,9 @@ def merge_args(args_opt, cloud_args):
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) _enable_graph_kernel = args.device_target == "GPU"
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=_enable_graph_kernel, device_target=args.device_target)
device_num = int(os.environ.get("DEVICE_NUM", 1)) device_num = int(os.environ.get("DEVICE_NUM", 1))
if args.is_distributed: if args.is_distributed:
if args.device_target == "Ascend": if args.device_target == "Ascend":