!40218 deal with duplicate index

Merge pull request !40218 from looop5/index_add_acc
This commit is contained in:
i-robot 2022-08-11 06:26:39 +00:00 committed by Gitee
commit 231e0eb928
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 33 additions and 4 deletions

View File

@ -20,16 +20,35 @@
#include <memory>
#include <utility>
#include <map>
#include <unordered_set>
#include "mindspore/core/ops/index_add.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "include/common/thread_pool.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kIndexAddInputsNum = 3;
constexpr size_t kIndexAddOutputsNum = 1;
bool HasDuplicateIndex(const int32_t *indices, size_t len) {
MS_EXCEPTION_IF_NULL(indices);
std::unordered_set<int32_t> unique_idx;
for (size_t i = 0; i < len; ++i) {
if (unique_idx.find(indices[i]) != unique_idx.end()) {
return true;
}
unique_idx.insert(indices[i]);
}
return false;
}
size_t CalcSizePerThread(size_t total_block) {
size_t pool_thread_num = GetActorMgrInnerThreadPool()->GetKernelThreadNum();
pool_thread_num = pool_thread_num == 0 ? 1 : pool_thread_num;
size_t block_num = (total_block + pool_thread_num - 1) / pool_thread_num;
return block_num;
}
} // namespace
bool IndexAddCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -100,6 +119,7 @@ void IndexAddCpuKernelMod::CheckParams() {
x_nums_ = 1;
y_nums_ = 1;
inner_size_ = 1;
outer_size_ = 1;
for (size_t i = 0; i < x_shape_.size(); ++i) {
if (x_shape_[i] <= 0 || y_shape_[i] <= 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'x' shape[" << i << "] or 'y' shape [" << i
@ -113,7 +133,9 @@ void IndexAddCpuKernelMod::CheckParams() {
}
x_nums_ *= LongToSize(x_shape_[i]);
y_nums_ *= LongToSize(y_shape_[i]);
if (i > axis) {
if (i < axis) {
outer_size_ *= LongToSize(x_shape_[i]);
} else if (i > axis) {
inner_size_ *= LongToSize(x_shape_[i]);
}
}
@ -165,9 +187,15 @@ bool IndexAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
}
};
auto heavy_task_block = [this, task_block](const size_t start, const size_t end) {
task_block(start * y_axis_size_, end * y_axis_size_);
};
const float block_size = 1024;
const size_t inner_block_size = 100;
if (inner_size_ > 1 && inner_size_ <= inner_block_size) {
if (HasDuplicateIndex(indices, y_axis_size_)) {
ParallelLaunch(heavy_task_block, outer_size_, CalcSizePerThread(outer_size_), this);
} else if (inner_size_ > 1 && inner_size_ <= inner_block_size) {
ParallelLaunch(task_block, y_nums_ / inner_size_, block_size / inner_size_, this);
} else {
ParallelLaunch(task, y_nums_, block_size, this);

View File

@ -63,6 +63,7 @@ class IndexAddCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
size_t x_nums_{1};
size_t y_nums_{1};
size_t inner_size_{1};
size_t outer_size_{1};
size_t x_axis_size_{1};
size_t y_axis_size_{1};
};

View File

@ -148,7 +148,7 @@ class PadNet(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [np.bool_, np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32,