!37808 improve IndexAdd in some cases

Merge pull request !37808 from looop5/index_add_prof
This commit is contained in:
i-robot 2022-07-12 01:37:50 +00:00 committed by Gitee
commit 1d55a86053
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 26 additions and 5 deletions

View File

@ -137,10 +137,10 @@ bool IndexAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
for (size_t i = start; i < end; ++i) {
// calc idx_y in y.shape[axis]
const size_t y_axis_idx = (i / inner_size_) % y_axis_size_;
// calc idx_x in x.shape[axis]
const size_t x_axis_idx = static_cast<size_t>(indices[y_axis_idx]);
// only process add operation when idx_x is valid
if (x_axis_idx < x_axis_size_) {
if (indices[y_axis_idx] >= 0 && static_cast<size_t>(indices[y_axis_idx]) < x_axis_size_) {
// calc idx_x in x.shape[axis]
const size_t x_axis_idx = static_cast<size_t>(indices[y_axis_idx]);
const size_t x_outer_idx = i / y_axis_inner_size;
const size_t x_inner_idx = i % inner_size_;
const size_t x_idx = x_outer_idx * x_axis_inner_size + x_axis_idx * inner_size_ + x_inner_idx;
@ -148,8 +148,29 @@ bool IndexAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
}
}
};
auto task_block = [&](const size_t start, const size_t end) {
for (size_t i = start; i < end; ++i) {
const size_t y_outer_idx = i / y_axis_size_;
const size_t y_axis_idx = i - y_outer_idx * y_axis_size_;
if (indices[y_axis_idx] >= 0 && static_cast<size_t>(indices[y_axis_idx]) < x_axis_size_) {
const size_t x_axis_idx = static_cast<size_t>(indices[y_axis_idx]);
const size_t x_idx = y_outer_idx * x_axis_inner_size + x_axis_idx * inner_size_;
const size_t y_idx = i * inner_size_;
for (size_t j = 0; j < inner_size_; ++j) {
x[x_idx + j] += y[y_idx + j];
}
}
}
};
const float block_size = 1024;
ParallelLaunch(task, y_nums_, block_size, this);
const size_t inner_block_size = 100;
if (inner_size_ > 1 && inner_size_ <= inner_block_size) {
ParallelLaunch(task_block, y_nums_ / inner_size_, block_size / inner_size_, this);
} else {
ParallelLaunch(task, y_nums_, block_size, this);
}
return true;
}

View File

@ -304,7 +304,7 @@ def test_index_add_function():
@pytest.mark.env_onecard
def test_index_add_dynamic():
"""
Feature: test IndexAdd dynamic shape.
Feature: test IndexAdd dynamic shape with set_inputs.
Description: input y is dynamic shape.
Expectation: the result match with numpy result
"""