forked from mindspore-Ecosystem/mindspore
!1489 GPU fix slice
Merge pull request !1489 from VectorSL/gpu-fix-slice
This commit is contained in:
commit
02f33a17b5
|
@ -79,6 +79,9 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
if (size_[i] < 0) {
|
||||
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
|
||||
}
|
||||
if (size_[i] == 0) {
|
||||
size_[i] = begin_[i] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T);
|
||||
|
|
|
@ -47,7 +47,7 @@ __global__ void SliceGrad(const T* dy, int p, int start, int length, T* output)
|
|||
}
|
||||
template <typename T>
|
||||
__global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ((ended - 1 - begin) / stride) + 1;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast<float>(ended - begin) / stride);
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
output[p + pos] = input[start + pos * stride];
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ __global__ void StridedSlice(const T* input, int p, int start, int begin, int st
|
|||
}
|
||||
template <typename T>
|
||||
__global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ((ended - 1 - begin) / stride) + 1;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast<float>(ended - begin) / stride);
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
dx[start + pos * stride] = dy[p + pos];
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ void CalStridedSlice(const size_t input_size, const T* input, const std::vector<
|
|||
(strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3];
|
||||
StridedSlice<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, p, start, begin[3], strides[3],
|
||||
ended, output);
|
||||
p = p + (end[3] - 1 - begin[3]) / strides[3] + 1;
|
||||
p = p + std::ceil(static_cast<float>(end[3] - begin[3]) / strides[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector
|
|||
(strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3];
|
||||
StridedSliceGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(dy, p, start, begin[3], strides[3],
|
||||
ended, dx);
|
||||
p = p + (end[3] - 1 - begin[3]) / strides[3] + 1;
|
||||
p = p + std::ceil(static_cast<float>(end[3] - begin[3]) / strides[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue