Remove the call to GRUV2 in nn.GRU

This commit is contained in:
liuluobin 2022-11-22 15:57:38 +08:00
parent ac3cb5a914
commit 820b19fa3f
2 changed files with 61 additions and 92 deletions

View File

@ -28,7 +28,7 @@ from mindspore.common.parameter import ParameterTuple, Parameter
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.ops.operations._rl_inner_ops import CudnnGRU, GRUV2 from mindspore.ops.operations._rl_inner_ops import CudnnGRU
from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
from mindspore.nn.layer.rnn_utils import _Reverse, _ReverseSequence from mindspore.nn.layer.rnn_utils import _Reverse, _ReverseSequence
@ -237,19 +237,15 @@ class _DynamicGRUCPUGPU(Cell):
b_ih.view(-1, 1, 1), b_ih.view(-1, 1, 1),
b_hh.view(-1, 1, 1) b_hh.view(-1, 1, 1)
)) ))
if seq_length is None: output, h_n, _, _ = CudnnGRU(input_size, hidden_size, 1, has_bias, False, 0.0)(
output, h_n, _, _ = CudnnGRU(input_size, hidden_size, 1, has_bias, False, 0.0)( x,
x, h_0.view(1, *h_0.shape),
h_0.view(1, *h_0.shape), weights.astype(x.dtype)
weights.astype(x.dtype) )
) if seq_length is not None:
else: h_n = get_hidden(output, seq_length)
output, h_n, _, _ = GRUV2(input_size, hidden_size, 1, has_bias, False, 0.0, self.training)( mask = sequence_mask(seq_length, x.shape[0])
x, output = select_by_mask(output, mask)
h_0.view(1, *h_0.shape),
weights.astype(x.dtype),
seq_length
)
else: else:
output, h_n = _DynamicRNNBase('GRU')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh) output, h_n = _DynamicRNNBase('GRU')(x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh)

View File

@ -18,28 +18,59 @@ import pytest
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops.operations._rl_inner_ops import GRUV2
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, input_size, hidden_size, num_layers, bidirectional): def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional):
super().__init__() super().__init__()
self.gru_nn = nn.GRU(input_size, hidden_size, num_layers, True, True, 0.0, bidirectional) self.gru_nn = nn.GRU(input_size, hidden_size, num_layers, has_bias, False, 0.0, bidirectional)
def construct(self, x, h, seq_lengths): def construct(self, x, h, seq_lengths):
_, hy = self.gru_nn(x, h, seq_lengths) output, hy = self.gru_nn(x, h, seq_lengths)
return hy return output, hy
class NetGruV2(nn.Cell):
def __init__(self, input_size, hidden_size, num_layers, has_bias, weights, is_train):
super(NetGruV2, self).__init__()
self.gruv2 = GRUV2(input_size, hidden_size, num_layers, has_bias, False, 0.0, is_train)
self.weights = weights
def construct(self, x, h_0, seq_len):
output, h_n, _, _ = self.gruv2(x, h_0, self.weights.astype(x.dtype), seq_len)
return output, h_n
def get_weights_from_gru(gru_nn, has_bias):
if has_bias:
weights = ops.concat((
gru_nn.w_ih_list[0].view(-1, 1, 1),
gru_nn.w_hh_list[0].view(-1, 1, 1),
gru_nn.b_ih_list[0].view(-1, 1, 1),
gru_nn.b_hh_list[0].view(-1, 1, 1)
))
else:
weights = ops.concat((
gru_nn.w_ih_list[0].view(-1, 1, 1),
gru_nn.w_hh_list[0].view(-1, 1, 1),
))
return weights
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize("is_train", [True, False]) @pytest.mark.parametrize("is_train", [True, False])
def test_gruv2_op_float32_1(is_train): @pytest.mark.parametrize("dtype", [ms.float16, ms.float32])
def test_gruv2_op(has_bias, is_train, dtype):
""" """
Feature: test GRUV2 with using float32 Feature: test GRUV2
Description: num_layers=1, bidirectional=False Description: num_layers=1, bidirectional=False
Expectation: the result match with expect. Expectation: the result is equal to nn.GRU.
""" """
batch_size = 3 batch_size = 3
max_seq_length = 5 max_seq_length = 5
@ -50,76 +81,17 @@ def test_gruv2_op_float32_1(is_train):
num_directions = 2 if bidirectional else 1 num_directions = 2 if bidirectional else 1
seq_lengths = Tensor([5, 3, 2], ms.int32) seq_lengths = Tensor([5, 3, 2], ms.int32)
np.random.seed(1) x = Tensor(np.random.normal(0.0, 1.0, (max_seq_length, batch_size, input_size)), dtype)
x = Tensor(np.random.normal(0.0, 1.0, (batch_size, max_seq_length, input_size)), ms.float32) h0 = Tensor(np.random.normal(0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), dtype)
h0 = Tensor(np.random.normal(0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), ms.float32) net = Net(input_size, hidden_size, num_layers, has_bias, bidirectional).set_train(is_train)
net = Net(input_size, hidden_size, num_layers, bidirectional) weights = get_weights_from_gru(net.gru_nn, has_bias)
net.set_train(is_train) gruv2_net = NetGruV2(input_size, hidden_size, num_layers, has_bias, weights, is_train)
me_hy = net(x, h0, seq_lengths).asnumpy() expect_output, expect_hy = net(x, h0, seq_lengths)
expect_hy = np.array([[[0.23690273, -0.42312058, 0.2012992], me_output, me_hy = gruv2_net(x, h0, seq_lengths)
[0.5544311, -0.28084755, -0.03353014],
[0.12614538, -0.26933774, 0.11727069]]], np.float32)
assert np.allclose(me_hy, expect_hy, 0.0001, 0.0001)
rtol, atol = (1e-3, 1e-3) if dtype == ms.float16 else (1e-4, 1e-4)
@pytest.mark.level0 assert np.allclose(me_output.asnumpy(), expect_output.asnumpy(), rtol, atol)
@pytest.mark.platform_x86_gpu_training assert np.allclose(me_hy.asnumpy(), expect_hy.asnumpy(), rtol, atol)
@pytest.mark.env_onecard
@pytest.mark.parametrize("is_train", [True, False])
def test_gruv2_op_float32_2(is_train):
"""
Feature: test GRUV2 with using float32
Description: num_layers=3, bidirectional=True
Expectation: the result match with expect.
"""
batch_size = 3
max_seq_length = 5
input_size = 10
hidden_size = 1
num_layers = 1
bidirectional = True
num_directions = 2 if bidirectional else 1
seq_lengths = Tensor([5, 3, 2], ms.int32)
np.random.seed(4)
x = Tensor(np.random.normal(0.0, 1.0, (batch_size, max_seq_length, input_size)), ms.float32)
h0 = Tensor(np.random.normal(0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), ms.float32)
net = Net(input_size, hidden_size, num_layers, bidirectional)
net.set_train(is_train)
me_hy = net(x, h0, seq_lengths).asnumpy()
expect_hy = np.array([[[0.32341897], [0.83405745], [0.22347865]], [[-0.40905663], [-0.8938196], [-0.8207804]]],
np.float32)
assert np.allclose(me_hy, expect_hy, 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("is_train", [True, False])
def test_gruv2_op_float16(is_train):
"""
Feature: test GRUV2 with using float16
Description: num_layers=1, bidirectional=False
Expectation: the result match with expect.
"""
batch_size = 3
max_seq_length = 5
input_size = 10
hidden_size = 3
num_layers = 1
bidirectional = False
num_directions = 2 if bidirectional else 1
seq_lengths = Tensor([5, 3, 2], ms.int32)
np.random.seed(1)
x = Tensor(np.random.normal(0.0, 1.0, (batch_size, max_seq_length, input_size)), ms.float16)
h0 = Tensor(np.random.normal(0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), ms.float16)
net = Net(input_size, hidden_size, num_layers, bidirectional)
net.set_train(is_train)
me_hy = net(x, h0, seq_lengths).asnumpy()
expect_hy = np.array([[[0.2368, -0.4233, 0.2017], [0.5547, -0.281, -0.03323], [0.1263, -0.2693, 0.1175]]],
np.float16)
assert np.allclose(me_hy, expect_hy, 0.001, 0.001)
@pytest.mark.level0 @pytest.mark.level0
@ -141,8 +113,9 @@ def test_gruv2_op_float64_exception():
seq_lengths = Tensor([5, 3, 2], ms.int32) seq_lengths = Tensor([5, 3, 2], ms.int32)
np.random.seed(1) np.random.seed(1)
x = Tensor(np.random.normal(0.0, 1.0, (batch_size, max_seq_length, input_size)), ms.float64) x = Tensor(np.random.normal(0.0, 1.0, (max_seq_length, batch_size, input_size)), ms.float64)
h0 = Tensor(np.random.normal(0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), ms.float64) h0 = Tensor(np.random.normal(0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), ms.float64)
net = Net(input_size, hidden_size, num_layers, bidirectional) weights = Tensor(np.random.normal(0.0, 1.0, (3 * hidden_size * (input_size + hidden_size), 1, 1)), ms.float64)
net = NetGruV2(input_size, hidden_size, num_layers, False, weights, False)
with pytest.raises(TypeError): with pytest.raises(TypeError):
net(x, h0, seq_lengths) net(x, h0, seq_lengths)