From 10abb68498d1ec3171cc65463aaba2132c10d7a9 Mon Sep 17 00:00:00 2001 From: liuxiao Date: Fri, 8 May 2020 18:53:38 +0800 Subject: [PATCH] add ops CTCLoss --- mindspore/ccsrc/transform/convert.cc | 6 +- mindspore/ccsrc/transform/op_declare.cc | 16 +++++ mindspore/ccsrc/transform/op_declare.h | 4 ++ mindspore/ops/_grad/grad_nn_ops.py | 13 +++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/nn_ops.py | 77 ++++++++++++++++++++++++- tests/ut/python/ops/test_ops.py | 7 +++ 7 files changed, 123 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 91879fe689f..409fce75e55 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -196,6 +196,8 @@ const char kNameBatchToSpace[] = "BatchToSpace"; const char kNameAtan2[] = "Atan2"; const char kNameApplyRMSProp[] = "ApplyRMSProp"; const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; +const char kNameL2Loss[] = "L2Loss"; +const char kNameCTCLoss[] = "CTCLoss"; // -----------------OpAdapter initialization-------------- std::unordered_map &DfGraphConvertor::get_adpt_map() { @@ -391,7 +393,9 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, {string(kNameAtan2), ADPT_DESC(Atan2)}, {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, - {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; + {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, + {string(kNameL2Loss), ADPT_DESC(L2Loss)}, + {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}}; #ifdef ENABLE_GE adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 7af932af0e5..281b3cdfe23 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -1227,6 +1227,22 @@ INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; +// L2Loss +INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; +OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}}; + +// CTCLoss +INPUT_MAP(CTCLoss) = {{1, INPUT_DESC(inputs)}, + {2, INPUT_DESC(labels_indices)}, + {3, INPUT_DESC(labels_values)}, + {4, INPUT_DESC(sequence_length)}}; +ATTR_MAP(CTCLoss) = { + {"preprocess_collapse_repeated", ATTR_DESC(preprocess_collapse_repeated, AnyTraits())}, + {"ctc_merge_repeated", ATTR_DESC(ctc_merge_repeated, AnyTraits())}, + {"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits())}}; +OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}}; + #ifdef ENABLE_GE // Print INPUT_MAP(Print) = EMPTY_INPUT_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index c31dcc85099..161c6579cee 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -465,6 +465,10 @@ DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) +DECLARE_OP_ADAPTER(L2Loss) +DECLARE_OP_USE_OUTPUT(L2Loss) +DECLARE_OP_ADAPTER(CTCLoss) +DECLARE_OP_USE_OUTPUT(CTCLoss) #ifdef ENABLE_GE DECLARE_OP_ADAPTER(Print) DECLARE_OP_USE_DYN_INPUT(Print) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index c79098faa37..c557301285c 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -668,3 +668,16 @@ def get_bprop_dropout(self): return (dx,) return bprop + + +@bprop_getters.register(P.CTCLoss) +def get_bprop_ctc_loss(self): + """Grad definition for `CTCLoss` operation""" + expand = P.ExpandDims() + + def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout): + grad_loss = out[1] + grad = grad_loss * expand(dout[0], -1) + return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length) + + return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index d00370f490b..235b593c7af 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -55,7 +55,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, DropoutDoMask, DropoutGrad, Dropout, DropoutGenMask, Flatten, FusedBatchNorm, Gelu, Elu, - GetNext, L2Normalize, LayerNorm, L2Loss, + GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, LogSoftmax, MaxPool, AvgPool, Conv2DBackpropInput, ConfusionMulGrad, @@ -172,6 +172,7 @@ __all__ = [ 'Reciprocal', 'SmoothL1Loss', 'L2Loss', + 'CTCLoss', 'ReduceAll', 'ScalarToArray', 'ScalarToTensor', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index efb77ad3e32..254cc641b0c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1564,7 +1564,7 @@ class L2Loss(PrimitiveWithInfer): def infer_dtype(self, x_type): validator.check_subclass("x_type", x_type, mstype.tensor, self.name) - valid_types = [mstype.float16, mstype.float32, mstype.double] + valid_types = [mstype.float16, mstype.float32] validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name) return x_type @@ -2871,3 +2871,78 @@ class DropoutGrad(PrimitiveWithInfer): valid_types = (mstype.float16, mstype.float32) validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name) return dy_dtype + + +class CTCLoss(PrimitiveWithInfer): + """ + Calculates the CTC(Connectionist Temporal Classification) loss. Also calculates the gradient. + + Args: + preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation. + Default: False. + ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged + and are interpreted as individual labels. This is a simplfied version if CTC. + Default: True. + ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored. + Default: False. + + Inputs: + - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is + :math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels` + indicates the number of actual labels. Blank labels are reserved. + - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]` + stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2. + - **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The + type must be int32. `labels_values[i]` must in the range of `[0, num_class)`. + - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. + The type must be int32. Each value in the tensor should not greater than `max_time`. + + Outputs: + - **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`. Has the same + type with `inputs`. + - **gradient** (Tensor) - The gradient of `loss`. Has the same type and shape with `inputs`. + + Examples: + >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) + >>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64) + >>> labels_values = Tensor(np.array([2, 2]), mindspore.int32) + >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) + >>> ctc_loss = P.CTCloss() + >>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length) + """ + @prim_attr_register + def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=False, + ignore_longer_outputs_than_inputs=False): + self.init_prim_io_names(inputs=["inputs", "labels_indices", "labels_values", "sequence_length"], + outputs=["loss", "gradient"]) + validator.check_value_type("preprocess_collapse_repeated", preprocess_collapse_repeated, [bool], self.name) + self.preprocess_collapse_repeated_ = preprocess_collapse_repeated + self.ctc_merge_repeated_ = validator.check_value_type("ctc_merge_repeated", ctc_merge_repeated, + [bool], self.name) + validator.check_value_type("ignore_longer_outputs_than_inputs", + ignore_longer_outputs_than_inputs, [bool], self.name) + self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs + + def infer_shape(self, inputs, labels_indices, labels_values, sequence_length): + validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name) + validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name) + validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name) + validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name) + validator.check('labels_indices size', labels_indices[0], 'labels_values size', + labels_values[0], Rel.EQ, self.name) + validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size', + sequence_length[0], Rel.EQ, self.name) + batch_size = [] + batch_size.append(inputs[1]) + return batch_size, inputs + + def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length): + validator.check_subclass("inputs_dtype", inputs, mstype.tensor, self.name) + validator.check_subclass("labels_indices_dtype", labels_indices, mstype.tensor, self.name) + validator.check_subclass("labels_values_dtype", labels_values, mstype.tensor, self.name) + validator.check_subclass("sequence_length_dtype", sequence_length, mstype.tensor, self.name) + validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name) + validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name) + validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) + validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name) + return inputs, inputs diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 938e943914d..c7d6cd12f35 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -909,6 +909,13 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], 'desc_bprop': [3, 3], 'skip': ['backward']}), + ('CTCLoss', { + 'block': P.CTCLoss(), + 'desc_inputs': [Tensor(np.ones([6, 4, 6]).astype(np.float32)), + Tensor(np.array([[0, 1], [1, 0], [2, 3], [3, 2]]).astype(np.int64)), + Tensor(np.array([1, 2, 3, 4]).astype(np.int32)), + Tensor(np.array([6, 6, 6, 6]).astype(np.int32))], + 'desc_bprop': [[4], [6, 4, 6]]}), ('L2Loss_1', { 'block': P.L2Loss(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],