forked from mindspore-Ecosystem/mindspore
use CudnnGRU for GRU GPU version
This commit is contained in:
parent
ed0400e8d3
commit
d95eec0951
|
@ -25,6 +25,7 @@ from mindspore.common.parameter import ParameterTuple, Parameter
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.ops.operations._rl_inner_ops import CudnnGRU
|
||||
from .rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
|
||||
from .rnn_utils import _Reverse, _ReverseSequence
|
||||
|
||||
|
@ -194,12 +195,41 @@ class _DynamicRNNTanh(_DynamicRNNBase):
|
|||
super().__init__(mode)
|
||||
|
||||
|
||||
class _DynamicGRUCPUGPU(_DynamicRNNBase):
|
||||
class _DynamicGRUCPUGPU(Cell):
|
||||
'''Dynamic GRU module on CPU and GPU'''
|
||||
|
||||
def __init__(self):
|
||||
mode = 'GRU'
|
||||
super().__init__(mode)
|
||||
super().__init__()
|
||||
self.concat = P.Concat()
|
||||
self.is_gpu = context.get_context("device_target") == "GPU"
|
||||
|
||||
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
||||
gate_size, input_size = w_ih.shape
|
||||
hidden_size = gate_size // 3
|
||||
if self.is_gpu and seq_length is None:
|
||||
if b_ih is None:
|
||||
weights = self.concat((
|
||||
w_ih.view(-1, 1, 1),
|
||||
w_hh.view(-1, 1, 1)
|
||||
))
|
||||
has_bias = False
|
||||
else:
|
||||
has_bias = True
|
||||
weights = self.concat((
|
||||
w_ih.view(-1, 1, 1),
|
||||
w_hh.view(-1, 1, 1),
|
||||
b_ih.view(-1, 1, 1),
|
||||
b_hh.view(-1, 1, 1)
|
||||
))
|
||||
output, h_n, _, _ = CudnnGRU(input_size, hidden_size, 1, has_bias, False, 0.0)(
|
||||
x,
|
||||
h_0.view(1, *h_0.shape),
|
||||
weights.astype(x.dtype)
|
||||
)
|
||||
else:
|
||||
output, h_n = _DynamicRNNBase('GRU')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh)
|
||||
|
||||
return output, h_n
|
||||
|
||||
|
||||
class _DynamicGRUAscend(Cell):
|
||||
|
|
Loading…
Reference in New Issue