fix sparsematirxadd
This commit is contained in:
parent
c9937a553f
commit
a5222765a7
|
@ -112,6 +112,7 @@ bool SparseMatrixAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
|
|||
<< outputs.size() << " output(s).";
|
||||
}
|
||||
const auto a_batch_size = inputs[kABatchPtrIdx]->size / sizeof(T);
|
||||
const auto a_dense_shape = reinterpret_cast<T *>(inputs[kADenseShapeIdx]->addr);
|
||||
const auto a_indptr = reinterpret_cast<T *>(inputs[kAIndptrIdx]->addr);
|
||||
const auto a_indices = reinterpret_cast<T *>(inputs[kAIndicesIdx]->addr);
|
||||
const auto a_values = reinterpret_cast<S *>(inputs[kAValuesIdx]->addr);
|
||||
|
@ -120,22 +121,27 @@ bool SparseMatrixAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
|
|||
const auto b_values = reinterpret_cast<S *>(inputs[kBValuesIdx]->addr);
|
||||
const auto alpha = reinterpret_cast<S *>(inputs[kAlphaIdx]->addr);
|
||||
const auto beta = reinterpret_cast<S *>(inputs[kBetaIdx]->addr);
|
||||
auto c_indptr = reinterpret_cast<T *>(outputs[kAIndptrIdx]->addr);
|
||||
auto c_indices = reinterpret_cast<T *>(outputs[kAIndicesIdx]->addr);
|
||||
auto c_values = reinterpret_cast<S *>(outputs[kAValuesIdx]->addr);
|
||||
|
||||
auto batch = static_cast<size_t>(a_batch_size > 1 ? (a_batch_size - 1) : 1);
|
||||
auto c_indptr = reinterpret_cast<T *>(outputs[kOutIndptr]->addr);
|
||||
auto c_indices = reinterpret_cast<T *>(outputs[kOutIndices]->addr);
|
||||
auto c_values = reinterpret_cast<S *>(outputs[kOutValue]->addr);
|
||||
auto c_dense_shape = reinterpret_cast<T *>(outputs[kOutDenseShape]->addr);
|
||||
auto c_batch = reinterpret_cast<T *>(outputs[kOutBatch]->addr);
|
||||
// Consider the dense shape of input and output are the same.
|
||||
(void)memcpy(c_dense_shape, a_dense_shape, outputs[kOutDenseShape]->size);
|
||||
auto batch_size = static_cast<size_t>(a_batch_size > 1 ? (a_batch_size - 1) : 1);
|
||||
c_batch[0] = 0;
|
||||
// Do the compute: C = alpha * A + beta * B.
|
||||
c_indptr[0] = 0;
|
||||
std::set<T> index_set;
|
||||
size_t c_idx = 0;
|
||||
S a_v = 0;
|
||||
S b_v = 0;
|
||||
size_t tmp_batch = 0;
|
||||
size_t a_v_idx = 0;
|
||||
size_t b_v_idx = 0;
|
||||
for (size_t s = 0; s < batch; s++) { // loop for all batches
|
||||
for (size_t s = 0; s < batch_size; s++) { // loop for all batches
|
||||
auto task = [this, &a_indptr, &a_indices, &a_values, &b_indptr, &b_indices, &b_values, &alpha, &beta, &c_indptr,
|
||||
&c_indices, &c_values, &index_set, &c_idx, &a_v, &b_v, &a_v_idx, &b_v_idx,
|
||||
&c_indices, &c_values, &index_set, &c_idx, &a_v, &b_v, &a_v_idx, &b_v_idx, &tmp_batch,
|
||||
&s](size_t start, size_t end) {
|
||||
for (size_t x = start; x < end; x++) { // one batch
|
||||
auto i = x + s;
|
||||
|
@ -169,16 +175,21 @@ bool SparseMatrixAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
|
|||
b_v = 0; // Reset the tmp value, real number or zero.
|
||||
a_v = 0;
|
||||
}
|
||||
tmp_batch += index_set.size();
|
||||
index_set.clear();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, row_, this, ¶llel_search_info_);
|
||||
if (s < batch_size - 1) {
|
||||
c_batch[s + 1] = tmp_batch + c_batch[s];
|
||||
}
|
||||
tmp_batch = 0;
|
||||
}
|
||||
// Update output shape and type
|
||||
std::vector<int64_t> out_indptr_shape;
|
||||
std::vector<int64_t> out_indices_shape;
|
||||
std::vector<int64_t> out_values_shape;
|
||||
(void)out_indptr_shape.emplace_back(SizeToLong(batch * (row_ + 1)));
|
||||
(void)out_indptr_shape.emplace_back(SizeToLong(batch_size * (row_ + 1)));
|
||||
(void)out_indices_shape.emplace_back(SizeToLong(c_idx));
|
||||
(void)out_values_shape.emplace_back(SizeToLong(c_idx));
|
||||
outputs_[kOutIndptr]->SetShapeVector(out_indptr_shape);
|
||||
|
|
Loading…
Reference in New Issue