diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h index eb829f73c6b..81910b50914 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h @@ -79,6 +79,9 @@ class SliceGpuFwdKernel : public GpuKernel { if (size_[i] < 0) { size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; } + if (size_[i] == 0) { + size_[i] = begin_[i] + 1; + } } input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu index 78a52149ae3..e49a22bb468 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu @@ -47,7 +47,7 @@ __global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) } template __global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ((ended - 1 - begin) / stride) + 1; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); pos += blockDim.x * gridDim.x) { output[p + pos] = input[start + pos * stride]; } @@ -55,7 +55,7 @@ __global__ void StridedSlice(const T* input, int p, int start, int begin, int st } template __global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ((ended - 1 - begin) / stride) + 1; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); pos += blockDim.x * gridDim.x) { dx[start + pos * stride] = dy[p + pos]; } @@ -117,7 +117,7 @@ void CalStridedSlice(const size_t input_size, const T* input, const std::vector< (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; StridedSlice<<>>(input, p, start, begin[3], strides[3], ended, output); - p = p + (end[3] - 1 - begin[3]) / strides[3] + 1; + p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); } } } @@ -141,7 +141,7 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; StridedSliceGrad<<>>(dy, p, start, begin[3], strides[3], ended, dx); - p = p + (end[3] - 1 - begin[3]) / strides[3] + 1; + p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); } } }