From 678e9f41b8c251cc1cd0c02e7cf744cf11a9776d Mon Sep 17 00:00:00 2001 From: lishanni513 Date: Tue, 11 May 2021 14:35:14 +0800 Subject: [PATCH] Fix bug: Enable Graph Kernel only GPU target --- model_zoo/official/recommend/wide_and_deep/train.py | 5 +++-- model_zoo/official/recommend/wide_and_deep/train_and_eval.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/recommend/wide_and_deep/train.py b/model_zoo/official/recommend/wide_and_deep/train.py index 7366c41cfd2..f10775bbc26 100644 --- a/model_zoo/official/recommend/wide_and_deep/train.py +++ b/model_zoo/official/recommend/wide_and_deep/train.py @@ -88,6 +88,7 @@ def test_train(configure): if __name__ == "__main__": config = WideDeepConfig() config.argparse_init() - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target) + _enable_graph_kernel = config.device_target == "GPU" + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=_enable_graph_kernel, device_target=config.device_target) test_train(config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval.py index 59c797b30dc..df24bff0818 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval.py @@ -105,7 +105,8 @@ def test_train_eval(config): if __name__ == "__main__": wide_deep_config = WideDeepConfig() wide_deep_config.argparse_init() - - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=wide_deep_config.device_target) + _enable_graph_kernel = wide_deep_config.device_target == "GPU" + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=_enable_graph_kernel, device_target=wide_deep_config.device_target) context.set_context(enable_sparse=wide_deep_config.sparse) test_train_eval(wide_deep_config)