forked from mindspore-Ecosystem/mindspore
fix ops.Select infer shape
This commit is contained in:
parent
890e5a124c
commit
fdbac22d91
|
@ -20,7 +20,7 @@ mindspore.ops.Select
|
|||
- **y** (Tensor) - 第二个被选择的Tensor,shape是 :math:`(x_1, x_2, ..., x_N, ..., x_R)`。
|
||||
|
||||
输出:
|
||||
Tensor,具有与输入 `x` 相同的shape。
|
||||
Tensor,具有与输入 `condition` 相同的shape。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `x` 或者 `y` 不是Tensor。
|
||||
|
|
|
@ -47,24 +47,14 @@ void SelectImpl(const bool *conds, void *x, void *y, void *result, size_t size)
|
|||
}
|
||||
}
|
||||
|
||||
bool CheckScalarOrTensor(ShapeVector input) {
|
||||
// check 1D tensor with one element or scalar.
|
||||
auto size = input.size();
|
||||
bool flag = true;
|
||||
if (size != 0 && (size != 1 || (input[0] > 0 && input[0] != 1))) {
|
||||
flag = false;
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
|
||||
void SelectInferShapeCheck(const std::vector<int64_t> &x_shape, const std::vector<int64_t> &y_shape,
|
||||
const std::vector<int64_t> &cond_shape, size_t shape_size) {
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
if ((x_shape[i] > 0 && cond_shape[i] > 0 && x_shape[i] != cond_shape[i]) ||
|
||||
(x_shape[i] > 0 && y_shape[i] > 0 && x_shape[i] != y_shape[i])) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', shape of tensor condition, x and y must be the same. But got condition shape: " << cond_shape
|
||||
<< ", x shape: " << x_shape << ", y shape: " << y_shape << ".";
|
||||
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
|
||||
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -79,20 +69,13 @@ abstract::BaseShapePtr SelectInferShape(const PrimitivePtr &, const std::vector<
|
|||
auto cond_shape_size = cond_shape.size();
|
||||
auto x_shape_size = x_shape.size();
|
||||
auto y_shape_size = y_shape.size();
|
||||
if (cond_shape_size != 0 && x_shape_size != 0 && y_shape_size != 0) {
|
||||
if (cond_shape_size != x_shape_size || y_shape_size != x_shape_size) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', shape size of tensor condition, x and y must be equal. But got condition size: "
|
||||
<< cond_shape_size << ", x size: " << x_shape_size << ", y size: " << y_shape_size << ".";
|
||||
}
|
||||
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
|
||||
} else {
|
||||
if (!(CheckScalarOrTensor(cond_shape) && CheckScalarOrTensor(x_shape) && CheckScalarOrTensor(y_shape))) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Select', when any of cond, x, y is of scalar type, "
|
||||
"the rest must be 1D tensor with one element or scalar!";
|
||||
}
|
||||
if (cond_shape_size != x_shape_size || y_shape_size != x_shape_size) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'Select', the shape of 'condition', 'x' and 'y' must be the same. But got 'condition' shape: "
|
||||
<< cond_shape << ", 'x' shape: " << x_shape << ", 'y' shape: " << y_shape << ".";
|
||||
}
|
||||
return input_args[kSelectXIndex]->BuildShape();
|
||||
SelectInferShapeCheck(x_shape, y_shape, cond_shape, x_shape_size);
|
||||
return input_args[kSelectCondIndex]->BuildShape();
|
||||
}
|
||||
|
||||
TypePtr SelectInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -3088,7 +3088,7 @@ class Select(Primitive):
|
|||
- **y** (Tensor) - The second tensor to be selected and the shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `x`.
|
||||
Tensor, has the same shape as `condition`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `y` is not a Tensor.
|
||||
|
|
|
@ -585,7 +585,7 @@ class SparseTensorDenseAdd(Primitive):
|
|||
of 'x2' don't meet the parameter description.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
|
@ -645,7 +645,7 @@ class SparseTensorDenseMatmul(Primitive):
|
|||
and shape of `dense` don't meet the parameter description.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indices = Tensor([[0, 1], [1, 2]], dtype=mindspore.int32)
|
||||
|
|
Loading…
Reference in New Issue