!29205 fix error msg for zeros and ones & gathernd bug

Merge pull request !29205 from Simson/push-to-r1.6
This commit is contained in:
i-robot 2022-01-18 06:16:11 +00:00 committed by Gitee
commit 88b1a4704d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 19 additions and 4 deletions

View File

@ -75,10 +75,11 @@ TypePtr GatherNdInferType(const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &arg) { return arg == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
(void)types.emplace("x1", input_args[kInputIndex0]->BuildType());
(void)types.emplace("x2", input_args[kInputIndex0]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, "GatherNd");
std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, int_types, "GatherNd");
return x_type;
}
} // namespace
AbstractBasePtr GatherNdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -30,6 +30,13 @@ abstract::ShapePtr OnesInferShape(const PrimitivePtr &primitive, const std::vect
auto prim_name = primitive->name();
// check
auto shape_value = input_args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(shape_value);
if (shape_value->isa<ValueList>()) {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
<< "], the input"
" must be a Int or a tuple with all Int elements, but got "
<< shape_value->ToString();
}
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
return std::make_shared<abstract::Shape>(out_shape);

View File

@ -31,6 +31,13 @@ abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vec
auto prim_name = primitive->name();
// check
auto shape_value = input_args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(shape_value);
if (shape_value->isa<ValueList>()) {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
<< "], the input"
" must be a Int or a tuple with all Int elements, but got "
<< shape_value->ToString();
}
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
return std::make_shared<abstract::Shape>(out_shape);