fix reduce ops axis multiple bug in GPU

This commit is contained in:
buxue 2021-11-22 18:03:42 +08:00
parent 6f559516ea
commit 89a688f3be
4 changed files with 13 additions and 10 deletions

View File

@ -116,6 +116,8 @@ class ArrayReduceGpuKernel : public GpuKernel {
axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis);
}
std::sort(axis_.begin(), axis_.end());
auto multiple_pos = std::unique(axis_.begin(), axis_.end());
axis_.erase(multiple_pos, axis_.end());
}
} else if (prim->GetAttr("axis")->isa<Int64Imm>()) {
int axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));

View File

@ -508,7 +508,7 @@ class Tensor(Tensor_):
Check all tensor elements along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction,
axis (Union[None, int, tuple(int)]): Dimensions of reduction,
when the axis is None or empty tuple, reduce all dimensions. Default: ().
keep_dims (bool): Whether to keep the reduced dimensions. Default: False.
@ -540,7 +540,7 @@ class Tensor(Tensor_):
Check any tensor element along a given axis evaluate to True.
Args:
axis (Union[None, int, tuple(int)): Dimensions of reduction,
axis (Union[None, int, tuple(int)]): Dimensions of reduction,
when the axis is None or empty tuple, reduce all dimensions. Default: ().
keep_dims (bool): Whether to keep the reduced dimensions. Default: False.

View File

@ -480,10 +480,11 @@ class _Reduce(PrimitiveWithInfer):
if np_reduce_func is not None:
value = input_x['value'].asnumpy()
if isinstance(axis_v, int):
axis_v = (axis_v,)
elif not axis_v:
axis_v = [i for i in range(len(input_x['shape']))]
axis_v = tuple(axis_v)
pass
elif axis_v:
axis_v = tuple(set(axis_v))
else:
axis_v = tuple(range(len(input_x['shape'])))
value = np_reduce_func(value, axis_v, keepdims=self.keep_dims)
value = np.array(value)
value = Tensor(value)

View File

@ -58,7 +58,7 @@ axis7 = (1, 2)
keep_dims7 = True
x8 = np.random.rand(2, 1, 1, 4).astype(np.float32)
axis8 = (1, 2)
axis8 = (1, 2, 2, 1, 1, 2)
keep_dims8 = True
x9 = np.random.rand(2, 1, 1, 4).astype(np.float32)
@ -70,7 +70,7 @@ axis10 = (0, 1, 2, 3)
keep_dims10 = False
x11 = np.random.rand(1, 1, 1, 1).astype(np.float32)
axis11 = (0, 1, 2, 3)
axis11 = (0, 1, 1, 2, 3, 2, 2, 0, 3)
keep_dims11 = False
x12 = np.random.rand(2, 3, 4, 4).astype(np.float32)
@ -226,7 +226,7 @@ def test_ReduceSum():
assert np.all(diff7 < error7)
assert output[7].shape == expect7.shape
expect8 = np.sum(x8, axis=axis8, keepdims=keep_dims8)
expect8 = np.sum(x8, axis=tuple(set(axis8)), keepdims=keep_dims8)
diff8 = abs(output[8].asnumpy() - expect8)
error8 = np.ones(shape=expect8.shape) * 1.0e-5
assert np.all(diff8 < error8)
@ -244,7 +244,7 @@ def test_ReduceSum():
assert np.all(diff10 < error10)
assert output[10].shape == expect10.shape
expect11 = np.sum(x11, axis=axis11, keepdims=keep_dims11)
expect11 = np.sum(x11, axis=tuple(set(axis11)), keepdims=keep_dims11)
diff11 = abs(output[11].asnumpy() - expect11)
error11 = np.ones(shape=expect11.shape) * 1.0e-5
assert np.all(diff11 < error11)