forked from mindspore-Ecosystem/mindspore
remove root_tensor
This commit is contained in:
parent
7b20a5adf7
commit
b8bc15abe0
|
@ -106,9 +106,6 @@ class InnerKernel : public Kernel {
|
|||
virtual int FreeInWorkTensor() const {
|
||||
for (auto &in_tensor : this->in_tensors()) {
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
if (in_tensor->root_tensor() == in_tensor) {
|
||||
continue;
|
||||
}
|
||||
in_tensor->DecRefCount();
|
||||
}
|
||||
return lite::RET_OK;
|
||||
|
|
|
@ -106,9 +106,6 @@ class LiteKernel {
|
|||
}
|
||||
for (auto &in_tensor : this->in_tensors()) {
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
if (in_tensor->root_tensor() == in_tensor) {
|
||||
continue;
|
||||
}
|
||||
in_tensor->DecRefCount();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,12 +37,20 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin,
|
|||
MS_LOG(ERROR) << "input tensor or output tensor of merge is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
lite::STATUS ret;
|
||||
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
|
||||
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
|
||||
reinterpret_cast<lite::TensorList *>(src_tensor));
|
||||
lite::STATUS ret = RET_OK;
|
||||
if (src_tensor->IsConst() || src_tensor->IsGraphInput()) {
|
||||
dst_tensor->set_data(src_tensor->data());
|
||||
dst_tensor->set_own_data(false);
|
||||
MS_LOG(ERROR) << "Carry const data and graph inputs.";
|
||||
} else {
|
||||
ret = MoveTensorData(dst_tensor, src_tensor);
|
||||
if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) {
|
||||
MS_LOG(ERROR) << "Carry MoveTensorListData";
|
||||
ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor),
|
||||
reinterpret_cast<lite::TensorList *>(src_tensor));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Carry MoveTensorData";
|
||||
ret = MoveTensorData(dst_tensor, src_tensor);
|
||||
}
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Move data failed : " << ret;
|
||||
|
@ -64,45 +72,33 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_
|
|||
<< "output tensor shape: " << dst_tensor->shape();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (src_tensor->root_tensor() == nullptr) {
|
||||
if (src_tensor->IsConst() || src_tensor->IsGraphInput() || src_tensor->ref_count() > 1) {
|
||||
auto dst_data = dst_tensor->MutableData();
|
||||
if (dst_data == nullptr) {
|
||||
MS_LOG(ERROR) << "data of dst tensor is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto src_data = src_tensor->data_c();
|
||||
MS_ASSERT(src_data != nullptr);
|
||||
memcpy(dst_data, src_data, dst_tensor->Size());
|
||||
} else {
|
||||
dst_tensor->FreeData();
|
||||
dst_tensor->set_data(src_tensor->data_c());
|
||||
dst_tensor->set_own_data(true);
|
||||
src_tensor->set_data(nullptr);
|
||||
src_tensor->set_own_data(true);
|
||||
}
|
||||
} else {
|
||||
dst_tensor->set_root_tensor(src_tensor->root_tensor());
|
||||
if (src_tensor->allocator() == nullptr) {
|
||||
MS_LOG(ERROR) << "src_tensor allocator is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// need replace with increase data ref count
|
||||
memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) {
|
||||
int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist) {
|
||||
// shape may change, because tensors.size() can be change in RunGraph
|
||||
if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) {
|
||||
if (dst_tensorlist->data_type() != src_tensorlist->data_type() ||
|
||||
dst_tensorlist->format() != src_tensorlist->format()) {
|
||||
MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible";
|
||||
MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs "
|
||||
<< "output tensor data_type: " << dst_tensor->data_type()
|
||||
<< "input tensor format: " << src_tensor->format() << " vs "
|
||||
<< "output tensor format: " << dst_tensor->format();
|
||||
MS_LOG(ERROR) << "input tensor data_type: " << src_tensorlist->data_type() << " vs "
|
||||
<< "output tensor data_type: " << dst_tensorlist->data_type()
|
||||
<< "input tensor format: " << src_tensorlist->format() << " vs "
|
||||
<< "output tensor format: " << dst_tensorlist->format();
|
||||
return RET_ERROR;
|
||||
}
|
||||
// when tensorlist malloc is done. this need to check element_shape compatibility
|
||||
dst_tensor->set_element_shape(src_tensor->element_shape());
|
||||
dst_tensorlist->set_element_shape(src_tensorlist->element_shape());
|
||||
|
||||
auto update_data_type = kTypeUnknown;
|
||||
auto dst_tensor_data_type = dst_tensor->tensors_data_type();
|
||||
auto src_tensor_data_type = src_tensor->tensors_data_type();
|
||||
auto dst_tensor_data_type = dst_tensorlist->tensors_data_type();
|
||||
auto src_tensor_data_type = src_tensorlist->tensors_data_type();
|
||||
if (dst_tensor_data_type != src_tensor_data_type) {
|
||||
if (src_tensor_data_type != kTypeUnknown && dst_tensor_data_type != kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "input tensorlist and output tensorlist is incompatible";
|
||||
|
@ -111,15 +107,22 @@ int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::Tens
|
|||
update_data_type = dst_tensor_data_type != kTypeUnknown ? dst_tensor_data_type : src_tensor_data_type;
|
||||
}
|
||||
if (update_data_type != kTypeUnknown) {
|
||||
src_tensor->set_tensors_data_type(update_data_type);
|
||||
dst_tensor->set_tensors_data_type(update_data_type);
|
||||
src_tensorlist->set_tensors_data_type(update_data_type);
|
||||
dst_tensorlist->set_tensors_data_type(update_data_type);
|
||||
}
|
||||
if (src_tensor->root_tensor() == nullptr) {
|
||||
dst_tensor->CopyTensorList(*src_tensor, false);
|
||||
src_tensor->set_tensors({});
|
||||
} else {
|
||||
size_t src_tensorlist_tensors_size = src_tensorlist->tensors().size();
|
||||
for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) {
|
||||
auto &src_tensor = src_tensorlist->tensors()[i];
|
||||
auto &dst_tensor = dst_tensorlist->tensors()[i];
|
||||
|
||||
if (src_tensor->allocator() != nullptr) {
|
||||
src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
|
||||
}
|
||||
dst_tensor->set_own_data(src_tensor->own_data());
|
||||
if (src_tensor->data() != nullptr) {
|
||||
dst_tensor->set_data(src_tensor->data());
|
||||
}
|
||||
dst_tensor->set_shape(src_tensor->shape());
|
||||
dst_tensor->set_root_tensor(src_tensor->root_tensor());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -32,8 +32,8 @@ class CarryDataKernel : public InnerKernel {
|
|||
protected:
|
||||
int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end,
|
||||
std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit);
|
||||
static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
|
||||
static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor);
|
||||
int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor);
|
||||
int MoveTensorListData(lite::TensorList *dst_tensorlist, lite::TensorList *src_tensorlist);
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
#include "src/tensorlist.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class SwitchCPUKernel : public CarryDataKernel {
|
||||
class SwitchCPUKernel : public InnerKernel {
|
||||
public:
|
||||
SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: CarryDataKernel(parameter, inputs, outputs, ctx) {}
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {}
|
||||
~SwitchCPUKernel() override = default;
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
|
|
|
@ -34,9 +34,6 @@ int TensorListGetItemCPUKernel::Run() {
|
|||
MS_ASSERT(in_tensors_.at(1) != nullptr);
|
||||
MS_ASSERT(out_tensors_.at(0) != nullptr);
|
||||
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
|
||||
if (input0->root_tensor() != nullptr) {
|
||||
input0 = reinterpret_cast<lite::TensorList *>(input0->root_tensor());
|
||||
}
|
||||
dtype_ = input0->tensors_data_type();
|
||||
MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr);
|
||||
index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0];
|
||||
|
|
|
@ -278,22 +278,6 @@ std::string Tensor::ToString() const {
|
|||
return oss.str();
|
||||
}
|
||||
|
||||
void Tensor::set_root_tensor(Tensor *tensor) {
|
||||
this->root_tensor_ = tensor;
|
||||
if (this->root_tensor_ == this) {
|
||||
return;
|
||||
}
|
||||
if (this->root_tensor_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
this->shape_ = this->root_tensor_->shape_;
|
||||
this->format_ = this->root_tensor_->format_;
|
||||
this->data_type_ = this->root_tensor_->data_type_;
|
||||
this->category_ = this->root_tensor_->category_;
|
||||
this->quant_params_ = this->root_tensor_->quant_params_;
|
||||
this->quant_clusters_ = this->root_tensor_->quant_clusters_;
|
||||
}
|
||||
|
||||
int Tensor::MallocData(const AllocatorPtr allocator) {
|
||||
if (this->data_ != nullptr) {
|
||||
return RET_OK;
|
||||
|
@ -344,16 +328,6 @@ void *Tensor::ReallocData() {
|
|||
}
|
||||
|
||||
void *Tensor::MutableData() {
|
||||
if (this->root_tensor_ != nullptr) {
|
||||
if (this->root_tensor_ != this && this->root_tensor_->data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "root tensor has not been malloced";
|
||||
return nullptr;
|
||||
} else if (this->root_tensor_ != this && this->root_tensor_->data_ != nullptr) {
|
||||
return this->root_tensor_->data_;
|
||||
} else {
|
||||
// malloc self
|
||||
}
|
||||
}
|
||||
if (this->data_ == nullptr) {
|
||||
auto ret = this->MallocData();
|
||||
if (ret != 0) {
|
||||
|
|
|
@ -119,12 +119,7 @@ class Tensor : public mindspore::tensor::MSTensor {
|
|||
|
||||
void *data() override { return this->data_; }
|
||||
|
||||
virtual void *data_c() const {
|
||||
if (this->root_tensor_ != nullptr) {
|
||||
return this->root_tensor_->data_;
|
||||
}
|
||||
return data_;
|
||||
}
|
||||
virtual void *data_c() const { return data_; }
|
||||
|
||||
void set_data(void *data) override {
|
||||
this->data_ = data;
|
||||
|
@ -188,10 +183,6 @@ class Tensor : public mindspore::tensor::MSTensor {
|
|||
}
|
||||
}
|
||||
|
||||
virtual void set_root_tensor(Tensor *tensor);
|
||||
|
||||
Tensor *root_tensor() const { return this->root_tensor_; }
|
||||
|
||||
bool IsReady() const {
|
||||
return this->IsConst() || (this->IsGraphInput() && this->data_ != nullptr) || ref_count() >= 1;
|
||||
}
|
||||
|
@ -247,7 +238,6 @@ class Tensor : public mindspore::tensor::MSTensor {
|
|||
std::vector<LiteQuantParam> quant_params_;
|
||||
std::vector<float> quant_clusters_;
|
||||
AllocatorPtr allocator_ = nullptr;
|
||||
Tensor *root_tensor_ = nullptr;
|
||||
bool own_data_{false};
|
||||
float scale_ = 1.0f;
|
||||
};
|
||||
|
|
|
@ -209,31 +209,8 @@ int TensorList::CheckTensorListParam() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void TensorList::set_root_tensor(Tensor *tensor) {
|
||||
Tensor::set_root_tensor(tensor);
|
||||
if (this->data_type_ != kObjectTypeTensorType || tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_);
|
||||
this->element_shape_ = root_tensorlist->element_shape_;
|
||||
this->max_elements_num_ = root_tensorlist->max_elements_num_;
|
||||
this->tensors_data_type_ = root_tensorlist->tensors_data_type_;
|
||||
}
|
||||
|
||||
Tensor *TensorList::GetTensor(int index) {
|
||||
// return tensor[index] ptr. With this function, you can modify tensors_[index] at will.
|
||||
if (this->root_tensor_ != nullptr) {
|
||||
if (this->data_type_ != kObjectTypeTensorType) {
|
||||
MS_LOG(ERROR) << "root_tensor of tensorlist should be a tensorlist";
|
||||
return nullptr;
|
||||
}
|
||||
auto root_tensorlist = reinterpret_cast<TensorList *>(this->root_tensor_);
|
||||
if (index < 0 || index >= static_cast<int>(root_tensorlist->tensors_.size())) {
|
||||
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
|
||||
return nullptr;
|
||||
}
|
||||
return root_tensorlist->tensors_[index];
|
||||
}
|
||||
if (index < 0 || index >= static_cast<int>(this->tensors_.size())) {
|
||||
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
|
||||
return nullptr;
|
||||
|
|
|
@ -109,8 +109,6 @@ class TensorList : public Tensor {
|
|||
|
||||
bool IsConst() const override;
|
||||
|
||||
void set_root_tensor(Tensor *tensor) override;
|
||||
|
||||
void set_ref_count(int ref_count) override {
|
||||
ref_count_ = ref_count;
|
||||
for (auto tensor : tensors_) {
|
||||
|
|
|
@ -164,11 +164,7 @@ void ConvertOtherTensor(MetaGraphT *graph, uint32_t index, bool *convert_succ, s
|
|||
lite_tensors->emplace_back(lite_tensor.release());
|
||||
return;
|
||||
}
|
||||
if (lite_tensor->root_tensor() != nullptr) {
|
||||
lite_tensor->root_tensor()->set_data(tensorT->data.data());
|
||||
} else {
|
||||
lite_tensor->set_data(tensorT->data.data());
|
||||
}
|
||||
lite_tensor->set_data(tensorT->data.data());
|
||||
lite_tensors->emplace_back(lite_tensor.release());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue