forked from mindspore-Ecosystem/mindspore
!1070 Adapt aicpu op CTCLoss and TBE op L2Loss for GE.
Merge pull request !1070 from liuxiao/ops-for-VM
This commit is contained in:
commit
ba3d48817e
|
@ -196,6 +196,8 @@ const char kNameBatchToSpace[] = "BatchToSpace";
|
|||
const char kNameAtan2[] = "Atan2";
|
||||
const char kNameApplyRMSProp[] = "ApplyRMSProp";
|
||||
const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";
|
||||
const char kNameL2Loss[] = "L2Loss";
|
||||
const char kNameCTCLoss[] = "CTCLoss";
|
||||
|
||||
// -----------------OpAdapter initialization--------------
|
||||
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
|
||||
|
@ -391,7 +393,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)},
|
||||
{string(kNameAtan2), ADPT_DESC(Atan2)},
|
||||
{string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)},
|
||||
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}};
|
||||
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)},
|
||||
{string(kNameL2Loss), ADPT_DESC(L2Loss)},
|
||||
{string(kNameCTCLoss), ADPT_DESC(CTCLoss)}};
|
||||
#ifdef ENABLE_GE
|
||||
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
||||
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
|
||||
|
|
|
@ -1227,6 +1227,22 @@ INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)},
|
|||
ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}};
|
||||
|
||||
// L2Loss
|
||||
INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// CTCLoss
|
||||
INPUT_MAP(CTCLoss) = {{1, INPUT_DESC(inputs)},
|
||||
{2, INPUT_DESC(labels_indices)},
|
||||
{3, INPUT_DESC(labels_values)},
|
||||
{4, INPUT_DESC(sequence_length)}};
|
||||
ATTR_MAP(CTCLoss) = {
|
||||
{"preprocess_collapse_repeated", ATTR_DESC(preprocess_collapse_repeated, AnyTraits<bool>())},
|
||||
{"ctc_merge_repeated", ATTR_DESC(ctc_merge_repeated, AnyTraits<bool>())},
|
||||
{"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}};
|
||||
|
||||
#ifdef ENABLE_GE
|
||||
// Print
|
||||
INPUT_MAP(Print) = EMPTY_INPUT_MAP;
|
||||
|
|
|
@ -465,6 +465,10 @@ DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD)
|
|||
DECLARE_OP_USE_OUTPUT(ApplyRMSPropD)
|
||||
DECLARE_OP_ADAPTER(ApplyCenteredRMSProp)
|
||||
DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp)
|
||||
DECLARE_OP_ADAPTER(L2Loss)
|
||||
DECLARE_OP_USE_OUTPUT(L2Loss)
|
||||
DECLARE_OP_ADAPTER(CTCLoss)
|
||||
DECLARE_OP_USE_OUTPUT(CTCLoss)
|
||||
#ifdef ENABLE_GE
|
||||
DECLARE_OP_ADAPTER(Print)
|
||||
DECLARE_OP_USE_DYN_INPUT(Print)
|
||||
|
|
|
@ -668,3 +668,16 @@ def get_bprop_dropout(self):
|
|||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.CTCLoss)
|
||||
def get_bprop_ctc_loss(self):
|
||||
"""Grad definition for `CTCLoss` operation"""
|
||||
expand = P.ExpandDims()
|
||||
|
||||
def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
|
||||
grad_loss = out[1]
|
||||
grad = grad_loss * expand(dout[0], -1)
|
||||
return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -55,7 +55,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
|||
DropoutDoMask, DropoutGrad, Dropout,
|
||||
DropoutGenMask, Flatten, FusedBatchNorm,
|
||||
Gelu, Elu,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss,
|
||||
LogSoftmax,
|
||||
MaxPool,
|
||||
AvgPool, Conv2DBackpropInput, ConfusionMulGrad,
|
||||
|
@ -172,6 +172,7 @@ __all__ = [
|
|||
'Reciprocal',
|
||||
'SmoothL1Loss',
|
||||
'L2Loss',
|
||||
'CTCLoss',
|
||||
'ReduceAll',
|
||||
'ScalarToArray',
|
||||
'ScalarToTensor',
|
||||
|
|
|
@ -1564,7 +1564,7 @@ class L2Loss(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
||||
valid_types = [mstype.float16, mstype.float32, mstype.double]
|
||||
valid_types = [mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name)
|
||||
return x_type
|
||||
|
||||
|
@ -2871,3 +2871,78 @@ class DropoutGrad(PrimitiveWithInfer):
|
|||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
|
||||
return dy_dtype
|
||||
|
||||
|
||||
class CTCLoss(PrimitiveWithInfer):
|
||||
"""
|
||||
Calculates the CTC(Connectionist Temporal Classification) loss. Also calculates the gradient.
|
||||
|
||||
Args:
|
||||
preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation.
|
||||
Default: False.
|
||||
ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged
|
||||
and are interpreted as individual labels. This is a simplfied version if CTC.
|
||||
Default: True.
|
||||
ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored.
|
||||
Default: False.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
|
||||
:math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels`
|
||||
indicates the number of actual labels. Blank labels are reserved.
|
||||
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
|
||||
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
|
||||
- **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The
|
||||
type must be int32. `labels_values[i]` must in the range of `[0, num_class)`.
|
||||
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
|
||||
The type must be int32. Each value in the tensor should not greater than `max_time`.
|
||||
|
||||
Outputs:
|
||||
- **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`. Has the same
|
||||
type with `inputs`.
|
||||
- **gradient** (Tensor) - The gradient of `loss`. Has the same type and shape with `inputs`.
|
||||
|
||||
Examples:
|
||||
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
|
||||
>>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64)
|
||||
>>> labels_values = Tensor(np.array([2, 2]), mindspore.int32)
|
||||
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
|
||||
>>> ctc_loss = P.CTCloss()
|
||||
>>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=False,
|
||||
ignore_longer_outputs_than_inputs=False):
|
||||
self.init_prim_io_names(inputs=["inputs", "labels_indices", "labels_values", "sequence_length"],
|
||||
outputs=["loss", "gradient"])
|
||||
validator.check_value_type("preprocess_collapse_repeated", preprocess_collapse_repeated, [bool], self.name)
|
||||
self.preprocess_collapse_repeated_ = preprocess_collapse_repeated
|
||||
self.ctc_merge_repeated_ = validator.check_value_type("ctc_merge_repeated", ctc_merge_repeated,
|
||||
[bool], self.name)
|
||||
validator.check_value_type("ignore_longer_outputs_than_inputs",
|
||||
ignore_longer_outputs_than_inputs, [bool], self.name)
|
||||
self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs
|
||||
|
||||
def infer_shape(self, inputs, labels_indices, labels_values, sequence_length):
|
||||
validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name)
|
||||
validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name)
|
||||
validator.check('labels_indices size', labels_indices[0], 'labels_values size',
|
||||
labels_values[0], Rel.EQ, self.name)
|
||||
validator.check('inputs batch_size', inputs[1], 'sequence_length batch_size',
|
||||
sequence_length[0], Rel.EQ, self.name)
|
||||
batch_size = []
|
||||
batch_size.append(inputs[1])
|
||||
return batch_size, inputs
|
||||
|
||||
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
|
||||
validator.check_subclass("inputs_dtype", inputs, mstype.tensor, self.name)
|
||||
validator.check_subclass("labels_indices_dtype", labels_indices, mstype.tensor, self.name)
|
||||
validator.check_subclass("labels_values_dtype", labels_values, mstype.tensor, self.name)
|
||||
validator.check_subclass("sequence_length_dtype", sequence_length, mstype.tensor, self.name)
|
||||
validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name)
|
||||
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
|
||||
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
|
||||
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name)
|
||||
return inputs, inputs
|
||||
|
|
|
@ -909,6 +909,13 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
|
||||
'desc_bprop': [3, 3],
|
||||
'skip': ['backward']}),
|
||||
('CTCLoss', {
|
||||
'block': P.CTCLoss(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 4, 6]).astype(np.float32)),
|
||||
Tensor(np.array([[0, 1], [1, 0], [2, 3], [3, 2]]).astype(np.int64)),
|
||||
Tensor(np.array([1, 2, 3, 4]).astype(np.int32)),
|
||||
Tensor(np.array([6, 6, 6, 6]).astype(np.int32))],
|
||||
'desc_bprop': [[4], [6, 4, 6]]}),
|
||||
('L2Loss_1', {
|
||||
'block': P.L2Loss(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
|
||||
|
|
Loading…
Reference in New Issue