forked from OSSInnovation/mindspore
!3714 stridedslice/stridedslicegrad 4D to 7D
Merge pull request !3714 from panbingao/stridedslice
This commit is contained in:
commit
51fcaf6e61
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 4;
|
||||
constexpr int MAX_DIMS = 7;
|
||||
template <typename T>
|
||||
class StridedSliceGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -65,8 +65,17 @@ class StridedSliceGpuKernel : public GpuKernel {
|
|||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T));
|
||||
output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[2] * output_shape_[3] * sizeof(T));
|
||||
size_t size = sizeof(T);
|
||||
for (size_t i = 0; i < MAX_DIMS; i++) {
|
||||
size *= input_shape_[i];
|
||||
}
|
||||
input_size_list_.push_back(size);
|
||||
|
||||
int size1 = sizeof(T);
|
||||
for (size_t i = 0; i < MAX_DIMS; i++) {
|
||||
size1 *= output_shape_[i];
|
||||
}
|
||||
output_size_list_.push_back(size1);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 4;
|
||||
constexpr int MAX_DIMS = 7;
|
||||
template <typename T>
|
||||
class StridedSliceGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -66,8 +66,17 @@ class StridedSliceGradGpuKernel : public GpuKernel {
|
|||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[2] * output_shape_[3] * sizeof(T));
|
||||
output_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T));
|
||||
int size = sizeof(T);
|
||||
for (size_t i = 0; i < MAX_DIMS; i++) {
|
||||
size *= output_shape_[i];
|
||||
}
|
||||
input_size_list_.push_back(size);
|
||||
|
||||
int size1 = sizeof(T);
|
||||
for (size_t i = 0; i < MAX_DIMS; i++) {
|
||||
size1 *= input_shape_[i];
|
||||
}
|
||||
output_size_list_.push_back(size1);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -82,18 +82,25 @@ void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<int> i
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void StridedSliceKernel(const int b0, const int b1, const int b2, const int b3, const int s0, const int s1,
|
||||
const int s2, const int s3, const int i0, const int i1, const int i2, const int i3,
|
||||
const int o0, const int o1, const int o2, const int o3, const T *input_addr,
|
||||
T *output_addr) {
|
||||
int output_num = o0 * o1 * o2 * o3;
|
||||
__global__ void StridedSliceKernel(const int b0, const int b1, const int b2, const int b3, const int b4,
|
||||
const int b5, const int b6, const int s0, const int s1, const int s2,
|
||||
const int s3, const int s4, const int s5, const int s6, const int i0,
|
||||
const int i1, const int i2, const int i3, const int i4, const int i5,
|
||||
const int i6, const int o0, const int o1, const int o2, const int o3,
|
||||
const int o4, const int o5, const int o6, const T *input_addr, T *output_addr) {
|
||||
int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (o1 * o2 * o3) % o0;
|
||||
int j = pos / (o2 * o3) % o1;
|
||||
int k = pos / o3 % o2;
|
||||
int l = pos % o3;
|
||||
int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
|
||||
int j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
|
||||
int k = pos / (o3 * o4 * o5 * o6) % o2;
|
||||
int l = pos / (o4 * o5 * o6) % o3;
|
||||
int m = pos / (o5 * o6) % o4;
|
||||
int n = pos / (o6) % o5;
|
||||
int o = pos % o6;
|
||||
|
||||
int input_idx = (i * s0 + b0) * i1 * i2 * i3 + (j * s1 + b1) * i2 * i3 + (k * s2 + b2) * i3 + (l * s3 + b3);
|
||||
int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
|
||||
+ (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \
|
||||
+ (n * s5 + b5) * i6 + (o * s6 + b6);
|
||||
output_addr[pos] = input_addr[input_idx];
|
||||
}
|
||||
}
|
||||
|
@ -102,26 +109,36 @@ template <typename T>
|
|||
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
|
||||
const std::vector<int> &strides, const std::vector<int> &output_shape, const T *input, T *output,
|
||||
cudaStream_t cuda_stream) {
|
||||
int size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3];
|
||||
int size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \
|
||||
* output_shape[4] * output_shape[5] * output_shape[6];
|
||||
StridedSliceKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
|
||||
begin[0], begin[1], begin[2], begin[3], strides[0], strides[1], strides[2], strides[3], input_shape[0],
|
||||
input_shape[1], input_shape[2], input_shape[3], output_shape[0], output_shape[1], output_shape[2], output_shape[3],
|
||||
input, output);
|
||||
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6],
|
||||
strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6],
|
||||
input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[4], input_shape[5], input_shape[6],
|
||||
output_shape[0], output_shape[1], output_shape[2], output_shape[3], output_shape[4], output_shape[5],
|
||||
output_shape[6], input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, const int b3, const int s0,
|
||||
const int s1, const int s2, const int s3, const int i0, const int i1,
|
||||
const int i2, const int i3, const int o0, const int o1, const int o2,
|
||||
const int o3, const T *dy, T *dx) {
|
||||
int output_num = o0 * o1 * o2 * o3;
|
||||
__global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, const int b3, const int b4,
|
||||
const int b5, const int b6, const int s0, const int s1, const int s2,
|
||||
const int s3, const int s4, const int s5, const int s6, const int i0,
|
||||
const int i1, const int i2, const int i3, const int i4, const int i5,
|
||||
const int i6, const int o0, const int o1, const int o2, const int o3,
|
||||
const int o4, const int o5, const int o6, const T *dy, T *dx) {
|
||||
int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (o1 * o2 * o3) % o0;
|
||||
int j = pos / (o2 * o3) % o1;
|
||||
int k = pos / o3 % o2;
|
||||
int l = pos % o3;
|
||||
int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
|
||||
int j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
|
||||
int k = pos / (o3 * o4 * o5 * o6) % o2;
|
||||
int l = pos / (o4 * o5 * o6) % o3;
|
||||
int m = pos / (o5 * o6) % o4;
|
||||
int n = pos / (o6) % o5;
|
||||
int o = pos % o6;
|
||||
|
||||
int input_idx = (i * s0 + b0) * i1 * i2 * i3 + (j * s1 + b1) * i2 * i3 + (k * s2 + b2) * i3 + (l * s3 + b3);
|
||||
int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
|
||||
+ (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \
|
||||
+ (n * s5 + b5) * i6 + (o * s6 + b6);
|
||||
dx[input_idx] = dy[pos];
|
||||
}
|
||||
return;
|
||||
|
@ -130,10 +147,13 @@ __global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2,
|
|||
template <typename T>
|
||||
void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, const std::vector<int> &strides,
|
||||
const std::vector<int> &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream) {
|
||||
int size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3];
|
||||
int size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6];
|
||||
StridedSliceGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
|
||||
begin[0], begin[1], begin[2], begin[3], strides[0], strides[1], strides[2], strides[3], dx_shape[0], dx_shape[1],
|
||||
dx_shape[2], dx_shape[3], dy_shape[0], dy_shape[1], dy_shape[2], dy_shape[3], dy, dx);
|
||||
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6],
|
||||
strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6],
|
||||
dx_shape[0], dx_shape[1], dx_shape[2], dx_shape[3], dx_shape[4], dx_shape[5], dx_shape[6],
|
||||
dy_shape[0], dy_shape[1], dy_shape[2], dy_shape[3], dy_shape[4], dy_shape[5], dy_shape[6],
|
||||
dy, dx);
|
||||
}
|
||||
|
||||
template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -274,3 +274,37 @@ def test_strided_slice_grad():
|
|||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.]]])
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
||||
x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(np.float32))
|
||||
net = StridedSliceNet((0, 0, 0, 1, 1, 2, 2), (1, 1, 1, 2, 3, 3, 4), (1, 1, 1, 1, 1, 1, 1))
|
||||
dx = GradData(net)(x)
|
||||
expect = np.array([[[[[[[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.]],
|
||||
|
||||
[[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.]],
|
||||
|
||||
[[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.]]],
|
||||
|
||||
[[[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.]],
|
||||
|
||||
[[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0.]],
|
||||
|
||||
[[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 1., 0.],
|
||||
[0., 0., 0., 0., 0.]]]]]]])
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
|
|
@ -93,3 +93,15 @@ def test_stridedslice():
|
|||
y = Tensor(x_np)[:, ::-1]
|
||||
expect = x_np[:, ::-1]
|
||||
assert np.allclose(y.asnumpy(), expect)
|
||||
|
||||
x = Tensor(np.arange(0, 2 * 3 * 4 * 5 * 4 * 3 * 2).reshape(2, 3, 4, 5, 4, 3, 2).astype(np.float32))
|
||||
y = P.StridedSlice()(x, (1, 0, 0, 2, 1, 2, 0), (2, 2, 2, 4, 2, 3, 2), (1, 1, 1, 1, 1, 1, 2))
|
||||
expect = np.array([[[[[[[1498.]]],
|
||||
[[[1522.]]]],
|
||||
[[[[1618.]]],
|
||||
[[[1642.]]]]],
|
||||
[[[[[1978.]]],
|
||||
[[[2002.]]]],
|
||||
[[[[2098.]]],
|
||||
[[[2122.]]]]]]])
|
||||
assert np.allclose(y.asnumpy(), expect)
|
||||
|
|
Loading…
Reference in New Issue