remove root_tensor

This commit is contained in:
mengyuanli 2021-07-28 16:08:25 +08:00
parent 7b20a5adf7
commit b8bc15abe0
11 changed files with 49 additions and 120 deletions

View File

@ -106,9 +106,6 @@ class InnerKernel : public Kernel {
virtual int FreeInWorkTensor() const {
for (auto &in_tensor : this->in_tensors()) {
MS_ASSERT(in_tensor != nullptr);
if (in_tensor->root_tensor() == in_tensor) {
continue;
}
in_tensor->DecRefCount();
}
return lite::RET_OK;

View File

@ -106,9 +106,6 @@ class LiteKernel {
}
for (auto &in_tensor : this->in_tensors()) {
MS_ASSERT(in_tensor != nullptr);
if (in_tensor->root_tensor() == in_tensor) {
continue;
}
in_tensor->DecRefCount();
}
}

View File

@ -37,13 +37,21 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr";
return RET_ERROR;
}
lite::STATUS ret;
lite::STATUS ret = RET_OK;
if (src_tensor->IsConst() || src_tensor->IsGraphInput()) {
dst_tensor->set_data(src_tensor->data());
dst_tensor->set_own_data(false);
MS_LOG(ERROR) << "Carry const data and graph inputs.";
} else {
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
MS_LOG(ERROR) << "Carry MoveTensorListData";
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
reinterpret_cast<lite::TensorList *>(src_tensor));
} else {
MS_LOG(ERROR) << "Carry MoveTensorData";
ret = MoveTensorData(dst_tensor, src_tensor);
}
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "Move data failed : " << ret;
return ret;
@ -64,45 +72,33 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
<< "output tensor shape: " << dst_tensor->shape();
return RET_ERROR;
}
if (src_tensor->root_tensor() == nullptr) {
if (src_tensor->IsConst() || src_tensor->IsGraphInput() || src_tensor->ref_count() > 1) {
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
MS_LOG(ERROR) << "data of dst tensor is nullptr";
if (src_tensor->allocator() == nullptr) {
MS_LOG(ERROR) << "src_tensor allocator is nullptr.";
return RET_ERROR;
}
auto src_data = src_tensor->data_c();
MS_ASSERT(src_data != nullptr);
memcpy(dst_data, src_data, dst_tensor->Size());
} else {
dst_tensor->FreeData();
dst_tensor->set_data(src_tensor->data_c());
dst_tensor->set_own_data(true);
src_tensor->set_data(nullptr);
src_tensor->set_own_data(true);
}
} else {
dst_tensor->set_root_tensor(src_tensor->root_tensor());
}
// need replace with increase data ref count
memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size());
return RET_OK;
}
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist) {
// shape may change, because tensors.size() can be change in RunGraph
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) {
if (dst_tensorlist->data_type() != src_tensorlist->data_type() ||
dst_tensorlist->format() != src_tensorlist->format()) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
<< "output tensor data_type: " << dst_tensor->data_type()
<< "input tensor format: " << src_tensor->format() << " vs "
<< "output tensor format: " << dst_tensor->format();
MS_LOG(ERROR) << "input tensor data_type: " << src_tensorlist->data_type() << " vs "
<< "output tensor data_type: " << dst_tensorlist->data_type()
<< "input tensor format: " << src_tensorlist->format() << " vs "
<< "output tensor format: " << dst_tensorlist->format();
return RET_ERROR;
}
// when tensorlist malloc is done. this need to check element_shape compatibility
dst_tensor->set_element_shape(src_tensor->element_shape());
dst_tensorlist->set_element_shape(src_tensorlist->element_shape());
auto update_data_type = kTypeUnknown;
auto dst_tensor_data_type = dst_tensor->tensors_data_type();
auto src_tensor_data_type = src_tensor->tensors_data_type();
auto dst_tensor_data_type = dst_tensorlist->tensors_data_type();
auto src_tensor_data_type = src_tensorlist->tensors_data_type();
if (dst_tensor_data_type != src_tensor_data_type) {
if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) {
MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible";
@ -111,15 +107,22 @@ int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::Tens
update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type;
}
if (update_data_type != kTypeUnknown) {
src_tensor->set_tensors_data_type(update_data_type);
dst_tensor->set_tensors_data_type(update_data_type);
src_tensorlist->set_tensors_data_type(update_data_type);
dst_tensorlist->set_tensors_data_type(update_data_type);
}
size_t src_tensorlist_tensors_size = src_tensorlist->tensors().size();
for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) {
auto &src_tensor = src_tensorlist->tensors()[i];
auto &dst_tensor = dst_tensorlist->tensors()[i];
if (src_tensor->allocator() != nullptr) {
src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
}
dst_tensor->set_own_data(src_tensor->own_data());
if (src_tensor->data() != nullptr) {
dst_tensor->set_data(src_tensor->data());
}
if (src_tensor->root_tensor() == nullptr) {
dst_tensor->CopyTensorList(*src_tensor, false);
src_tensor->set_tensors({});
} else {
dst_tensor->set_shape(src_tensor->shape());
dst_tensor->set_root_tensor(src_tensor->root_tensor());
}
return RET_OK;
}

View File

@ -32,8 +32,8 @@ class CarryDataKernel : public InnerKernel {
protected:
int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end,
std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit);
static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor);
int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
int MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist);
};
} // namespace mindspore::kernel

View File

@ -22,11 +22,11 @@
#include "src/tensorlist.h"
namespace mindspore::kernel {
class SwitchCPUKernel : public CarryDataKernel {
class SwitchCPUKernel : public InnerKernel {
public:
SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: CarryDataKernel(parameter, inputs, outputs, ctx) {}
: InnerKernel(parameter, inputs, outputs, ctx) {}
~SwitchCPUKernel() override = default;
int Init() override;
int ReSize() override;

View File

@ -34,9 +34,6 @@ int TensorListGetItemCPUKernel::Run() {
MS_ASSERT(in_tensors_.at(1) != nullptr);
MS_ASSERT(out_tensors_.at(0) != nullptr);
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
if (input0->root_tensor() != nullptr) {
input0 = reinterpret_cast<lite::TensorList *>(input0->root_tensor());
}
dtype_ = input0->tensors_data_type();
MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr);
index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0];

View File

@ -278,22 +278,6 @@ std::string Tensor::ToString() const {
return oss.str();
}
void Tensor::set_root_tensor(Tensor *tensor) {
this->root_tensor_ = tensor;
if (this->root_tensor_ == this) {
return;
}
if (this->root_tensor_ == nullptr) {
return;
}
this->shape_ = this->root_tensor_->shape_;
this->format_ = this->root_tensor_->format_;
this->data_type_ = this->root_tensor_->data_type_;
this->category_ = this->root_tensor_->category_;
this->quant_params_ = this->root_tensor_->quant_params_;
this->quant_clusters_ = this->root_tensor_->quant_clusters_;
}
int Tensor::MallocData(const AllocatorPtr allocator) {
if (this->data_ != nullptr) {
return RET_OK;
@ -344,16 +328,6 @@ void *Tensor::ReallocData() {
}
void *Tensor::MutableData() {
if (this->root_tensor_ != nullptr) {
if (this->root_tensor_ != this && this->root_tensor_->data_ == nullptr) {
MS_LOG(ERROR) << "root tensor has not been malloced";
return nullptr;
} else if (this->root_tensor_ != this && this->root_tensor_->data_ != nullptr) {
return this->root_tensor_->data_;
} else {
// malloc self
}
}
if (this->data_ == nullptr) {
auto ret = this->MallocData();
if (ret != 0) {

View File

@ -119,12 +119,7 @@ class Tensor : public mindspore::tensor::MSTensor {
void *data() override { return this->data_; }
virtual void *data_c() const {
if (this->root_tensor_ != nullptr) {
return this->root_tensor_->data_;
}
return data_;
}
virtual void *data_c() const { return data_; }
void set_data(void *data) override {
this->data_ = data;
@ -188,10 +183,6 @@ class Tensor : public mindspore::tensor::MSTensor {
}
}
virtual void set_root_tensor(Tensor *tensor);
Tensor *root_tensor() const { return this->root_tensor_; }
bool IsReady() const {
return this->IsConst() || (this->IsGraphInput() && this->data_ != nullptr) || ref_count() >= 1;
}
@ -247,7 +238,6 @@ class Tensor : public mindspore::tensor::MSTensor {
std::vector<LiteQuantParam> quant_params_;
std::vector<float> quant_clusters_;
AllocatorPtr allocator_ = nullptr;
Tensor *root_tensor_ = nullptr;
bool own_data_{false};
float scale_ = 1.0f;
};

View File

@ -209,31 +209,8 @@ int TensorList::CheckTensorListParam() {
return RET_OK;
}
void TensorList::set_root_tensor(Tensor *tensor) {
Tensor::set_root_tensor(tensor);
if (this->data_type_ != kObjectTypeTensorType || tensor == nullptr) {
return;
}
auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_);
this->element_shape_ = root_tensorlist->element_shape_;
this->max_elements_num_ = root_tensorlist->max_elements_num_;
this->tensors_data_type_ = root_tensorlist->tensors_data_type_;
}
Tensor *TensorList::GetTensor(int index) {
// return tensor[index] ptr. With this function, you can modify tensors_[index] at will.
if (this->root_tensor_ != nullptr) {
if (this->data_type_ != kObjectTypeTensorType) {
MS_LOG(ERROR) << "root_tensor of tensorlist should be a tensorlist";
return nullptr;
}
auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_);
if (index < 0 || index >= static_cast<int>(root_tensorlist->tensors_.size())) {
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
return nullptr;
}
return root_tensorlist->tensors_[index];
}
if (index < 0 || index >= static_cast<int>(this->tensors_.size())) {
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
return nullptr;

View File

@ -109,8 +109,6 @@ class TensorList : public Tensor {
bool IsConst() const override;
void set_root_tensor(Tensor *tensor) override;
void set_ref_count(int ref_count) override {
ref_count_ = ref_count;
for (auto tensor : tensors_) {

View File

@ -164,11 +164,7 @@ void ConvertOtherTensor(MetaGraphT *graph, uint32_t index, bool *convert_succ, s
lite_tensors->emplace_back(lite_tensor.release());
return;
}
if (lite_tensor->root_tensor() != nullptr) {
lite_tensor->root_tensor()->set_data(tensorT->data.data());
} else {
lite_tensor->set_data(tensorT->data.data());
}
lite_tensors->emplace_back(lite_tensor.release());
}