!44271 add Dynamic Shape check of CTCLossV2

Merge pull request !44271 from zhujingxuan/master
This commit is contained in:
i-robot 2022-10-20 15:48:29 +00:00 committed by Gitee
commit d36b653d98
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 21 additions and 27 deletions

View File

@ -43,7 +43,8 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
auto target_lengths_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex3]->BuildShape())[kShape];
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape)) {
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape) || IsDynamicRank(input_lengths_shape) ||
IsDynamicRank(target_lengths_shape)) {
std::vector<int64_t> dyn_shape = {abstract::Shape::kShapeRankAny};
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(dyn_shape);
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(dyn_shape);
@ -58,6 +59,15 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
int64_t C = log_probs_shape[kIndex2];
int64_t S = targets_shape[kIndex1];
int64_t padded_S = (S == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (kMulti * S + 1);
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{N});
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{N, T, padded_S});
if (IsDynamicShape(log_probs_shape) || IsDynamicShape(targets_shape) || IsDynamicShape(input_lengths_shape) ||
IsDynamicShape(target_lengths_shape)) {
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
}
(void)CheckAndConvertUtils::CheckValue<size_t>("dim of input_lengths", input_lengths_shape.size(), kEqual, kDim1,
prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("dim of target_lengths", target_lengths_shape.size(), kEqual, kDim1,
@ -69,10 +79,6 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
auto blank = GetValue<int64_t>(primitive->GetAttr(kAttrBlank));
CheckAndConvertUtils::CheckInRange(kAttrBlank, blank, kIncludeLeft, {0, C}, prim_name);
std::vector<int64_t> out_dim0 = {N};
std::vector<int64_t> out_dim1 = {N, T, kMulti * S + 1};
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
}

View File

@ -24,6 +24,7 @@
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace ops {
@ -34,28 +35,15 @@ namespace {
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr size_t kLenLogProbs = 3;
constexpr int64_t kInputSize = 7;
constexpr size_t kIdx2 = 2;
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto log_probs_shape = log_probs_shape_map[kShape];
auto log_probs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
if (IsDynamicRank(log_probs_shape)) {
std::vector<int64_t> dyn_shape = {abstract::Shape::kShapeRankAny};
return std::make_shared<abstract::Shape>(dyn_shape);
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
}
if (log_probs_shape.size() != kLenLogProbs) {
MS_LOG(EXCEPTION) << "For '" << prim_name
<< "', input log_probs's dims must be 3, but got: " << log_probs_shape.size() << ".";
}
int64_t T = log_probs_shape[0];
int64_t N = log_probs_shape[1];
int64_t C = log_probs_shape[kIdx2];
(void)CheckAndConvertUtils::CheckValue("dim of log_probs", log_probs_shape.size(), kEqual, kLenLogProbs, prim_name);
int64_t T = log_probs_shape[kIndex0];
int64_t N = log_probs_shape[kIndex1];
int64_t C = log_probs_shape[kIndex2];
ShapeVector output_shape = {T, N, C};
return std::make_shared<abstract::Shape>(output_shape);
}
@ -77,9 +65,9 @@ TypePtr CTCLossV2GradInferType(const PrimitivePtr &primitive, const std::vector<
MIND_API_OPERATOR_IMPL(CTCLossV2Grad, BaseOperator);
AbstractBasePtr CTCLossV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
MS_EXCEPTION_IF_NULL(primitive);
constexpr int64_t kInputNum = 7;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
auto infer_shape = CTCLossV2GradInferShape(primitive, input_args);
auto infer_type = CTCLossV2GradInferType(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);