!15470 Fix bug of InferImplAdd

From: @jojobugfree
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-04-22 09:24:56 +08:00 committed by Gitee
commit ca8ab21233
1 changed files with 5 additions and 21 deletions

View File

@ -60,27 +60,6 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
return out->Broaden(); return out->Broaden();
} }
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
ShapePtr shape_x = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
MS_EXCEPTION_IF_NULL(shape_x);
std::vector<int64_t> x_dims = shape_x->shape();
ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[1]->GetShapeTrack());
MS_EXCEPTION_IF_NULL(shape_y);
std::vector<int64_t> y_dims = shape_y->shape();
auto broadcast_shape = BroadcastShape(x_dims, y_dims);
if (broadcast_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}
auto out = args_spec_list[0]->Broaden();
out->set_shape(std::make_shared<Shape>(broadcast_shape));
return out;
}
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: one tensor. // Inputs: one tensor.
@ -272,6 +251,11 @@ AbstractBasePtr InferImplMul(const AnalysisEnginePtr &engine_ptr, const Primitiv
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
} }
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
}
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive, AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list); return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);