From a3cee90b2597b40b6acaf86910d51aed61d02969 Mon Sep 17 00:00:00 2001 From: caojian05 Date: Thu, 28 May 2020 19:07:26 +0800 Subject: [PATCH] LSTM network adapt to cpu target. --- mindspore/model_zoo/lstm.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/mindspore/model_zoo/lstm.py b/mindspore/model_zoo/lstm.py index 35fe6743030..7368bbf8e5a 100644 --- a/mindspore/model_zoo/lstm.py +++ b/mindspore/model_zoo/lstm.py @@ -17,7 +17,7 @@ import math import numpy as np -from mindspore import Parameter, Tensor, nn +from mindspore import Parameter, Tensor, nn, context, ParameterTuple from mindspore.common.initializer import initializer from mindspore.ops import operations as P @@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): if bidirectional: num_directions = 2 + if context.get_context("device_target") == "CPU": + h_list = [] + c_list = [] + for i in range(num_layers): + hi = Parameter(initializer( + Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), + [num_directions, batch_size, hidden_size] + ), name='h' + str(i)) + h_list.append(hi) + ci = Parameter(initializer( + Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), + [num_directions, batch_size, hidden_size] + ), name='c' + str(i)) + c_list.append(ci) + h = ParameterTuple(tuple(h_list)) + c = ParameterTuple(tuple(c_list)) + return h, c + h = Tensor( np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) c = Tensor(