forked from mindspore-Ecosystem/mindspore
!16223 Enable acceleration by Graph Kernel for Wide&Deep only GPU.
From: @lishanni513 Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
507bf8c5c8
|
@ -88,6 +88,7 @@ def test_train(configure):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = WideDeepConfig()
|
config = WideDeepConfig()
|
||||||
config.argparse_init()
|
config.argparse_init()
|
||||||
|
_enable_graph_kernel = config.device_target == "GPU"
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target)
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
enable_graph_kernel=_enable_graph_kernel, device_target=config.device_target)
|
||||||
test_train(config)
|
test_train(config)
|
||||||
|
|
|
@ -105,7 +105,8 @@ def test_train_eval(config):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
wide_deep_config = WideDeepConfig()
|
wide_deep_config = WideDeepConfig()
|
||||||
wide_deep_config.argparse_init()
|
wide_deep_config.argparse_init()
|
||||||
|
_enable_graph_kernel = wide_deep_config.device_target == "GPU"
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=wide_deep_config.device_target)
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
enable_graph_kernel=_enable_graph_kernel, device_target=wide_deep_config.device_target)
|
||||||
context.set_context(enable_sparse=wide_deep_config.sparse)
|
context.set_context(enable_sparse=wide_deep_config.sparse)
|
||||||
test_train_eval(wide_deep_config)
|
test_train_eval(wide_deep_config)
|
||||||
|
|
Loading…
Reference in New Issue