forked from mindspore-Ecosystem/mindspore
Gpu Slice kernel performance improvement
This commit is contained in:
parent
d9dd6aa0b8
commit
5b7790a2a7
|
@ -41,8 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel {
|
||||||
CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output,
|
CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output,
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
} else {
|
} else {
|
||||||
CalSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, output,
|
Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0],
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
input_shape_[1], input_shape_[2], input_shape_[3], input, output,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,11 +21,22 @@
|
||||||
#include "kernel/gpu/cuda_impl/slice_impl.cuh"
|
#include "kernel/gpu/cuda_impl/slice_impl.cuh"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void Slice(const T* input, int p, int start, int length, T* output) {
|
__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4,
|
||||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) {
|
const int l1, const int l2, const int l3, const int l4,
|
||||||
output[p + pos] = input[start + pos];
|
const int d1, const int d2, const int d3, const int d4,
|
||||||
|
const T *input, T *output) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) {
|
||||||
|
int i = pos / (l2 * l3 * l4) % l1;
|
||||||
|
int j = pos / (l3 * l4) % l2;
|
||||||
|
int k = pos / l4 % l3;
|
||||||
|
int o = pos % l4;
|
||||||
|
|
||||||
|
int offset = (i + s1) * (d2 * d3 * d4) +
|
||||||
|
(j + s2) * (d3 * d4) +
|
||||||
|
(k + s3) * d4 +
|
||||||
|
(o + s4);
|
||||||
|
output[pos] = input[offset];
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) {
|
__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) {
|
||||||
|
@ -64,22 +75,12 @@ void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaSt
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, const std::vector<int> begin,
|
void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
|
||||||
const std::vector<int> size, T* output, cudaStream_t cuda_stream) {
|
const int l1, const int l2, const int l3, const int l4,
|
||||||
int block = in_shape[1] * in_shape[2] * in_shape[3];
|
const int d1, const int d2, const int d3, const int d4,
|
||||||
int map = in_shape[2] * in_shape[3];
|
const T *input, T *output, cudaStream_t stream) {
|
||||||
int w = in_shape[3];
|
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4,
|
||||||
int length = size[3];
|
d1, d2, d3, d4, input, output);
|
||||||
int p = 0;
|
|
||||||
for (int i = begin[0]; i < size[0] + begin[0]; i++) {
|
|
||||||
for (int j = begin[1]; j < size[1] + begin[1]; j++) {
|
|
||||||
for (int k = begin[2]; k < size[2] + begin[2]; k++) {
|
|
||||||
Slice<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, p, i * block + j * map + k * w + begin[3],
|
|
||||||
length, output);
|
|
||||||
p = p + size[3];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, const std::vector<int> begin,
|
void CalSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, const std::vector<int> begin,
|
||||||
|
@ -147,9 +148,10 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector
|
||||||
}
|
}
|
||||||
|
|
||||||
template void FillDeviceArray<float>(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream);
|
template void FillDeviceArray<float>(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream);
|
||||||
template void CalSlice<float>(const size_t input_size, const float* input, const std::vector<int> in_shape,
|
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
|
||||||
const std::vector<int> begin, const std::vector<int> size, float* output,
|
const int l1, const int l2, const int l3, const int l4,
|
||||||
cudaStream_t cuda_stream);
|
const int d1, const int d2, const int d3, const int d4,
|
||||||
|
const float *input, float *output, cudaStream_t stream);
|
||||||
template void CalSliceGrad<float>(const size_t input_size, const float* dy, const std::vector<int> in_shape,
|
template void CalSliceGrad<float>(const size_t input_size, const float* dy, const std::vector<int> in_shape,
|
||||||
const std::vector<int> begin, const std::vector<int> size, float* output,
|
const std::vector<int> begin, const std::vector<int> size, float* output,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
@ -160,9 +162,10 @@ template void CalStridedSliceGrad<float>(const size_t input_size, const float* d
|
||||||
const std::vector<int> begin, const std::vector<int> end,
|
const std::vector<int> begin, const std::vector<int> end,
|
||||||
const std::vector<int> strides, float* dx, cudaStream_t cuda_stream);
|
const std::vector<int> strides, float* dx, cudaStream_t cuda_stream);
|
||||||
template void FillDeviceArray<half>(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream);
|
template void FillDeviceArray<half>(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream);
|
||||||
template void CalSlice<half>(const size_t input_size, const half* input, const std::vector<int> in_shape,
|
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
|
||||||
const std::vector<int> begin, const std::vector<int> size, half* output,
|
const int l1, const int l2, const int l3, const int l4,
|
||||||
cudaStream_t cuda_stream);
|
const int d1, const int d2, const int d3, const int d4,
|
||||||
|
const half *input, half *output, cudaStream_t stream);
|
||||||
template void CalSliceGrad<half>(const size_t input_size, const half* dy, const std::vector<int> in_shape,
|
template void CalSliceGrad<half>(const size_t input_size, const half* dy, const std::vector<int> in_shape,
|
||||||
const std::vector<int> begin, const std::vector<int> size, half* output,
|
const std::vector<int> begin, const std::vector<int> size, half* output,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
@ -173,9 +176,10 @@ template void CalStridedSliceGrad<half>(const size_t input_size, const half* dy,
|
||||||
const std::vector<int> begin, const std::vector<int> end,
|
const std::vector<int> begin, const std::vector<int> end,
|
||||||
const std::vector<int> strides, half* dx, cudaStream_t cuda_stream);
|
const std::vector<int> strides, half* dx, cudaStream_t cuda_stream);
|
||||||
template void FillDeviceArray<int>(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream);
|
template void FillDeviceArray<int>(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream);
|
||||||
template void CalSlice<int>(const size_t input_size, const int* input, const std::vector<int> in_shape,
|
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
|
||||||
const std::vector<int> begin, const std::vector<int> size, int* output,
|
const int l1, const int l2, const int l3, const int l4,
|
||||||
cudaStream_t cuda_stream);
|
const int d1, const int d2, const int d3, const int d4,
|
||||||
|
const int *input, int *output, cudaStream_t stream);
|
||||||
template void CalSliceGrad<int>(const size_t input_size, const int* dy, const std::vector<int> in_shape,
|
template void CalSliceGrad<int>(const size_t input_size, const int* dy, const std::vector<int> in_shape,
|
||||||
const std::vector<int> begin, const std::vector<int> size, int* output,
|
const std::vector<int> begin, const std::vector<int> size, int* output,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -21,9 +21,12 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "device/gpu/cuda_common.h"
|
#include "device/gpu/cuda_common.h"
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, const std::vector<int> begin,
|
void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
|
||||||
const std::vector<int> size, T* output, cudaStream_t cuda_stream);
|
const int l1, const int l2, const int l3, const int l4,
|
||||||
|
const int d1, const int d2, const int d3, const int d4,
|
||||||
|
const T *input, T *output, cudaStream_t stream);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalSliceGrad(const size_t input_size, const T* input, const std::vector<int> in_shape,
|
void CalSliceGrad(const size_t input_size, const T* input, const std::vector<int> in_shape,
|
||||||
const std::vector<int> begin, const std::vector<int> size, T* output, cudaStream_t cuda_stream);
|
const std::vector<int> begin, const std::vector<int> size, T* output, cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -43,3 +43,22 @@ def test_slice():
|
||||||
slice = Slice()
|
slice = Slice()
|
||||||
output = slice(x)
|
output = slice(x)
|
||||||
assert (output.asnumpy() == expect).all()
|
assert (output.asnumpy() == expect).all()
|
||||||
|
|
||||||
|
|
||||||
|
class SliceNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SliceNet, self).__init__()
|
||||||
|
self.slice = P.Slice()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224))
|
||||||
|
|
||||||
|
def test_slice_4d():
|
||||||
|
x_np = np.random.randn(32, 24, 224, 224).astype(np.float32)
|
||||||
|
output_np = x_np[:, 11:18, :, :]
|
||||||
|
|
||||||
|
x_ms = Tensor(x_np)
|
||||||
|
net = SliceNet()
|
||||||
|
output_ms = net(x_ms)
|
||||||
|
|
||||||
|
assert (output_ms.asnumpy() == output_np).all()
|
||||||
|
|
Loading…
Reference in New Issue