diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py index 19fd171d309..0af156d3e10 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py @@ -34,6 +34,7 @@ from src.config import WideDeepConfig sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + def get_wide_deep_net(config): """ Get network of wide&deep model. @@ -54,6 +55,7 @@ class ModelBuilder(): """ ModelBuilder """ + def __init__(self): pass @@ -162,4 +164,6 @@ if __name__ == "__main__": if wide_deep_config.sparse: context.set_context(enable_sparse=True) + if wide_deep_config.device_target == "GPU": + context.set_context(enable_graph_kernel=True) train_and_eval(wide_deep_config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py index f051fb055a4..2d7376f8ddb 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py @@ -47,6 +47,7 @@ class ModelBuilder(): """ ModelBuilder """ + def __init__(self): pass @@ -124,6 +125,8 @@ if __name__ == "__main__": wide_deep_config.sparse = True if wide_deep_config.sparse: context.set_context(enable_sparse=True) + if wide_deep_config.device_target == "GPU": + context.set_context(enable_graph_kernel=True) context.set_ps_context(enable_ps=True) train_and_eval(wide_deep_config)