LSTM network adapt to cpu target.

This commit is contained in:
caojian05 2020-05-28 19:07:26 +08:00
parent 65fe160845
commit a3cee90b25
1 changed files with 19 additions and 1 deletions

View File

@ -17,7 +17,7 @@ import math
import numpy as np import numpy as np
from mindspore import Parameter, Tensor, nn from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2
if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
return h, c
h = Tensor( h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor( c = Tensor(