!48309 Fix applyrmsprop infercheck with ascend dynamic shape

Merge pull request !48309 from 冯一航/fix_applyrmsprop_infercheck_with_ascend
This commit is contained in:
i-robot 2023-02-02 02:06:11 +00:00 committed by Gitee
commit c520d8bfd8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 9 additions and 3 deletions

View File

@ -30,6 +30,7 @@ class ApplyRMSPropInfer : public abstract::OpInferBase {
auto op_name = primitive->name();
MS_LOG(INFO) << "For '" << op_name << "', it's now doing infer shape.";
const int64_t kInputNum = 5;
const int64_t kInputNumNormal = 8;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, op_name);
auto var_shape = input_args[0]->BuildShape();
auto ms_shape = input_args[1]->BuildShape();
@ -67,9 +68,14 @@ class ApplyRMSPropInfer : public abstract::OpInferBase {
<< "', 'grad' must have the same shape as 'var'. But got 'grad' shape: "
<< grad_shape->ToString() << ", 'var' shape: " << var_shape->ToString() << ".";
}
if (input_args.size() >= kInputNumNormal) {
auto shape_element = var_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
} else {
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape, ms_shape, mom_shape});
}
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {