forked from mindspore-Ecosystem/mindspore
!15541 fix bug of scatter operators: multithread operation will cause input data update error
From: @dragon_d Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @wuxuejian
This commit is contained in:
commit
c81ecab938
|
@ -83,50 +83,40 @@ bool ScatterArithmeticCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr>
|
|||
|
||||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, const T *updates) {
|
||||
auto task = [this, input, indices, updates](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
input[base_index_input + j] += updates[base_index_updates + j];
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, indices_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterSub(T *input, const int *indices, const T *updates) {
|
||||
auto task = [this, input, indices, updates](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
input[base_index_input + j] -= updates[base_index_updates + j];
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, indices_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterMul(T *input, const int *indices, const T *updates) {
|
||||
auto task = [this, input, indices, updates](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
input[base_index_input + j] *= updates[base_index_updates + j];
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, indices_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterDiv(T *input, const int *indices, const T *updates) {
|
||||
auto task = [this, input, indices, updates](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
auto dividend = input[indices[i] * inner_size_ + j];
|
||||
auto divisor = updates[i * inner_size_ + j];
|
||||
|
@ -147,14 +137,11 @@ void ScatterArithmeticCPUKernel<T>::ScatterDiv(T *input, const int *indices, con
|
|||
input[indices[i] * inner_size_ + j] = dividend / divisor;
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, indices_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterMax(T *input, const int *indices, const T *updates) {
|
||||
auto task = [this, input, indices, updates](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
|
@ -163,14 +150,11 @@ void ScatterArithmeticCPUKernel<T>::ScatterMax(T *input, const int *indices, con
|
|||
: updates[base_index_updates + j];
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, indices_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterMin(T *input, const int *indices, const T *updates) {
|
||||
auto task = [this, input, indices, updates](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
|
@ -179,22 +163,17 @@ void ScatterArithmeticCPUKernel<T>::ScatterMin(T *input, const int *indices, con
|
|||
: updates[base_index_updates + j];
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, indices_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterUpdate(T *input, const int *indices, const T *updates) {
|
||||
auto task = [this, input, indices, updates](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
input[base_index_input + j] = updates[base_index_updates + j];
|
||||
}
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, indices_size_);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue