remove root_tensor

This commit is contained in:
mengyuanli 2021-07-28 16:08:25 +08:00
parent 7b20a5adf7
commit b8bc15abe0
11 changed files with 49 additions and 120 deletions

View File

@ -106,9 +106,6 @@ class InnerKernel : public Kernel {
virtual int FreeInWorkTensor() const { virtual int FreeInWorkTensor() const {
for (auto &in_tensor : this->in_tensors()) { for (auto &in_tensor : this->in_tensors()) {
MS_ASSERT(in_tensor != nullptr); MS_ASSERT(in_tensor != nullptr);
if (in_tensor->root_tensor() == in_tensor) {
continue;
}
in_tensor->DecRefCount(); in_tensor->DecRefCount();
} }
return lite::RET_OK; return lite::RET_OK;

View File

@ -106,9 +106,6 @@ class LiteKernel {
} }
for (auto &in_tensor : this->in_tensors()) { for (auto &in_tensor : this->in_tensors()) {
MS_ASSERT(in_tensor != nullptr); MS_ASSERT(in_tensor != nullptr);
if (in_tensor->root_tensor() == in_tensor) {
continue;
}
in_tensor->DecRefCount(); in_tensor->DecRefCount();
} }
} }

View File

@ -37,12 +37,20 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr"; MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr";
return RET_ERROR; return RET_ERROR;
} }
lite::STATUS ret; lite::STATUS ret = RET_OK;
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { if (src_tensor->IsConst() || src_tensor->IsGraphInput()) {
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor), dst_tensor->set_data(src_tensor->data());
reinterpret_cast<lite::TensorList *>(src_tensor)); dst_tensor->set_own_data(false);
MS_LOG(ERROR) << "Carry const data and graph inputs.";
} else { } else {
ret = MoveTensorData(dst_tensor, src_tensor); if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
MS_LOG(ERROR) << "Carry MoveTensorListData";
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
reinterpret_cast<lite::TensorList *>(src_tensor));
} else {
MS_LOG(ERROR) << "Carry MoveTensorData";
ret = MoveTensorData(dst_tensor, src_tensor);
}
} }
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Move data failed : " << ret; MS_LOG(ERROR) << "Move data failed : " << ret;
@ -64,45 +72,33 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
<< "output tensor shape: " << dst_tensor->shape(); << "output tensor shape: " << dst_tensor->shape();
return RET_ERROR; return RET_ERROR;
} }
if (src_tensor->root_tensor() == nullptr) { if (src_tensor->allocator() == nullptr) {
if (src_tensor->IsConst() || src_tensor->IsGraphInput() || src_tensor->ref_count() > 1) { MS_LOG(ERROR) << "src_tensor allocator is nullptr.";
auto dst_data = dst_tensor->MutableData(); return RET_ERROR;
if (dst_data == nullptr) {
MS_LOG(ERROR) << "data of dst tensor is nullptr";
return RET_ERROR;
}
auto src_data = src_tensor->data_c();
MS_ASSERT(src_data != nullptr);
memcpy(dst_data, src_data, dst_tensor->Size());
} else {
dst_tensor->FreeData();
dst_tensor->set_data(src_tensor->data_c());
dst_tensor->set_own_data(true);
src_tensor->set_data(nullptr);
src_tensor->set_own_data(true);
}
} else {
dst_tensor->set_root_tensor(src_tensor->root_tensor());
} }
// need replace with increase data ref count
memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size());
return RET_OK; return RET_OK;
} }
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist) {
// shape may change, because tensors.size() can be change in RunGraph // shape may change, because tensors.size() can be change in RunGraph
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) { if (dst_tensorlist->data_type() != src_tensorlist->data_type() ||
dst_tensorlist->format() != src_tensorlist->format()) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " MS_LOG(ERROR) << "input tensor data_type: " << src_tensorlist->data_type() << " vs "
<< "output tensor data_type: " << dst_tensor->data_type() << "output tensor data_type: " << dst_tensorlist->data_type()
<< "input tensor format: " << src_tensor->format() << " vs " << "input tensor format: " << src_tensorlist->format() << " vs "
<< "output tensor format: " << dst_tensor->format(); << "output tensor format: " << dst_tensorlist->format();
return RET_ERROR; return RET_ERROR;
} }
// when tensorlist malloc is done. this need to check element_shape compatibility // when tensorlist malloc is done. this need to check element_shape compatibility
dst_tensor->set_element_shape(src_tensor->element_shape()); dst_tensorlist->set_element_shape(src_tensorlist->element_shape());
auto update_data_type = kTypeUnknown; auto update_data_type = kTypeUnknown;
auto dst_tensor_data_type = dst_tensor->tensors_data_type(); auto dst_tensor_data_type = dst_tensorlist->tensors_data_type();
auto src_tensor_data_type = src_tensor->tensors_data_type(); auto src_tensor_data_type = src_tensorlist->tensors_data_type();
if (dst_tensor_data_type != src_tensor_data_type) { if (dst_tensor_data_type != src_tensor_data_type) {
if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) { if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible"; MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible";
@ -111,15 +107,22 @@ int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::Tens
update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type; update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type;
} }
if (update_data_type != kTypeUnknown) { if (update_data_type != kTypeUnknown) {
src_tensor->set_tensors_data_type(update_data_type); src_tensorlist->set_tensors_data_type(update_data_type);
dst_tensor->set_tensors_data_type(update_data_type); dst_tensorlist->set_tensors_data_type(update_data_type);
} }
if (src_tensor->root_tensor() == nullptr) { size_t src_tensorlist_tensors_size = src_tensorlist->tensors().size();
dst_tensor->CopyTensorList(*src_tensor, false); for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) {
src_tensor->set_tensors({}); auto &src_tensor = src_tensorlist->tensors()[i];
} else { auto &dst_tensor = dst_tensorlist->tensors()[i];
if (src_tensor->allocator() != nullptr) {
src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
}
dst_tensor->set_own_data(src_tensor->own_data());
if (src_tensor->data() != nullptr) {
dst_tensor->set_data(src_tensor->data());
}
dst_tensor->set_shape(src_tensor->shape()); dst_tensor->set_shape(src_tensor->shape());
dst_tensor->set_root_tensor(src_tensor->root_tensor());
} }
return RET_OK; return RET_OK;
} }

View File

@ -32,8 +32,8 @@ class CarryDataKernel : public InnerKernel {
protected: protected:
int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end, int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end,
std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit); std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit);
static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); int MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist);
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -22,11 +22,11 @@
#include "src/tensorlist.h" #include "src/tensorlist.h"
namespace mindspore::kernel { namespace mindspore::kernel {
class SwitchCPUKernel : public CarryDataKernel { class SwitchCPUKernel : public InnerKernel {
public: public:
SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: CarryDataKernel(parameter, inputs, outputs, ctx) {} : InnerKernel(parameter, inputs, outputs, ctx) {}
~SwitchCPUKernel() override = default; ~SwitchCPUKernel() override = default;
int Init() override; int Init() override;
int ReSize() override; int ReSize() override;

View File

@ -34,9 +34,6 @@ int TensorListGetItemCPUKernel::Run() {
MS_ASSERT(in_tensors_.at(1) != nullptr); MS_ASSERT(in_tensors_.at(1) != nullptr);
MS_ASSERT(out_tensors_.at(0) != nullptr); MS_ASSERT(out_tensors_.at(0) != nullptr);
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0)); auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
if (input0->root_tensor() != nullptr) {
input0 = reinterpret_cast<lite::TensorList *>(input0->root_tensor());
}
dtype_ = input0->tensors_data_type(); dtype_ = input0->tensors_data_type();
MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr); MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr);
index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0]; index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0];

View File

@ -278,22 +278,6 @@ std::string Tensor::ToString() const {
return oss.str(); return oss.str();
} }
void Tensor::set_root_tensor(Tensor *tensor) {
this->root_tensor_ = tensor;
if (this->root_tensor_ == this) {
return;
}
if (this->root_tensor_ == nullptr) {
return;
}
this->shape_ = this->root_tensor_->shape_;
this->format_ = this->root_tensor_->format_;
this->data_type_ = this->root_tensor_->data_type_;
this->category_ = this->root_tensor_->category_;
this->quant_params_ = this->root_tensor_->quant_params_;
this->quant_clusters_ = this->root_tensor_->quant_clusters_;
}
int Tensor::MallocData(const AllocatorPtr allocator) { int Tensor::MallocData(const AllocatorPtr allocator) {
if (this->data_ != nullptr) { if (this->data_ != nullptr) {
return RET_OK; return RET_OK;
@ -344,16 +328,6 @@ void *Tensor::ReallocData() {
} }
void *Tensor::MutableData() { void *Tensor::MutableData() {
if (this->root_tensor_ != nullptr) {
if (this->root_tensor_ != this && this->root_tensor_->data_ == nullptr) {
MS_LOG(ERROR) << "root tensor has not been malloced";
return nullptr;
} else if (this->root_tensor_ != this && this->root_tensor_->data_ != nullptr) {
return this->root_tensor_->data_;
} else {
// malloc self
}
}
if (this->data_ == nullptr) { if (this->data_ == nullptr) {
auto ret = this->MallocData(); auto ret = this->MallocData();
if (ret != 0) { if (ret != 0) {

View File

@ -119,12 +119,7 @@ class Tensor : public mindspore::tensor::MSTensor {
void *data() override { return this->data_; } void *data() override { return this->data_; }
virtual void *data_c() const { virtual void *data_c() const { return data_; }
if (this->root_tensor_ != nullptr) {
return this->root_tensor_->data_;
}
return data_;
}
void set_data(void *data) override { void set_data(void *data) override {
this->data_ = data; this->data_ = data;
@ -188,10 +183,6 @@ class Tensor : public mindspore::tensor::MSTensor {
} }
} }
virtual void set_root_tensor(Tensor *tensor);
Tensor *root_tensor() const { return this->root_tensor_; }
bool IsReady() const { bool IsReady() const {
return this->IsConst() || (this->IsGraphInput() && this->data_ != nullptr) || ref_count() >= 1; return this->IsConst() || (this->IsGraphInput() && this->data_ != nullptr) || ref_count() >= 1;
} }
@ -247,7 +238,6 @@ class Tensor : public mindspore::tensor::MSTensor {
std::vector<LiteQuantParam> quant_params_; std::vector<LiteQuantParam> quant_params_;
std::vector<float> quant_clusters_; std::vector<float> quant_clusters_;
AllocatorPtr allocator_ = nullptr; AllocatorPtr allocator_ = nullptr;
Tensor *root_tensor_ = nullptr;
bool own_data_{false}; bool own_data_{false};
float scale_ = 1.0f; float scale_ = 1.0f;
}; };

View File

@ -209,31 +209,8 @@ int TensorList::CheckTensorListParam() {
return RET_OK; return RET_OK;
} }
void TensorList::set_root_tensor(Tensor *tensor) {
Tensor::set_root_tensor(tensor);
if (this->data_type_ != kObjectTypeTensorType || tensor == nullptr) {
return;
}
auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_);
this->element_shape_ = root_tensorlist->element_shape_;
this->max_elements_num_ = root_tensorlist->max_elements_num_;
this->tensors_data_type_ = root_tensorlist->tensors_data_type_;
}
Tensor *TensorList::GetTensor(int index) { Tensor *TensorList::GetTensor(int index) {
// return tensor[index] ptr. With this function, you can modify tensors_[index] at will. // return tensor[index] ptr. With this function, you can modify tensors_[index] at will.
if (this->root_tensor_ != nullptr) {
if (this->data_type_ != kObjectTypeTensorType) {
MS_LOG(ERROR) << "root_tensor of tensorlist should be a tensorlist";
return nullptr;
}
auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_);
if (index < 0 || index >= static_cast<int>(root_tensorlist->tensors_.size())) {
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
return nullptr;
}
return root_tensorlist->tensors_[index];
}
if (index < 0 || index >= static_cast<int>(this->tensors_.size())) { if (index < 0 || index >= static_cast<int>(this->tensors_.size())) {
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!"; MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
return nullptr; return nullptr;

View File

@ -109,8 +109,6 @@ class TensorList : public Tensor {
bool IsConst() const override; bool IsConst() const override;
void set_root_tensor(Tensor *tensor) override;
void set_ref_count(int ref_count) override { void set_ref_count(int ref_count) override {
ref_count_ = ref_count; ref_count_ = ref_count;
for (auto tensor : tensors_) { for (auto tensor : tensors_) {

View File

@ -164,11 +164,7 @@ void ConvertOtherTensor(MetaGraphT *graph, uint32_t index, bool *convert_succ, s
lite_tensors->emplace_back(lite_tensor.release()); lite_tensors->emplace_back(lite_tensor.release());
return; return;
} }
if (lite_tensor->root_tensor() != nullptr) { lite_tensor->set_data(tensorT->data.data());
lite_tensor->root_tensor()->set_data(tensorT->data.data());
} else {
lite_tensor->set_data(tensorT->data.data());
}
lite_tensors->emplace_back(lite_tensor.release()); lite_tensors->emplace_back(lite_tensor.release());
} }