fix error handling of StandardLaplace

This commit is contained in:
panshaowu 2022-11-02 11:35:04 +08:00
parent 30d7af21ae
commit e433a9ef71
5 changed files with 10 additions and 2 deletions

View File

@ -24,3 +24,4 @@ mindspore.ops.StandardLaplace
- **TypeError** - `shape` 既不是tuple也不是Tensor。
- **ValueError** - `seed``seed2` 不是非负的int。
- **ValueError** - `shape` 为tuple时包含非正的元素。
- **ValueError** - `shape` 为秩不等于1的Tensor。

View File

@ -22,3 +22,4 @@ mindspore.ops.standard_laplace
- **TypeError** - `shape` 既不是tuple也不是Tensor。
- **ValueError** - `seed``seed2` 不是非负的int。
- **ValueError** - `shape` 为tuple时包含非正的元素。
- **ValueError** - `shape` 为秩不等于1的Tensor。

View File

@ -61,6 +61,10 @@ abstract::ShapePtr StandardLaplaceInferShape(const PrimitivePtr &primitive,
return std::make_shared<abstract::Shape>(out_shape);
} else if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>()) {
if (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {
if (x_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', rank of the input Tensor shall be 1, but got: " << x_shape.size() << ".";
}
ShapeVector input_shape = CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name);
return std::make_shared<abstract::Shape>(input_shape);
} else {
@ -72,7 +76,7 @@ abstract::ShapePtr StandardLaplaceInferShape(const PrimitivePtr &primitive,
}
} else {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
<< "', input must be a tuple, or a Tensor with all Int elements, but got: "
<< input_args[kInputIndex0]->ToString() << ".";
}
}
@ -94,7 +98,7 @@ TypePtr StandardLaplaceInferType(const PrimitivePtr &primitive, const std::vecto
(void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_dtype, valid_shape_types, prim_name);
} else {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
<< "', input must be a tuple, or a Tensor with all Int elements, but got: "
<< input_args[kInputIndex0]->ToString() << ".";
}
return std::make_shared<TensorType>(kFloat32);

View File

@ -109,6 +109,7 @@ def standard_laplace(shape, seed=0, seed2=0):
TypeError: If shape is neither a tuple nor a Tensor.
ValueError: If seed or seed2 is not a non-negative int.
ValueError: If shape is a tuple containing non-positive items.
ValueError: If shape is a Tensor, and the rank of the Tensor is not equal to 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

View File

@ -184,6 +184,7 @@ class StandardLaplace(Primitive):
TypeError: If shape is neither a tuple nor a Tensor.
ValueError: If seed or seed2 is not a non-negative int.
ValueError: If shape is a tuple containing non-positive items.
ValueError: If shape is a Tensor, and the rank of the Tensor is not equal to 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``