From e4fbf40afde4feabc12613a5c6f7c55e1a8ecd6c Mon Sep 17 00:00:00 2001 From: fengyihang Date: Wed, 1 Feb 2023 16:59:43 +0800 Subject: [PATCH] fix applyrmsprop infer check --- mindspore/core/ops/apply_rms_prop.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mindspore/core/ops/apply_rms_prop.cc b/mindspore/core/ops/apply_rms_prop.cc index 777f427cef2..1d1f7afcb37 100644 --- a/mindspore/core/ops/apply_rms_prop.cc +++ b/mindspore/core/ops/apply_rms_prop.cc @@ -30,6 +30,7 @@ class ApplyRMSPropInfer : public abstract::OpInferBase { auto op_name = primitive->name(); MS_LOG(INFO) << "For '" << op_name << "', it's now doing infer shape."; const int64_t kInputNum = 5; + const int64_t kInputNumNormal = 8; CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, op_name); auto var_shape = input_args[0]->BuildShape(); auto ms_shape = input_args[1]->BuildShape(); @@ -67,9 +68,14 @@ class ApplyRMSPropInfer : public abstract::OpInferBase { << "', 'grad' must have the same shape as 'var'. But got 'grad' shape: " << grad_shape->ToString() << ", 'var' shape: " << var_shape->ToString() << "."; } - auto shape_element = var_shape->cast(); - MS_EXCEPTION_IF_NULL(shape_element); - return shape_element; + if (input_args.size() >= kInputNumNormal) { + auto shape_element = var_shape->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + return shape_element; + } else { + return std::make_shared( + std::vector{var_shape, ms_shape, mom_shape}); + } } TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override {