Enable acceleration by graph kernel for LSTM model on GPU device.

This commit is contained in:
y00451588 2021-04-26 17:06:14 +08:00
parent d6f58cb765
commit 5dd26694d4
1 changed files with 5 additions and 0 deletions

View File

@ -51,11 +51,16 @@ if __name__ == '__main__':
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
parser.add_argument("--enable_graph_kernel", type=str, default="true", choices=["true", "false"],
help="Accelerate by graph kernel, default is true.")
args = parser.parse_args()
_enable_graph_kernel = args.enable_graph_kernel == "true" and args.device_target == "GPU"
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
enable_graph_kernel=_enable_graph_kernel,
device_target=args.device_target)
rank = 0