!37808 improve IndexAdd in some cases
Merge pull request !37808 from looop5/index_add_prof
This commit is contained in:
commit
1d55a86053
|
@ -137,10 +137,10 @@ bool IndexAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
|
|||
for (size_t i = start; i < end; ++i) {
|
||||
// calc idx_y in y.shape[axis]
|
||||
const size_t y_axis_idx = (i / inner_size_) % y_axis_size_;
|
||||
// calc idx_x in x.shape[axis]
|
||||
const size_t x_axis_idx = static_cast<size_t>(indices[y_axis_idx]);
|
||||
// only process add operation when idx_x is valid
|
||||
if (x_axis_idx < x_axis_size_) {
|
||||
if (indices[y_axis_idx] >= 0 && static_cast<size_t>(indices[y_axis_idx]) < x_axis_size_) {
|
||||
// calc idx_x in x.shape[axis]
|
||||
const size_t x_axis_idx = static_cast<size_t>(indices[y_axis_idx]);
|
||||
const size_t x_outer_idx = i / y_axis_inner_size;
|
||||
const size_t x_inner_idx = i % inner_size_;
|
||||
const size_t x_idx = x_outer_idx * x_axis_inner_size + x_axis_idx * inner_size_ + x_inner_idx;
|
||||
|
@ -148,8 +148,29 @@ bool IndexAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto task_block = [&](const size_t start, const size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
const size_t y_outer_idx = i / y_axis_size_;
|
||||
const size_t y_axis_idx = i - y_outer_idx * y_axis_size_;
|
||||
if (indices[y_axis_idx] >= 0 && static_cast<size_t>(indices[y_axis_idx]) < x_axis_size_) {
|
||||
const size_t x_axis_idx = static_cast<size_t>(indices[y_axis_idx]);
|
||||
const size_t x_idx = y_outer_idx * x_axis_inner_size + x_axis_idx * inner_size_;
|
||||
const size_t y_idx = i * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; ++j) {
|
||||
x[x_idx + j] += y[y_idx + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const float block_size = 1024;
|
||||
ParallelLaunch(task, y_nums_, block_size, this);
|
||||
const size_t inner_block_size = 100;
|
||||
if (inner_size_ > 1 && inner_size_ <= inner_block_size) {
|
||||
ParallelLaunch(task_block, y_nums_ / inner_size_, block_size / inner_size_, this);
|
||||
} else {
|
||||
ParallelLaunch(task, y_nums_, block_size, this);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -304,7 +304,7 @@ def test_index_add_function():
|
|||
@pytest.mark.env_onecard
|
||||
def test_index_add_dynamic():
|
||||
"""
|
||||
Feature: test IndexAdd dynamic shape.
|
||||
Feature: test IndexAdd dynamic shape with set_inputs.
|
||||
Description: input y is dynamic shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue