fix error handling of StandardLaplace
This commit is contained in:
parent
30d7af21ae
commit
e433a9ef71
|
@ -24,3 +24,4 @@ mindspore.ops.StandardLaplace
|
|||
- **TypeError** - `shape` 既不是tuple,也不是Tensor。
|
||||
- **ValueError** - `seed` 或 `seed2` 不是非负的int。
|
||||
- **ValueError** - `shape` 为tuple时,包含非正的元素。
|
||||
- **ValueError** - `shape` 为秩不等于1的Tensor。
|
||||
|
|
|
@ -22,3 +22,4 @@ mindspore.ops.standard_laplace
|
|||
- **TypeError** - `shape` 既不是tuple,也不是Tensor。
|
||||
- **ValueError** - `seed` 或 `seed2` 不是非负的int。
|
||||
- **ValueError** - `shape` 为tuple时,包含非正的元素。
|
||||
- **ValueError** - `shape` 为秩不等于1的Tensor。
|
||||
|
|
|
@ -61,6 +61,10 @@ abstract::ShapePtr StandardLaplaceInferShape(const PrimitivePtr &primitive,
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>()) {
|
||||
if (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {
|
||||
if (x_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', rank of the input Tensor shall be 1, but got: " << x_shape.size() << ".";
|
||||
}
|
||||
ShapeVector input_shape = CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name);
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
} else {
|
||||
|
@ -72,7 +76,7 @@ abstract::ShapePtr StandardLaplaceInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
|
||||
<< "', input must be a tuple, or a Tensor with all Int elements, but got: "
|
||||
<< input_args[kInputIndex0]->ToString() << ".";
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +98,7 @@ TypePtr StandardLaplaceInferType(const PrimitivePtr &primitive, const std::vecto
|
|||
(void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_dtype, valid_shape_types, prim_name);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
|
||||
<< "', input must be a tuple, or a Tensor with all Int elements, but got: "
|
||||
<< input_args[kInputIndex0]->ToString() << ".";
|
||||
}
|
||||
return std::make_shared<TensorType>(kFloat32);
|
||||
|
|
|
@ -109,6 +109,7 @@ def standard_laplace(shape, seed=0, seed2=0):
|
|||
TypeError: If shape is neither a tuple nor a Tensor.
|
||||
ValueError: If seed or seed2 is not a non-negative int.
|
||||
ValueError: If shape is a tuple containing non-positive items.
|
||||
ValueError: If shape is a Tensor, and the rank of the Tensor is not equal to 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
|
@ -184,6 +184,7 @@ class StandardLaplace(Primitive):
|
|||
TypeError: If shape is neither a tuple nor a Tensor.
|
||||
ValueError: If seed or seed2 is not a non-negative int.
|
||||
ValueError: If shape is a tuple containing non-positive items.
|
||||
ValueError: If shape is a Tensor, and the rank of the Tensor is not equal to 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
Loading…
Reference in New Issue