diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 5aec089dfc1..df9d3f136a4 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -892,7 +892,7 @@ def get_bprop_dynamic_rnn(self): return bprop -@bprop_getters.register(inner.DynamicGRUV2) +@bprop_getters.register(P.DynamicGRUV2) def get_bprop_dynamic_gru_v2(self): """Grad definition for `DynamicGRUV2` operation.""" dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip, diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index dd3b99fc700..7cf6960db93 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, - SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, + SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, @@ -119,6 +119,7 @@ __all__ = [ 'Rsqrt', 'Sqrt', 'Square', + 'DynamicGRUV2', 'SquaredDifference', 'Xdivy', 'Xlogy', diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index eb8cb4fdb66..e3b39bcc758 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -529,159 +529,6 @@ class MatrixSetDiag(PrimitiveWithInfer): return assist_shape -class DynamicGRUV2(PrimitiveWithInfer): - r""" - DynamicGRUV2 Operator. - - Args: - direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. - Only 'UNIDIRECTIONAL' is currently supported. - cell_depth (int): An integer identifying the cell depth in the op. Default: 1. - keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. - cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. - num_proj (int): An integer identifying the num proj in the op. Default: 0. - time_major (bool): A bool identifying the time major in the op. Default: True. - activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'. - Only 'tanh' is currently supported. - gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh. - 'zrh' is another option. - reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True. - is_training (bool): A bool identifying is training in the op. Default: True. - - Inputs: - - **x** (Tensor) - Current words. - Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`. - The data type must be float16. - - **weight_input** (Tensor) - Input-hidden weight. - Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`. - The data type must be float16. - - **weight_hidden** (Tensor) - Hidden-hidden weight. - Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. - The data type must be float16. - - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. - The data type must be float16 or float32. - - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. - The data type must be float16 or float32. - - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. - Only `None` is currently supported. - - **init_h** (Tensor) - Hidden state of initial time. - Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`. - The data type must be float16 or float32. - - Outputs: - - **y** (Tensor) - A Tensor of shape :math: - if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, - if num_proj == 0 `(num_step, batch_size, hidden_size)`. - Has the same data type with input `bais_type`. - - **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. - Has the same data type with input `bais_type`. - - **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. - Has the same data type with input `bais_type`. - - **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. - Has the same data type with input `bais_type`. - - **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. - Has the same data type with input `bais_type`. - - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. - Has the same data type with input `bais_type`. - - - If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32. - - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`. - - If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`. - - Examples: - >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) - >>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16)) - >>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16)) - >>> bias_i = Tensor(np.random.rand(48).astype(np.float16)) - >>> bias_h = Tensor(np.random.rand(48).astype(np.float16)) - >>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16)) - >>> dynamic_gru_v2 = ops.DynamicGRUV2() - >>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h) - >>> result = output[0].shape - >>> print(result) - (2, 8, 16) - """ - - @prim_attr_register - def __init__(self, - direction='UNIDIRECTIONAL', - cell_depth=1, - keep_prob=1.0, - cell_clip=-1.0, - num_proj=0, - time_major=True, - activation="tanh", - gate_order="rzh", - reset_after=True, - is_training=True): - self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) - self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) - self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) - self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) - self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) - self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) - self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) - self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) - self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) - self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) - self.add_prim_attr("io_format", "ND") - - def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): - validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) - validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) - validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) - - num_step, batch_size, input_size = x_shape - hidden_size = winput_shape[-1] // 3 - if winput_shape[-1] % 3 != 0: - raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") - - self.placeholder_index = [3, 4, 5] - if binput_shape is not None: - validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) - validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) - self.placeholder_index.remove(3) - if bhidden_shape is not None: - validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name) - validator.check("bias_hidden_shape", bhidden_shape, - "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) - self.placeholder_index.remove(4) - if seq_shape is not None: - raise ValueError(f"For {self.name}, seq_shape should be None.") - - validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) - validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) - validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) - validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", - whidden_shape[-1], Rel.EQ, self.name) - validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) - validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name) - if self.num_proj > 0: - y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) - else: - y_shape = (num_step, batch_size, hidden_size) - out_shape = (num_step, batch_size, hidden_size) - self.add_prim_attr("placeholder_index", self.placeholder_index) - return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape - - def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): - validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) - validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) - validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) - validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name) - b_dtype = mstype.float32 - if binput_dtype is not None: - validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, - (mstype.float16, mstype.float32), self.name) - b_dtype = binput_dtype - elif bhidden_dtype is not None: - validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, - (mstype.float16, mstype.float32), self.name) - b_dtype = bhidden_dtype - - return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype - - class ConfusionMulGrad(PrimitiveWithInfer): """ `output0` is the dot product result of input0 and input1. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2a585f0a833..1c80c1ae8e9 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6403,32 +6403,18 @@ class DynamicRNN(PrimitiveWithInfer): - **tanhct** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`). Has the same type with input `b`. + Supported Platforms: + ``Ascend`` + Examples: - >>> import mindspore - >>> import mindspore.nn as nn - >>> import numpy as np - >>> from mindspore import Parameter - >>> from mindspore import Tensor - >>> from mindspore.ops import operations as ops - >>> import mindspore.context as context - >>> context.set_context(mode=context.GRAPH_MODE) - >>> class DynamicRNNNet(nn.Cell): - >>> def __init__(self): - >>> super(DynamicRNNNet, self).__init__() - >>> self.dynamic_rnn = ops.DynamicRNN() - >>> - >>> def construct(self, x, w, b, init_h, init_c): - >>> out = self.dynamic_rnn(x, w, b, None, init_h, init_c) - >>> return out - >>> >>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16)) >>> w = Tensor(np.random.rand(96, 128).astype(np.float16)) >>> b = Tensor(np.random.rand(128).astype(np.float16)) >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) - >>> net = DynamicRNNNet() - >>> output = net(x, w, b, init_h, init_c) - >>> output[0].shape + >>> dynamic_rnn = ops.DynamicRNNN() + >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) + >>> print(output[0].shape) (2, 16, 32) """ @@ -6493,6 +6479,161 @@ class DynamicRNN(PrimitiveWithInfer): return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype +class DynamicGRUV2(PrimitiveWithInfer): + r""" + DynamicGRUV2 Operator. + + Args: + direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. + Only 'UNIDIRECTIONAL' is currently supported. + cell_depth (int): An integer identifying the cell depth in the op. Default: 1. + keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. + cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. + num_proj (int): An integer identifying the num proj in the op. Default: 0. + time_major (bool): A bool identifying the time major in the op. Default: True. + activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'. + Only 'tanh' is currently supported. + gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh. + 'zrh' is another option. + reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True. + is_training (bool): A bool identifying is training in the op. Default: True. + + Inputs: + - **x** (Tensor) - Current words. + Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`. + The data type must be float16. + - **weight_input** (Tensor) - Input-hidden weight. + Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`. + The data type must be float16. + - **weight_hidden** (Tensor) - Hidden-hidden weight. + Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. + The data type must be float16. + - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. + The data type must be float16 or float32. + - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. + The data type must be float16 or float32. + - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. + Only `None` is currently supported. + - **init_h** (Tensor) - Hidden state of initial time. + Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`. + The data type must be float16 or float32. + + Outputs: + - **y** (Tensor) - A Tensor of shape :math: + if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, + if num_proj == 0 `(num_step, batch_size, hidden_size)`. + Has the same data type with input `bais_type`. + - **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. + Has the same data type with input `bais_type`. + + - If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32. + - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`. + - If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) + >>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16)) + >>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16)) + >>> bias_i = Tensor(np.random.rand(48).astype(np.float16)) + >>> bias_h = Tensor(np.random.rand(48).astype(np.float16)) + >>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16)) + >>> dynamic_gru_v2 = ops.DynamicGRUV2() + >>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h) + >>> print(output[0].shape) + (2, 8, 16) + """ + + @prim_attr_register + def __init__(self, + direction='UNIDIRECTIONAL', + cell_depth=1, + keep_prob=1.0, + cell_clip=-1.0, + num_proj=0, + time_major=True, + activation="tanh", + gate_order="rzh", + reset_after=True, + is_training=True): + self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) + self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) + self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) + self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) + self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) + self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) + self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) + self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) + self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) + self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) + self.add_prim_attr("io_format", "ND") + + def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): + validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) + validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) + validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) + + num_step, batch_size, input_size = x_shape + hidden_size = winput_shape[-1] // 3 + if winput_shape[-1] % 3 != 0: + raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") + + self.placeholder_index = [3, 4, 5] + if binput_shape is not None: + validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) + validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) + self.placeholder_index.remove(3) + if bhidden_shape is not None: + validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name) + validator.check("bias_hidden_shape", bhidden_shape, + "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) + self.placeholder_index.remove(4) + if seq_shape is not None: + raise ValueError(f"For {self.name}, seq_shape should be None.") + + validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) + validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) + validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) + validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", + whidden_shape[-1], Rel.EQ, self.name) + validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) + validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name) + if self.num_proj > 0: + y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) + else: + y_shape = (num_step, batch_size, hidden_size) + out_shape = (num_step, batch_size, hidden_size) + self.add_prim_attr("placeholder_index", self.placeholder_index) + return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape + + def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): + validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) + validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) + validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) + validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name) + b_dtype = mstype.float32 + if binput_dtype is not None: + validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, + (mstype.float16, mstype.float32), self.name) + b_dtype = binput_dtype + elif bhidden_dtype is not None: + validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, + (mstype.float16, mstype.float32), self.name) + b_dtype = bhidden_dtype + + return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype + + class InTopK(PrimitiveWithInfer): r""" Determines whether the targets are in the top `k` predictions. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 96e321efeec..6804924a9c9 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -822,7 +822,7 @@ class DynamicGRUV2Net(nn.Cell): def __init__(self): super(DynamicGRUV2Net, self).__init__() - self.dynamic_gru = inner.DynamicGRUV2() + self.dynamic_gru = P.DynamicGRUV2() def construct(self, x, w_i, w_h, b_i, b_h, init_h): return self.dynamic_gru(x, w_i, w_h, b_i, b_h, None, init_h)