!17193 add check of undetermined type in addn infer

From: @simson_wu
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-29 11:03:45 +08:00 committed by Gitee
commit 37f6be5de8
1 changed files with 8 additions and 5 deletions

View File

@ -40,8 +40,8 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
for (size_t i = 0; i < elements.size(); ++i) {
auto shape = elements[i]->BuildShape();
if (*shape != *shape_0) {
MS_LOG(EXCEPTION) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
<< " are not consistent with the shape of input[0]" << shape_0->ToString();
MS_EXCEPTION(ValueError) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
<< " are not consistent with the shape of input[0]" << shape_0->ToString();
}
}
auto in_shape = element0_shape_map[kShape];
@ -62,7 +62,7 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim->name());
std::map<std::string, TypePtr> types;
types.emplace("element_0", elements[0]->BuildType());
for (size_t i = 1; i < elements.size(); ++i) {
for (size_t i = 0; i < elements.size(); ++i) {
if (elements[i]->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return elements[0]->BuildType();
}
@ -82,8 +82,11 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return std::make_shared<abstract::AbstractTensor>(AddNInferType(primitive, input_args),
AddNInferShape(primitive, input_args));
auto abs_type = AddNInferType(primitive, input_args);
if (abs_type->type_id() == kObjectTypeUndeterminedType) {
return std::make_shared<abstract::AbstractUndetermined>();
}
return std::make_shared<abstract::AbstractTensor>(abs_type, AddNInferShape(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true);
} // namespace ops