forked from mindspore-Ecosystem/mindspore
!29216 [MSLITE][DEVELOP] fix bug of scatter_nd op: indices tensor support int64
Merge pull request !29216 from yangruoqi713/scatter
This commit is contained in:
commit
6474426e59
|
@ -46,8 +46,6 @@ int ScatterNdUpdateCPUKernel::ReSize() {
|
|||
auto input = in_tensors_.at(kScatterUpdateInputIndex);
|
||||
auto indices = in_tensors_.at(kScatterIndicesIndex);
|
||||
auto update = in_tensors_.at(kScatterUpdateIndex);
|
||||
auto output = out_tensors_.front();
|
||||
output_ptr_ = output->data();
|
||||
|
||||
// check indices shape
|
||||
int input_rank = static_cast<int>(input->shape().size());
|
||||
|
@ -87,15 +85,32 @@ int ScatterNdUpdateCPUKernel::ReSize() {
|
|||
param_->num_unit *= update_shape.at(i);
|
||||
}
|
||||
|
||||
int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
|
||||
MS_ASSERT(indices_ptr != nullptr);
|
||||
auto indices_ptr = indices->data();
|
||||
if (indices_ptr == nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
output_unit_offsets_.clear();
|
||||
for (int i = 0; i < param_->num_unit; i++) {
|
||||
int tmp_stride = 0;
|
||||
for (int j = 0; j < indice_unit_rank; j++) {
|
||||
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
|
||||
if (indices->data_type() == kNumberTypeInt || indices->data_type() == kNumberTypeInt32) {
|
||||
auto indices_data = reinterpret_cast<int *>(indices_ptr);
|
||||
for (int i = 0; i < param_->num_unit; i++) {
|
||||
int tmp_stride = 0;
|
||||
for (int j = 0; j < indice_unit_rank; j++) {
|
||||
tmp_stride += indices_data[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
|
||||
}
|
||||
output_unit_offsets_.push_back(tmp_stride);
|
||||
}
|
||||
output_unit_offsets_.push_back(tmp_stride);
|
||||
} else if (indices->data_type() == kNumberTypeInt64) {
|
||||
auto indices_data = reinterpret_cast<int64_t *>(indices_ptr);
|
||||
for (int i = 0; i < param_->num_unit; i++) {
|
||||
int tmp_stride = 0;
|
||||
for (int j = 0; j < indice_unit_rank; j++) {
|
||||
tmp_stride += indices_data[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
|
||||
}
|
||||
output_unit_offsets_.push_back(tmp_stride);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data type for indices tensor.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -126,7 +141,7 @@ int ScatterNdUpdateCPUKernel::Run() {
|
|||
auto in_tensor = in_tensors().front();
|
||||
auto out_tensor = out_tensors().front();
|
||||
if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() ||
|
||||
op_parameter_->is_train_session_) {
|
||||
in_tensor->own_data() == false || op_parameter_->is_train_session_) {
|
||||
memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
|
||||
} else {
|
||||
out_tensor->FreeData();
|
||||
|
@ -134,7 +149,6 @@ int ScatterNdUpdateCPUKernel::Run() {
|
|||
in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
|
||||
out_tensor->set_data(in_tensor->data());
|
||||
out_tensor->set_own_data(in_tensor->own_data());
|
||||
output_ptr_ = out_tensor->data();
|
||||
}
|
||||
auto indices = in_tensors_.at(kScatterIndicesIndex);
|
||||
if (!indices->IsConst() && ReSize() != RET_OK) {
|
||||
|
|
|
@ -39,7 +39,6 @@ class ScatterNdUpdateCPUKernel : public InnerKernel {
|
|||
|
||||
private:
|
||||
ScatterNDParameter *param_ = nullptr;
|
||||
void *output_ptr_ = nullptr;
|
||||
std::vector<int> output_unit_offsets_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue