fix sparsematirxadd

This commit is contained in:
VectorSL 2022-07-19 19:10:12 +08:00
parent c9937a553f
commit a5222765a7
1 changed files with 19 additions and 8 deletions

View File

@ -112,6 +112,7 @@ bool SparseMatrixAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
<< outputs.size() << " output(s).";
}
const auto a_batch_size = inputs[kABatchPtrIdx]->size / sizeof(T);
const auto a_dense_shape = reinterpret_cast<T *>(inputs[kADenseShapeIdx]->addr);
const auto a_indptr = reinterpret_cast<T *>(inputs[kAIndptrIdx]->addr);
const auto a_indices = reinterpret_cast<T *>(inputs[kAIndicesIdx]->addr);
const auto a_values = reinterpret_cast<S *>(inputs[kAValuesIdx]->addr);
@ -120,22 +121,27 @@ bool SparseMatrixAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
const auto b_values = reinterpret_cast<S *>(inputs[kBValuesIdx]->addr);
const auto alpha = reinterpret_cast<S *>(inputs[kAlphaIdx]->addr);
const auto beta = reinterpret_cast<S *>(inputs[kBetaIdx]->addr);
auto c_indptr = reinterpret_cast<T *>(outputs[kAIndptrIdx]->addr);
auto c_indices = reinterpret_cast<T *>(outputs[kAIndicesIdx]->addr);
auto c_values = reinterpret_cast<S *>(outputs[kAValuesIdx]->addr);
auto batch = static_cast<size_t>(a_batch_size > 1 ? (a_batch_size - 1) : 1);
auto c_indptr = reinterpret_cast<T *>(outputs[kOutIndptr]->addr);
auto c_indices = reinterpret_cast<T *>(outputs[kOutIndices]->addr);
auto c_values = reinterpret_cast<S *>(outputs[kOutValue]->addr);
auto c_dense_shape = reinterpret_cast<T *>(outputs[kOutDenseShape]->addr);
auto c_batch = reinterpret_cast<T *>(outputs[kOutBatch]->addr);
// Consider the dense shape of input and output are the same.
(void)memcpy(c_dense_shape, a_dense_shape, outputs[kOutDenseShape]->size);
auto batch_size = static_cast<size_t>(a_batch_size > 1 ? (a_batch_size - 1) : 1);
c_batch[0] = 0;
// Do the compute: C = alpha * A + beta * B.
c_indptr[0] = 0;
std::set<T> index_set;
size_t c_idx = 0;
S a_v = 0;
S b_v = 0;
size_t tmp_batch = 0;
size_t a_v_idx = 0;
size_t b_v_idx = 0;
for (size_t s = 0; s < batch; s++) { // loop for all batches
for (size_t s = 0; s < batch_size; s++) { // loop for all batches
auto task = [this, &a_indptr, &a_indices, &a_values, &b_indptr, &b_indices, &b_values, &alpha, &beta, &c_indptr,
&c_indices, &c_values, &index_set, &c_idx, &a_v, &b_v, &a_v_idx, &b_v_idx,
&c_indices, &c_values, &index_set, &c_idx, &a_v, &b_v, &a_v_idx, &b_v_idx, &tmp_batch,
&s](size_t start, size_t end) {
for (size_t x = start; x < end; x++) { // one batch
auto i = x + s;
@ -169,16 +175,21 @@ bool SparseMatrixAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
b_v = 0; // Reset the tmp value, real number or zero.
a_v = 0;
}
tmp_batch += index_set.size();
index_set.clear();
}
};
ParallelLaunchAutoSearch(task, row_, this, &parallel_search_info_);
if (s < batch_size - 1) {
c_batch[s + 1] = tmp_batch + c_batch[s];
}
tmp_batch = 0;
}
// Update output shape and type
std::vector<int64_t> out_indptr_shape;
std::vector<int64_t> out_indices_shape;
std::vector<int64_t> out_values_shape;
(void)out_indptr_shape.emplace_back(SizeToLong(batch * (row_ + 1)));
(void)out_indptr_shape.emplace_back(SizeToLong(batch_size * (row_ + 1)));
(void)out_indices_shape.emplace_back(SizeToLong(c_idx));
(void)out_values_shape.emplace_back(SizeToLong(c_idx));
outputs_[kOutIndptr]->SetShapeVector(out_indptr_shape);