forked from mindspore-Ecosystem/mindspore
!47341 fix processing illegal inputs bug of StandardLaplace Op
Merge pull request !47341 from panshaowu/master
This commit is contained in:
commit
9ee6b4e1da
|
@ -47,16 +47,34 @@ namespace {
|
|||
abstract::ShapePtr StandardLaplaceInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input = input_args[kInputIndex0];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (!input->isa<abstract::AbstractTensor>() && !input->isa<abstract::AbstractTuple>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
auto shape_value = input_args[kInputIndex0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
if (input_args[kInputIndex0]->isa<abstract::AbstractTuple>()) {
|
||||
if (IsValueKnown(shape_value)) {
|
||||
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", shape_value, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
constexpr int dynamic_rank_value = -2;
|
||||
ShapeVector shape = {dynamic_rank_value};
|
||||
return std::make_shared<abstract::Shape>(shape);
|
||||
}
|
||||
} else if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>()) {
|
||||
if (IsValueKnown(shape_value)) {
|
||||
ShapeVector input_shape = CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name);
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
} else {
|
||||
constexpr int dynamic_rank_value = -2;
|
||||
ShapeVector shape = {dynamic_rank_value};
|
||||
return std::make_shared<abstract::Shape>(shape);
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', input must be a tuple, or a Tensor with all Int elements, but got: "
|
||||
<< input_args[kInputIndex0]->ToString() << ".";
|
||||
}
|
||||
|
||||
auto out_shape = GetShapeValue(primitive, input);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr StandardLaplaceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
Loading…
Reference in New Issue