fix ops.Select infer shape

This commit is contained in:
lilinjie 2022-10-24 00:38:10 +08:00
parent 890e5a124c
commit fdbac22d91
4 changed files with 12 additions and 29 deletions

View File

@ -20,7 +20,7 @@ mindspore.ops.Select
- **y** (Tensor) - 第二个被选择的Tensorshape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`
输出:
Tensor具有与输入 `x` 相同的shape。
Tensor具有与输入 `condition` 相同的shape。
异常:
- **TypeError** - 如果 `x` 或者 `y` 不是Tensor。

View File

@ -47,24 +47,14 @@ void SelectImpl(const bool *conds, void *x, void *y, void *result, size_t size)
}
}
bool CheckScalarOrTensor(ShapeVector input) {
// check 1D tensor with one element or scalar.
auto size = input.size();
bool flag = true;
if (size != 0 && (size != 1 || (input[0] > 0 && input[0] != 1))) {
flag = false;
}
return flag;
}
void SelectInferShapeCheck(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
const std::vector<int64_t> &cond_shape, size_t shape_size) {
for (size_t i = 0; i < shape_size; i++) {
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
MS_EXCEPTION(ValueError)
<< "For 'Select', shape of tensor condition, x and y must be the same. But got condition shape: " << cond_shape
<< ", x shape: " << x_shape << ", y shape: " << y_shape << ".";
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
}
}
}
@ -79,20 +69,13 @@ abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &, const std::vector<
auto cond_shape_size = cond_shape.size();
auto x_shape_size = x_shape.size();
auto y_shape_size = y_shape.size();
if (cond_shape_size != 0 && x_shape_size != 0 && y_shape_size != 0) {
if (cond_shape_size != x_shape_size || y_shape_size != x_shape_size) {
MS_EXCEPTION(ValueError)
<< "For 'Select', shape size of tensor condition, x and y must be equal. But got condition size: "
<< cond_shape_size << ", x size: " << x_shape_size << ", y size: " << y_shape_size << ".";
}
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
} else {
if (!(CheckScalarOrTensor(cond_shape) && CheckScalarOrTensor(x_shape) && CheckScalarOrTensor(y_shape))) {
MS_EXCEPTION(ValueError) << "For 'Select', when any of cond, x, y is of scalar type, "
"the rest must be 1D tensor with one element or scalar!";
}
if (cond_shape_size != x_shape_size || y_shape_size != x_shape_size) {
MS_EXCEPTION(ValueError)
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
}
return input_args[kSelectXIndex]->BuildShape();
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
return input_args[kSelectCondIndex]->BuildShape();
}
TypePtr SelectInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -3088,7 +3088,7 @@ class Select(Primitive):
- **y** (Tensor) - The second tensor to be selected and the shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
Outputs:
Tensor, has the same shape as `x`.
Tensor, has the same shape as `condition`.
Raises:
TypeError: If `x` or `y` is not a Tensor.

View File

@ -585,7 +585,7 @@ class SparseTensorDenseAdd(Primitive):
of 'x2' don't meet the parameter description.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor
@ -645,7 +645,7 @@ class SparseTensorDenseMatmul(Primitive):
and shape of `dense` don't meet the parameter description.
Supported Platforms:
``Ascend`` ``CPU``
``CPU``
Examples:
>>> indices = Tensor([[0, 1], [1, 2]], dtype=mindspore.int32)