forked from mindspore-Ecosystem/mindspore
Fix bug of Add infershape
This commit is contained in:
parent
d48151ab1e
commit
9923d4794d
|
@ -60,27 +60,6 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return out->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
ShapePtr shape_x = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
|
||||
MS_EXCEPTION_IF_NULL(shape_x);
|
||||
std::vector<int64_t> x_dims = shape_x->shape();
|
||||
ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[1]->GetShapeTrack());
|
||||
MS_EXCEPTION_IF_NULL(shape_y);
|
||||
std::vector<int64_t> y_dims = shape_y->shape();
|
||||
auto broadcast_shape = BroadcastShape(x_dims, y_dims);
|
||||
if (broadcast_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
|
||||
<< args_spec_list[1]->ToString();
|
||||
}
|
||||
auto out = args_spec_list[0]->Broaden();
|
||||
out->set_shape(std::make_shared<Shape>(broadcast_shape));
|
||||
return out;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: one tensor.
|
||||
|
@ -272,6 +251,11 @@ AbstractBasePtr InferImplMul(const AnalysisEnginePtr &engine_ptr, const Primitiv
|
|||
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
||||
|
|
Loading…
Reference in New Issue