From 9923d4794d7bbb310a30772167a6bdfcbca14c80 Mon Sep 17 00:00:00 2001 From: caifubi Date: Wed, 21 Apr 2021 15:25:16 +0800 Subject: [PATCH] Fix bug of Add infershape --- mindspore/core/abstract/prim_maths.cc | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 81e2841fd54..28524199ed5 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -60,27 +60,6 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr return out->Broaden(); } -AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - ShapePtr shape_x = dyn_cast(args_spec_list[0]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(shape_x); - std::vector x_dims = shape_x->shape(); - ShapePtr shape_y = dyn_cast(args_spec_list[1]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(shape_y); - std::vector y_dims = shape_y->shape(); - auto broadcast_shape = BroadcastShape(x_dims, y_dims); - if (broadcast_shape.empty()) { - MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," - << args_spec_list[1]->ToString(); - } - auto out = args_spec_list[0]->Broaden(); - out->set_shape(std::make_shared(broadcast_shape)); - return out; -} - AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: one tensor. @@ -272,6 +251,11 @@ AbstractBasePtr InferImplMul(const AnalysisEnginePtr &engine_ptr, const Primitiv return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); } +AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); +} + AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);