!19205 ctc loss v2 and ctc los v2 grad
Merge pull request !19205 from liubuyu/ctcloss
This commit is contained in:
commit
967a3b8104
|
@ -272,6 +272,8 @@ inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgP
|
|||
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");
|
||||
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
inline const PrimitivePtr kPrimCTCLossV2 = std::make_shared<Primitive>("CTCLossV2");
|
||||
inline const PrimitivePtr kPrimCTCLossV2Grad = std::make_shared<Primitive>("CTCLossV2Grad");
|
||||
inline const PrimitivePtr kPrimCTCLoss = std::make_shared<Primitive>(kCTCLoss);
|
||||
inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection");
|
||||
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>(kConv2DTranspose);
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/ctc_loss_v2.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr size_t kLenTarget = 2;
|
||||
constexpr size_t kInputSize = 4;
|
||||
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto targets_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto log_probs_shape = log_probs_shape_map[kShape];
|
||||
auto targets_shape = targets_shape_map[kShape];
|
||||
if (log_probs_shape.size() != kLenLogProbs) {
|
||||
MS_LOG(EXCEPTION) << "Input log_probs's dims must be 3, but got :" << log_probs_shape.size();
|
||||
}
|
||||
if (targets_shape.size() != kLenTarget) {
|
||||
MS_LOG(EXCEPTION) << "Input targets's dims must be 2, but got :" << targets_shape.size();
|
||||
}
|
||||
int64_t T = log_probs_shape[0];
|
||||
int64_t N = log_probs_shape[1];
|
||||
int64_t S = targets_shape[1];
|
||||
|
||||
ShapeVector output_shape;
|
||||
std::vector<int64_t> out_dim0 = {N};
|
||||
std::vector<int64_t> out_dim1 = {N, T, 2 * S + 1};
|
||||
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
|
||||
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
|
||||
}
|
||||
|
||||
TuplePtr CTCLossV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat32};
|
||||
auto type = CheckAndConvertUtils::CheckTypeValid("log_probs", input_args[0]->BuildType(), valid_types, name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr CTCLossV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(CTCLossV2InferShape(primitive, input_args), CTCLossV2InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CTCLossV2, prim::kPrimCTCLossV2, CTCLossV2Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CTC_LOSS_V2_H_
|
||||
#define MINDSPORE_CORE_OPS_CTC_LOSS_V2_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCTCLossV2 = "CTCLossV2";
|
||||
class CTCLossV2 : public PrimitiveC {
|
||||
public:
|
||||
CTCLossV2() : PrimitiveC(kNameCTCLossV2) {
|
||||
InitIOName({"log_probs", "targets", "input_lengths", "target_lengths"}, {"neg_log_likelihood", "log_alpha"});
|
||||
}
|
||||
~CTCLossV2() = default;
|
||||
MS_DECLARE_PARENT(CTCLossV2, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr CTCLossV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimCTCLossV2Ptr = std::shared_ptr<CTCLossV2>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CTC_LOSS_V2_H_
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/ctc_loss_v2_grad.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr size_t kInputSize = 7;
|
||||
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto log_probs_shape = log_probs_shape_map[kShape];
|
||||
if (log_probs_shape.size() != kLenLogProbs) {
|
||||
MS_LOG(EXCEPTION) << "Input log_probs's dims must be 3, but got :" << log_probs_shape.size();
|
||||
}
|
||||
int64_t T = log_probs_shape[0];
|
||||
int64_t N = log_probs_shape[1];
|
||||
int64_t C = log_probs_shape[2];
|
||||
ShapeVector output_shape = {N, T, C};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr CTCLossV2GradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
types.emplace("grad_out", input_args[0]->BuildType());
|
||||
types.emplace("log_probs", input_args[1]->BuildType());
|
||||
auto out_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, name);
|
||||
return out_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr CTCLossV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_shape = CTCLossV2GradInferShape(primitive, input_args);
|
||||
auto infer_type = CTCLossV2GradInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CTCLossV2Grad, prim::kPrimCTCLossV2Grad, CTCLossV2GradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CTC_LOSS_V2_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_CTC_LOSS_V2_GRAD_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCTCLossV2Grad = "CTCLossV2Grad";
|
||||
class CTCLossV2Grad : public PrimitiveC {
|
||||
public:
|
||||
CTCLossV2Grad() : PrimitiveC(kNameCTCLossV2Grad) {
|
||||
InitIOName(
|
||||
{"grad_out", "log_probs", "targets", "input_lengths", "target_lengths", "neg_log_likelihood", "log_alpha"},
|
||||
{"grad"});
|
||||
}
|
||||
~CTCLossV2Grad() = default;
|
||||
MS_DECLARE_PARENT(CTCLossV2Grad, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr CTCLossV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimCTCLossV2Ptr = std::shared_ptr<CTCLossV2Grad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CTC_LOSS_V2_GRAD_H_
|
|
@ -16,5 +16,6 @@
|
|||
"""grad experimental impl."""
|
||||
from .._grad.grad_base import get_bprop_fn
|
||||
from . import grad_inner_ops
|
||||
from . import grad_nn_ops
|
||||
|
||||
__all__ = ['get_bprop_fn']
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Define the grad rules of neural network related operations."""
|
||||
from .._grad.grad_base import bprop_getters
|
||||
from .. import operations as P
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
||||
|
||||
@bprop_getters.register(P.CTCLossV2)
|
||||
def get_bprop_ctc_loss_v2(self):
|
||||
"""Grad definition for `CTCLossV2` operation"""
|
||||
transpose = P.Transpose()
|
||||
ctc_loss_grad = P.CTCLossV2Grad(self.blank, self.reduction, self.zero_infinity)
|
||||
|
||||
def bprop(log_probs, targets, input_lengths, target_lengths, out, dout):
|
||||
grad = ctc_loss_grad(dout[1], log_probs, targets, input_lengths, target_lengths, out[0], out[1])
|
||||
grad = transpose(grad, (1, 0, 2))
|
||||
return grad, zeros_like(targets), zeros_like(input_lengths), zeros_like(target_lengths)
|
||||
|
||||
return bprop
|
|
@ -382,3 +382,5 @@ from .log_ds import _log_ds_tbe
|
|||
from .neg_ds import _neg_ds_tbe
|
||||
from .not_equal_ds import _not_ds_equal_tbe
|
||||
from .reciprocal_ds import _reciprocal_ds_tbe
|
||||
from .ctc_loss_v2 import _ctc_loss_v2_tbe
|
||||
from .ctc_loss_v2_grad import _ctc_loss_v2_grad_tbe
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CTC_LossV2 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
ctc_loss_v2_info = TBERegOp("CTCLossV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("ctc_loss_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("ctc_loss_v2") \
|
||||
.partial_flag(True) \
|
||||
.attr("blank", "optional", "int", "all", "0") \
|
||||
.attr("reduction", "optional", "str", "all", "none") \
|
||||
.attr("zero_infinity", "optional", "bool", "all", "false") \
|
||||
.input(0, "log_probs", False, "required", "all") \
|
||||
.input(1, "targets", False, "required", "all") \
|
||||
.input(2, "input_lengths", False, "required", "all") \
|
||||
.input(3, "target_lengths", False, "required", "all") \
|
||||
.output(0, "neg_log_likelihood", False, "required", "all") \
|
||||
.output(1, "log_alpha", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(ctc_loss_v2_info)
|
||||
def _ctc_loss_v2_tbe():
|
||||
"""CTCLossV2 TBE register"""
|
||||
return
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CTC_LossV2Grad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
ctc_loss_v2_grad_info = TBERegOp("CTCLossV2Grad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("ctc_loss_v2_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("ctc_loss_v2_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("blank", "optional", "int", "all", "0") \
|
||||
.attr("reduction", "optional", "str", "all", "none") \
|
||||
.attr("zero_infinity", "optional", "bool", "all", "false") \
|
||||
.input(0, "grad_out", False, "required", "all") \
|
||||
.input(1, "log_probs", False, "required", "all") \
|
||||
.input(2, "targets", False, "required", "all") \
|
||||
.input(3, "input_lengths", False, "required", "all") \
|
||||
.input(4, "target_lengths", False, "required", "all") \
|
||||
.input(5, "neg_log_likelihood", False, "required", "all") \
|
||||
.input(6, "log_alpha", False, "required", "all") \
|
||||
.output(0, "grad", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(ctc_loss_v2_grad_info)
|
||||
def _ctc_loss_v2_grad_tbe():
|
||||
"""CTCLossV2Grad TBE register"""
|
||||
return
|
|
@ -70,7 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
|||
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
|
||||
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
|
||||
GeLU, Gelu, FastGeLU, FastGelu, Elu,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCLossV2Grad, CTCGreedyDecoder,
|
||||
LogSoftmax, MaxPool3D, AvgPool3D,
|
||||
MaxPool, DataFormatDimMap,
|
||||
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
||||
|
|
|
@ -8306,6 +8306,100 @@ def _deconv_output_length(input_length, kernel_size, stride_size, dilation_size)
|
|||
return length
|
||||
|
||||
|
||||
class CTCLossV2(Primitive):
|
||||
"""
|
||||
Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.
|
||||
|
||||
The CTC algorithm is proposed in `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
|
||||
Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_.
|
||||
|
||||
Args:
|
||||
blank (int): The blank label. Default: 0.
|
||||
reduction (string): Apply specific reduction method to the output. Currently only support 'none'.
|
||||
Default: "none".
|
||||
zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number
|
||||
of classes (including blank).
|
||||
- **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences.
|
||||
- **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the input.
|
||||
- **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the target.
|
||||
|
||||
Outputs:
|
||||
- **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node.
|
||||
- **log_alpha** (Tensor) - The probability of possible trace of input to target.
|
||||
|
||||
Raises:
|
||||
TypeError: If `zero_infinity` is not a bool, reduction is not string.
|
||||
|
||||
Supported Platforms:
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, blank, reduction="none", zero_infinity=False):
|
||||
"""Initialize CTCLossV2"""
|
||||
self.init_prim_io_names(inputs=["log_probs", "targets", "input_lengths", "target_lengths"],
|
||||
outputs=["neg_log_likelihood", "log_alpha"])
|
||||
validator.check_value_type("blank", blank, [int], self.name)
|
||||
self.add_prim_attr("blank", blank)
|
||||
validator.check_value_type("reduction", reduction, [str], self.name)
|
||||
self.reduction = self.reduction.lower()
|
||||
validator.check_string(self.reduction, ['none'], 'reduction', self.name)
|
||||
self.add_prim_attr("reduction", self.reduction)
|
||||
validator.check_value_type("zero_infinity", zero_infinity, [bool], self.name)
|
||||
self.add_prim_attr("zero_infinity", zero_infinity)
|
||||
|
||||
|
||||
class CTCLossV2Grad(Primitive):
|
||||
"""
|
||||
Calculates the gradient of CTC (Connectionist Temporal Classification) loss.
|
||||
|
||||
The CTC algorithm is proposed in `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
|
||||
Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_.
|
||||
|
||||
Args:
|
||||
blank (int): The blank label. Default: 0.
|
||||
reduction (string): Apply specific reduction method to the output. Currently only support 'none'.
|
||||
Default: "none".
|
||||
zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **grad_out** (Tenosr) - Gradient renewal codfficient, A tensor for shape (N), where N is batch size.
|
||||
- **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number
|
||||
of classes (including blank).
|
||||
- **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences.
|
||||
- **input_lengths** (Union(tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the input.
|
||||
- **target_lengths** (Union(tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the target.
|
||||
- **log_alpha** (Tensor) - The probability of possible trace of input to target.
|
||||
- **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node.
|
||||
|
||||
Outputs:
|
||||
- **grad** (Tensor) - The grad of Connectionist Temporal Classification Loss
|
||||
|
||||
Raises:
|
||||
TypeError: If `zero_infinity` is not a bool, reduction is not string.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, blank, reduction="none", zero_infinity=False):
|
||||
"""Initialize CTCLossV2Grad"""
|
||||
self.init_prim_io_names(inputs=["grad_out", "log_probs", "targets", "input_lengths", "target_lengths",
|
||||
"neg_log_likelihood", "log_alpha"],
|
||||
outputs=["grad"])
|
||||
validator.check_value_type("blank", blank, [int], self.name)
|
||||
self.add_prim_attr("blank", blank)
|
||||
validator.check_value_type("reduction", reduction, [str], self.name)
|
||||
self.add_prim_attr("reduction", reduction)
|
||||
validator.check_value_type("zero_infinity", zero_infinity, [bool], self.name)
|
||||
self.add_prim_attr("zero_infinity", zero_infinity)
|
||||
|
||||
|
||||
|
||||
class Conv3DTranspose(PrimitiveWithInfer):
|
||||
r"""
|
||||
Computes a 3D transposed convolution, which is also known as a deconvolution
|
||||
|
|
Loading…
Reference in New Issue