forked from mindspore-Ecosystem/mindspore
!27537 Fixes core dump issue of ScatterAdd and ScatterSub
Merge pull request !27537 from huangbo/master_1209
This commit is contained in:
commit
b16e6bed20
|
@ -52,6 +52,7 @@ void ScatterArithmeticCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< "', the dimension of 'input_x' should be greater than or equal to 1, but got "
|
||||
<< input_shape.size() << ".";
|
||||
}
|
||||
input_shape_0 = SizeToInt(input_shape[0]);
|
||||
input_size_ = 1;
|
||||
inner_size_ = 1;
|
||||
if (input_shape.empty()) {
|
||||
|
@ -92,6 +93,9 @@ bool ScatterArithmeticCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr>
|
|||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, const T *updates) const {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
if (indices[i] >= input_shape_0) {
|
||||
continue;
|
||||
}
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
|
@ -103,6 +107,9 @@ void ScatterArithmeticCPUKernel<T>::ScatterAdd(T *input, const int *indices, con
|
|||
template <typename T>
|
||||
void ScatterArithmeticCPUKernel<T>::ScatterSub(T *input, const int *indices, const T *updates) const {
|
||||
for (size_t i = 0; i < indices_size_; i++) {
|
||||
if (indices[i] >= input_shape_0) {
|
||||
continue;
|
||||
}
|
||||
auto base_index_updates = i * inner_size_;
|
||||
auto base_index_input = indices[i] * inner_size_;
|
||||
for (size_t j = 0; j < inner_size_; j++) {
|
||||
|
|
|
@ -49,6 +49,7 @@ class ScatterArithmeticCPUKernel : public CPUKernel {
|
|||
using TypeComputeFunc = std::function<void(ScatterArithmeticCPUKernel *, T *, const int *, const T *)>;
|
||||
|
||||
TypeComputeFunc compute_func_;
|
||||
int input_shape_0{0};
|
||||
size_t input_size_{0};
|
||||
size_t inner_size_{0};
|
||||
size_t indices_size_{0};
|
||||
|
|
Loading…
Reference in New Issue