!47341 fix processing illegal inputs bug of StandardLaplace Op

Merge pull request !47341 from panshaowu/master
This commit is contained in:
i-robot 2022-12-30 03:49:50 +00:00 committed by Gitee
commit 9ee6b4e1da
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 25 additions and 7 deletions

View File

@ -47,16 +47,34 @@ namespace {
abstract::ShapePtr StandardLaplaceInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input = input_args[kInputIndex0];
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<abstract::AbstractTensor>() && !input->isa<abstract::AbstractTuple>()) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
auto shape_value = input_args[kInputIndex0]->BuildValue();
MS_EXCEPTION_IF_NULL(shape_value);
if (input_args[kInputIndex0]->isa<abstract::AbstractTuple>()) {
if (IsValueKnown(shape_value)) {
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
return std::make_shared<abstract::Shape>(out_shape);
} else {
constexpr int dynamic_rank_value = -2;
ShapeVector shape = {dynamic_rank_value};
return std::make_shared<abstract::Shape>(shape);
}
} else if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>()) {
if (IsValueKnown(shape_value)) {
ShapeVector input_shape = CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name);
return std::make_shared<abstract::Shape>(input_shape);
} else {
constexpr int dynamic_rank_value = -2;
ShapeVector shape = {dynamic_rank_value};
return std::make_shared<abstract::Shape>(shape);
}
} else {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', input must be a tuple, or a Tensor with all Int elements, but got: "
<< input_args[kInputIndex0]->ToString() << ".";
}
auto out_shape = GetShapeValue(primitive, input);
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr StandardLaplaceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {