forked from mindspore-Ecosystem/mindspore
open graph_kernel for wide & deep ps mode
This commit is contained in:
parent
7820e12396
commit
f3f4833f61
|
@ -34,6 +34,7 @@ from src.config import WideDeepConfig
|
|||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
def get_wide_deep_net(config):
|
||||
"""
|
||||
Get network of wide&deep model.
|
||||
|
@ -54,6 +55,7 @@ class ModelBuilder():
|
|||
"""
|
||||
ModelBuilder
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
@ -162,4 +164,6 @@ if __name__ == "__main__":
|
|||
|
||||
if wide_deep_config.sparse:
|
||||
context.set_context(enable_sparse=True)
|
||||
if wide_deep_config.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
|
@ -47,6 +47,7 @@ class ModelBuilder():
|
|||
"""
|
||||
ModelBuilder
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
@ -124,6 +125,8 @@ if __name__ == "__main__":
|
|||
wide_deep_config.sparse = True
|
||||
if wide_deep_config.sparse:
|
||||
context.set_context(enable_sparse=True)
|
||||
if wide_deep_config.device_target == "GPU":
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
train_and_eval(wide_deep_config)
|
||||
|
|
Loading…
Reference in New Issue