!44271 add Dynamic Shape check of CTCLossV2
Merge pull request !44271 from zhujingxuan/master
This commit is contained in:
commit
d36b653d98
|
@ -43,7 +43,8 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
auto target_lengths_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kIndex3]->BuildShape())[kShape];
|
||||
|
||||
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape)) {
|
||||
if (IsDynamicRank(log_probs_shape) || IsDynamicRank(targets_shape) || IsDynamicRank(input_lengths_shape) ||
|
||||
IsDynamicRank(target_lengths_shape)) {
|
||||
std::vector<int64_t> dyn_shape = {abstract::Shape::kShapeRankAny};
|
||||
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(dyn_shape);
|
||||
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(dyn_shape);
|
||||
|
@ -58,6 +59,15 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
int64_t C = log_probs_shape[kIndex2];
|
||||
int64_t S = targets_shape[kIndex1];
|
||||
|
||||
int64_t padded_S = (S == abstract::Shape::kShapeDimAny) ? abstract::Shape::kShapeDimAny : (kMulti * S + 1);
|
||||
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{N});
|
||||
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{N, T, padded_S});
|
||||
|
||||
if (IsDynamicShape(log_probs_shape) || IsDynamicShape(targets_shape) || IsDynamicShape(input_lengths_shape) ||
|
||||
IsDynamicShape(target_lengths_shape)) {
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
|
||||
}
|
||||
|
||||
(void)CheckAndConvertUtils::CheckValue<size_t>("dim of input_lengths", input_lengths_shape.size(), kEqual, kDim1,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckValue<size_t>("dim of target_lengths", target_lengths_shape.size(), kEqual, kDim1,
|
||||
|
@ -69,10 +79,6 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
auto blank = GetValue<int64_t>(primitive->GetAttr(kAttrBlank));
|
||||
CheckAndConvertUtils::CheckInRange(kAttrBlank, blank, kIncludeLeft, {0, C}, prim_name);
|
||||
|
||||
std::vector<int64_t> out_dim0 = {N};
|
||||
std::vector<int64_t> out_dim1 = {N, T, kMulti * S + 1};
|
||||
abstract::ShapePtr neg_log_shape = std::make_shared<abstract::Shape>(out_dim0);
|
||||
abstract::ShapePtr log_alpha_shape = std::make_shared<abstract::Shape>(out_dim1);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{neg_log_shape, log_alpha_shape});
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -34,28 +35,15 @@ namespace {
|
|||
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr int64_t kInputSize = 7;
|
||||
constexpr size_t kIdx2 = 2;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto log_probs_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto log_probs_shape = log_probs_shape_map[kShape];
|
||||
auto log_probs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
if (IsDynamicRank(log_probs_shape)) {
|
||||
std::vector<int64_t> dyn_shape = {abstract::Shape::kShapeRankAny};
|
||||
return std::make_shared<abstract::Shape>(dyn_shape);
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
if (log_probs_shape.size() != kLenLogProbs) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name
|
||||
<< "', input log_probs's dims must be 3, but got: " << log_probs_shape.size() << ".";
|
||||
}
|
||||
int64_t T = log_probs_shape[0];
|
||||
int64_t N = log_probs_shape[1];
|
||||
int64_t C = log_probs_shape[kIdx2];
|
||||
(void)CheckAndConvertUtils::CheckValue("dim of log_probs", log_probs_shape.size(), kEqual, kLenLogProbs, prim_name);
|
||||
int64_t T = log_probs_shape[kIndex0];
|
||||
int64_t N = log_probs_shape[kIndex1];
|
||||
int64_t C = log_probs_shape[kIndex2];
|
||||
ShapeVector output_shape = {T, N, C};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
@ -77,9 +65,9 @@ TypePtr CTCLossV2GradInferType(const PrimitivePtr &primitive, const std::vector<
|
|||
MIND_API_OPERATOR_IMPL(CTCLossV2Grad, BaseOperator);
|
||||
AbstractBasePtr CTCLossV2GradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr int64_t kInputNum = 7;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
|
||||
auto infer_shape = CTCLossV2GradInferShape(primitive, input_args);
|
||||
auto infer_type = CTCLossV2GradInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
|
|
Loading…
Reference in New Issue