diff --git a/mindspore/lite/src/inner_kernel.h b/mindspore/lite/src/inner_kernel.h index 6ce8ba9c3ff..93c490544be 100644 --- a/mindspore/lite/src/inner_kernel.h +++ b/mindspore/lite/src/inner_kernel.h @@ -106,9 +106,6 @@ class InnerKernel : public Kernel { virtual int FreeInWorkTensor() const { for (auto &in_tensor : this->in_tensors()) { MS_ASSERT(in_tensor != nullptr); - if (in_tensor->root_tensor() == in_tensor) { - continue; - } in_tensor->DecRefCount(); } return lite::RET_OK; diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 2fa67dd7eb4..55456c46c05 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -106,9 +106,6 @@ class LiteKernel { } for (auto &in_tensor : this->in_tensors()) { MS_ASSERT(in_tensor != nullptr); - if (in_tensor->root_tensor() == in_tensor) { - continue; - } in_tensor->DecRefCount(); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc index 31c4e212a7b..8c9db2bf95c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.cc @@ -37,12 +37,20 @@ int CarryDataKernel::MoveData(std::vector::iterator dst_begin, MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr"; return RET_ERROR; } - lite::STATUS ret; - if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { - ret = MoveTensorListData(reinterpret_cast(dst_tensor), - reinterpret_cast(src_tensor)); + lite::STATUS ret = RET_OK; + if (src_tensor->IsConst() || src_tensor->IsGraphInput()) { + dst_tensor->set_data(src_tensor->data()); + dst_tensor->set_own_data(false); + MS_LOG(ERROR) << "Carry const data and graph inputs."; } else { - ret = MoveTensorData(dst_tensor, src_tensor); + if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { + MS_LOG(ERROR) << "Carry MoveTensorListData"; + ret = MoveTensorListData(reinterpret_cast(dst_tensor), + reinterpret_cast(src_tensor)); + } else { + MS_LOG(ERROR) << "Carry MoveTensorData"; + ret = MoveTensorData(dst_tensor, src_tensor); + } } if (ret != RET_OK) { MS_LOG(ERROR) << "Move data failed : " << ret; @@ -64,45 +72,33 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_ << "output tensor shape: " << dst_tensor->shape(); return RET_ERROR; } - if (src_tensor->root_tensor() == nullptr) { - if (src_tensor->IsConst() || src_tensor->IsGraphInput() || src_tensor->ref_count() > 1) { - auto dst_data = dst_tensor->MutableData(); - if (dst_data == nullptr) { - MS_LOG(ERROR) << "data of dst tensor is nullptr"; - return RET_ERROR; - } - auto src_data = src_tensor->data_c(); - MS_ASSERT(src_data != nullptr); - memcpy(dst_data, src_data, dst_tensor->Size()); - } else { - dst_tensor->FreeData(); - dst_tensor->set_data(src_tensor->data_c()); - dst_tensor->set_own_data(true); - src_tensor->set_data(nullptr); - src_tensor->set_own_data(true); - } - } else { - dst_tensor->set_root_tensor(src_tensor->root_tensor()); + if (src_tensor->allocator() == nullptr) { + MS_LOG(ERROR) << "src_tensor allocator is nullptr."; + return RET_ERROR; } + + // need replace with increase data ref count + memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size()); return RET_OK; } -int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { +int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist) { // shape may change, because tensors.size() can be change in RunGraph - if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) { + if (dst_tensorlist->data_type() != src_tensorlist->data_type() || + dst_tensorlist->format() != src_tensorlist->format()) { MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; - MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " - << "output tensor data_type: " << dst_tensor->data_type() - << "input tensor format: " << src_tensor->format() << " vs " - << "output tensor format: " << dst_tensor->format(); + MS_LOG(ERROR) << "input tensor data_type: " << src_tensorlist->data_type() << " vs " + << "output tensor data_type: " << dst_tensorlist->data_type() + << "input tensor format: " << src_tensorlist->format() << " vs " + << "output tensor format: " << dst_tensorlist->format(); return RET_ERROR; } // when tensorlist malloc is done. this need to check element_shape compatibility - dst_tensor->set_element_shape(src_tensor->element_shape()); + dst_tensorlist->set_element_shape(src_tensorlist->element_shape()); auto update_data_type = kTypeUnknown; - auto dst_tensor_data_type = dst_tensor->tensors_data_type(); - auto src_tensor_data_type = src_tensor->tensors_data_type(); + auto dst_tensor_data_type = dst_tensorlist->tensors_data_type(); + auto src_tensor_data_type = src_tensorlist->tensors_data_type(); if (dst_tensor_data_type != src_tensor_data_type) { if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) { MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible"; @@ -111,15 +107,22 @@ int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::Tens update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type; } if (update_data_type != kTypeUnknown) { - src_tensor->set_tensors_data_type(update_data_type); - dst_tensor->set_tensors_data_type(update_data_type); + src_tensorlist->set_tensors_data_type(update_data_type); + dst_tensorlist->set_tensors_data_type(update_data_type); } - if (src_tensor->root_tensor() == nullptr) { - dst_tensor->CopyTensorList(*src_tensor, false); - src_tensor->set_tensors({}); - } else { + size_t src_tensorlist_tensors_size = src_tensorlist->tensors().size(); + for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) { + auto &src_tensor = src_tensorlist->tensors()[i]; + auto &dst_tensor = dst_tensorlist->tensors()[i]; + + if (src_tensor->allocator() != nullptr) { + src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count()); + } + dst_tensor->set_own_data(src_tensor->own_data()); + if (src_tensor->data() != nullptr) { + dst_tensor->set_data(src_tensor->data()); + } dst_tensor->set_shape(src_tensor->shape()); - dst_tensor->set_root_tensor(src_tensor->root_tensor()); } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h index 465cb70b502..7f79fd94707 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/carry_data.h @@ -32,8 +32,8 @@ class CarryDataKernel : public InnerKernel { protected: int MoveData(std::vector::iterator dst_begin, std::vector::iterator dst_end, std::vector::iterator src_begin, std::vector::iterator src_limit); - static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); - static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); + int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); + int MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist); }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/switch.h b/mindspore/lite/src/runtime/kernel/arm/base/switch.h index 219158014cf..8f9439c0d92 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/switch.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/switch.h @@ -22,11 +22,11 @@ #include "src/tensorlist.h" namespace mindspore::kernel { -class SwitchCPUKernel : public CarryDataKernel { +class SwitchCPUKernel : public InnerKernel { public: SwitchCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : CarryDataKernel(parameter, inputs, outputs, ctx) {} + : InnerKernel(parameter, inputs, outputs, ctx) {} ~SwitchCPUKernel() override = default; int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/tensorlist_getitem.cc b/mindspore/lite/src/runtime/kernel/arm/base/tensorlist_getitem.cc index a01511afe2c..ea54d8a9b01 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/tensorlist_getitem.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/tensorlist_getitem.cc @@ -34,9 +34,6 @@ int TensorListGetItemCPUKernel::Run() { MS_ASSERT(in_tensors_.at(1) != nullptr); MS_ASSERT(out_tensors_.at(0) != nullptr); auto input0 = reinterpret_cast(in_tensors_.at(0)); - if (input0->root_tensor() != nullptr) { - input0 = reinterpret_cast(input0->root_tensor()); - } dtype_ = input0->tensors_data_type(); MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr); index_ = reinterpret_cast(in_tensors_.at(1)->data_c())[0]; diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index 6277895f8b0..93822eb96e3 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -278,22 +278,6 @@ std::string Tensor::ToString() const { return oss.str(); } -void Tensor::set_root_tensor(Tensor *tensor) { - this->root_tensor_ = tensor; - if (this->root_tensor_ == this) { - return; - } - if (this->root_tensor_ == nullptr) { - return; - } - this->shape_ = this->root_tensor_->shape_; - this->format_ = this->root_tensor_->format_; - this->data_type_ = this->root_tensor_->data_type_; - this->category_ = this->root_tensor_->category_; - this->quant_params_ = this->root_tensor_->quant_params_; - this->quant_clusters_ = this->root_tensor_->quant_clusters_; -} - int Tensor::MallocData(const AllocatorPtr allocator) { if (this->data_ != nullptr) { return RET_OK; @@ -344,16 +328,6 @@ void *Tensor::ReallocData() { } void *Tensor::MutableData() { - if (this->root_tensor_ != nullptr) { - if (this->root_tensor_ != this && this->root_tensor_->data_ == nullptr) { - MS_LOG(ERROR) << "root tensor has not been malloced"; - return nullptr; - } else if (this->root_tensor_ != this && this->root_tensor_->data_ != nullptr) { - return this->root_tensor_->data_; - } else { - // malloc self - } - } if (this->data_ == nullptr) { auto ret = this->MallocData(); if (ret != 0) { diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index bb7aacc3561..1933aeec957 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -119,12 +119,7 @@ class Tensor : public mindspore::tensor::MSTensor { void *data() override { return this->data_; } - virtual void *data_c() const { - if (this->root_tensor_ != nullptr) { - return this->root_tensor_->data_; - } - return data_; - } + virtual void *data_c() const { return data_; } void set_data(void *data) override { this->data_ = data; @@ -188,10 +183,6 @@ class Tensor : public mindspore::tensor::MSTensor { } } - virtual void set_root_tensor(Tensor *tensor); - - Tensor *root_tensor() const { return this->root_tensor_; } - bool IsReady() const { return this->IsConst() || (this->IsGraphInput() && this->data_ != nullptr) || ref_count() >= 1; } @@ -247,7 +238,6 @@ class Tensor : public mindspore::tensor::MSTensor { std::vector quant_params_; std::vector quant_clusters_; AllocatorPtr allocator_ = nullptr; - Tensor *root_tensor_ = nullptr; bool own_data_{false}; float scale_ = 1.0f; }; diff --git a/mindspore/lite/src/tensorlist.cc b/mindspore/lite/src/tensorlist.cc index 472478ccbe8..a325890d6ca 100644 --- a/mindspore/lite/src/tensorlist.cc +++ b/mindspore/lite/src/tensorlist.cc @@ -209,31 +209,8 @@ int TensorList::CheckTensorListParam() { return RET_OK; } -void TensorList::set_root_tensor(Tensor *tensor) { - Tensor::set_root_tensor(tensor); - if (this->data_type_ != kObjectTypeTensorType || tensor == nullptr) { - return; - } - auto root_tensorlist = reinterpret_cast(this->root_tensor_); - this->element_shape_ = root_tensorlist->element_shape_; - this->max_elements_num_ = root_tensorlist->max_elements_num_; - this->tensors_data_type_ = root_tensorlist->tensors_data_type_; -} - Tensor *TensorList::GetTensor(int index) { // return tensor[index] ptr. With this function, you can modify tensors_[index] at will. - if (this->root_tensor_ != nullptr) { - if (this->data_type_ != kObjectTypeTensorType) { - MS_LOG(ERROR) << "root_tensor of tensorlist should be a tensorlist"; - return nullptr; - } - auto root_tensorlist = reinterpret_cast(this->root_tensor_); - if (index < 0 || index >= static_cast(root_tensorlist->tensors_.size())) { - MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!"; - return nullptr; - } - return root_tensorlist->tensors_[index]; - } if (index < 0 || index >= static_cast(this->tensors_.size())) { MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!"; return nullptr; diff --git a/mindspore/lite/src/tensorlist.h b/mindspore/lite/src/tensorlist.h index e285078ce57..d03ee57bd2d 100644 --- a/mindspore/lite/src/tensorlist.h +++ b/mindspore/lite/src/tensorlist.h @@ -109,8 +109,6 @@ class TensorList : public Tensor { bool IsConst() const override; - void set_root_tensor(Tensor *tensor) override; - void set_ref_count(int ref_count) override { ref_count_ = ref_count; for (auto tensor : tensors_) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index 7a39ffefc14..6c320553eb2 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -164,11 +164,7 @@ void ConvertOtherTensor(MetaGraphT *graph, uint32_t index, bool *convert_succ, s lite_tensors->emplace_back(lite_tensor.release()); return; } - if (lite_tensor->root_tensor() != nullptr) { - lite_tensor->root_tensor()->set_data(tensorT->data.data()); - } else { - lite_tensor->set_data(tensorT->data.data()); - } + lite_tensor->set_data(tensorT->data.data()); lite_tensors->emplace_back(lite_tensor.release()); }