Fix bug of Add infershape

This commit is contained in:
caifubi 2021-04-21 15:25:16 +08:00
parent d48151ab1e
commit 9923d4794d
1 changed files with 5 additions and 21 deletions

View File

@ -60,27 +60,6 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
return out->Broaden();
}
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
ShapePtr shape_x = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
MS_EXCEPTION_IF_NULL(shape_x);
std::vector<int64_t> x_dims = shape_x->shape();
ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[1]->GetShapeTrack());
MS_EXCEPTION_IF_NULL(shape_y);
std::vector<int64_t> y_dims = shape_y->shape();
auto broadcast_shape = BroadcastShape(x_dims, y_dims);
if (broadcast_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}
auto out = args_spec_list[0]->Broaden();
out->set_shape(std::make_shared<Shape>(broadcast_shape));
return out;
}
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: one tensor.
@ -272,6 +251,11 @@ AbstractBasePtr InferImplMul(const AnalysisEnginePtr &engine_ptr, const Primitiv
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
}
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
}
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);