forked from mindspore-Ecosystem/mindspore
!17193 add check of undetermined type in addn infer
From: @simson_wu Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
37f6be5de8
|
@ -40,8 +40,8 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
auto shape = elements[i]->BuildShape();
|
||||
if (*shape != *shape_0) {
|
||||
MS_LOG(EXCEPTION) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
|
||||
<< " are not consistent with the shape of input[0]" << shape_0->ToString();
|
||||
MS_EXCEPTION(ValueError) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
|
||||
<< " are not consistent with the shape of input[0]" << shape_0->ToString();
|
||||
}
|
||||
}
|
||||
auto in_shape = element0_shape_map[kShape];
|
||||
|
@ -62,7 +62,7 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("element_0", elements[0]->BuildType());
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
if (elements[i]->BuildType()->type_id() == kObjectTypeUndeterminedType) {
|
||||
return elements[0]->BuildType();
|
||||
}
|
||||
|
@ -82,8 +82,11 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(AddNInferType(primitive, input_args),
|
||||
AddNInferShape(primitive, input_args));
|
||||
auto abs_type = AddNInferType(primitive, input_args);
|
||||
if (abs_type->type_id() == kObjectTypeUndeterminedType) {
|
||||
return std::make_shared<abstract::AbstractUndetermined>();
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(abs_type, AddNInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
Loading…
Reference in New Issue