forked from mindspore-Ecosystem/mindspore
!48309 Fix applyrmsprop infercheck with ascend dynamic shape
Merge pull request !48309 from 冯一航/fix_applyrmsprop_infercheck_with_ascend
This commit is contained in:
commit
c520d8bfd8
|
@ -30,6 +30,7 @@ class ApplyRMSPropInfer : public abstract::OpInferBase {
|
|||
auto op_name = primitive->name();
|
||||
MS_LOG(INFO) << "For '" << op_name << "', it's now doing infer shape.";
|
||||
const int64_t kInputNum = 5;
|
||||
const int64_t kInputNumNormal = 8;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, op_name);
|
||||
auto var_shape = input_args[0]->BuildShape();
|
||||
auto ms_shape = input_args[1]->BuildShape();
|
||||
|
@ -67,9 +68,14 @@ class ApplyRMSPropInfer : public abstract::OpInferBase {
|
|||
<< "', 'grad' must have the same shape as 'var'. But got 'grad' shape: "
|
||||
<< grad_shape->ToString() << ", 'var' shape: " << var_shape->ToString() << ".";
|
||||
}
|
||||
if (input_args.size() >= kInputNumNormal) {
|
||||
auto shape_element = var_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
} else {
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{var_shape, ms_shape, mom_shape});
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
|
|
Loading…
Reference in New Issue