forked from mindspore-Ecosystem/mindspore
!49182 Fix SparseAddGrad bug
Merge pull request !49182 from YijieChen/ops
This commit is contained in:
commit
0ba09ac448
|
@ -67,6 +67,7 @@ int SparseAddGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
ResetResource();
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
indices_column_ = inputs.at(1)->GetShapeVector()[1];
|
||||
if (ret == KRET_UNKNOWN_OUT_SHAPE) {
|
||||
if (input_size_list_.size() != kInputNum) {
|
||||
MS_LOG(ERROR) << "Input size list should be " << kInputNum << ", but got " << input_size_list_.size();
|
||||
|
@ -103,6 +104,26 @@ int SparseAddGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
int SparseAddGradCpuKernelMod::CompareTwoIndices(const T &a_indices, const T &b_indices, const S *backprop_value,
|
||||
int64_t *a_row, const int64_t b_row, const size_t dims, S *dx_value,
|
||||
bool *idx_geq) {
|
||||
for (int64_t dim = 0; dim < SizeToLong(dims); dim++) {
|
||||
auto a_idx = a_indices[*a_row * dims + dim];
|
||||
auto b_idx = b_indices[b_row * dims + dim];
|
||||
if (a_idx < b_idx) {
|
||||
*idx_geq = false;
|
||||
*a_row += 1;
|
||||
return -1;
|
||||
} else if (a_idx > b_idx) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
dx_value[*a_row] = backprop_value[b_row];
|
||||
*a_row += 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
|
@ -124,37 +145,40 @@ bool SparseAddGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPt
|
|||
auto dx1 = reinterpret_cast<T *>(outputs[kDx1Idx]->addr);
|
||||
auto dx2 = reinterpret_cast<T *>(outputs[kDx2Idx]->addr);
|
||||
|
||||
const int64_t x1_indices_num = inputs[kX1IndicesIdx]->size / (sizeof(S) * 2);
|
||||
const int64_t x2_indices_num = inputs[kX2IndicesIdx]->size / (sizeof(S) * 2);
|
||||
const int64_t out_indices_num = inputs[kOutIndicesIdx]->size / (sizeof(S) * 2);
|
||||
const int64_t x1_indices_num = inputs[kX1IndicesIdx]->size / (sizeof(S) * indices_column_);
|
||||
const int64_t x2_indices_num = inputs[kX2IndicesIdx]->size / (sizeof(S) * indices_column_);
|
||||
const int64_t out_indices_num = inputs[kOutIndicesIdx]->size / (sizeof(S) * indices_column_);
|
||||
|
||||
auto arrayHash = [fn = std::hash<int>{}](const std::array<int, 2> &arr) -> size_t {
|
||||
return std::accumulate(arr.begin(), arr.end(), 0u, [&](size_t acc, int num) { return (acc << 1) ^ fn(num); });
|
||||
};
|
||||
memset(dx1, 0, sizeof(T) * x1_indices_num);
|
||||
memset(dx2, 0, sizeof(T) * x2_indices_num);
|
||||
|
||||
constexpr int dimension_difference = 2;
|
||||
std::unordered_map<std::array<int, 2>, int, decltype(arrayHash)> out_map(0, arrayHash);
|
||||
for (int i = 0; i < out_indices_num * dimension_difference; i += dimension_difference) {
|
||||
std::array<int, 2> index{};
|
||||
index[0] = out_indices[i];
|
||||
index[1] = out_indices[i + 1];
|
||||
out_map[index] = static_cast<int>(i / dimension_difference);
|
||||
}
|
||||
int64_t i = 0;
|
||||
int64_t j = 0;
|
||||
int64_t k = 0;
|
||||
bool a_idx_geq;
|
||||
bool b_idx_geq;
|
||||
|
||||
for (int i = 0; i < x1_indices_num * dimension_difference; i += dimension_difference) {
|
||||
std::array<int, 2> index{};
|
||||
index[0] = x1_indices[i];
|
||||
index[1] = x1_indices[i + 1];
|
||||
if (out_map.find(index) != out_map.end()) {
|
||||
dx1[static_cast<int>(i / dimension_difference)] = dout[out_map[index]];
|
||||
while (i < x1_indices_num && j < x2_indices_num && k < out_indices_num) {
|
||||
a_idx_geq = b_idx_geq = true;
|
||||
CompareTwoIndices(x1_indices, out_indices, dout, &i, k, indices_column_, dx1, &a_idx_geq);
|
||||
CompareTwoIndices(x2_indices, out_indices, dout, &j, k, indices_column_, dx2, &b_idx_geq);
|
||||
if (a_idx_geq && b_idx_geq) {
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x2_indices_num * dimension_difference; i += dimension_difference) {
|
||||
std::array<int, 2> index{};
|
||||
index[0] = x2_indices[i];
|
||||
index[1] = x2_indices[i + 1];
|
||||
if (out_map.find(index) != out_map.end()) {
|
||||
dx2[static_cast<int>(i / dimension_difference)] = dout[out_map[index]];
|
||||
while (i < x1_indices_num && k < out_indices_num) {
|
||||
a_idx_geq = true;
|
||||
CompareTwoIndices(x1_indices, out_indices, dout, &i, k, indices_column_, dx1, &a_idx_geq);
|
||||
if (a_idx_geq) {
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
|
||||
while (j < x2_indices_num && k < out_indices_num) {
|
||||
b_idx_geq = true;
|
||||
CompareTwoIndices(x2_indices, out_indices, dout, &j, k, indices_column_, dx2, &b_idx_geq);
|
||||
if (b_idx_geq) {
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -51,11 +51,15 @@ class SparseAddGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelH
|
|||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T, typename S>
|
||||
int CompareTwoIndices(const T &a_indices, const T &b_indices, const S *backprop_value, int64_t *a_row,
|
||||
const int64_t b_row, const size_t dims, S *dx_value, bool *idx_geq);
|
||||
|
||||
std::vector<size_t> dout_shape_;
|
||||
std::vector<size_t> x1_indices_shape_;
|
||||
std::vector<size_t> x2_indices_shape_;
|
||||
std::vector<size_t> out_indices_shape_;
|
||||
size_t indices_column_ = 0;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,6 +80,7 @@ class SparseAddGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
input_size_list_.push_back(sum_indices_size_ * index_bytes_);
|
||||
output_size_list_.push_back(x1_index_num_ * value_bytes_);
|
||||
output_size_list_.push_back(x2_index_num_ * value_bytes_);
|
||||
work_size_list_.push_back(dim_ * index_bytes_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -95,6 +96,7 @@ class SparseAddGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
T *out_indices_ptr = nullptr;
|
||||
S *dx1_ptr = nullptr;
|
||||
S *dx2_ptr = nullptr;
|
||||
T *temp_save_ptr = nullptr;
|
||||
int flag = GetDeviceAddress<S>(input_ptrs, kSparseAddGradIndex0, kernel_name_, &dout_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
|
@ -124,13 +126,19 @@ class SparseAddGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<T>(work_ptrs, kSparseAddGradIndex0, kernel_name_, &temp_save_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream)),
|
||||
"For SparseAddGrad, cudaStreamSynchronize failed.");
|
||||
// call cuda kernel
|
||||
MS_LOG(INFO) << "For SparseAddGrad, x1_index_num_, x2_index_num_, sum_index_num_, dim_ " << x1_index_num_ << ", "
|
||||
<< x2_index_num_ << ", " << sum_index_num_ << ", " << dim_;
|
||||
CalSparseAddGrad(dout_ptr, x1_indices_ptr, x1_index_num_, x2_indices_ptr, x2_index_num_, out_indices_ptr,
|
||||
sum_index_num_, dx1_ptr, dx2_ptr, dim_, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
sum_index_num_, temp_save_ptr, dx1_ptr, dx2_ptr, dim_, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream)),
|
||||
"For SparseAddGrad, cudaStreamSynchronize failed.");
|
||||
return 0;
|
||||
|
|
|
@ -18,122 +18,104 @@
|
|||
#include <algorithm>
|
||||
template <typename T, typename S>
|
||||
__global__ void SparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size,
|
||||
const T *out_indices, size_t out_size, S *dx1, S *dx2, size_t dim, S init_val) {
|
||||
size_t stride = gridDim.x * blockDim.x;
|
||||
size_t threadId = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t x1_idx = threadId;
|
||||
while (x1_idx < x1_size) {
|
||||
size_t idx = x1_idx * dim;
|
||||
auto x_idx = x1_indices[idx];
|
||||
auto y_idx = x1_indices[idx + 1];
|
||||
size_t catch_x1_i = 0;
|
||||
for (size_t j = 0; j < x1_size; j++) {
|
||||
auto oj = j * dim;
|
||||
if (x1_indices[oj] == x_idx && x1_indices[oj + 1] == y_idx) {
|
||||
if (x1_idx == j) {
|
||||
break;
|
||||
} else {
|
||||
catch_x1_i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
S val = init_val;
|
||||
size_t same_x1_i = 0;
|
||||
for (size_t i = 0; i < out_size; i++) {
|
||||
auto oi = i * dim;
|
||||
if (out_indices[oi] == x_idx && out_indices[oi + 1] == y_idx) {
|
||||
if (same_x1_i == catch_x1_i) {
|
||||
val = dout[i];
|
||||
break;
|
||||
} else {
|
||||
same_x1_i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
dx1[x1_idx] = val;
|
||||
x1_idx += stride;
|
||||
const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim,
|
||||
S init_val) {
|
||||
size_t stride = gridDim.x * blockDim.x;
|
||||
size_t threadId = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t x1_idx = threadId;
|
||||
memset(dx1, 0, sizeof(T) * x1_size);
|
||||
memset(dx2, 0, sizeof(T) * x2_size);
|
||||
while (x1_idx < x1_size) {
|
||||
size_t idx = x1_idx * dim;
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
temp_save_ptr[i] = x1_indices[idx + i];
|
||||
}
|
||||
for (size_t i = 0; i < out_size; i++) {
|
||||
auto oi = i * dim;
|
||||
bool same_flag = true;
|
||||
for (size_t j = 0; j < dim; j++) {
|
||||
if (temp_save_ptr[j] != out_indices[oi + j]) {
|
||||
same_flag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (same_flag) {
|
||||
dx1[x1_idx] = dout[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
x1_idx += stride;
|
||||
}
|
||||
|
||||
size_t x2_idx = threadId;
|
||||
while (x2_idx < x2_size) {
|
||||
size_t idx = x2_idx * dim;
|
||||
auto x_idx = x2_indices[idx];
|
||||
auto y_idx = x2_indices[idx + 1];
|
||||
size_t catch_x2_i = 0;
|
||||
for (size_t j = 0; j < x2_size; j++) {
|
||||
auto oj = j * dim;
|
||||
if (x2_indices[oj] == x_idx && x2_indices[oj + 1] == y_idx) {
|
||||
if (x2_idx == j) {
|
||||
break;
|
||||
} else {
|
||||
catch_x2_i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
S val = init_val;
|
||||
size_t same_x2_i = 0;
|
||||
for (size_t i = 0; i < out_size; i++) {
|
||||
auto oi = i * dim;
|
||||
if (out_indices[oi] == x_idx && out_indices[oi + 1] == y_idx) {
|
||||
if (same_x2_i == catch_x2_i) {
|
||||
val = dout[i];
|
||||
break;
|
||||
} else {
|
||||
same_x2_i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
dx2[x2_idx] = val;
|
||||
x2_idx += stride;
|
||||
size_t x2_idx = threadId;
|
||||
while (x2_idx < x2_size) {
|
||||
size_t idx = x2_idx * dim;
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
temp_save_ptr[i] = x2_indices[idx + i];
|
||||
}
|
||||
for (size_t i = 0; i < out_size; i++) {
|
||||
auto oi = i * dim;
|
||||
bool same_flag = true;
|
||||
for (size_t j = 0; j < dim; j++) {
|
||||
if (temp_save_ptr[j] != out_indices[oi + j]) {
|
||||
same_flag = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (same_flag) {
|
||||
dx2[x2_idx] = dout[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
x2_idx += stride;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size,
|
||||
const T *out_indices, size_t out_size, S *dx1, S *dx2, size_t dim,
|
||||
const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1, S *dx2, size_t dim,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
size_t max_in_size = std::max(x1_size, x2_size);
|
||||
SparseAddGrad<<<CUDA_BLOCKS(device_id, max_in_size), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, out_size,
|
||||
dx1, dx2, dim, S(0));
|
||||
dim3 blockSize(1);
|
||||
dim3 gridSize(1);
|
||||
SparseAddGrad<<<gridSize, blockSize, 0, cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices,
|
||||
out_size, temp_save_ptr, dx1, dx2, dim, S(0));
|
||||
return;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices, size_t x2_size,
|
||||
const T *out_indices, size_t out_size, cuComplex *dx1, cuComplex *dx2, size_t dim,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
size_t max_in_size = std::max(x1_size, x2_size);
|
||||
SparseAddGrad<<<CUDA_BLOCKS(device_id, max_in_size), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, out_size,
|
||||
dx1, dx2, dim, {0, 0});
|
||||
const T *out_indices, size_t out_size, T *temp_save_ptr, cuComplex *dx1, cuComplex *dx2,
|
||||
size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
dim3 blockSize(1);
|
||||
dim3 gridSize(1);
|
||||
SparseAddGrad<<<gridSize, blockSize, 0, cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices,
|
||||
out_size, temp_save_ptr, dx1, dx2, dim, {0, 0});
|
||||
return;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
void CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices,
|
||||
size_t x2_size, const T *out_indices, size_t out_size, cuDoubleComplex *dx1,
|
||||
cuDoubleComplex *dx2, size_t dim, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
size_t max_in_size = std::max(x1_size, x2_size);
|
||||
SparseAddGrad<<<CUDA_BLOCKS(device_id, max_in_size), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices, out_size,
|
||||
dx1, dx2, dim, {0, 0});
|
||||
size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, cuDoubleComplex *dx1,
|
||||
cuDoubleComplex *dx2, size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
dim3 blockSize(1);
|
||||
dim3 gridSize(1);
|
||||
SparseAddGrad<<<gridSize, blockSize, 0, cuda_stream>>>(dout, x1_indices, x1_size, x2_indices, x2_size, out_indices,
|
||||
out_size, temp_save_ptr, dx1, dx2, dim, {0, 0});
|
||||
return;
|
||||
}
|
||||
|
||||
#define GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(index_type, val_type) \
|
||||
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type, val_type>(const val_type *dout, \
|
||||
const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
|
||||
const index_type *out_indices, size_t out_size, val_type *dx1, val_type *dx2, size_t dim, \
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#define GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(index_type, val_type) \
|
||||
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type, val_type>( \
|
||||
const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
|
||||
const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \
|
||||
size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#define GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(index_type, val_type) \
|
||||
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type>(const val_type *dout, \
|
||||
const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
|
||||
const index_type *out_indices, size_t out_size, val_type *dx1, val_type *dx2, size_t dim, \
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#define GPU_SPARSE_ADD_GRAD_COMPLEX_EXPORT_REGISTER(index_type, val_type) \
|
||||
template CUDA_LIB_EXPORT void CalSparseAddGrad<index_type>( \
|
||||
const val_type *dout, const index_type *x1_indices, size_t x1_size, const index_type *x2_indices, size_t x2_size, \
|
||||
const index_type *out_indices, size_t out_size, index_type *temp_save_ptr, val_type *dx1, val_type *dx2, \
|
||||
size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int8_t)
|
||||
GPU_SPARSE_ADD_GRAD_EXPORT_REGISTER(int64_t, int16_t)
|
||||
|
|
|
@ -19,20 +19,19 @@
|
|||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size,
|
||||
const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size,
|
||||
S *dx1, S *dx2, size_t dim,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
CUDA_LIB_EXPORT void CalSparseAddGrad(const S *dout, const T *x1_indices, size_t x1_size, const T *x2_indices,
|
||||
size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr, S *dx1,
|
||||
S *dx2, size_t dim, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size,
|
||||
const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size,
|
||||
cuComplex *dx1, cuComplex *dx2, size_t dim,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
CUDA_LIB_EXPORT void CalSparseAddGrad(const cuComplex *dout, const T *x1_indices, size_t x1_size, const T *x2_indices,
|
||||
size_t x2_size, const T *out_indices, size_t out_size, T *temp_save_ptr,
|
||||
cuComplex *dx1, cuComplex *dx2, size_t dim, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalSparseAddGrad(const cuDoubleComplex *dout, const T *x1_indices, size_t x1_size,
|
||||
const T *x2_indices, size_t x2_size, const T *out_indices, size_t out_size,
|
||||
cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim,
|
||||
T *temp_save_ptr, cuDoubleComplex *dx1, cuDoubleComplex *dx2, size_t dim,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_ADD_GRAD_IMPL_CUH_
|
||||
|
|
Loading…
Reference in New Issue