diff --git a/example/resnet101_imagenet2012/train.py b/example/resnet101_imagenet2012/train.py index 1401a340005..4ebd9e44193 100755 --- a/example/resnet101_imagenet2012/train.py +++ b/example/resnet101_imagenet2012/train.py @@ -48,7 +48,8 @@ args_opt = parser.parse_args() device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id, + enable_auto_mixed_precision=True) if __name__ == '__main__': if not args_opt.do_eval and args_opt.run_distribute: diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py index 86a373c2dc8..b37c794822c 100755 --- a/example/resnet50_cifar10/train.py +++ b/example/resnet50_cifar10/train.py @@ -41,8 +41,8 @@ args_opt = parser.parse_args() device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) -context.set_context(device_id=device_id) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id, + enable_auto_mixed_precision=True) if __name__ == '__main__': if not args_opt.do_eval and args_opt.run_distribute: diff --git a/example/resnet50_imagenet2012/train.py b/example/resnet50_imagenet2012/train.py index 2d39f58cae4..9b3fc7573c6 100755 --- a/example/resnet50_imagenet2012/train.py +++ b/example/resnet50_imagenet2012/train.py @@ -43,8 +43,9 @@ args_opt = parser.parse_args() device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) -context.set_context(device_id=device_id) + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id, + enable_auto_mixed_precision=True) if __name__ == '__main__': if not args_opt.do_eval and args_opt.run_distribute: diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 13d4696693e..add27f187cb 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -72,7 +72,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { enable_mem_reuse_ = true; enable_gpu_summary_ = true; precompile_only_ = false; - auto_mixed_precision_flag_ = true; + auto_mixed_precision_flag_ = false; enable_pynative_infer_ = false; enable_dynamic_mem_pool_ = true; graph_memory_max_size_ = "0";