!1489 GPU fix slice

Merge pull request !1489 from VectorSL/gpu-fix-slice
This commit is contained in:
mindspore-ci-bot 2020-05-27 14:20:30 +08:00 committed by Gitee
commit 02f33a17b5
2 changed files with 7 additions and 4 deletions

View File

@ -79,6 +79,9 @@ class SliceGpuFwdKernel : public GpuKernel {
if (size_[i] < 0) { if (size_[i] < 0) {
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
} }
if (size_[i] == 0) {
size_[i] = begin_[i] + 1;
}
} }
input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T);

View File

@ -47,7 +47,7 @@ __global__ void SliceGrad(const T* dy, int p, int start, int length, T* output)
} }
template <typename T> template <typename T>
__global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { __global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ((ended - 1 - begin) / stride) + 1; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast<float>(ended - begin) / stride);
pos += blockDim.x * gridDim.x) { pos += blockDim.x * gridDim.x) {
output[p + pos] = input[start + pos * stride]; output[p + pos] = input[start + pos * stride];
} }
@ -55,7 +55,7 @@ __global__ void StridedSlice(const T* input, int p, int start, int begin, int st
} }
template <typename T> template <typename T>
__global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { __global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ((ended - 1 - begin) / stride) + 1; for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast<float>(ended - begin) / stride);
pos += blockDim.x * gridDim.x) { pos += blockDim.x * gridDim.x) {
dx[start + pos * stride] = dy[p + pos]; dx[start + pos * stride] = dy[p + pos];
} }
@ -117,7 +117,7 @@ void CalStridedSlice(const size_t input_size, const T* input, const std::vector<
(strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3];
StridedSlice<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, p, start, begin[3], strides[3], StridedSlice<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, p, start, begin[3], strides[3],
ended, output); ended, output);
p = p + (end[3] - 1 - begin[3]) / strides[3] + 1; p = p + std::ceil(static_cast<float>(end[3] - begin[3]) / strides[3]);
} }
} }
} }
@ -141,7 +141,7 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector
(strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3];
StridedSliceGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(dy, p, start, begin[3], strides[3], StridedSliceGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(dy, p, start, begin[3], strides[3],
ended, dx); ended, dx);
p = p + (end[3] - 1 - begin[3]) / strides[3] + 1; p = p + std::ceil(static_cast<float>(end[3] - begin[3]) / strides[3]);
} }
} }
} }