!19205 ctc loss v2 and ctc los v2 grad

Merge pull request !19205 from liubuyu/ctcloss
This commit is contained in:
i-robot 2021-07-06 11:09:56 +00:00 committed by Gitee
commit 967a3b8104
12 changed files with 464 additions and 1 deletions

View File

@ -272,6 +272,8 @@ inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgP
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
inline const PrimitivePtr kPrimCTCLossV2 = std::make_shared<Primitive>("CTCLossV2");
inline const PrimitivePtr kPrimCTCLossV2Grad = std::make_shared<Primitive>("CTCLossV2Grad");
inline const PrimitivePtr kPrimCTCLoss = std::make_shared<Primitive>(kCTCLoss);
inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection");
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>(kConv2DTranspose);

View File

@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/ctc_loss_v2.h"
#include <vector>
#include <string>
#include <memory>
#include <map>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kLenLogProbs = 3;
constexpr size_t kLenTarget = 2;
constexpr size_t kInputSize = 4;
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto targets_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto log_probs_shape = log_probs_shape_map[kShape];
auto targets_shape = targets_shape_map[kShape];
if (log_probs_shape.size() != kLenLogProbs) {
MS_LOG(EXCEPTION) << "Input log_probs's dims must be 3, but got :" << log_probs_shape.size();
}
if (targets_shape.size() != kLenTarget) {
MS_LOG(EXCEPTION) << "Input targets's dims must be 2, but got :" << targets_shape.size();
}
int64_t T = log_probs_shape[0];
int64_t N = log_probs_shape[1];
int64_t S = targets_shape[1];
ShapeVector output_shape;
std::vector<int64_t> out_dim0 = {N};
std::vector<int64_t> out_dim1 = {N, T, 2 * S + 1};
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
}
TuplePtr CTCLossV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto name = primitive->name();
const std::set<TypePtr> valid_types = {kFloat32};
auto type = CheckAndConvertUtils::CheckTypeValid("log_probs", input_args[0]->BuildType(), valid_types, name);
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
}
} // namespace
AbstractBasePtr CTCLossV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(CTCLossV2InferShape(primitive, input_args), CTCLossV2InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(CTCLossV2, prim::kPrimCTCLossV2, CTCLossV2Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CTC_LOSS_V2_H_
#define MINDSPORE_CORE_OPS_CTC_LOSS_V2_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCTCLossV2 = "CTCLossV2";
class CTCLossV2 : public PrimitiveC {
public:
CTCLossV2() : PrimitiveC(kNameCTCLossV2) {
InitIOName({"log_probs", "targets", "input_lengths", "target_lengths"}, {"neg_log_likelihood", "log_alpha"});
}
~CTCLossV2() = default;
MS_DECLARE_PARENT(CTCLossV2, PrimitiveC);
void Init() {}
};
AbstractBasePtr CTCLossV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCTCLossV2Ptr = std::shared_ptr<CTCLossV2>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CTC_LOSS_V2_H_

View File

@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/ctc_loss_v2_grad.h"
#include <vector>
#include <string>
#include <memory>
#include <map>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kLenLogProbs = 3;
constexpr size_t kInputSize = 7;
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto log_probs_shape = log_probs_shape_map[kShape];
if (log_probs_shape.size() != kLenLogProbs) {
MS_LOG(EXCEPTION) << "Input log_probs's dims must be 3, but got :" << log_probs_shape.size();
}
int64_t T = log_probs_shape[0];
int64_t N = log_probs_shape[1];
int64_t C = log_probs_shape[2];
ShapeVector output_shape = {N, T, C};
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr CTCLossV2GradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto name = primitive->name();
const std::set<TypePtr> valid_types = {kFloat32};
std::map<std::string, TypePtr> types;
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);
types.emplace("grad_out", input_args[0]->BuildType());
types.emplace("log_probs", input_args[1]->BuildType());
auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, name);
return out_type;
}
} // namespace
AbstractBasePtr CTCLossV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_shape = CTCLossV2GradInferShape(primitive, input_args);
auto infer_type = CTCLossV2GradInferType(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(CTCLossV2Grad, prim::kPrimCTCLossV2Grad, CTCLossV2GradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CTC_LOSS_V2_GRAD_H_
#define MINDSPORE_CORE_OPS_CTC_LOSS_V2_GRAD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCTCLossV2Grad = "CTCLossV2Grad";
class CTCLossV2Grad : public PrimitiveC {
public:
CTCLossV2Grad() : PrimitiveC(kNameCTCLossV2Grad) {
InitIOName(
{"grad_out", "log_probs", "targets", "input_lengths", "target_lengths", "neg_log_likelihood", "log_alpha"},
{"grad"});
}
~CTCLossV2Grad() = default;
MS_DECLARE_PARENT(CTCLossV2Grad, PrimitiveC);
void Init() {}
};
AbstractBasePtr CTCLossV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCTCLossV2Ptr = std::shared_ptr<CTCLossV2Grad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CTC_LOSS_V2_GRAD_H_

View File

@ -16,5 +16,6 @@
"""grad experimental impl."""
from .._grad.grad_base import get_bprop_fn
from . import grad_inner_ops
from . import grad_nn_ops
__all__ = ['get_bprop_fn']

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the grad rules of neural network related operations."""
from .._grad.grad_base import bprop_getters
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.CTCLossV2)
def get_bprop_ctc_loss_v2(self):
"""Grad definition for `CTCLossV2` operation"""
transpose = P.Transpose()
ctc_loss_grad = P.CTCLossV2Grad(self.blank, self.reduction, self.zero_infinity)
def bprop(log_probs, targets, input_lengths, target_lengths, out, dout):
grad = ctc_loss_grad(dout[1], log_probs, targets, input_lengths, target_lengths, out[0], out[1])
grad = transpose(grad, (1, 0, 2))
return grad, zeros_like(targets), zeros_like(input_lengths), zeros_like(target_lengths)
return bprop

View File

@ -382,3 +382,5 @@ from .log_ds import _log_ds_tbe
from .neg_ds import _neg_ds_tbe
from .not_equal_ds import _not_ds_equal_tbe
from .reciprocal_ds import _reciprocal_ds_tbe
from .ctc_loss_v2 import _ctc_loss_v2_tbe
from .ctc_loss_v2_grad import _ctc_loss_v2_grad_tbe

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CTC_LossV2 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
ctc_loss_v2_info = TBERegOp("CTCLossV2") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("ctc_loss_v2.so") \
.compute_cost(10) \
.kernel_name("ctc_loss_v2") \
.partial_flag(True) \
.attr("blank", "optional", "int", "all", "0") \
.attr("reduction", "optional", "str", "all", "none") \
.attr("zero_infinity", "optional", "bool", "all", "false") \
.input(0, "log_probs", False, "required", "all") \
.input(1, "targets", False, "required", "all") \
.input(2, "input_lengths", False, "required", "all") \
.input(3, "target_lengths", False, "required", "all") \
.output(0, "neg_log_likelihood", False, "required", "all") \
.output(1, "log_alpha", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(ctc_loss_v2_info)
def _ctc_loss_v2_tbe():
"""CTCLossV2 TBE register"""
return

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CTC_LossV2Grad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
ctc_loss_v2_grad_info = TBERegOp("CTCLossV2Grad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("ctc_loss_v2_grad.so") \
.compute_cost(10) \
.kernel_name("ctc_loss_v2_grad") \
.partial_flag(True) \
.attr("blank", "optional", "int", "all", "0") \
.attr("reduction", "optional", "str", "all", "none") \
.attr("zero_infinity", "optional", "bool", "all", "false") \
.input(0, "grad_out", False, "required", "all") \
.input(1, "log_probs", False, "required", "all") \
.input(2, "targets", False, "required", "all") \
.input(3, "input_lengths", False, "required", "all") \
.input(4, "target_lengths", False, "required", "all") \
.input(5, "neg_log_likelihood", False, "required", "all") \
.input(6, "log_alpha", False, "required", "all") \
.output(0, "grad", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(ctc_loss_v2_grad_info)
def _ctc_loss_v2_grad_tbe():
"""CTCLossV2Grad TBE register"""
return

View File

@ -70,7 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
GeLU, Gelu, FastGeLU, FastGelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCLossV2Grad, CTCGreedyDecoder,
LogSoftmax, MaxPool3D, AvgPool3D,
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,

View File

@ -8306,6 +8306,100 @@ def _deconv_output_length(input_length, kernel_size, stride_size, dilation_size)
return length
class CTCLossV2(Primitive):
"""
Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.
The CTC algorithm is proposed in `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_.
Args:
blank (int): The blank label. Default: 0.
reduction (string): Apply specific reduction method to the output. Currently only support 'none'.
Default: "none".
zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
Inputs:
- **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number
of classes (including blank).
- **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences.
- **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the input.
- **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the target.
Outputs:
- **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node.
- **log_alpha** (Tensor) - The probability of possible trace of input to target.
Raises:
TypeError: If `zero_infinity` is not a bool, reduction is not string.
Supported Platforms:
"""
@prim_attr_register
def __init__(self, blank, reduction="none", zero_infinity=False):
"""Initialize CTCLossV2"""
self.init_prim_io_names(inputs=["log_probs", "targets", "input_lengths", "target_lengths"],
outputs=["neg_log_likelihood", "log_alpha"])
validator.check_value_type("blank", blank, [int], self.name)
self.add_prim_attr("blank", blank)
validator.check_value_type("reduction", reduction, [str], self.name)
self.reduction = self.reduction.lower()
validator.check_string(self.reduction, ['none'], 'reduction', self.name)
self.add_prim_attr("reduction", self.reduction)
validator.check_value_type("zero_infinity", zero_infinity, [bool], self.name)
self.add_prim_attr("zero_infinity", zero_infinity)
class CTCLossV2Grad(Primitive):
"""
Calculates the gradient of CTC (Connectionist Temporal Classification) loss.
The CTC algorithm is proposed in `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_.
Args:
blank (int): The blank label. Default: 0.
reduction (string): Apply specific reduction method to the output. Currently only support 'none'.
Default: "none".
zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
Inputs:
- **grad_out** (Tenosr) - Gradient renewal codfficient, A tensor for shape (N), where N is batch size.
- **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number
of classes (including blank).
- **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences.
- **input_lengths** (Union(tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the input.
- **target_lengths** (Union(tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the target.
- **log_alpha** (Tensor) - The probability of possible trace of input to target.
- **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node.
Outputs:
- **grad** (Tensor) - The grad of Connectionist Temporal Classification Loss
Raises:
TypeError: If `zero_infinity` is not a bool, reduction is not string.
Supported Platforms:
``Ascend``
"""
@prim_attr_register
def __init__(self, blank, reduction="none", zero_infinity=False):
"""Initialize CTCLossV2Grad"""
self.init_prim_io_names(inputs=["grad_out", "log_probs", "targets", "input_lengths", "target_lengths",
"neg_log_likelihood", "log_alpha"],
outputs=["grad"])
validator.check_value_type("blank", blank, [int], self.name)
self.add_prim_attr("blank", blank)
validator.check_value_type("reduction", reduction, [str], self.name)
self.add_prim_attr("reduction", reduction)
validator.check_value_type("zero_infinity", zero_infinity, [bool], self.name)
self.add_prim_attr("zero_infinity", zero_infinity)
class Conv3DTranspose(PrimitiveWithInfer):
r"""
Computes a 3D transposed convolution, which is also known as a deconvolution