From d95eec0951287c4dd5b1e4ab8c627c535134a253 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Wed, 5 Jan 2022 19:43:37 +0800 Subject: [PATCH] use CudnnGRU for GRU GPU version --- mindspore/python/mindspore/nn/layer/rnns.py | 36 +++++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/mindspore/python/mindspore/nn/layer/rnns.py b/mindspore/python/mindspore/nn/layer/rnns.py index 983c19e6b3c..5647b0d835f 100644 --- a/mindspore/python/mindspore/nn/layer/rnns.py +++ b/mindspore/python/mindspore/nn/layer/rnns.py @@ -25,6 +25,7 @@ from mindspore.common.parameter import ParameterTuple, Parameter from mindspore.nn.cell import Cell from mindspore import log as logger from mindspore._checkparam import Validator as validator +from mindspore.ops.operations._rl_inner_ops import CudnnGRU from .rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell from .rnn_utils import _Reverse, _ReverseSequence @@ -194,12 +195,41 @@ class _DynamicRNNTanh(_DynamicRNNBase): super().__init__(mode) -class _DynamicGRUCPUGPU(_DynamicRNNBase): +class _DynamicGRUCPUGPU(Cell): '''Dynamic GRU module on CPU and GPU''' def __init__(self): - mode = 'GRU' - super().__init__(mode) + super().__init__() + self.concat = P.Concat() + self.is_gpu = context.get_context("device_target") == "GPU" + + def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh): + gate_size, input_size = w_ih.shape + hidden_size = gate_size // 3 + if self.is_gpu and seq_length is None: + if b_ih is None: + weights = self.concat(( + w_ih.view(-1, 1, 1), + w_hh.view(-1, 1, 1) + )) + has_bias = False + else: + has_bias = True + weights = self.concat(( + w_ih.view(-1, 1, 1), + w_hh.view(-1, 1, 1), + b_ih.view(-1, 1, 1), + b_hh.view(-1, 1, 1) + )) + output, h_n, _, _ = CudnnGRU(input_size, hidden_size, 1, has_bias, False, 0.0)( + x, + h_0.view(1, *h_0.shape), + weights.astype(x.dtype) + ) + else: + output, h_n = _DynamicRNNBase('GRU')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh) + + return output, h_n class _DynamicGRUAscend(Cell):