diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h index 99f67f6dc14..2e5776bf4f9 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h @@ -192,6 +192,8 @@ constexpr const char kNameAscendDequant[] = "Dequant"; constexpr const char kNameReverseSequence[] = "ReverseSequence"; constexpr const char kNameEditDistance[] = "EditDistance"; constexpr const char kNameCase[] = "Case"; +constexpr const char kNameAssert[] = "Assert"; +constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder"; class OpAdapterMap { public: diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.cc index e28c25da299..d718a2ed15d 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.cc @@ -28,4 +28,13 @@ ATTR_MAP(CTCLoss) = { {"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits())}}; OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}}; REG_ADPT_DESC(CTCLoss, kNameCTCLoss, ADPT_DESC(CTCLoss)) + +// CTCGreedyDecoder +INPUT_MAP(CTCGreedyDecoder) = {{1, INPUT_DESC(inputs)}, {2, INPUT_DESC(sequence_length)}}; +ATTR_MAP(CTCGreedyDecoder) = {{"merge_repeated", ATTR_DESC(merge_repeated, AnyTraits())}}; +OUTPUT_MAP(CTCGreedyDecoder) = {{0, OUTPUT_DESC(decoded_indices)}, + {1, OUTPUT_DESC(decoded_values)}, + {2, OUTPUT_DESC(decoded_shape)}, + {3, OUTPUT_DESC(log_probability)}}; +REG_ADPT_DESC(CTCGreedyDecoder, kNameCTCGreedyDecoder, ADPT_DESC(CTCGreedyDecoder)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.h index 5a7bc6b6d96..074b19592a3 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/ctc_ops_declare.h @@ -25,5 +25,8 @@ namespace mindspore::transform { DECLARE_OP_ADAPTER(CTCLoss) DECLARE_OP_USE_OUTPUT(CTCLoss) + +DECLARE_OP_ADAPTER(CTCGreedyDecoder) +DECLARE_OP_USE_OUTPUT(CTCGreedyDecoder) } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_CTC_OPS_DECLARE_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.cc index e73a127ca21..8b7757f85db 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.cc @@ -23,5 +23,10 @@ INPUT_MAP(Print) = EMPTY_INPUT_MAP; DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}}; ATTR_MAP(Print) = EMPTY_ATTR_MAP; REG_ADPT_DESC(Print, kNamePrint, ADPT_DESC(Print)) + +INPUT_MAP(Assert) = {{1, INPUT_DESC(input_condition)}}; +DYN_INPUT_MAP(Assert) = {{2, DYN_INPUT_DESC(input_data)}}; +ATTR_MAP(Assert) = {{"summarize", ATTR_DESC(summarize, AnyTraits())}}; +REG_ADPT_DESC(Assert, kNameAssert, ADPT_DESC(Assert)) #endif } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.h index 524da0079b4..5419edf18a3 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/logging_ops_declare.h @@ -26,6 +26,9 @@ namespace mindspore::transform { #ifdef ENABLE_GE DECLARE_OP_ADAPTER(Print) DECLARE_OP_USE_DYN_INPUT(Print) + +DECLARE_OP_ADAPTER(Assert) +DECLARE_OP_USE_DYN_INPUT(Assert) #endif } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_LOGGING_OPS_DECLARE_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 18ba4e95cb6..16ff3814b5d 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -39,7 +39,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast _VirtualDiv, _GetTensorSlice, _HostAllGather, _HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, - TensorSummary, HistogramSummary, Debug, Print) + TensorSummary, HistogramSummary, Debug, Print, Assert) from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import ScalarCast @@ -64,7 +64,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl DropoutDoMask, DropoutGrad, Dropout, DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, Gelu, Elu, - GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, + GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder, LogSoftmax, MaxPool, DataFormatDimMap, AvgPool, Conv2DBackpropInput, ConfusionMulGrad, @@ -201,6 +201,7 @@ __all__ = [ 'HistogramSummary', "Debug", "Print", + "Assert", 'InsertGradientOf', 'HookBackward', 'InvertPermutation', @@ -225,6 +226,7 @@ __all__ = [ 'SmoothL1Loss', 'L2Loss', 'CTCLoss', + 'CTCGreedyDecoder', 'RNNTLoss', 'ReduceAll', 'ReduceAny', diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 454e3f22a9f..ee6be790b36 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -16,6 +16,7 @@ """debug_ops""" from types import FunctionType, MethodType from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive @@ -364,3 +365,47 @@ class Debug(Primitive): def __call__(self, *args, **kwargs): pass + + +class Assert(PrimitiveWithInfer): + """ + Asserts that the given condition is true. + If input condition evaluates to false, print the list of tensor in data. + + Args: + summarize (int): Print this many entries of each tensor. + + Inputs: + - **condition** [Union[Tensor[bool], bool]] - The condition to evaluate. + - **input_data** (Union(tuple[Tensor], list[Tensor])) - The tensors to print out when condition is false. + + Examples: + >>> class AssertDemo(nn.Cell): + >>> def __init__(self): + >>> super(AssertDemo, self).__init__() + >>> self.assert = P.Assert(summarize=10) + >>> self.add = P.TensorAdd() + >>> + >>> def construct(self, x, y): + >>> data = self.add(x, y) + >>> self.assert(True, [data]) + >>> return data + """ + + @prim_attr_register + def __init__(self, summarize=3): + """init Assert""" + self.summarize = validator.check_value_type("summarize", summarize, [int], self.name) + + def infer_shape(self, condition, inputs): + condition_len = len(condition) + validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name) + if condition_len == 1: + validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name) + return [1] + + def infer_dtype(self, condition, inputs): + validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name) + for dtype in inputs: + validator.check_subclass("input", dtype, [mstype.tensor], self.name) + return mstype.int32 diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 72bbc10d40d..db4360f6d3e 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5173,6 +5173,7 @@ class CTCLoss(PrimitiveWithInfer): - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is :math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels` indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`. + Data type must be float32 or float64. - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]` stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2. - **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The @@ -5222,10 +5223,6 @@ class CTCLoss(PrimitiveWithInfer): return batch_size, inputs def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length): - validator.check_subclass("inputs_dtype", inputs, mstype.tensor, self.name) - validator.check_subclass("labels_indices_dtype", labels_indices, mstype.tensor, self.name) - validator.check_subclass("labels_values_dtype", labels_values, mstype.tensor, self.name) - validator.check_subclass("sequence_length_dtype", sequence_length, mstype.tensor, self.name) validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name) validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name) validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) @@ -5233,6 +5230,72 @@ class CTCLoss(PrimitiveWithInfer): return inputs, inputs +class CTCGreedyDecoder(PrimitiveWithInfer): + """ + Performs greedy decoding on the logits given in inputs. + + Args: + merge_repeated (bool): If True, merge repeated classes in output. Default: True. + + Inputs: + - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is + :math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels` + indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`. + Data type must be float32 or float64. + - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. + The type must be int32. Each value in the tensor should not greater than `max_time`. + + Outputs: + - **decoded_indices** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs, 2)`. + Data type is int64. + - **decoded_values** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs)`, + it stores the decoded classes. Data type is int64. + - **decoded_shape** (Tensor) - The value of tensor is :math:`[batch_size, max_decoded_legth]`. + Data type is int64. + - **log_probability** (Tensor) - A tensor with shape of :math:`(batch_size, 1)`, + containing sequence log-probability. Has the same type as `inputs`. + + Examples: + >>> class CTCGreedyDecoderNet(nn.Cell): + >>> def __init__(self): + >>> super(CTCGreedyDecoderNet, self).__init__() + >>> self.ctc_greedy_decoder = P.CTCGreedyDecoder() + >>> self.assert_op = P.Assert(300) + >>> + >>> def construct(self, inputs, sequence_length): + >>> out = self.ctc_greedy_decoder(inputs,sequence_length) + >>> self.assert_op(True, (out[0], out[1], out[2], out[3])) + >>> return out[2] + >>> + >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) + >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) + >>> net = CTCGreedyDecoderNet() + >>> output = net(inputs, sequence_length) + """ + + @prim_attr_register + def __init__(self, merge_repeated=True): + self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) + + def infer_shape(self, inputs_shape, sequence_length_shape): + validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name) + validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name) + validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', + sequence_length_shape[0], Rel.EQ, self.name) + total_decoded_outputs = -1 + decoded_indices_shape = [total_decoded_outputs, 2] + decoded_values = [total_decoded_outputs] + decoded_shape = [2] + log_probability_shape = [inputs_shape[1], 1] + return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape + + def infer_dtype(self, inputs_dtype, sequence_length_dtype): + validator.check_tensor_type_same({"inputs_dtype": inputs_dtype}, [mstype.float32, mstype.double], self.name) + validator.check_tensor_type_same({"sequence_length_dtype": sequence_length_dtype}, [mstype.int32], self.name) + decoded_type = mstype.tensor_type(mstype.int64) + return decoded_type, decoded_type, decoded_type, inputs_dtype + + class BasicLSTMCell(PrimitiveWithInfer): r""" Performs the long short term memory(LSTM) on the input. @@ -5361,6 +5424,7 @@ class InTopK(PrimitiveWithInfer): >>> result = in_top_k(x1, x2) [True False] """ + @prim_attr_register def __init__(self, k): """Init InTopK""" diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 313c5f58772..b64e42dbb4a 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -621,6 +621,18 @@ class UniformNet(nn.Cell): return out +class CTCGreedyDecoderNet(nn.Cell): + def __init__(self): + super(CTCGreedyDecoderNet, self).__init__() + self.ctc_greedy_decoder = P.CTCGreedyDecoder() + self.assert_op = P.Assert(300) + + def construct(self, inputs, sequence_length): + out = self.ctc_greedy_decoder(inputs,sequence_length) + self.assert_op(True, (out[0], out[1], out[2], out[3])) + return out[2] + + class StridedSliceNet(nn.Cell): def __init__(self): super(StridedSliceNet, self).__init__() @@ -1672,6 +1684,10 @@ test_case_nn_ops = [ Tensor(np.array([1, 2, 3, 4]).astype(np.int32)), Tensor(np.array([6, 6, 6, 6]).astype(np.int32))], 'desc_bprop': [[4], [6, 4, 6]]}), + ('CTCGreedyDecoder', { + 'block': CTCGreedyDecoderNet(), + 'desc_inputs': [[2, 2, 3], Tensor(np.array([2, 2]).astype(np.int32))], + 'skip': ['backward']}), ('L2Loss_1', { 'block': P.L2Loss(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],