open graph_kernel for wide & deep ps mode

This commit is contained in:
zengzitao 2021-05-18 14:38:07 +08:00
parent 7820e12396
commit f3f4833f61
2 changed files with 7 additions and 0 deletions

View File

@ -34,6 +34,7 @@ from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def get_wide_deep_net(config):
"""
Get network of wide&deep model.
@ -54,6 +55,7 @@ class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
@ -162,4 +164,6 @@ if __name__ == "__main__":
if wide_deep_config.sparse:
context.set_context(enable_sparse=True)
if wide_deep_config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
train_and_eval(wide_deep_config)

View File

@ -47,6 +47,7 @@ class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
@ -124,6 +125,8 @@ if __name__ == "__main__":
wide_deep_config.sparse = True
if wide_deep_config.sparse:
context.set_context(enable_sparse=True)
if wide_deep_config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
context.set_ps_context(enable_ps=True)
train_and_eval(wide_deep_config)