!29216 [MSLITE][DEVELOP] fix bug of scatter_nd op: indices tensor support int64

Merge pull request !29216 from yangruoqi713/scatter
This commit is contained in:
i-robot 2022-01-18 01:25:29 +00:00 committed by Gitee
commit 6474426e59
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 25 additions and 12 deletions

View File

@ -46,8 +46,6 @@ int ScatterNdUpdateCPUKernel::ReSize() {
auto input = in_tensors_.at(kScatterUpdateInputIndex);
auto indices = in_tensors_.at(kScatterIndicesIndex);
auto update = in_tensors_.at(kScatterUpdateIndex);
auto output = out_tensors_.front();
output_ptr_ = output->data();
// check indices shape
int input_rank = static_cast<int>(input->shape().size());
@ -87,15 +85,32 @@ int ScatterNdUpdateCPUKernel::ReSize() {
param_->num_unit *= update_shape.at(i);
}
int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
MS_ASSERT(indices_ptr != nullptr);
auto indices_ptr = indices->data();
if (indices_ptr == nullptr) {
return RET_OK;
}
output_unit_offsets_.clear();
for (int i = 0; i < param_->num_unit; i++) {
int tmp_stride = 0;
for (int j = 0; j < indice_unit_rank; j++) {
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
if (indices->data_type() == kNumberTypeInt || indices->data_type() == kNumberTypeInt32) {
auto indices_data = reinterpret_cast<int *>(indices_ptr);
for (int i = 0; i < param_->num_unit; i++) {
int tmp_stride = 0;
for (int j = 0; j < indice_unit_rank; j++) {
tmp_stride += indices_data[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
}
output_unit_offsets_.push_back(tmp_stride);
}
output_unit_offsets_.push_back(tmp_stride);
} else if (indices->data_type() == kNumberTypeInt64) {
auto indices_data = reinterpret_cast<int64_t *>(indices_ptr);
for (int i = 0; i < param_->num_unit; i++) {
int tmp_stride = 0;
for (int j = 0; j < indice_unit_rank; j++) {
tmp_stride += indices_data[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
}
output_unit_offsets_.push_back(tmp_stride);
}
} else {
MS_LOG(ERROR) << "Unsupported data type for indices tensor.";
return RET_ERROR;
}
return RET_OK;
}
@ -126,7 +141,7 @@ int ScatterNdUpdateCPUKernel::Run() {
auto in_tensor = in_tensors().front();
auto out_tensor = out_tensors().front();
if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() ||
op_parameter_->is_train_session_) {
in_tensor->own_data() == false || op_parameter_->is_train_session_) {
memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
} else {
out_tensor->FreeData();
@ -134,7 +149,6 @@ int ScatterNdUpdateCPUKernel::Run() {
in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
out_tensor->set_data(in_tensor->data());
out_tensor->set_own_data(in_tensor->own_data());
output_ptr_ = out_tensor->data();
}
auto indices = in_tensors_.at(kScatterIndicesIndex);
if (!indices->IsConst() && ReSize() != RET_OK) {

View File

@ -39,7 +39,6 @@ class ScatterNdUpdateCPUKernel : public InnerKernel {
private:
ScatterNDParameter *param_ = nullptr;
void *output_ptr_ = nullptr;
std::vector<int> output_unit_offsets_;
};
} // namespace mindspore::kernel