From 3908ce6173545e445bb1e63e7059c08dc7023333 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 19 Jan 2022 11:15:56 +0800 Subject: [PATCH] fix a bug about pynative mode error on Ascend --- mindspore/core/ops/gelu.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mindspore/core/ops/gelu.cc b/mindspore/core/ops/gelu.cc index 4ca26199253..2a176b6735f 100644 --- a/mindspore/core/ops/gelu.cc +++ b/mindspore/core/ops/gelu.cc @@ -28,9 +28,6 @@ namespace mindspore { namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("gelu infer", SizeToLong(input_args.size()), kEqual, 1, op_name); MS_EXCEPTION_IF_NULL(input_args[0]); auto x = input_args[0]->BuildShape(); MS_EXCEPTION_IF_NULL(x); @@ -39,9 +36,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("gelu infer", SizeToLong(input_args.size()), kEqual, 1, op_name); std::map types; const std::set valid_types = {kFloat16, kFloat32}; MS_EXCEPTION_IF_NULL(input_args[0]); @@ -51,7 +45,12 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + auto infer_type = InferType(primitive, input_args); + auto infer_shape = InferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); } REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true); } // namespace ops