!8875 Fix some bug about GatherD and GatherDGrad op on gpu

From: @yuan_shen_zhou
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-24 09:46:50 +08:00 committed by Gitee
commit 6a04c21456
12 changed files with 309 additions and 93 deletions

View File

@ -18,6 +18,14 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherGpuFwdKernel, double, int)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
GatherGpuFwdKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
@ -34,5 +42,59 @@ MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
GatherGpuFwdKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherGpuFwdKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
GatherGpuFwdKernel, int, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
GatherGpuFwdKernel, int8_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
GatherGpuFwdKernel, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherGpuFwdKernel, int16_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
GatherGpuFwdKernel, int16_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
GatherGpuFwdKernel, int64_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
GatherGpuFwdKernel, int64_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherGpuFwdKernel, uchar, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
GatherGpuFwdKernel, uchar, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
GatherGpuFwdKernel, bool, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
GatherGpuFwdKernel, bool, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherGpuFwdKernel, uint32_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherGpuFwdKernel, uint32_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
GatherGpuFwdKernel, uint64_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
GatherGpuFwdKernel, uint64_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
GatherGpuFwdKernel, uint16_t, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
GatherGpuFwdKernel, uint16_t, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -41,7 +41,7 @@ class GatherGpuFwdKernel : public GpuKernel {
S *index_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2],
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3],
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -83,15 +83,17 @@ class GatherGpuFwdKernel : public GpuKernel {
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_of_index = output_shapes_[IntToSize(axis_)];
size_t dim_after_index = 1;
size_t dim_at_axis_input = input_shapes_[IntToSize(axis_)];
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
size_t dim_after_axis = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_index *= output_shapes_[i];
dim_after_axis *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_of_index;
dims_[2] = dim_after_index;
dims_[1] = dim_at_axis_input;
dims_[2] = dim_at_axis_output;
dims_[3] = dim_after_axis;
return;
}
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
@ -109,7 +111,7 @@ class GatherGpuFwdKernel : public GpuKernel {
std::vector<size_t> index_shapes_;
std::vector<size_t> output_shapes_;
size_t dims_[3] = {};
size_t dims_[4] = {};
int axis_;
cudnnHandle_t handle_;

View File

@ -18,6 +18,14 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
GatherGradGpuKernel, int, double)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
GatherGradGpuKernel, int64_t, double)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -41,7 +41,7 @@ class GatherGradGpuKernel : public GpuKernel {
S *grad_addr = GetDeviceAddress<S>(inputs, 1);
S *output_addr = GetDeviceAddress<S>(outputs, 0);
GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2],
GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3],
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -84,15 +84,17 @@ class GatherGradGpuKernel : public GpuKernel {
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_of_indices = output_shapes_[IntToSize(axis_)];
size_t dim_after_indices = 1;
size_t dim_at_axis_index = index_shapes_[IntToSize(axis_)];
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
size_t dim_after_axis = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_indices *= output_shapes_[i];
dim_after_axis *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_of_indices;
dims_[2] = dim_after_indices;
dims_[1] = dim_at_axis_index;
dims_[2] = dim_at_axis_output;
dims_[3] = dim_after_axis;
return;
}
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
@ -110,7 +112,7 @@ class GatherGradGpuKernel : public GpuKernel {
std::vector<size_t> grad_shapes_;
std::vector<size_t> output_shapes_;
size_t dims_[3] = {};
size_t dims_[4] = {};
int axis_;
cudnnHandle_t handle_;

View File

@ -18,35 +18,125 @@
#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
__global__ void GatherKernel(const T *input, const S *index, T *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2) {
size_t num = output_dim0 * output_dim1 * output_dim2;
__global__ void GatherKernel(const T *input, const S *index, T *output, const size_t dim_before_axis,
const size_t dim_at_axis_input, const size_t dim_at_axis_output,
const size_t dim_after_axis) {
size_t num = dim_before_axis * dim_at_axis_output * dim_after_axis;
size_t i, k;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num;
id += blockDim.x * gridDim.x) {
i = id / (output_dim1 * output_dim2) % output_dim0;
k = id % output_dim2;
i = id / (dim_at_axis_output * dim_after_axis);
k = id % dim_after_axis;
size_t j_read = static_cast<size_t>(index[id]);
size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k;
size_t read_id = i * dim_at_axis_input * dim_after_axis + j_read * dim_after_axis + k;
output[id] = input[read_id];
}
return;
}
template <typename T, typename S>
void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream) {
size_t size = output_dim0 * output_dim1 * output_dim2;
GatherKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, index, output, output_dim0, output_dim1,
output_dim2);
void Gather(const T *input, const S *index, T *output, const size_t dim_before_axis,
const size_t dim_at_axis_input, const size_t dim_at_axis_output,
const size_t dim_after_axis, cudaStream_t stream) {
size_t size = dim_before_axis * dim_at_axis_output * dim_after_axis;
GatherKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, index, output, dim_before_axis, dim_at_axis_input,
dim_at_axis_output, dim_after_axis);
return;
}
template void Gather<float, int>(const float *input, const int *index, float *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
template void Gather<float, int64_t>(const float *input, const int64_t *index, float *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
template void Gather<half, int>(const half *input, const int *index, half *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
template void Gather<half, int64_t>(const half *input, const int64_t *index, half *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
template void Gather<double, int>(const double *input, const int *index, double *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<double, int64_t>(const double *input, const int64_t *index, double *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<float, int>(const float *input, const int *index, float *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<float, int64_t>(const float *input, const int64_t *index, float *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<half, int>(const half *input, const int *index, half *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<half, int64_t>(const half *input, const int64_t *index, half *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int64_t, int>(const int64_t *input, const int *index, int64_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int64_t, int64_t>(const int64_t *input, const int64_t *index, int64_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int, int>(const int *input, const int *index, int *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int, int64_t>(const int *input, const int64_t *index, int *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int16_t, int>(const int16_t *input, const int *index, int16_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int16_t, int64_t>(const int16_t *input, const int64_t *index, int16_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int8_t, int>(const int8_t *input, const int *index, int8_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<int8_t, int64_t>(const int8_t *input, const int64_t *index, int8_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<unsigned char, int>(const unsigned char *input, const int *index, unsigned char *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<unsigned char, int64_t>(const unsigned char *input, const int64_t *index, unsigned char *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<bool, int>(const bool *input, const int *index, bool *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<bool, int64_t>(const bool *input, const int64_t *index, bool *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<uint16_t, int>(const uint16_t *input, const int *index, uint16_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<uint16_t, int64_t>(const uint16_t *input, const int64_t *index, uint16_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<uint32_t, int>(const uint32_t *input, const int *index, uint32_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<uint32_t, int64_t>(const uint32_t *input, const int64_t *index, uint32_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<uint64_t, int>(const uint64_t *input, const int *index, uint64_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void Gather<uint64_t, int64_t>(const uint64_t *input, const int64_t *index, uint64_t *output,
const size_t dim_before_axis, const size_t dim_at_axis_input,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);

View File

@ -17,7 +17,8 @@
#ifndef MINDSPORE_GATHER_GPU_CU_H
#define MINDSPORE_GATHER_GPU_CU_H
template <typename T, typename S>
void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
void Gather(const T *input, const S *index, T *output, const size_t dim_before_axis,
const size_t dim_at_axis_input, const size_t dim_at_axis_output,
const size_t dim_after_axis, cudaStream_t stream);
#endif

View File

@ -16,44 +16,71 @@
#include <iostream>
#include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
__global__ void GatherGradKernel(const T *index, const S *grad, S *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2) {
size_t num = output_dim0 * output_dim1 * output_dim2;
__global__ void GatherGradKernel(const size_t num, const T *index, const S *grad, S *output,
const size_t dim_before_axis, const size_t dim_at_axis_index,
const size_t dim_at_axis_output, const size_t dim_after_axis) {
size_t i, k;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num;
id += blockDim.x * gridDim.x) {
i = id / (output_dim1 * output_dim2) % output_dim0;
k = id % output_dim2;
i = id / (dim_at_axis_index * dim_after_axis);
k = id % dim_after_axis;
size_t j_read = static_cast<size_t>(index[id]);
size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k;
output[read_id] = grad[id];
size_t read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k;
MsAtomicAdd(output + read_id, grad[id]);
}
return;
}
template <typename S>
__global__ void InitOutput(const size_t size, S *output) {
S zero = 0;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) {
output[id] = zero;
}
return;
}
template <typename T, typename S>
void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream) {
size_t size = output_dim0 * output_dim1 * output_dim2;
GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(index, grad, output,
output_dim0, output_dim1, output_dim2);
void GatherGrad(const T *index, const S *grad, S *output, const size_t dim_before_axis,
const size_t dim_at_axis_index, const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream) {
size_t size = dim_before_axis * dim_at_axis_output * dim_after_axis;
InitOutput<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, output);
size = dim_before_axis * dim_at_axis_index * dim_after_axis;
GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, index, grad, output,
dim_before_axis, dim_at_axis_index,
dim_at_axis_output, dim_after_axis);
return;
}
template void GatherGrad<int, double>(const int *index, const double *grad, double *output,
const size_t dim_before_axis, const size_t dim_at_axis_index,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void GatherGrad<int64_t, double>(const int64_t *index, const double *grad, double *output,
const size_t dim_before_axis, const size_t dim_at_axis_index,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void GatherGrad<int, float>(const int *index, const float *grad, float *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
template void GatherGrad<int, half>(const int *index, const half *grad, half *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
const size_t dim_before_axis, const size_t dim_at_axis_index,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void GatherGrad<int64_t, float>(const int64_t *index, const float *grad, float *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
const size_t dim_before_axis, const size_t dim_at_axis_index,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void GatherGrad<int, half>(const int *index, const half *grad, half *output,
const size_t dim_before_axis, const size_t dim_at_axis_index,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
template void GatherGrad<int64_t, half>(const int64_t *index, const half *grad, half *output,
const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
const size_t dim_before_axis, const size_t dim_at_axis_index,
const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);

View File

@ -17,7 +17,8 @@
#ifndef MINDSPORE_GATHER_GRAD_GPU_CU_H
#define MINDSPORE_GATHER_GRAD_GPU_CU_H
template <typename T, typename S>
void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
void GatherGrad(const T *index, const S *grad, S *output, const size_t dim_before_axis,
const size_t dim_at_axis_index, const size_t dim_at_axis_output, const size_t dim_after_axis,
cudaStream_t stream);
#endif

View File

@ -19,6 +19,18 @@
#include <cuda_fp16.h>
__device__ static inline double MsAtomicAdd(double *address, const double val) {
unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT
unsigned long long int old = *address_as_ull; // NOLINT
unsigned long long int assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
}
while (assumed != old); // NOLINT
return __longlong_as_double(old);
}
__device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); }
__device__ static inline int MsAtomicAdd(int *address, int val) { return atomicAdd(address, val); }

View File

@ -396,16 +396,11 @@ def get_bprop_gather_v2(self):
@bprop_getters.register(P.GatherD)
def get_bprop_gather_d(self):
"""Generate bprop for GatherD"""
gather_d = P.GatherD()
def bprop(x, dim, index, out, dout):
return P.GatherDGrad(dim)(index, dout)
def bprop_ascend(x, dim, index, out, dout):
return (gather_d(dout, dim, index), zeros_like(dim), zeros_like(index))
if context.get_context('device_target') == 'Ascend':
return bprop_ascend
x_shp = shape_op(x)
dx = G.GatherDGrad(dim, x_shp)(index, dout)
return dx, zeros_like(dim), zeros_like(index)
return bprop

View File

@ -1385,14 +1385,15 @@ class GatherDGrad(PrimitiveWithInfer):
"""Performs grad of GatherD operation."""
@prim_attr_register
def __init__(self, dim=0):
def __init__(self, dim=0, shape=None):
"""Initialize GatherDGrad"""
validator.check_is_int(dim, int)
self.add_prim_attr("dim", dim)
self.out_shape = shape
self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output'])
def infer_shape(self, index_shape, grad_shape):
return grad_shape
return self.out_shape
def infer_dtype(self, index_dtype, grad_dtype):
return grad_dtype

View File

@ -19,32 +19,37 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops.operations._grad_ops as P
import mindspore.ops.operations as P
import mindspore.ops.operations._grad_ops as G
from mindspore.ops.composite import GradOperation
from mindspore import Tensor
class GatherDGradNet(nn.Cell):
class GatherDNet(nn.Cell):
def __init__(self, dim=0):
super(GatherDGradNet, self).__init__()
self.gather_d_grad = P.GatherDGrad(dim)
super(GatherDNet, self).__init__()
self.gather_d = P.GatherD()
self.dim = dim
def construct(self, index, grad):
return self.gather_d_grad(index, grad)
def construct(self, x, index):
return self.gather_d(x, self.dim, index)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather_grad_graph_int32_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float32)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
net = GatherDGradNet(dim)
output = net(index, grad)
net = GatherDNet(dim)
grad_net = GradOperation(get_all=True, sens_param=True)(net)
output = grad_net(x, index, grad)
error = 1e-4
diff = output.asnumpy() - expect
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@ -52,16 +57,18 @@ def test_gather_grad_graph_int32_fp32():
@pytest.mark.env_onecard
def test_gather_grad_graph_int64_fp32():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float32)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
net = GatherDGradNet(dim)
output = net(index, grad)
net = GatherDNet(dim)
grad_net = GradOperation(get_all=True, sens_param=True)(net)
output = grad_net(x, index, grad)
error = 1e-4
diff = output.asnumpy() - expect
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@ -69,16 +76,18 @@ def test_gather_grad_graph_int64_fp32():
@pytest.mark.env_onecard
def test_gather_grad_graph_int32_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float16)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
net = GatherDGradNet(dim)
output = net(index, grad)
net = GatherDNet(dim)
grad_net = GradOperation(get_all=True, sens_param=True)(net)
output = grad_net(x, index, grad)
error = 1e-4
diff = output.asnumpy() - expect
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@ -86,16 +95,18 @@ def test_gather_grad_graph_int32_fp16():
@pytest.mark.env_onecard
def test_gather_grad_graph_int64_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float16)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
net = GatherDGradNet(dim)
output = net(index, grad)
net = GatherDNet(dim)
grad_net = GradOperation(get_all=True, sens_param=True)(net)
output = grad_net(x, index, grad)
error = 1e-4
diff = output.asnumpy() - expect
diff = output[0].asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@ -103,13 +114,14 @@ def test_gather_grad_graph_int64_fp16():
@pytest.mark.env_onecard
def test_gather_grad_pynative_int32_fp32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_shape = (2, 5)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
output = P.GatherDGrad(dim)(index, grad)
output = G.GatherDGrad(dim, x_shape)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@ -119,13 +131,14 @@ def test_gather_grad_pynative_int32_fp32():
@pytest.mark.env_onecard
def test_gather_grad_pynative_int64_fp32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_shape = (2, 5)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
output = P.GatherDGrad(dim)(index, grad)
output = G.GatherDGrad(dim, x_shape)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@ -135,13 +148,14 @@ def test_gather_grad_pynative_int64_fp32():
@pytest.mark.env_onecard
def test_gather_grad_pynative_int32_fp16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_shape = (2, 5)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
output = P.GatherDGrad(dim)(index, grad)
output = G.GatherDGrad(dim, x_shape)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)
@ -151,13 +165,14 @@ def test_gather_grad_pynative_int32_fp16():
@pytest.mark.env_onecard
def test_gather_grad_pynative_int64_fp16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_shape = (2, 5)
dim = 0
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
output = P.GatherDGrad(dim)(index, grad)
output = G.GatherDGrad(dim, x_shape)(index, grad)
error = 1e-4
diff = output.asnumpy() - expect
assert np.all(diff < error)