forked from mindspore-Ecosystem/mindspore
LSTM network adapt to cpu target.
This commit is contained in:
parent
65fe160845
commit
a3cee90b25
|
@ -17,7 +17,7 @@ import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Parameter, Tensor, nn
|
from mindspore import Parameter, Tensor, nn, context, ParameterTuple
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
|
||||||
if bidirectional:
|
if bidirectional:
|
||||||
num_directions = 2
|
num_directions = 2
|
||||||
|
|
||||||
|
if context.get_context("device_target") == "CPU":
|
||||||
|
h_list = []
|
||||||
|
c_list = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
hi = Parameter(initializer(
|
||||||
|
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||||
|
[num_directions, batch_size, hidden_size]
|
||||||
|
), name='h' + str(i))
|
||||||
|
h_list.append(hi)
|
||||||
|
ci = Parameter(initializer(
|
||||||
|
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
|
||||||
|
[num_directions, batch_size, hidden_size]
|
||||||
|
), name='c' + str(i))
|
||||||
|
c_list.append(ci)
|
||||||
|
h = ParameterTuple(tuple(h_list))
|
||||||
|
c = ParameterTuple(tuple(c_list))
|
||||||
|
return h, c
|
||||||
|
|
||||||
h = Tensor(
|
h = Tensor(
|
||||||
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||||
c = Tensor(
|
c = Tensor(
|
||||||
|
|
Loading…
Reference in New Issue