!41971 r1.9 Fix exception raise of addn

Merge pull request !41971 from chenfei_mindspore/r1.9-addn-log-fix
This commit is contained in:
i-robot 2022-09-17 01:11:42 +00:00 committed by Gitee
commit 1e0c228bae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 9 additions and 8 deletions

View File

@ -76,23 +76,24 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
for (size_t i = 0; i < elements.size(); ++i) {
auto shape = elements[i]->BuildShape();
if (!shape->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', input[" << i
<< "] should be a Tensor, but got:" << elements[i]->ToString();
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', input[" << i
<< "] should be a Tensor, but got:" << elements[i]->ToString();
}
const auto &shape_vec = shape->cast<abstract::ShapePtr>()->shape();
// If any shape is dynamic rank, return a dynamic rank.
if (IsDynamicRank(shape_vec)) {
return std::make_shared<abstract::Shape>(ShapeVector({UNKNOWN_RANK}));
}
// Record input0's shape.
if (i == 0) {
output_shape = shape_vec;
continue;
}
// If any shape is dynamic rank, return a dynamic rank.
if (IsDynamicRank(shape_vec)) {
return std::make_shared<abstract::Shape>(ShapeVector({UNKNOWN_RANK}));
}
// Join input[i] with input[0]
if (!AddNDynShapeJoin(&output_shape, &shape_vec)) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', input[" << i << "]:" << shape->ToString()
<< " is not compatible with input0:" << shape_0->ToString();
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', input shape must be same, but got shape of input["
<< i << "]: " << shape->ToString() << ", shape of input[0]: " << shape_0->ToString()
<< ".";
}
}
return std::make_shared<abstract::Shape>(output_shape);