use CudnnGRU for GRU GPU version

This commit is contained in:
lvyufeng 2022-01-05 19:43:37 +08:00
parent ed0400e8d3
commit d95eec0951
1 changed files with 33 additions and 3 deletions

View File

@ -25,6 +25,7 @@ from mindspore.common.parameter import ParameterTuple, Parameter
from mindspore.nn.cell import Cell
from mindspore import log as logger
from mindspore._checkparam import Validator as validator
from mindspore.ops.operations._rl_inner_ops import CudnnGRU
from .rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
from .rnn_utils import _Reverse, _ReverseSequence
@ -194,12 +195,41 @@ class _DynamicRNNTanh(_DynamicRNNBase):
super().__init__(mode)
class _DynamicGRUCPUGPU(_DynamicRNNBase):
class _DynamicGRUCPUGPU(Cell):
'''Dynamic GRU module on CPU and GPU'''
def __init__(self):
mode = 'GRU'
super().__init__(mode)
super().__init__()
self.concat = P.Concat()
self.is_gpu = context.get_context("device_target") == "GPU"
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
gate_size, input_size = w_ih.shape
hidden_size = gate_size // 3
if self.is_gpu and seq_length is None:
if b_ih is None:
weights = self.concat((
w_ih.view(-1, 1, 1),
w_hh.view(-1, 1, 1)
))
has_bias = False
else:
has_bias = True
weights = self.concat((
w_ih.view(-1, 1, 1),
w_hh.view(-1, 1, 1),
b_ih.view(-1, 1, 1),
b_hh.view(-1, 1, 1)
))
output, h_n, _, _ = CudnnGRU(input_size, hidden_size, 1, has_bias, False, 0.0)(
x,
h_0.view(1, *h_0.shape),
weights.astype(x.dtype)
)
else:
output, h_n = _DynamicRNNBase('GRU')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh)
return output, h_n
class _DynamicGRUAscend(Cell):