!49182 Fix SparseAddGrad bug

Merge pull request !49182 from YijieChen/ops
This commit is contained in:
i-robot 2023-02-23 09:53:14 +00:00 committed by Gitee
commit 0ba09ac448
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 148 additions and 131 deletions

View File

@ -67,6 +67,7 @@ int SparseAddGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
ResetResource();
auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
indices_column_ = inputs.at(1)->GetShapeVector()[1];
if (ret == KRET_UNKNOWN_OUT_SHAPE) {
if (input_size_list_.size() != kInputNum) {
MS_LOG(ERROR) << "Input size list should be " << kInputNum << ", but got " << input_size_list_.size();
@ -103,6 +104,26 @@ int SparseAddGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
return ret;
}
template <typename T, typename S>
int SparseAddGradCpuKernelMod::CompareTwoIndices(const T &a_indices, const T &b_indices, const S *backprop_value,
int64_t *a_row, const int64_t b_row, const size_t dims, S *dx_value,
bool *idx_geq) {
for (int64_t dim = 0; dim < SizeToLong(dims); dim++) {
auto a_idx = a_indices[*a_row * dims + dim];
auto b_idx = b_indices[b_row * dims + dim];
if (a_idx < b_idx) {
*idx_geq = false;
*a_row += 1;
return -1;
} else if (a_idx > b_idx) {
return 1;
}
}
dx_value[*a_row] = backprop_value[b_row];
*a_row += 1;
return 0;
}
template <typename T, typename S>
bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
@ -124,37 +145,40 @@ bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
auto dx1 = reinterpret_cast<T *>(outputs[kDx1Idx]->addr);
auto dx2 = reinterpret_cast<T *>(outputs[kDx2Idx]->addr);
const int64_t x1_indices_num = inputs[kX1IndicesIdx]->size / (sizeof(S) * 2);
const int64_t x2_indices_num = inputs[kX2IndicesIdx]->size / (sizeof(S) * 2);
const int64_t out_indices_num = inputs[kOutIndicesIdx]->size / (sizeof(S) * 2);
const int64_t x1_indices_num = inputs[kX1IndicesIdx]->size / (sizeof(S) * indices_column_);
const int64_t x2_indices_num = inputs[kX2IndicesIdx]->size / (sizeof(S) * indices_column_);
const int64_t out_indices_num = inputs[kOutIndicesIdx]->size / (sizeof(S) * indices_column_);
auto arrayHash = [fn = std::hash<int>{}](const std::array<int, 2> &arr) -> size_t {
return std::accumulate(arr.begin(), arr.end(), 0u, [&](size_t acc, int num) { return (acc << 1) ^ fn(num); });
};
memset(dx1, 0, sizeof(T) * x1_indices_num);
memset(dx2, 0, sizeof(T) * x2_indices_num);
constexpr int dimension_difference = 2;
std::unordered_map<std::array<int, 2>, int, decltype(arrayHash)> out_map(0, arrayHash);
for (int i = 0; i < out_indices_num * dimension_difference; i += dimension_difference) {
std::array<int, 2> index{};
index[0] = out_indices[i];
index[1] = out_indices[i + 1];
out_map[index] = static_cast<int>(i / dimension_difference);
}
int64_t i = 0;
int64_t j = 0;
int64_t k = 0;
bool a_idx_geq;
bool b_idx_geq;
for (int i = 0; i < x1_indices_num * dimension_difference; i += dimension_difference) {
std::array<int, 2> index{};
index[0] = x1_indices[i];
index[1] = x1_indices[i + 1];
if (out_map.find(index) != out_map.end()) {
dx1[static_cast<int>(i / dimension_difference)] = dout[out_map[index]];
while (i < x1_indices_num && j < x2_indices_num && k < out_indices_num) {
a_idx_geq = b_idx_geq = true;
CompareTwoIndices(x1_indices, out_indices, dout, &i, k, indices_column_, dx1, &a_idx_geq);
CompareTwoIndices(x2_indices, out_indices, dout, &j, k, indices_column_, dx2, &b_idx_geq);
if (a_idx_geq && b_idx_geq) {
k += 1;
}
}
for (int i = 0; i < x2_indices_num * dimension_difference; i += dimension_difference) {
std::array<int, 2> index{};
index[0] = x2_indices[i];
index[1] = x2_indices[i + 1];
if (out_map.find(index) != out_map.end()) {
dx2[static_cast<int>(i / dimension_difference)] = dout[out_map[index]];
while (i < x1_indices_num && k < out_indices_num) {
a_idx_geq = true;
CompareTwoIndices(x1_indices, out_indices, dout, &i, k, indices_column_, dx1, &a_idx_geq);
if (a_idx_geq) {
k += 1;
}
}
while (j < x2_indices_num && k < out_indices_num) {
b_idx_geq = true;
CompareTwoIndices(x2_indices, out_indices, dout, &j, k, indices_column_, dx2, &b_idx_geq);
if (b_idx_geq) {
k += 1;
}
}

View File

@ -51,11 +51,15 @@ class SparseAddGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelH
template <typename T, typename S>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
template <typename T, typename S>
int CompareTwoIndices(const T &a_indices, const T &b_indices, const S *backprop_value, int64_t *a_row,
const int64_t b_row, const size_t dims, S *dx_value, bool *idx_geq);
std::vector<size_t> dout_shape_;
std::vector<size_t> x1_indices_shape_;
std::vector<size_t> x2_indices_shape_;
std::vector<size_t> out_indices_shape_;
size_t indices_column_ = 0;
};
} // namespace kernel
} // namespace mindspore

View File

@ -80,6 +80,7 @@ class SparseAddGradHelperGpuKernel : public GpuKernelHelperBase {
input_size_list_.push_back(sum_indices_size_ * index_bytes_);
output_size_list_.push_back(x1_index_num_ * value_bytes_);
output_size_list_.push_back(x2_index_num_ * value_bytes_);
work_size_list_.push_back(dim_ * index_bytes_);
return 0;
}
@ -95,6 +96,7 @@ class SparseAddGradHelperGpuKernel : public GpuKernelHelperBase {
T *out_indices_ptr = nullptr;
S *dx1_ptr = nullptr;
S *dx2_ptr = nullptr;
T *temp_save_ptr = nullptr;
int flag = GetDeviceAddress<S>(input_ptrs, kSparseAddGradIndex0, kernel_name_, &dout_ptr);
if (flag != 0) {
return flag;
@ -124,13 +126,19 @@ class SparseAddGradHelperGpuKernel : public GpuKernelHelperBase {
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(work_ptrs, kSparseAddGradIndex0, kernel_name_, &temp_save_ptr);
if (flag != 0) {
return flag;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream)),
"For SparseAddGrad, cudaStreamSynchronize failed.");
// call cuda kernel
MS_LOG(INFO) << "For SparseAddGrad, x1_index_num_, x2_index_num_, sum_index_num_, dim_ " << x1_index_num_ << ", "
<< x2_index_num_ << ", " << sum_index_num_ << ", " << dim_;
CalSparseAddGrad(dout_ptr, x1_indices_ptr, x1_index_num_, x2_indices_ptr, x2_index_num_, out_indices_ptr,
sum_index_num_, dx1_ptr, dx2_ptr, dim_, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
sum_index_num_, temp_save_ptr, dx1_ptr, dx2_ptr, dim_, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream)),
"For SparseAddGrad, cudaStreamSynchronize failed.");
return 0;

View File

@ -18,122 +18,104 @@
#include <algorithm>
template <typename T, typename S>
__global__ void SparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size,
const T *out_indices, size_t out_size, S *dx1, S *dx2, size_t dim, S init_val) {
size_t stride = gridDim.x * blockDim.x;
size_t threadId = blockIdx.x * blockDim.x + threadIdx.x;
size_t x1_idx = threadId;
while (x1_idx < x1_size) {
size_t idx = x1_idx * dim;
auto x_idx = x1_indices[idx];
auto y_idx = x1_indices[idx + 1];
size_t catch_x1_i = 0;
for (size_t j = 0; j < x1_size; j++) {
auto oj = j * dim;
if (x1_indices[oj] == x_idx && x1_indices[oj + 1] == y_idx) {
if (x1_idx == j) {
break;
} else {
catch_x1_i += 1;
}
}
}
S val = init_val;
size_t same_x1_i = 0;
for (size_t i = 0; i < out_size; i++) {
auto oi = i * dim;
if (out_indices[oi] == x_idx && out_indices[oi + 1] == y_idx) {
if (same_x1_i == catch_x1_i) {
val = dout[i];
break;
} else {
same_x1_i += 1;
}
}
}
dx1[x1_idx] = val;
x1_idx += stride;
const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim,
S init_val) {
size_t stride = gridDim.x * blockDim.x;
size_t threadId = blockIdx.x * blockDim.x + threadIdx.x;
size_t x1_idx = threadId;
memset(dx1, 0, sizeof(T) * x1_size);
memset(dx2, 0, sizeof(T) * x2_size);
while (x1_idx < x1_size) {
size_t idx = x1_idx * dim;
for (size_t i = 0; i < dim; i++) {
temp_save_ptr[i] = x1_indices[idx + i];
}
for (size_t i = 0; i < out_size; i++) {
auto oi = i * dim;
bool same_flag = true;
for (size_t j = 0; j < dim; j++) {
if (temp_save_ptr[j] != out_indices[oi + j]) {
same_flag = false;
break;
}
}
if (same_flag) {
dx1[x1_idx] = dout[i];
break;
}
}
x1_idx += stride;
}
size_t x2_idx = threadId;
while (x2_idx < x2_size) {
size_t idx = x2_idx * dim;
auto x_idx = x2_indices[idx];
auto y_idx = x2_indices[idx + 1];
size_t catch_x2_i = 0;
for (size_t j = 0; j < x2_size; j++) {
auto oj = j * dim;
if (x2_indices[oj] == x_idx && x2_indices[oj + 1] == y_idx) {
if (x2_idx == j) {
break;
} else {
catch_x2_i += 1;
}
}
}
S val = init_val;
size_t same_x2_i = 0;
for (size_t i = 0; i < out_size; i++) {
auto oi = i * dim;
if (out_indices[oi] == x_idx && out_indices[oi + 1] == y_idx) {
if (same_x2_i == catch_x2_i) {
val = dout[i];
break;
} else {
same_x2_i += 1;
}
}
}
dx2[x2_idx] = val;
x2_idx += stride;
size_t x2_idx = threadId;
while (x2_idx < x2_size) {
size_t idx = x2_idx * dim;
for (size_t i = 0; i < dim; i++) {
temp_save_ptr[i] = x2_indices[idx + i];
}
for (size_t i = 0; i < out_size; i++) {
auto oi = i * dim;
bool same_flag = true;
for (size_t j = 0; j < dim; j++) {
if (temp_save_ptr[j] != out_indices[oi + j]) {
same_flag = false;
break;
}
}
if (same_flag) {
dx2[x2_idx] = dout[i];
break;
}
}
x2_idx += stride;
}
return;
}
template <typename T, typename S>
void CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size,
const T *out_indices, size_t out_size, S *dx1, S *dx2, size_t dim,
const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim,
const uint32_t &device_id, cudaStream_t cuda_stream) {
size_t max_in_size = std::max(x1_size, x2_size);
SparseAddGrad<<<CUDA_BLOCKS(device_id, max_in_size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, out_size,
dx1, dx2, dim, S(0));
dim3 blockSize(1);
dim3 gridSize(1);
SparseAddGrad<<<gridSize, blockSize, 0, cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices,
out_size, temp_save_ptr, dx1, dx2, dim, S(0));
return;
}
template<typename T>
template <typename T>
void CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size,
const T *out_indices, size_t out_size, cuComplex *dx1, cuComplex *dx2, size_t dim,
const uint32_t &device_id, cudaStream_t cuda_stream) {
size_t max_in_size = std::max(x1_size, x2_size);
SparseAddGrad<<<CUDA_BLOCKS(device_id, max_in_size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, out_size,
dx1, dx2, dim, {0, 0});
const T *out_indices, size_t out_size, T *temp_save_ptr, cuComplex *dx1, cuComplex *dx2,
size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream) {
dim3 blockSize(1);
dim3 gridSize(1);
SparseAddGrad<<<gridSize, blockSize, 0, cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices,
out_size, temp_save_ptr, dx1, dx2, dim, {0, 0});
return;
}
template<typename T>
template <typename T>
void CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices,
size_t x2_size, const T *out_indices, size_t out_size, cuDoubleComplex *dx1,
cuDoubleComplex *dx2, size_t dim, const uint32_t &device_id,
cudaStream_t cuda_stream) {
size_t max_in_size = std::max(x1_size, x2_size);
SparseAddGrad<<<CUDA_BLOCKS(device_id, max_in_size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, out_size,
dx1, dx2, dim, {0, 0});
size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, cuDoubleComplex *dx1,
cuDoubleComplex *dx2, size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream) {
dim3 blockSize(1);
dim3 gridSize(1);
SparseAddGrad<<<gridSize, blockSize, 0, cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices,
out_size, temp_save_ptr, dx1, dx2, dim, {0, 0});
return;
}
#define GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(index_type, val_type) \
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type, val_type>(const val_type *dout, \
const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
const index_type *out_indices, size_t out_size, val_type *dx1, val_type *dx2, size_t dim, \
const uint32_t &device_id, cudaStream_t cuda_stream);
#define GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(index_type, val_type) \
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type, val_type>( \
const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \
size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream);
#define GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(index_type, val_type) \
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type>(const val_type *dout, \
const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
const index_type *out_indices, size_t out_size, val_type *dx1, val_type *dx2, size_t dim, \
const uint32_t &device_id, cudaStream_t cuda_stream);
#define GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(index_type, val_type) \
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type>( \
const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \
size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream);
GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int8_t)
GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int16_t)

View File

@ -19,20 +19,19 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size,
const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size,
S *dx1, S *dx2, size_t dim,
const uint32_t &device_id, cudaStream_t cuda_stream);
CUDA_LIB_EXPORT void CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices,
size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1,
S *dx2, size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size,
const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size,
cuComplex *dx1, cuComplex *dx2, size_t dim,
const uint32_t &device_id, cudaStream_t cuda_stream);
CUDA_LIB_EXPORT void CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices,
size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr,
cuComplex *dx1, cuComplex *dx2, size_t dim, const uint32_t &device_id,
cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size,
const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size,
cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim,
T *temp_save_ptr, cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_