diff --git a/model_zoo/official/cv/faster_rcnn/train.py b/model_zoo/official/cv/faster_rcnn/train.py index ac8c0415c74..4856f3e6e95 100644 --- a/model_zoo/official/cv/faster_rcnn/train.py +++ b/model_zoo/official/cv/faster_rcnn/train.py @@ -53,6 +53,8 @@ args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) if __name__ == '__main__': + if args_opt.device_target == "GPU": + context.set_context(enable_graph_kernel=True) if args_opt.run_distribute: if args_opt.device_target == "Ascend": rank = args_opt.rank_id