!4074 fix cumsum bug

Merge pull request !4074 from baihuawei/0806
This commit is contained in:
mindspore-ci-bot 2020-08-10 20:37:01 +08:00 committed by Gitee
commit 4554a80807
4 changed files with 113 additions and 12 deletions

View File

@ -18,12 +18,85 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
__global__ void Copy(T *input, T *output, size_t size) {
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) {
input[write_index] = output[write_index];
}
}
template <typename T>
__global__ void LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
for (size_t j = 0; j < dim1; ++j) {
size_t read_index = j * stride2 + offset;
if (j == 0) {
output[read_index] = 0;
} else {
size_t read_index2 = (j - 1) * stride2 + offset;
output[read_index] = input[read_index2];
}
}
}
}
template <typename T>
__global__ void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
for (int j = dim1 - 1; j >= 0; --j) {
size_t read_index = j * stride2 + offset;
if (j == dim1 - 1) {
output[read_index] = 0;
} else {
size_t read_index2 = (j + 1) * stride2 + offset;
output[read_index] = input[read_index2];
}
}
}
}
template <typename T>
__global__ void CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
for (int j = dim1 - 1; j >= 0; --j) {
size_t read_index = j * stride2 + offset;
if (j == dim1 - 1) {
output[read_index] = input[read_index];
} else {
size_t read_index2 = (j + 1) * stride2 + offset;
output[read_index] = output[read_index2] + input[read_index];
}
}
}
}
template <typename T>
__global__ void CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num;
write_index += blockDim.x * gridDim.x) {
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
@ -39,12 +112,32 @@ __global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size
}
}
template <typename T>
void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
cudaStream_t stream) {
void CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream) {
int size = dim0 * dim2;
CumSumKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
if (exclusive_) {
if (reverse_) {
RightMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
CumSumKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride,
stride2);
} else {
LeftMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
CumSumKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride, stride2);
}
} else {
if (reverse_) {
CumSumKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride,
stride2);
} else {
CumSumKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
}
}
return;
}
template void CumSum<float>(float *input, float *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2, cudaStream_t stream);
template void CumSum<float>(const float *input, float *output, float *workspace, size_t dim0, size_t dim1, size_t dim2,
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
template void CumSum<half>(const half *input, half *output, half *workspace, size_t dim0, size_t dim1, size_t dim2,
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);

View File

@ -17,6 +17,6 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
template <typename T>
void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
cudaStream_t stream);
void CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_

View File

@ -20,5 +20,7 @@ namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CumSumGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CumSumGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -27,7 +27,7 @@ namespace kernel {
template <typename T>
class CumSumGpuKernel : public GpuKernel {
public:
CumSumGpuKernel() : axis_(0), input_size_0_(0), stride_(0), stride2_(0) {}
CumSumGpuKernel() : exclusive_(false), reverse_(false), axis_(0), input_size_0_(0), stride_(0), stride2_(0) {}
~CumSumGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -38,7 +38,8 @@ class CumSumGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
CumSum(input_addr, output_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_,
T *ws_addr = GetDeviceAddress<T>(workspace, 0);
CumSum(input_addr, output_addr, ws_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_, exclusive_, reverse_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -51,6 +52,8 @@ class CumSumGpuKernel : public GpuKernel {
input_size_0_ = sizeof(T);
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
axis_ = GetAttr<int>(kernel_node, "axis");
exclusive_ = GetAttr<bool>(kernel_node, "exclusive");
reverse_ = GetAttr<bool>(kernel_node, "reverse");
int input_dim_length = SizeToInt(shape_.size());
if (axis_ >= input_dim_length) {
MS_LOG(EXCEPTION) << "Axis out of bounds.";
@ -70,6 +73,7 @@ class CumSumGpuKernel : public GpuKernel {
void InitSizeLists() override {
input_size_list_.push_back(input_size_0_);
output_size_list_.push_back(input_size_0_);
workspace_size_list_.push_back(input_size_0_);
}
private:
@ -87,6 +91,8 @@ class CumSumGpuKernel : public GpuKernel {
stride2_ = dims_[2];
return;
}
bool exclusive_;
bool reverse_;
int axis_;
size_t input_size_0_;
size_t stride_;