LSTM network adapt to cpu target.
This commit is contained in:
parent
65fe160845
commit
a3cee90b25
|
@ -17,7 +17,7 @@ import math
|
|||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Parameter, Tensor, nn
|
||||
from mindspore import Parameter, Tensor, nn, context, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
|
|||
if bidirectional:
|
||||
num_directions = 2
|
||||
|
||||
if context.get_context("device_target") == "CPU":
|
||||
h_list = []
|
||||
c_list = []
|
||||
for i in range(num_layers):
|
||||
hi = Parameter(initializer(
|
||||
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]
|
||||
), name='h' + str(i))
|
||||
h_list.append(hi)
|
||||
ci = Parameter(initializer(
|
||||
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||
[num_directions, batch_size, hidden_size]
|
||||
), name='c' + str(i))
|
||||
c_list.append(ci)
|
||||
h = ParameterTuple(tuple(h_list))
|
||||
c = ParameterTuple(tuple(c_list))
|
||||
return h, c
|
||||
|
||||
h = Tensor(
|
||||
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
c = Tensor(
|
||||
|
|
Loading…
Reference in New Issue