forked from mindspore-Ecosystem/mindspore
!8875 Fix some bug about GatherD and GatherDGrad op on gpu
From: @yuan_shen_zhou Reviewed-by: Signed-off-by:
This commit is contained in:
commit
6a04c21456
|
@ -18,6 +18,14 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherGpuFwdKernel, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherGpuFwdKernel, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
@ -34,5 +42,59 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
GatherD,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherGpuFwdKernel, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherGpuFwdKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherGpuFwdKernel, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherGpuFwdKernel, int8_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherGpuFwdKernel, int8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherGpuFwdKernel, int16_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherGpuFwdKernel, int16_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherGpuFwdKernel, int64_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherGpuFwdKernel, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherGpuFwdKernel, uchar, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherGpuFwdKernel, uchar, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
GatherGpuFwdKernel, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
GatherGpuFwdKernel, bool, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherGpuFwdKernel, uint32_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherGpuFwdKernel, uint32_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
GatherGpuFwdKernel, uint64_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
GatherGpuFwdKernel, uint64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
GatherGpuFwdKernel, uint16_t, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
GatherGpuFwdKernel, uint16_t, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,7 +41,7 @@ class GatherGpuFwdKernel : public GpuKernel {
|
|||
S *index_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2],
|
||||
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3],
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -83,15 +83,17 @@ class GatherGpuFwdKernel : public GpuKernel {
|
|||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
}
|
||||
size_t dim_of_index = output_shapes_[IntToSize(axis_)];
|
||||
size_t dim_after_index = 1;
|
||||
size_t dim_at_axis_input = input_shapes_[IntToSize(axis_)];
|
||||
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
|
||||
size_t dim_after_axis = 1;
|
||||
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
|
||||
dim_after_index *= output_shapes_[i];
|
||||
dim_after_axis *= output_shapes_[i];
|
||||
}
|
||||
|
||||
dims_[0] = dim_before_axis;
|
||||
dims_[1] = dim_of_index;
|
||||
dims_[2] = dim_after_index;
|
||||
dims_[1] = dim_at_axis_input;
|
||||
dims_[2] = dim_at_axis_output;
|
||||
dims_[3] = dim_after_axis;
|
||||
return;
|
||||
}
|
||||
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
|
||||
|
@ -109,7 +111,7 @@ class GatherGpuFwdKernel : public GpuKernel {
|
|||
std::vector<size_t> index_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
|
||||
size_t dims_[3] = {};
|
||||
size_t dims_[4] = {};
|
||||
int axis_;
|
||||
cudnnHandle_t handle_;
|
||||
|
||||
|
|
|
@ -18,6 +18,14 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherGradGpuKernel, int, double)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherGradGpuKernel, int64_t, double)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -41,7 +41,7 @@ class GatherGradGpuKernel : public GpuKernel {
|
|||
S *grad_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
S *output_addr = GetDeviceAddress<S>(outputs, 0);
|
||||
|
||||
GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2],
|
||||
GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2], dims_[3],
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -84,15 +84,17 @@ class GatherGradGpuKernel : public GpuKernel {
|
|||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
}
|
||||
size_t dim_of_indices = output_shapes_[IntToSize(axis_)];
|
||||
size_t dim_after_indices = 1;
|
||||
size_t dim_at_axis_index = index_shapes_[IntToSize(axis_)];
|
||||
size_t dim_at_axis_output = output_shapes_[IntToSize(axis_)];
|
||||
size_t dim_after_axis = 1;
|
||||
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
|
||||
dim_after_indices *= output_shapes_[i];
|
||||
dim_after_axis *= output_shapes_[i];
|
||||
}
|
||||
|
||||
dims_[0] = dim_before_axis;
|
||||
dims_[1] = dim_of_indices;
|
||||
dims_[2] = dim_after_indices;
|
||||
dims_[1] = dim_at_axis_index;
|
||||
dims_[2] = dim_at_axis_output;
|
||||
dims_[3] = dim_after_axis;
|
||||
return;
|
||||
}
|
||||
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
|
||||
|
@ -110,7 +112,7 @@ class GatherGradGpuKernel : public GpuKernel {
|
|||
std::vector<size_t> grad_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
|
||||
size_t dims_[3] = {};
|
||||
size_t dims_[4] = {};
|
||||
int axis_;
|
||||
cudnnHandle_t handle_;
|
||||
|
||||
|
|
|
@ -18,35 +18,125 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherKernel(const T *input, const S *index, T *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2) {
|
||||
size_t num = output_dim0 * output_dim1 * output_dim2;
|
||||
__global__ void GatherKernel(const T *input, const S *index, T *output, const size_t dim_before_axis,
|
||||
const size_t dim_at_axis_input, const size_t dim_at_axis_output,
|
||||
const size_t dim_after_axis) {
|
||||
size_t num = dim_before_axis * dim_at_axis_output * dim_after_axis;
|
||||
size_t i, k;
|
||||
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num;
|
||||
id += blockDim.x * gridDim.x) {
|
||||
i = id / (output_dim1 * output_dim2) % output_dim0;
|
||||
k = id % output_dim2;
|
||||
i = id / (dim_at_axis_output * dim_after_axis);
|
||||
k = id % dim_after_axis;
|
||||
|
||||
size_t j_read = static_cast<size_t>(index[id]);
|
||||
size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k;
|
||||
size_t read_id = i * dim_at_axis_input * dim_after_axis + j_read * dim_after_axis + k;
|
||||
output[id] = input[read_id];
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T, typename S>
|
||||
void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1,
|
||||
const size_t output_dim2, cudaStream_t stream) {
|
||||
size_t size = output_dim0 * output_dim1 * output_dim2;
|
||||
GatherKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, index, output, output_dim0, output_dim1,
|
||||
output_dim2);
|
||||
void Gather(const T *input, const S *index, T *output, const size_t dim_before_axis,
|
||||
const size_t dim_at_axis_input, const size_t dim_at_axis_output,
|
||||
const size_t dim_after_axis, cudaStream_t stream) {
|
||||
size_t size = dim_before_axis * dim_at_axis_output * dim_after_axis;
|
||||
GatherKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, index, output, dim_before_axis, dim_at_axis_input,
|
||||
dim_at_axis_output, dim_after_axis);
|
||||
return;
|
||||
}
|
||||
|
||||
template void Gather<float, int>(const float *input, const int *index, float *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
|
||||
template void Gather<float, int64_t>(const float *input, const int64_t *index, float *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
|
||||
template void Gather<half, int>(const half *input, const int *index, half *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
|
||||
template void Gather<half, int64_t>(const half *input, const int64_t *index, half *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
|
||||
template void Gather<double, int>(const double *input, const int *index, double *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<double, int64_t>(const double *input, const int64_t *index, double *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<float, int>(const float *input, const int *index, float *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<float, int64_t>(const float *input, const int64_t *index, float *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<half, int>(const half *input, const int *index, half *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<half, int64_t>(const half *input, const int64_t *index, half *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int64_t, int>(const int64_t *input, const int *index, int64_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int64_t, int64_t>(const int64_t *input, const int64_t *index, int64_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int, int>(const int *input, const int *index, int *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int, int64_t>(const int *input, const int64_t *index, int *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int16_t, int>(const int16_t *input, const int *index, int16_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int16_t, int64_t>(const int16_t *input, const int64_t *index, int16_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int8_t, int>(const int8_t *input, const int *index, int8_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<int8_t, int64_t>(const int8_t *input, const int64_t *index, int8_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<unsigned char, int>(const unsigned char *input, const int *index, unsigned char *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<unsigned char, int64_t>(const unsigned char *input, const int64_t *index, unsigned char *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<bool, int>(const bool *input, const int *index, bool *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<bool, int64_t>(const bool *input, const int64_t *index, bool *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<uint16_t, int>(const uint16_t *input, const int *index, uint16_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<uint16_t, int64_t>(const uint16_t *input, const int64_t *index, uint16_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<uint32_t, int>(const uint32_t *input, const int *index, uint32_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<uint32_t, int64_t>(const uint32_t *input, const int64_t *index, uint32_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<uint64_t, int>(const uint64_t *input, const int *index, uint64_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void Gather<uint64_t, int64_t>(const uint64_t *input, const int64_t *index, uint64_t *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_input,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
#ifndef MINDSPORE_GATHER_GPU_CU_H
|
||||
#define MINDSPORE_GATHER_GPU_CU_H
|
||||
template <typename T, typename S>
|
||||
void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1,
|
||||
const size_t output_dim2, cudaStream_t stream);
|
||||
void Gather(const T *input, const S *index, T *output, const size_t dim_before_axis,
|
||||
const size_t dim_at_axis_input, const size_t dim_at_axis_output,
|
||||
const size_t dim_after_axis, cudaStream_t stream);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -16,44 +16,71 @@
|
|||
|
||||
#include <iostream>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherGradKernel(const T *index, const S *grad, S *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2) {
|
||||
size_t num = output_dim0 * output_dim1 * output_dim2;
|
||||
__global__ void GatherGradKernel(const size_t num, const T *index, const S *grad, S *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_index,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis) {
|
||||
size_t i, k;
|
||||
|
||||
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num;
|
||||
id += blockDim.x * gridDim.x) {
|
||||
i = id / (output_dim1 * output_dim2) % output_dim0;
|
||||
k = id % output_dim2;
|
||||
i = id / (dim_at_axis_index * dim_after_axis);
|
||||
k = id % dim_after_axis;
|
||||
|
||||
size_t j_read = static_cast<size_t>(index[id]);
|
||||
size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k;
|
||||
output[read_id] = grad[id];
|
||||
size_t read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k;
|
||||
MsAtomicAdd(output + read_id, grad[id]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void InitOutput(const size_t size, S *output) {
|
||||
S zero = 0;
|
||||
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) {
|
||||
output[id] = zero;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream) {
|
||||
size_t size = output_dim0 * output_dim1 * output_dim2;
|
||||
GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(index, grad, output,
|
||||
output_dim0, output_dim1, output_dim2);
|
||||
void GatherGrad(const T *index, const S *grad, S *output, const size_t dim_before_axis,
|
||||
const size_t dim_at_axis_index, const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream) {
|
||||
size_t size = dim_before_axis * dim_at_axis_output * dim_after_axis;
|
||||
InitOutput<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, output);
|
||||
|
||||
size = dim_before_axis * dim_at_axis_index * dim_after_axis;
|
||||
GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(size, index, grad, output,
|
||||
dim_before_axis, dim_at_axis_index,
|
||||
dim_at_axis_output, dim_after_axis);
|
||||
return;
|
||||
}
|
||||
|
||||
template void GatherGrad<int, double>(const int *index, const double *grad, double *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_index,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void GatherGrad<int64_t, double>(const int64_t *index, const double *grad, double *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_index,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void GatherGrad<int, float>(const int *index, const float *grad, float *output,
|
||||
const size_t output_dim0, const size_t output_dim1,
|
||||
const size_t output_dim2, cudaStream_t stream);
|
||||
|
||||
template void GatherGrad<int, half>(const int *index, const half *grad, half *output,
|
||||
const size_t output_dim0, const size_t output_dim1,
|
||||
const size_t output_dim2, cudaStream_t stream);
|
||||
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_index,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void GatherGrad<int64_t, float>(const int64_t *index, const float *grad, float *output,
|
||||
const size_t output_dim0, const size_t output_dim1,
|
||||
const size_t output_dim2, cudaStream_t stream);
|
||||
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_index,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void GatherGrad<int, half>(const int *index, const half *grad, half *output,
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_index,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
template void GatherGrad<int64_t, half>(const int64_t *index, const half *grad, half *output,
|
||||
const size_t output_dim0, const size_t output_dim1,
|
||||
const size_t output_dim2, cudaStream_t stream);
|
||||
const size_t dim_before_axis, const size_t dim_at_axis_index,
|
||||
const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
#ifndef MINDSPORE_GATHER_GRAD_GPU_CU_H
|
||||
#define MINDSPORE_GATHER_GRAD_GPU_CU_H
|
||||
template <typename T, typename S>
|
||||
void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0,
|
||||
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
|
||||
void GatherGrad(const T *index, const S *grad, S *output, const size_t dim_before_axis,
|
||||
const size_t dim_at_axis_index, const size_t dim_at_axis_output, const size_t dim_after_axis,
|
||||
cudaStream_t stream);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -19,6 +19,18 @@
|
|||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
__device__ static inline double MsAtomicAdd(double *address, const double val) {
|
||||
unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT
|
||||
unsigned long long int old = *address_as_ull; // NOLINT
|
||||
unsigned long long int assumed; // NOLINT
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
|
||||
}
|
||||
while (assumed != old); // NOLINT
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
|
||||
__device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); }
|
||||
|
||||
__device__ static inline int MsAtomicAdd(int *address, int val) { return atomicAdd(address, val); }
|
||||
|
|
|
@ -396,16 +396,11 @@ def get_bprop_gather_v2(self):
|
|||
@bprop_getters.register(P.GatherD)
|
||||
def get_bprop_gather_d(self):
|
||||
"""Generate bprop for GatherD"""
|
||||
gather_d = P.GatherD()
|
||||
|
||||
def bprop(x, dim, index, out, dout):
|
||||
return P.GatherDGrad(dim)(index, dout)
|
||||
|
||||
def bprop_ascend(x, dim, index, out, dout):
|
||||
return (gather_d(dout, dim, index), zeros_like(dim), zeros_like(index))
|
||||
|
||||
if context.get_context('device_target') == 'Ascend':
|
||||
return bprop_ascend
|
||||
x_shp = shape_op(x)
|
||||
dx = G.GatherDGrad(dim, x_shp)(index, dout)
|
||||
return dx, zeros_like(dim), zeros_like(index)
|
||||
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -1385,14 +1385,15 @@ class GatherDGrad(PrimitiveWithInfer):
|
|||
"""Performs grad of GatherD operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dim=0):
|
||||
def __init__(self, dim=0, shape=None):
|
||||
"""Initialize GatherDGrad"""
|
||||
validator.check_is_int(dim, int)
|
||||
self.add_prim_attr("dim", dim)
|
||||
self.out_shape = shape
|
||||
self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, index_shape, grad_shape):
|
||||
return grad_shape
|
||||
return self.out_shape
|
||||
|
||||
def infer_dtype(self, index_dtype, grad_dtype):
|
||||
return grad_dtype
|
||||
|
|
|
@ -19,32 +19,37 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
import mindspore.ops.operations._grad_ops as P
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.ops.operations._grad_ops as G
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore import Tensor
|
||||
|
||||
class GatherDGradNet(nn.Cell):
|
||||
class GatherDNet(nn.Cell):
|
||||
def __init__(self, dim=0):
|
||||
super(GatherDGradNet, self).__init__()
|
||||
self.gather_d_grad = P.GatherDGrad(dim)
|
||||
super(GatherDNet, self).__init__()
|
||||
self.gather_d = P.GatherD()
|
||||
self.dim = dim
|
||||
|
||||
def construct(self, index, grad):
|
||||
return self.gather_d_grad(index, grad)
|
||||
def construct(self, x, index):
|
||||
return self.gather_d(x, self.dim, index)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_graph_int32_fp32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float32)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
|
||||
net = GatherDGradNet(dim)
|
||||
output = net(index, grad)
|
||||
net = GatherDNet(dim)
|
||||
grad_net = GradOperation(get_all=True, sens_param=True)(net)
|
||||
output = grad_net(x, index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
diff = output[0].asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -52,16 +57,18 @@ def test_gather_grad_graph_int32_fp32():
|
|||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_graph_int64_fp32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float32)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
|
||||
net = GatherDGradNet(dim)
|
||||
output = net(index, grad)
|
||||
net = GatherDNet(dim)
|
||||
grad_net = GradOperation(get_all=True, sens_param=True)(net)
|
||||
output = grad_net(x, index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
diff = output[0].asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -69,16 +76,18 @@ def test_gather_grad_graph_int64_fp32():
|
|||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_graph_int32_fp16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float16)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
|
||||
net = GatherDGradNet(dim)
|
||||
output = net(index, grad)
|
||||
net = GatherDNet(dim)
|
||||
grad_net = GradOperation(get_all=True, sens_param=True)(net)
|
||||
output = grad_net(x, index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
diff = output[0].asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -86,16 +95,18 @@ def test_gather_grad_graph_int32_fp16():
|
|||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_graph_int64_fp16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor(np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), ms.float16)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
|
||||
net = GatherDGradNet(dim)
|
||||
output = net(index, grad)
|
||||
net = GatherDNet(dim)
|
||||
grad_net = GradOperation(get_all=True, sens_param=True)(net)
|
||||
output = grad_net(x, index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
diff = output[0].asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -103,13 +114,14 @@ def test_gather_grad_graph_int64_fp16():
|
|||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_pynative_int32_fp32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x_shape = (2, 5)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
|
||||
output = P.GatherDGrad(dim)(index, grad)
|
||||
output = G.GatherDGrad(dim, x_shape)(index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
@ -119,13 +131,14 @@ def test_gather_grad_pynative_int32_fp32():
|
|||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_pynative_int64_fp32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x_shape = (2, 5)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32)
|
||||
output = P.GatherDGrad(dim)(index, grad)
|
||||
output = G.GatherDGrad(dim, x_shape)(index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
@ -135,13 +148,14 @@ def test_gather_grad_pynative_int64_fp32():
|
|||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_pynative_int32_fp16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x_shape = (2, 5)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
|
||||
output = P.GatherDGrad(dim)(index, grad)
|
||||
output = G.GatherDGrad(dim, x_shape)(index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
@ -151,13 +165,14 @@ def test_gather_grad_pynative_int32_fp16():
|
|||
@pytest.mark.env_onecard
|
||||
def test_gather_grad_pynative_int64_fp16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x_shape = (2, 5)
|
||||
dim = 0
|
||||
index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64)
|
||||
grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710],
|
||||
[0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16)
|
||||
expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710],
|
||||
[0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16)
|
||||
output = P.GatherDGrad(dim)(index, grad)
|
||||
output = G.GatherDGrad(dim, x_shape)(index, grad)
|
||||
error = 1e-4
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
|
Loading…
Reference in New Issue