!4859 Add CTCGrerdyDecoder ops for old backend.

Merge pull request !4859 from liuxiao93/Add-ReversSqueuce-EditDistance-CTCGrerdyDecoder
This commit is contained in:
mindspore-ci-bot 2020-08-21 16:05:19 +08:00 committed by Gitee
commit 6e8d3a3b82
9 changed files with 155 additions and 6 deletions

View File

@ -192,6 +192,8 @@ constexpr const char kNameAscendDequant[] = "Dequant";
constexpr const char kNameReverseSequence[] = "ReverseSequence";
constexpr const char kNameEditDistance[] = "EditDistance";
constexpr const char kNameCase[] = "Case";
constexpr const char kNameAssert[] = "Assert";
constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder";
class OpAdapterMap {
public:

View File

@ -28,4 +28,13 @@ ATTR_MAP(CTCLoss) = {
{"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits<bool>())}};
OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}};
REG_ADPT_DESC(CTCLoss, kNameCTCLoss, ADPT_DESC(CTCLoss))
// CTCGreedyDecoder
INPUT_MAP(CTCGreedyDecoder) = {{1, INPUT_DESC(inputs)}, {2, INPUT_DESC(sequence_length)}};
ATTR_MAP(CTCGreedyDecoder) = {{"merge_repeated", ATTR_DESC(merge_repeated, AnyTraits<bool>())}};
OUTPUT_MAP(CTCGreedyDecoder) = {{0, OUTPUT_DESC(decoded_indices)},
{1, OUTPUT_DESC(decoded_values)},
{2, OUTPUT_DESC(decoded_shape)},
{3, OUTPUT_DESC(log_probability)}};
REG_ADPT_DESC(CTCGreedyDecoder, kNameCTCGreedyDecoder, ADPT_DESC(CTCGreedyDecoder))
} // namespace mindspore::transform

View File

@ -25,5 +25,8 @@
namespace mindspore::transform {
DECLARE_OP_ADAPTER(CTCLoss)
DECLARE_OP_USE_OUTPUT(CTCLoss)
DECLARE_OP_ADAPTER(CTCGreedyDecoder)
DECLARE_OP_USE_OUTPUT(CTCGreedyDecoder)
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_CTC_OPS_DECLARE_H_

View File

@ -23,5 +23,10 @@ INPUT_MAP(Print) = EMPTY_INPUT_MAP;
DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}};
ATTR_MAP(Print) = EMPTY_ATTR_MAP;
REG_ADPT_DESC(Print, kNamePrint, ADPT_DESC(Print))
INPUT_MAP(Assert) = {{1, INPUT_DESC(input_condition)}};
DYN_INPUT_MAP(Assert) = {{2, DYN_INPUT_DESC(input_data)}};
ATTR_MAP(Assert) = {{"summarize", ATTR_DESC(summarize, AnyTraits<int>())}};
REG_ADPT_DESC(Assert, kNameAssert, ADPT_DESC(Assert))
#endif
} // namespace mindspore::transform

View File

@ -26,6 +26,9 @@ namespace mindspore::transform {
#ifdef ENABLE_GE
DECLARE_OP_ADAPTER(Print)
DECLARE_OP_USE_DYN_INPUT(Print)
DECLARE_OP_ADAPTER(Assert)
DECLARE_OP_USE_DYN_INPUT(Assert)
#endif
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_LOGGING_OPS_DECLARE_H_

View File

@ -39,7 +39,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast
_VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Debug, Print)
TensorSummary, HistogramSummary, Debug, Print, Assert)
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast
@ -64,7 +64,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
DropoutDoMask, DropoutGrad, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder,
LogSoftmax,
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
@ -201,6 +201,7 @@ __all__ = [
'HistogramSummary',
"Debug",
"Print",
"Assert",
'InsertGradientOf',
'HookBackward',
'InvertPermutation',
@ -225,6 +226,7 @@ __all__ = [
'SmoothL1Loss',
'L2Loss',
'CTCLoss',
'CTCGreedyDecoder',
'RNNTLoss',
'ReduceAll',
'ReduceAny',

View File

@ -16,6 +16,7 @@
"""debug_ops"""
from types import FunctionType, MethodType
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
@ -364,3 +365,47 @@ class Debug(Primitive):
def __call__(self, *args, **kwargs):
pass
class Assert(PrimitiveWithInfer):
"""
Asserts that the given condition is true.
If input condition evaluates to false, print the list of tensor in data.
Args:
summarize (int): Print this many entries of each tensor.
Inputs:
- **condition** [Union[Tensor[bool], bool]] - The condition to evaluate.
- **input_data** (Union(tuple[Tensor], list[Tensor])) - The tensors to print out when condition is false.
Examples:
>>> class AssertDemo(nn.Cell):
>>> def __init__(self):
>>> super(AssertDemo, self).__init__()
>>> self.assert = P.Assert(summarize=10)
>>> self.add = P.TensorAdd()
>>>
>>> def construct(self, x, y):
>>> data = self.add(x, y)
>>> self.assert(True, [data])
>>> return data
"""
@prim_attr_register
def __init__(self, summarize=3):
"""init Assert"""
self.summarize = validator.check_value_type("summarize", summarize, [int], self.name)
def infer_shape(self, condition, inputs):
condition_len = len(condition)
validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name)
if condition_len == 1:
validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name)
return [1]
def infer_dtype(self, condition, inputs):
validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name)
for dtype in inputs:
validator.check_subclass("input", dtype, [mstype.tensor], self.name)
return mstype.int32

View File

@ -5173,6 +5173,7 @@ class CTCLoss(PrimitiveWithInfer):
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
Data type must be float32 or float64.
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
@ -5222,10 +5223,6 @@ class CTCLoss(PrimitiveWithInfer):
return batch_size, inputs
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
validator.check_subclass("inputs_dtype", inputs, mstype.tensor, self.name)
validator.check_subclass("labels_indices_dtype", labels_indices, mstype.tensor, self.name)
validator.check_subclass("labels_values_dtype", labels_values, mstype.tensor, self.name)
validator.check_subclass("sequence_length_dtype", sequence_length, mstype.tensor, self.name)
validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name)
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
@ -5233,6 +5230,72 @@ class CTCLoss(PrimitiveWithInfer):
return inputs, inputs
class CTCGreedyDecoder(PrimitiveWithInfer):
"""
Performs greedy decoding on the logits given in inputs.
Args:
merge_repeated (bool): If True, merge repeated classes in output. Default: True.
Inputs:
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
Data type must be float32 or float64.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
The type must be int32. Each value in the tensor should not greater than `max_time`.
Outputs:
- **decoded_indices** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs, 2)`.
Data type is int64.
- **decoded_values** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs)`,
it stores the decoded classes. Data type is int64.
- **decoded_shape** (Tensor) - The value of tensor is :math:`[batch_size, max_decoded_legth]`.
Data type is int64.
- **log_probability** (Tensor) - A tensor with shape of :math:`(batch_size, 1)`,
containing sequence log-probability. Has the same type as `inputs`.
Examples:
>>> class CTCGreedyDecoderNet(nn.Cell):
>>> def __init__(self):
>>> super(CTCGreedyDecoderNet, self).__init__()
>>> self.ctc_greedy_decoder = P.CTCGreedyDecoder()
>>> self.assert_op = P.Assert(300)
>>>
>>> def construct(self, inputs, sequence_length):
>>> out = self.ctc_greedy_decoder(inputs,sequence_length)
>>> self.assert_op(True, (out[0], out[1], out[2], out[3]))
>>> return out[2]
>>>
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
>>> net = CTCGreedyDecoderNet()
>>> output = net(inputs, sequence_length)
"""
@prim_attr_register
def __init__(self, merge_repeated=True):
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
def infer_shape(self, inputs_shape, sequence_length_shape):
validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name)
validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name)
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
sequence_length_shape[0], Rel.EQ, self.name)
total_decoded_outputs = -1
decoded_indices_shape = [total_decoded_outputs, 2]
decoded_values = [total_decoded_outputs]
decoded_shape = [2]
log_probability_shape = [inputs_shape[1], 1]
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
def infer_dtype(self, inputs_dtype, sequence_length_dtype):
validator.check_tensor_type_same({"inputs_dtype": inputs_dtype}, [mstype.float32, mstype.double], self.name)
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length_dtype}, [mstype.int32], self.name)
decoded_type = mstype.tensor_type(mstype.int64)
return decoded_type, decoded_type, decoded_type, inputs_dtype
class BasicLSTMCell(PrimitiveWithInfer):
r"""
Performs the long short term memory(LSTM) on the input.
@ -5361,6 +5424,7 @@ class InTopK(PrimitiveWithInfer):
>>> result = in_top_k(x1, x2)
[True False]
"""
@prim_attr_register
def __init__(self, k):
"""Init InTopK"""

View File

@ -621,6 +621,18 @@ class UniformNet(nn.Cell):
return out
class CTCGreedyDecoderNet(nn.Cell):
def __init__(self):
super(CTCGreedyDecoderNet, self).__init__()
self.ctc_greedy_decoder = P.CTCGreedyDecoder()
self.assert_op = P.Assert(300)
def construct(self, inputs, sequence_length):
out = self.ctc_greedy_decoder(inputs,sequence_length)
self.assert_op(True, (out[0], out[1], out[2], out[3]))
return out[2]
class StridedSliceNet(nn.Cell):
def __init__(self):
super(StridedSliceNet, self).__init__()
@ -1672,6 +1684,10 @@ test_case_nn_ops = [
Tensor(np.array([1, 2, 3, 4]).astype(np.int32)),
Tensor(np.array([6, 6, 6, 6]).astype(np.int32))],
'desc_bprop': [[4], [6, 4, 6]]}),
('CTCGreedyDecoder', {
'block': CTCGreedyDecoderNet(),
'desc_inputs': [[2, 2, 3], Tensor(np.array([2, 2]).astype(np.int32))],
'skip': ['backward']}),
('L2Loss_1', {
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],