forked from mindspore-Ecosystem/mindspore
!4859 Add CTCGrerdyDecoder ops for old backend.
Merge pull request !4859 from liuxiao93/Add-ReversSqueuce-EditDistance-CTCGrerdyDecoder
This commit is contained in:
commit
6e8d3a3b82
|
@ -192,6 +192,8 @@ constexpr const char kNameAscendDequant[] = "Dequant";
|
|||
constexpr const char kNameReverseSequence[] = "ReverseSequence";
|
||||
constexpr const char kNameEditDistance[] = "EditDistance";
|
||||
constexpr const char kNameCase[] = "Case";
|
||||
constexpr const char kNameAssert[] = "Assert";
|
||||
constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder";
|
||||
|
||||
class OpAdapterMap {
|
||||
public:
|
||||
|
|
|
@ -28,4 +28,13 @@ ATTR_MAP(CTCLoss) = {
|
|||
{"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}};
|
||||
REG_ADPT_DESC(CTCLoss, kNameCTCLoss, ADPT_DESC(CTCLoss))
|
||||
|
||||
// CTCGreedyDecoder
|
||||
INPUT_MAP(CTCGreedyDecoder) = {{1, INPUT_DESC(inputs)}, {2, INPUT_DESC(sequence_length)}};
|
||||
ATTR_MAP(CTCGreedyDecoder) = {{"merge_repeated", ATTR_DESC(merge_repeated, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(CTCGreedyDecoder) = {{0, OUTPUT_DESC(decoded_indices)},
|
||||
{1, OUTPUT_DESC(decoded_values)},
|
||||
{2, OUTPUT_DESC(decoded_shape)},
|
||||
{3, OUTPUT_DESC(log_probability)}};
|
||||
REG_ADPT_DESC(CTCGreedyDecoder, kNameCTCGreedyDecoder, ADPT_DESC(CTCGreedyDecoder))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -25,5 +25,8 @@
|
|||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(CTCLoss)
|
||||
DECLARE_OP_USE_OUTPUT(CTCLoss)
|
||||
|
||||
DECLARE_OP_ADAPTER(CTCGreedyDecoder)
|
||||
DECLARE_OP_USE_OUTPUT(CTCGreedyDecoder)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_CTC_OPS_DECLARE_H_
|
||||
|
|
|
@ -23,5 +23,10 @@ INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
|||
DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}};
|
||||
ATTR_MAP(Print) = EMPTY_ATTR_MAP;
|
||||
REG_ADPT_DESC(Print, kNamePrint, ADPT_DESC(Print))
|
||||
|
||||
INPUT_MAP(Assert) = {{1, INPUT_DESC(input_condition)}};
|
||||
DYN_INPUT_MAP(Assert) = {{2, DYN_INPUT_DESC(input_data)}};
|
||||
ATTR_MAP(Assert) = {{"summarize", ATTR_DESC(summarize, AnyTraits<int>())}};
|
||||
REG_ADPT_DESC(Assert, kNameAssert, ADPT_DESC(Assert))
|
||||
#endif
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -26,6 +26,9 @@ namespace mindspore::transform {
|
|||
#ifdef ENABLE_GE
|
||||
DECLARE_OP_ADAPTER(Print)
|
||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||
|
||||
DECLARE_OP_ADAPTER(Assert)
|
||||
DECLARE_OP_USE_DYN_INPUT(Assert)
|
||||
#endif
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_LOGGING_OPS_DECLARE_H_
|
||||
|
|
|
@ -39,7 +39,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast
|
|||
_VirtualDiv, _GetTensorSlice,
|
||||
_HostAllGather, _HostReduceScatter)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Debug, Print)
|
||||
TensorSummary, HistogramSummary, Debug, Print, Assert)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .inner_ops import ScalarCast
|
||||
|
||||
|
@ -64,7 +64,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
|
|||
DropoutDoMask, DropoutGrad, Dropout,
|
||||
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
|
||||
Gelu, Elu,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder,
|
||||
LogSoftmax,
|
||||
MaxPool, DataFormatDimMap,
|
||||
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
|
||||
|
@ -201,6 +201,7 @@ __all__ = [
|
|||
'HistogramSummary',
|
||||
"Debug",
|
||||
"Print",
|
||||
"Assert",
|
||||
'InsertGradientOf',
|
||||
'HookBackward',
|
||||
'InvertPermutation',
|
||||
|
@ -225,6 +226,7 @@ __all__ = [
|
|||
'SmoothL1Loss',
|
||||
'L2Loss',
|
||||
'CTCLoss',
|
||||
'CTCGreedyDecoder',
|
||||
'RNNTLoss',
|
||||
'ReduceAll',
|
||||
'ReduceAny',
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""debug_ops"""
|
||||
from types import FunctionType, MethodType
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||
|
||||
|
@ -364,3 +365,47 @@ class Debug(Primitive):
|
|||
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class Assert(PrimitiveWithInfer):
|
||||
"""
|
||||
Asserts that the given condition is true.
|
||||
If input condition evaluates to false, print the list of tensor in data.
|
||||
|
||||
Args:
|
||||
summarize (int): Print this many entries of each tensor.
|
||||
|
||||
Inputs:
|
||||
- **condition** [Union[Tensor[bool], bool]] - The condition to evaluate.
|
||||
- **input_data** (Union(tuple[Tensor], list[Tensor])) - The tensors to print out when condition is false.
|
||||
|
||||
Examples:
|
||||
>>> class AssertDemo(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(AssertDemo, self).__init__()
|
||||
>>> self.assert = P.Assert(summarize=10)
|
||||
>>> self.add = P.TensorAdd()
|
||||
>>>
|
||||
>>> def construct(self, x, y):
|
||||
>>> data = self.add(x, y)
|
||||
>>> self.assert(True, [data])
|
||||
>>> return data
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, summarize=3):
|
||||
"""init Assert"""
|
||||
self.summarize = validator.check_value_type("summarize", summarize, [int], self.name)
|
||||
|
||||
def infer_shape(self, condition, inputs):
|
||||
condition_len = len(condition)
|
||||
validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name)
|
||||
if condition_len == 1:
|
||||
validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name)
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self, condition, inputs):
|
||||
validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name)
|
||||
for dtype in inputs:
|
||||
validator.check_subclass("input", dtype, [mstype.tensor], self.name)
|
||||
return mstype.int32
|
||||
|
|
|
@ -5173,6 +5173,7 @@ class CTCLoss(PrimitiveWithInfer):
|
|||
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
|
||||
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
|
||||
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
||||
Data type must be float32 or float64.
|
||||
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
|
||||
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
|
||||
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
|
||||
|
@ -5222,10 +5223,6 @@ class CTCLoss(PrimitiveWithInfer):
|
|||
return batch_size, inputs
|
||||
|
||||
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
|
||||
validator.check_subclass("inputs_dtype", inputs, mstype.tensor, self.name)
|
||||
validator.check_subclass("labels_indices_dtype", labels_indices, mstype.tensor, self.name)
|
||||
validator.check_subclass("labels_values_dtype", labels_values, mstype.tensor, self.name)
|
||||
validator.check_subclass("sequence_length_dtype", sequence_length, mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name)
|
||||
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
|
||||
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
|
||||
|
@ -5233,6 +5230,72 @@ class CTCLoss(PrimitiveWithInfer):
|
|||
return inputs, inputs
|
||||
|
||||
|
||||
class CTCGreedyDecoder(PrimitiveWithInfer):
|
||||
"""
|
||||
Performs greedy decoding on the logits given in inputs.
|
||||
|
||||
Args:
|
||||
merge_repeated (bool): If True, merge repeated classes in output. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
|
||||
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
|
||||
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
||||
Data type must be float32 or float64.
|
||||
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
|
||||
The type must be int32. Each value in the tensor should not greater than `max_time`.
|
||||
|
||||
Outputs:
|
||||
- **decoded_indices** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs, 2)`.
|
||||
Data type is int64.
|
||||
- **decoded_values** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs)`,
|
||||
it stores the decoded classes. Data type is int64.
|
||||
- **decoded_shape** (Tensor) - The value of tensor is :math:`[batch_size, max_decoded_legth]`.
|
||||
Data type is int64.
|
||||
- **log_probability** (Tensor) - A tensor with shape of :math:`(batch_size, 1)`,
|
||||
containing sequence log-probability. Has the same type as `inputs`.
|
||||
|
||||
Examples:
|
||||
>>> class CTCGreedyDecoderNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(CTCGreedyDecoderNet, self).__init__()
|
||||
>>> self.ctc_greedy_decoder = P.CTCGreedyDecoder()
|
||||
>>> self.assert_op = P.Assert(300)
|
||||
>>>
|
||||
>>> def construct(self, inputs, sequence_length):
|
||||
>>> out = self.ctc_greedy_decoder(inputs,sequence_length)
|
||||
>>> self.assert_op(True, (out[0], out[1], out[2], out[3]))
|
||||
>>> return out[2]
|
||||
>>>
|
||||
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
|
||||
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
|
||||
>>> net = CTCGreedyDecoderNet()
|
||||
>>> output = net(inputs, sequence_length)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, merge_repeated=True):
|
||||
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
|
||||
|
||||
def infer_shape(self, inputs_shape, sequence_length_shape):
|
||||
validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name)
|
||||
validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name)
|
||||
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
|
||||
sequence_length_shape[0], Rel.EQ, self.name)
|
||||
total_decoded_outputs = -1
|
||||
decoded_indices_shape = [total_decoded_outputs, 2]
|
||||
decoded_values = [total_decoded_outputs]
|
||||
decoded_shape = [2]
|
||||
log_probability_shape = [inputs_shape[1], 1]
|
||||
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
|
||||
|
||||
def infer_dtype(self, inputs_dtype, sequence_length_dtype):
|
||||
validator.check_tensor_type_same({"inputs_dtype": inputs_dtype}, [mstype.float32, mstype.double], self.name)
|
||||
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length_dtype}, [mstype.int32], self.name)
|
||||
decoded_type = mstype.tensor_type(mstype.int64)
|
||||
return decoded_type, decoded_type, decoded_type, inputs_dtype
|
||||
|
||||
|
||||
class BasicLSTMCell(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs the long short term memory(LSTM) on the input.
|
||||
|
@ -5361,6 +5424,7 @@ class InTopK(PrimitiveWithInfer):
|
|||
>>> result = in_top_k(x1, x2)
|
||||
[True False]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, k):
|
||||
"""Init InTopK"""
|
||||
|
|
|
@ -621,6 +621,18 @@ class UniformNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class CTCGreedyDecoderNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CTCGreedyDecoderNet, self).__init__()
|
||||
self.ctc_greedy_decoder = P.CTCGreedyDecoder()
|
||||
self.assert_op = P.Assert(300)
|
||||
|
||||
def construct(self, inputs, sequence_length):
|
||||
out = self.ctc_greedy_decoder(inputs,sequence_length)
|
||||
self.assert_op(True, (out[0], out[1], out[2], out[3]))
|
||||
return out[2]
|
||||
|
||||
|
||||
class StridedSliceNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(StridedSliceNet, self).__init__()
|
||||
|
@ -1672,6 +1684,10 @@ test_case_nn_ops = [
|
|||
Tensor(np.array([1, 2, 3, 4]).astype(np.int32)),
|
||||
Tensor(np.array([6, 6, 6, 6]).astype(np.int32))],
|
||||
'desc_bprop': [[4], [6, 4, 6]]}),
|
||||
('CTCGreedyDecoder', {
|
||||
'block': CTCGreedyDecoderNet(),
|
||||
'desc_inputs': [[2, 2, 3], Tensor(np.array([2, 2]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
('L2Loss_1', {
|
||||
'block': P.L2Loss(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
|
||||
|
|
Loading…
Reference in New Issue