!29205 fix error msg for zeros and ones & gathernd bug
Merge pull request !29205 from Simson/push-to-r1.6
This commit is contained in:
commit
88b1a4704d
|
@ -75,10 +75,11 @@ TypePtr GatherNdInferType(const std::vector<AbstractBasePtr> &input_args) {
|
|||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x1", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("x2", input_args[kInputIndex0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, "GatherNd");
|
||||
std::set<TypePtr> int_types = {kInt8, kInt16, kInt32, kInt64};
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, int_types, "GatherNd");
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr GatherNdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -30,6 +30,13 @@ abstract::ShapePtr OnesInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto shape_value = input_args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
if (shape_value->isa<ValueList>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the input"
|
||||
" must be a Int or a tuple with all Int elements, but got "
|
||||
<< shape_value->ToString();
|
||||
}
|
||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
|
|
|
@ -31,6 +31,13 @@ abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
auto prim_name = primitive->name();
|
||||
// check
|
||||
auto shape_value = input_args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
if (shape_value->isa<ValueList>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the input"
|
||||
" must be a Int or a tuple with all Int elements, but got "
|
||||
<< shape_value->ToString();
|
||||
}
|
||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
|
|
Loading…
Reference in New Issue