forked from mindspore-Ecosystem/mindspore
!15470 Fix bug of InferImplAdd
From: @jojobugfree Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
ca8ab21233
|
@ -60,27 +60,6 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
|
||||||
return out->Broaden();
|
return out->Broaden();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
|
||||||
// Inputs: two tensors.
|
|
||||||
const std::string op_name = primitive->name();
|
|
||||||
CheckArgsSize(op_name, args_spec_list, 2);
|
|
||||||
ShapePtr shape_x = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_x);
|
|
||||||
std::vector<int64_t> x_dims = shape_x->shape();
|
|
||||||
ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[1]->GetShapeTrack());
|
|
||||||
MS_EXCEPTION_IF_NULL(shape_y);
|
|
||||||
std::vector<int64_t> y_dims = shape_y->shape();
|
|
||||||
auto broadcast_shape = BroadcastShape(x_dims, y_dims);
|
|
||||||
if (broadcast_shape.empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
|
|
||||||
<< args_spec_list[1]->ToString();
|
|
||||||
}
|
|
||||||
auto out = args_spec_list[0]->Broaden();
|
|
||||||
out->set_shape(std::make_shared<Shape>(broadcast_shape));
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// Inputs: one tensor.
|
// Inputs: one tensor.
|
||||||
|
@ -272,6 +251,11 @@ AbstractBasePtr InferImplMul(const AnalysisEnginePtr &engine_ptr, const Primitiv
|
||||||
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
|
||||||
|
|
Loading…
Reference in New Issue