!4859 Add CTCGrerdyDecoder ops for old backend.
Merge pull request !4859 from liuxiao93/Add-ReversSqueuce-EditDistance-CTCGrerdyDecoder
This commit is contained in:
commit
6e8d3a3b82
|
@ -192,6 +192,8 @@ constexpr const char kNameAscendDequant[] = "Dequant";
|
||||||
constexpr const char kNameReverseSequence[] = "ReverseSequence";
|
constexpr const char kNameReverseSequence[] = "ReverseSequence";
|
||||||
constexpr const char kNameEditDistance[] = "EditDistance";
|
constexpr const char kNameEditDistance[] = "EditDistance";
|
||||||
constexpr const char kNameCase[] = "Case";
|
constexpr const char kNameCase[] = "Case";
|
||||||
|
constexpr const char kNameAssert[] = "Assert";
|
||||||
|
constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder";
|
||||||
|
|
||||||
class OpAdapterMap {
|
class OpAdapterMap {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -28,4 +28,13 @@ ATTR_MAP(CTCLoss) = {
|
||||||
{"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits<bool>())}};
|
{"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}};
|
OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}};
|
||||||
REG_ADPT_DESC(CTCLoss, kNameCTCLoss, ADPT_DESC(CTCLoss))
|
REG_ADPT_DESC(CTCLoss, kNameCTCLoss, ADPT_DESC(CTCLoss))
|
||||||
|
|
||||||
|
// CTCGreedyDecoder
|
||||||
|
INPUT_MAP(CTCGreedyDecoder) = {{1, INPUT_DESC(inputs)}, {2, INPUT_DESC(sequence_length)}};
|
||||||
|
ATTR_MAP(CTCGreedyDecoder) = {{"merge_repeated", ATTR_DESC(merge_repeated, AnyTraits<bool>())}};
|
||||||
|
OUTPUT_MAP(CTCGreedyDecoder) = {{0, OUTPUT_DESC(decoded_indices)},
|
||||||
|
{1, OUTPUT_DESC(decoded_values)},
|
||||||
|
{2, OUTPUT_DESC(decoded_shape)},
|
||||||
|
{3, OUTPUT_DESC(log_probability)}};
|
||||||
|
REG_ADPT_DESC(CTCGreedyDecoder, kNameCTCGreedyDecoder, ADPT_DESC(CTCGreedyDecoder))
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
|
|
|
@ -25,5 +25,8 @@
|
||||||
namespace mindspore::transform {
|
namespace mindspore::transform {
|
||||||
DECLARE_OP_ADAPTER(CTCLoss)
|
DECLARE_OP_ADAPTER(CTCLoss)
|
||||||
DECLARE_OP_USE_OUTPUT(CTCLoss)
|
DECLARE_OP_USE_OUTPUT(CTCLoss)
|
||||||
|
|
||||||
|
DECLARE_OP_ADAPTER(CTCGreedyDecoder)
|
||||||
|
DECLARE_OP_USE_OUTPUT(CTCGreedyDecoder)
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_CTC_OPS_DECLARE_H_
|
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_CTC_OPS_DECLARE_H_
|
||||||
|
|
|
@ -23,5 +23,10 @@ INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
||||||
DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}};
|
DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}};
|
||||||
ATTR_MAP(Print) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Print) = EMPTY_ATTR_MAP;
|
||||||
REG_ADPT_DESC(Print, kNamePrint, ADPT_DESC(Print))
|
REG_ADPT_DESC(Print, kNamePrint, ADPT_DESC(Print))
|
||||||
|
|
||||||
|
INPUT_MAP(Assert) = {{1, INPUT_DESC(input_condition)}};
|
||||||
|
DYN_INPUT_MAP(Assert) = {{2, DYN_INPUT_DESC(input_data)}};
|
||||||
|
ATTR_MAP(Assert) = {{"summarize", ATTR_DESC(summarize, AnyTraits<int>())}};
|
||||||
|
REG_ADPT_DESC(Assert, kNameAssert, ADPT_DESC(Assert))
|
||||||
#endif
|
#endif
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
|
|
|
@ -26,6 +26,9 @@ namespace mindspore::transform {
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
DECLARE_OP_ADAPTER(Print)
|
DECLARE_OP_ADAPTER(Print)
|
||||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||||
|
|
||||||
|
DECLARE_OP_ADAPTER(Assert)
|
||||||
|
DECLARE_OP_USE_DYN_INPUT(Assert)
|
||||||
#endif
|
#endif
|
||||||
} // namespace mindspore::transform
|
} // namespace mindspore::transform
|
||||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_LOGGING_OPS_DECLARE_H_
|
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_LOGGING_OPS_DECLARE_H_
|
||||||
|
|
|
@ -39,7 +39,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast
|
||||||
_VirtualDiv, _GetTensorSlice,
|
_VirtualDiv, _GetTensorSlice,
|
||||||
_HostAllGather, _HostReduceScatter)
|
_HostAllGather, _HostReduceScatter)
|
||||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||||
TensorSummary, HistogramSummary, Debug, Print)
|
TensorSummary, HistogramSummary, Debug, Print, Assert)
|
||||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||||
from .inner_ops import ScalarCast
|
from .inner_ops import ScalarCast
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
|
||||||
DropoutDoMask, DropoutGrad, Dropout,
|
DropoutDoMask, DropoutGrad, Dropout,
|
||||||
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
|
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
|
||||||
Gelu, Elu,
|
Gelu, Elu,
|
||||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2,
|
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder,
|
||||||
LogSoftmax,
|
LogSoftmax,
|
||||||
MaxPool, DataFormatDimMap,
|
MaxPool, DataFormatDimMap,
|
||||||
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
|
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
|
||||||
|
@ -201,6 +201,7 @@ __all__ = [
|
||||||
'HistogramSummary',
|
'HistogramSummary',
|
||||||
"Debug",
|
"Debug",
|
||||||
"Print",
|
"Print",
|
||||||
|
"Assert",
|
||||||
'InsertGradientOf',
|
'InsertGradientOf',
|
||||||
'HookBackward',
|
'HookBackward',
|
||||||
'InvertPermutation',
|
'InvertPermutation',
|
||||||
|
@ -225,6 +226,7 @@ __all__ = [
|
||||||
'SmoothL1Loss',
|
'SmoothL1Loss',
|
||||||
'L2Loss',
|
'L2Loss',
|
||||||
'CTCLoss',
|
'CTCLoss',
|
||||||
|
'CTCGreedyDecoder',
|
||||||
'RNNTLoss',
|
'RNNTLoss',
|
||||||
'ReduceAll',
|
'ReduceAll',
|
||||||
'ReduceAny',
|
'ReduceAny',
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
"""debug_ops"""
|
"""debug_ops"""
|
||||||
from types import FunctionType, MethodType
|
from types import FunctionType, MethodType
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||||
|
|
||||||
|
@ -364,3 +365,47 @@ class Debug(Primitive):
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Assert(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Asserts that the given condition is true.
|
||||||
|
If input condition evaluates to false, print the list of tensor in data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summarize (int): Print this many entries of each tensor.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **condition** [Union[Tensor[bool], bool]] - The condition to evaluate.
|
||||||
|
- **input_data** (Union(tuple[Tensor], list[Tensor])) - The tensors to print out when condition is false.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class AssertDemo(nn.Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(AssertDemo, self).__init__()
|
||||||
|
>>> self.assert = P.Assert(summarize=10)
|
||||||
|
>>> self.add = P.TensorAdd()
|
||||||
|
>>>
|
||||||
|
>>> def construct(self, x, y):
|
||||||
|
>>> data = self.add(x, y)
|
||||||
|
>>> self.assert(True, [data])
|
||||||
|
>>> return data
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, summarize=3):
|
||||||
|
"""init Assert"""
|
||||||
|
self.summarize = validator.check_value_type("summarize", summarize, [int], self.name)
|
||||||
|
|
||||||
|
def infer_shape(self, condition, inputs):
|
||||||
|
condition_len = len(condition)
|
||||||
|
validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name)
|
||||||
|
if condition_len == 1:
|
||||||
|
validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name)
|
||||||
|
return [1]
|
||||||
|
|
||||||
|
def infer_dtype(self, condition, inputs):
|
||||||
|
validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name)
|
||||||
|
for dtype in inputs:
|
||||||
|
validator.check_subclass("input", dtype, [mstype.tensor], self.name)
|
||||||
|
return mstype.int32
|
||||||
|
|
|
@ -5173,6 +5173,7 @@ class CTCLoss(PrimitiveWithInfer):
|
||||||
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
|
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
|
||||||
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
|
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
|
||||||
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
||||||
|
Data type must be float32 or float64.
|
||||||
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
|
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
|
||||||
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
|
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
|
||||||
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
|
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
|
||||||
|
@ -5222,10 +5223,6 @@ class CTCLoss(PrimitiveWithInfer):
|
||||||
return batch_size, inputs
|
return batch_size, inputs
|
||||||
|
|
||||||
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
|
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
|
||||||
validator.check_subclass("inputs_dtype", inputs, mstype.tensor, self.name)
|
|
||||||
validator.check_subclass("labels_indices_dtype", labels_indices, mstype.tensor, self.name)
|
|
||||||
validator.check_subclass("labels_values_dtype", labels_values, mstype.tensor, self.name)
|
|
||||||
validator.check_subclass("sequence_length_dtype", sequence_length, mstype.tensor, self.name)
|
|
||||||
validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name)
|
validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name)
|
||||||
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
|
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
|
||||||
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
|
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
|
||||||
|
@ -5233,6 +5230,72 @@ class CTCLoss(PrimitiveWithInfer):
|
||||||
return inputs, inputs
|
return inputs, inputs
|
||||||
|
|
||||||
|
|
||||||
|
class CTCGreedyDecoder(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Performs greedy decoding on the logits given in inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
merge_repeated (bool): If True, merge repeated classes in output. Default: True.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
|
||||||
|
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
|
||||||
|
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
||||||
|
Data type must be float32 or float64.
|
||||||
|
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
|
||||||
|
The type must be int32. Each value in the tensor should not greater than `max_time`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **decoded_indices** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs, 2)`.
|
||||||
|
Data type is int64.
|
||||||
|
- **decoded_values** (Tensor) - A tensor with shape of :math:`(total_decoded_outputs)`,
|
||||||
|
it stores the decoded classes. Data type is int64.
|
||||||
|
- **decoded_shape** (Tensor) - The value of tensor is :math:`[batch_size, max_decoded_legth]`.
|
||||||
|
Data type is int64.
|
||||||
|
- **log_probability** (Tensor) - A tensor with shape of :math:`(batch_size, 1)`,
|
||||||
|
containing sequence log-probability. Has the same type as `inputs`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class CTCGreedyDecoderNet(nn.Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(CTCGreedyDecoderNet, self).__init__()
|
||||||
|
>>> self.ctc_greedy_decoder = P.CTCGreedyDecoder()
|
||||||
|
>>> self.assert_op = P.Assert(300)
|
||||||
|
>>>
|
||||||
|
>>> def construct(self, inputs, sequence_length):
|
||||||
|
>>> out = self.ctc_greedy_decoder(inputs,sequence_length)
|
||||||
|
>>> self.assert_op(True, (out[0], out[1], out[2], out[3]))
|
||||||
|
>>> return out[2]
|
||||||
|
>>>
|
||||||
|
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
|
||||||
|
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
|
||||||
|
>>> net = CTCGreedyDecoderNet()
|
||||||
|
>>> output = net(inputs, sequence_length)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, merge_repeated=True):
|
||||||
|
self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
|
||||||
|
|
||||||
|
def infer_shape(self, inputs_shape, sequence_length_shape):
|
||||||
|
validator.check_integer("inputs rank", len(inputs_shape), 3, Rel.EQ, self.name)
|
||||||
|
validator.check_integer("sequence_length rank", len(sequence_length_shape), 1, Rel.EQ, self.name)
|
||||||
|
validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
|
||||||
|
sequence_length_shape[0], Rel.EQ, self.name)
|
||||||
|
total_decoded_outputs = -1
|
||||||
|
decoded_indices_shape = [total_decoded_outputs, 2]
|
||||||
|
decoded_values = [total_decoded_outputs]
|
||||||
|
decoded_shape = [2]
|
||||||
|
log_probability_shape = [inputs_shape[1], 1]
|
||||||
|
return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, inputs_dtype, sequence_length_dtype):
|
||||||
|
validator.check_tensor_type_same({"inputs_dtype": inputs_dtype}, [mstype.float32, mstype.double], self.name)
|
||||||
|
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length_dtype}, [mstype.int32], self.name)
|
||||||
|
decoded_type = mstype.tensor_type(mstype.int64)
|
||||||
|
return decoded_type, decoded_type, decoded_type, inputs_dtype
|
||||||
|
|
||||||
|
|
||||||
class BasicLSTMCell(PrimitiveWithInfer):
|
class BasicLSTMCell(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Performs the long short term memory(LSTM) on the input.
|
Performs the long short term memory(LSTM) on the input.
|
||||||
|
@ -5361,6 +5424,7 @@ class InTopK(PrimitiveWithInfer):
|
||||||
>>> result = in_top_k(x1, x2)
|
>>> result = in_top_k(x1, x2)
|
||||||
[True False]
|
[True False]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, k):
|
def __init__(self, k):
|
||||||
"""Init InTopK"""
|
"""Init InTopK"""
|
||||||
|
|
|
@ -621,6 +621,18 @@ class UniformNet(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class CTCGreedyDecoderNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CTCGreedyDecoderNet, self).__init__()
|
||||||
|
self.ctc_greedy_decoder = P.CTCGreedyDecoder()
|
||||||
|
self.assert_op = P.Assert(300)
|
||||||
|
|
||||||
|
def construct(self, inputs, sequence_length):
|
||||||
|
out = self.ctc_greedy_decoder(inputs,sequence_length)
|
||||||
|
self.assert_op(True, (out[0], out[1], out[2], out[3]))
|
||||||
|
return out[2]
|
||||||
|
|
||||||
|
|
||||||
class StridedSliceNet(nn.Cell):
|
class StridedSliceNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(StridedSliceNet, self).__init__()
|
super(StridedSliceNet, self).__init__()
|
||||||
|
@ -1672,6 +1684,10 @@ test_case_nn_ops = [
|
||||||
Tensor(np.array([1, 2, 3, 4]).astype(np.int32)),
|
Tensor(np.array([1, 2, 3, 4]).astype(np.int32)),
|
||||||
Tensor(np.array([6, 6, 6, 6]).astype(np.int32))],
|
Tensor(np.array([6, 6, 6, 6]).astype(np.int32))],
|
||||||
'desc_bprop': [[4], [6, 4, 6]]}),
|
'desc_bprop': [[4], [6, 4, 6]]}),
|
||||||
|
('CTCGreedyDecoder', {
|
||||||
|
'block': CTCGreedyDecoderNet(),
|
||||||
|
'desc_inputs': [[2, 2, 3], Tensor(np.array([2, 2]).astype(np.int32))],
|
||||||
|
'skip': ['backward']}),
|
||||||
('L2Loss_1', {
|
('L2Loss_1', {
|
||||||
'block': P.L2Loss(),
|
'block': P.L2Loss(),
|
||||||
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
|
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
|
||||||
|
|
Loading…
Reference in New Issue