!40402 [MS][LITE]optimize subgraph resize

Merge pull request !40402 from mengyuanli/pc500
This commit is contained in:
i-robot 2022-08-18 11:26:11 +00:00 committed by Gitee
commit aecaffab18
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 321 additions and 130 deletions

View File

@ -24,6 +24,7 @@
#ifdef ENABLE_FP16
#include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
#endif
#include "nnacl/base/cast_base.h"
namespace mindspore {
namespace lite {
int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors, std::vector<TensorC *> *tensors_c,
@ -269,11 +270,11 @@ std::vector<mindspore::MSTensor> LiteTensorsToMSTensors(const std::vector<lite::
return tensors;
}
void MoveCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
int MoveCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
MS_ASSERT(src_tensor != dst_tensor);
if (src_tensor->data() == dst_tensor->data()) {
MS_LOG(DEBUG) << "no need to move data.";
return;
return RET_OK;
}
dst_tensor->FreeData();
dst_tensor->ResetRefCount();
@ -285,20 +286,23 @@ void MoveCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
dst_tensor->set_own_data(src_tensor->own_data());
src_tensor->DecRefCount();
return RET_OK;
}
void MoveTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
int MoveTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
if (src_tensor == dst_tensor) {
MS_LOG(INFO) << "no need to move.";
return;
return RET_OK;
}
MS_ASSERT(src_tensor->allocator() != nullptr);
auto ret = RET_OK;
if (src_tensor->data_type() == kObjectTypeTensorType) {
MoveTensorListTensorData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
ret =
MoveTensorListTensorData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
} else {
MoveCommonTensorData(dst_tensor, src_tensor);
ret = MoveCommonTensorData(dst_tensor, src_tensor);
}
return;
return ret;
}
void SetCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
@ -306,25 +310,15 @@ void SetCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
dst_tensor->set_own_data(false);
}
void SetTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
int SetTensorData(Tensor *dst_tensor, Tensor *src_tensor) {
auto ret = RET_OK;
if (src_tensor->data_type() == kObjectTypeTensorType) {
SetTensorListTensorData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
ret =
SetTensorListTensorData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
} else {
SetCommonTensorData(dst_tensor, src_tensor);
}
}
void SetTensorShape(Tensor *dst, Tensor *src) {
dst->set_shape(src->shape());
dst->set_format(src->format());
}
bool NeedCastData(Tensor *dst_tensor, Tensor *src_tensor) {
if (dst_tensor->data_type() != kObjectTypeTensorType && src_tensor->data_type() != kObjectTypeTensorType &&
dst_tensor->data_type() != src_tensor->data_type()) {
return true;
}
return NeedCastTensorListData(dst_tensor, src_tensor);
return ret;
}
int CastTensorData(Tensor *dst, Tensor *src, bool support_fp16) {
@ -342,7 +336,6 @@ int CastTensorData(Tensor *dst, Tensor *src, bool support_fp16) {
int CastCommonTensorData(Tensor *dst, Tensor *src, bool support_fp16) {
dst->ReallocData();
dst->ResetRefCount();
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
if (dst->shape() != src->shape()) {
MS_LOG(ERROR) << "dst tensor: " << dst->tensor_name() << " shape: " << dst->shape() << " vs "
<< "src tensor: " << src->tensor_name() << " shape: " << src->shape();
@ -353,27 +346,72 @@ int CastCommonTensorData(Tensor *dst, Tensor *src, bool support_fp16) {
auto src_nums_size = src->ElementsNum();
auto dst_data_type = static_cast<int>(dst->data_type());
auto src_data_type = static_cast<int>(src->data_type());
// Some case dst data type is unknown, we will set to float32. In this case, need case is true, but actually no need
// cast data
if (dst_data_type == src_data_type) {
memcpy(dst_data, src_data, src_nums_size);
return RET_OK;
}
if (dst_data_type == kNumberTypeFloat32 && src_data_type == kNumberTypeFloat16) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
Float16ToFloat32_fp16_handler(src_data, dst_data, src_nums_size, support_fp16);
#else
MS_LOG(ERROR) << "not enable fp16.";
return RET_NOT_SUPPORT;
#endif
} else if (dst_data_type == kNumberTypeFloat16 && src_data_type == kNumberTypeFloat32) {
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
Float32ToFloat16_fp16_handler(src_data, dst_data, src_nums_size, support_fp16);
#else
MS_LOG(ERROR) << "not enable fp16.";
return RET_NOT_SUPPORT;
#endif
} else if (dst_data_type == kNumberTypeFloat32 && src_data_type == kNumberTypeInt32) {
Int32ToFloat32(static_cast<const int32_t *>(src_data), static_cast<float *>(dst_data), src_nums_size);
} else if (dst_data_type == kNumberTypeInt32 && src_data_type == kNumberTypeFloat32) {
Float32ToInt32(static_cast<const float *>(src_data), static_cast<int32_t *>(dst_data), src_nums_size);
} else {
MS_LOG(ERROR) << "not support dst_data_type: " << dst_data_type << " src_data_type: " << src_data_type;
return RET_NOT_SUPPORT;
}
return RET_OK;
#endif
return RET_ERROR;
}
bool NeedCastData(Tensor *dst_tensor, Tensor *src_tensor) {
if (IsUnKnownDtype(dst_tensor) || IsUnKnownDtype(src_tensor)) {
MS_LOG(INFO) << "Type unknown, no need cast.";
return false;
}
return !IsSameDtype(dst_tensor, src_tensor);
}
#ifndef CONTROLFLOW_TENSORLIST_CLIP
bool NeedCastTensorListData(Tensor *dst_tensor, Tensor *src_tensor) {
if (dst_tensor->data_type() == kObjectTypeTensorType && src_tensor->data_type() == kObjectTypeTensorType &&
reinterpret_cast<TensorList *>(dst_tensor)->tensors_data_type() !=
reinterpret_cast<TensorList *>(src_tensor)->tensors_data_type()) {
return true;
int SetTensorShape(Tensor *dst, Tensor *src) {
if (dst->data_type() != kObjectTypeTensorType && src->data_type() != kObjectTypeTensorType) {
dst->set_shape(src->shape());
dst->set_format(src->format());
return RET_OK;
} else if (dst->data_type() == kObjectTypeTensorType && src->data_type() == kObjectTypeTensorType) {
auto input_tensorlist = reinterpret_cast<TensorList *>(dst);
auto input_data_tensorlist = reinterpret_cast<TensorList *>(src);
MS_CHECK_FALSE_MSG(input_tensorlist == nullptr, RET_ERROR, "cast to tensorlist failed.");
MS_CHECK_FALSE_MSG(input_data_tensorlist == nullptr, RET_ERROR, "cast to tensorlist failed.");
input_tensorlist->set_element_shape(input_data_tensorlist->element_shape());
// because some model shape is not same as tensors().size(), we need the real shape, which is the tensors().size().
int real_shape_val = static_cast<int>(input_data_tensorlist->tensors().size());
std::vector<int> real_shape{real_shape_val};
input_tensorlist->set_shape(real_shape);
// hard code for some model
if (input_data_tensorlist->tensors_data_type() != kTypeUnknown &&
input_tensorlist->tensors_data_type() == kTypeUnknown) {
input_tensorlist->set_tensors_data_type(input_data_tensorlist->tensors_data_type());
}
return RET_OK;
} else {
MS_LOG(ERROR) << "not able to set tensor shape between tensor and tensorlist.";
return RET_ERROR;
}
return false;
}
int CastTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist, bool support_fp16) {
@ -397,33 +435,15 @@ int CastTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorl
dst_tensorlist->ResetRefCount();
for (size_t i = 0; i < src_tensorlist->tensors().size(); ++i) {
auto &src_tensor = src_tensorlist->tensors()[i];
auto &dst_tensor = dst_tensorlist->tensors()[i];
CastCommonTensorData(dst_tensor, src_tensor, support_fp16);
auto src_tensor = src_tensorlist->tensors()[i];
auto dst_tensor = dst_tensorlist->tensors()[i];
auto ret = CastCommonTensorData(dst_tensor, src_tensor, support_fp16);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "cast tensor data failed.");
}
return RET_OK;
}
void SetTensorListShape(Tensor *dst, Tensor *src) {
auto input_tensorlist = reinterpret_cast<TensorList *>(dst);
auto input_data_tensorlist = reinterpret_cast<TensorList *>(src);
if (input_data_tensorlist == nullptr || input_tensorlist == nullptr) {
MS_LOG(ERROR) << "cast to tensorlist failed.";
return;
}
input_tensorlist->FreeTensorListData();
input_tensorlist->set_element_shape(input_data_tensorlist->element_shape());
input_tensorlist->set_shape(input_data_tensorlist->shape());
std::vector<std::vector<int>> tensor_shape{};
std::transform(input_data_tensorlist->tensors().begin(), input_data_tensorlist->tensors().end(),
std::back_inserter(tensor_shape), [](const Tensor *tensor_item) { return tensor_item->shape(); });
if (input_data_tensorlist->shape().empty()) {
return;
}
input_tensorlist->MallocTensorListData(input_data_tensorlist->tensors_data_type(), tensor_shape);
}
void MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
int MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
MS_ASSERT(src_tensorlist != nullptr);
MS_ASSERT(dst_tensorlist != nullptr);
dst_tensorlist->FreeData();
@ -437,13 +457,16 @@ void MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensor
<< " tesnors size: " << src_tensorlist_tensors_size
<< " vs dst tensorlist: " << dst_tensorlist->tensor_name()
<< " tensors size: " << dst_tensorlist_tensors_size;
return;
return RET_ERROR;
}
// hard code for some model
dst_tensorlist->set_tensors_data_type(src_tensorlist->tensors_data_type());
dst_tensorlist->set_own_data(src_tensorlist->own_data());
for (size_t i = 0; i < src_tensorlist_tensors_size; ++i) {
auto &src_tensor = src_tensorlist->tensors()[i];
auto &dst_tensor = dst_tensorlist->tensors()[i];
auto src_tensor = src_tensorlist->tensors()[i];
auto dst_tensor = dst_tensorlist->tensors()[i];
dst_tensor->set_own_data(src_tensor->own_data());
if (src_tensor->data() != nullptr) {
@ -457,14 +480,17 @@ void MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensor
} else {
src_tensorlist->DecRefCount();
}
return RET_OK;
}
void SetTensorListTensorData(TensorList *dst_tensor_list, TensorList *src_tensor_list) {
dst_tensor_list->FreeTensorListData();
int SetTensorListTensorData(TensorList *dst_tensor_list, TensorList *src_tensor_list) {
auto ret = dst_tensor_list->FreeTensorListData();
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "FreeTensorListData failed.");
dst_tensor_list->set_own_data(false);
dst_tensor_list->set_tensors(src_tensor_list->tensors());
dst_tensor_list->set_tensors_data_type(src_tensor_list->tensors_data_type());
dst_tensor_list->set_element_shape(src_tensor_list->element_shape());
return RET_OK;
}
void FreeTensorListC(TensorListC *tensorlist_c, std::shared_ptr<Allocator> allocator) {
@ -628,28 +654,94 @@ void SetTensorListTensorDataType(const TypeId &data_type, Tensor *tensor) {
}
}
bool IsSameDtype(const Tensor *input_1, const Tensor *input_2) {
if (input_1->data_type() != kObjectTypeTensorType && input_2->data_type() != kObjectTypeTensorType) {
return input_1->data_type() == input_2->data_type();
} else if (input_1->data_type() == kObjectTypeTensorType && input_2->data_type() == kObjectTypeTensorType) {
auto input_tensor_list_1 = reinterpret_cast<const TensorList *>(input_1);
auto input_tensor_list_2 = reinterpret_cast<const TensorList *>(input_2);
return input_tensor_list_1->tensors_data_type() == input_tensor_list_2->tensors_data_type();
} else {
return false;
}
}
bool IsSameShape(const Tensor *input_1, const Tensor *input_2) {
if (input_1->data_type() != kObjectTypeTensorType && input_2->data_type() != kObjectTypeTensorType) {
return input_1->shape() == input_2->shape();
} else if (input_1->data_type() == kObjectTypeTensorType && input_2->data_type() == kObjectTypeTensorType) {
auto input_tensor_list_1 = reinterpret_cast<const TensorList *>(input_1);
auto input_tensor_list_2 = reinterpret_cast<const TensorList *>(input_2);
return input_tensor_list_1->shape() == input_tensor_list_2->shape() &&
input_tensor_list_1->element_shape() == input_tensor_list_2->element_shape();
} else {
return false;
}
}
int MallocTensorData(Tensor *tensor) {
auto ret = RET_OK;
if (tensor->data_type() != kObjectTypeTensorType) {
tensor->FreeData();
auto size = tensor->ElementsNum();
if (size <= 0) {
return RET_OK;
}
ret = tensor->MallocData();
} else {
auto tensor_list = reinterpret_cast<TensorList *>(tensor);
ret = tensor_list->FreeTensorListData();
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "free tensor list data failed.");
auto size = tensor->ElementsNum();
if (size <= 0) {
return RET_OK;
}
std::vector<std::vector<int>> tensors_shape{};
for (int i = 0; i < size; ++i) {
tensors_shape.push_back(tensor_list->element_shape());
}
ret = tensor_list->MallocTensorListData(tensor_list->tensors_data_type(), tensors_shape);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "malloc tensor list data failed.");
}
return ret;
}
bool IsUnKnownDtype(const Tensor *input) {
if (input->data_type() == kTypeUnknown) {
return true;
} else if (input->data_type() == kObjectTypeTensorType) {
auto input_tensor_list = reinterpret_cast<const TensorList *>(input);
return input_tensor_list->tensors_data_type() == kTypeUnknown;
}
return false;
}
#else
bool NeedCastTensorListData(Tensor *dst_tensor, Tensor *src_tensor) { return false; }
int SetTensorShape(Tensor *dst, Tensor *src) {
if (dst->data_type() != kObjectTypeTensorType && src->data_type() != kObjectTypeTensorType) {
dst->set_shape(src->shape());
dst->set_format(src->format());
return RET_OK;
} else {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return RET_ERROR;
}
}
int CastTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist, bool support_fp16) {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return RET_OK;
return RET_ERROR;
}
void SetTensorListShape(Tensor *dst, Tensor *src) {
int MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return;
return RET_ERROR;
}
void MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
int SetTensorListTensorData(TensorList *dst_tensor_list, TensorList *src_tensor_list) {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return;
}
void SetTensorListTensorData(TensorList *dst_tensor_list, TensorList *src_tensor_list) {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return;
return RET_ERROR;
}
void FreeTensorListC(TensorListC *tensorlist_c, std::shared_ptr<Allocator> allocator) {
@ -697,6 +789,47 @@ void SetTensorListTensorDataType(const TypeId &data_type, Tensor *tensor) {
return;
}
bool IsSameDtype(const Tensor *input_1, const Tensor *input_2) {
if (input_1->data_type() != kObjectTypeTensorType && input_2->data_type() != kObjectTypeTensorType) {
return input_1->data_type() == input_2->data_type();
} else {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return false;
}
}
bool IsSameShape(const Tensor *input_1, const Tensor *input_2) {
if (input_1->data_type() != kObjectTypeTensorType && input_2->data_type() != kObjectTypeTensorType) {
return input_1->shape() == input_2->shape();
} else {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return false;
}
}
int MallocTensorData(Tensor *tensor) {
auto ret = RET_OK;
if (tensor->data_type() != kObjectTypeTensorType) {
tensor->FreeData();
auto size = tensor->ElementsNum();
if (size <= 0) {
return RET_OK;
}
ret = tensor->MallocData();
} else {
MS_LOG(ERROR) << unsupport_controlflow_tensorlist_log;
return RET_ERROR;
}
return ret;
}
bool IsUnKnownDtype(const Tensor *input) {
if (input->data_type() == kTypeUnknown) {
return true;
}
return false;
}
#endif
} // namespace lite
} // namespace mindspore

View File

@ -45,15 +45,13 @@ int CheckTensorsInvalid(const std::vector<Tensor *> &tensors);
int CheckGraphInputShapes(const std::vector<Tensor *> &inputs,
const std::unordered_map<Tensor *, std::vector<int>> &input_shape_map);
std::vector<mindspore::MSTensor> LiteTensorsToMSTensors(const std::vector<lite::Tensor *> &lite_tensors);
void MoveCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor);
void MoveTensorData(Tensor *dst_tensor, Tensor *src_tensor);
void SetTensorData(Tensor *dst_tensor, Tensor *src_tensor);
int MoveCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor);
int MoveTensorData(Tensor *dst_tensor, Tensor *src_tensor);
int SetTensorData(Tensor *dst_tensor, Tensor *src_tensor);
void SetCommonTensorData(Tensor *dst_tensor, Tensor *src_tensor);
void MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist);
void SetTensorListTensorData(TensorList *dst_tensor_list, TensorList *src_tensor_list);
void SetTensorShape(Tensor *dst, Tensor *src);
void SetTensorListShape(Tensor *dst, Tensor *src);
bool NeedCastTensorListData(Tensor *dst_tensor, Tensor *src_tensor);
int MoveTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorlist);
int SetTensorListTensorData(TensorList *dst_tensor_list, TensorList *src_tensor_list);
int SetTensorShape(Tensor *dst, Tensor *src);
bool NeedCastData(Tensor *dst_tensor, Tensor *src_tensor);
int CastTensorData(Tensor *dst, Tensor *src, bool support_fp16);
int CastCommonTensorData(Tensor *dst, Tensor *src, bool support_fp16);
@ -64,6 +62,10 @@ int DecodeTensorLsit(Tensor *tensor, const int *src_data);
Tensor *CreateTensorList(const std::vector<int> &shape, const Category &src_category, const void *src_data);
int CopyTensorListTensorDataType(TensorList *dst_tensorlist, TensorList *src_tensorlist);
void SetTensorListTensorDataType(const TypeId &data_type, Tensor *tensor);
bool IsSameDtype(const Tensor *input_1, const Tensor *input_2);
bool IsUnKnownDtype(const Tensor *input);
bool IsSameShape(const Tensor *input_1, const Tensor *input_2);
int MallocTensorData(Tensor *tensor);
} // namespace lite
} // namespace mindspore

View File

@ -59,18 +59,16 @@ int LiteEntranceOpActor::InitInputData() {
}
int LiteEntranceOpActor::SetInputShape() {
auto ret = RET_OK;
for (size_t i = 0; i < inputs_data_.size(); ++i) {
auto &output_tensor = kernel_->out_tensors()[i + 1];
if (output_tensor->shape() == inputs_data_[i]->shape()) {
continue;
}
if (output_tensor->data_type() == kObjectTypeTensorType) {
SetTensorListShape(output_tensor, inputs_data_[i]);
} else {
SetTensorShape(output_tensor, inputs_data_[i]);
}
ret = SetTensorShape(output_tensor, inputs_data_[i]);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "set input shape failed.");
}
return RET_OK;
return ret;
}
void LiteEntranceOpActor::AsyncOutput(OpContext<Tensor> *context) {

View File

@ -60,17 +60,14 @@ int LiteExitOpActor::InitInputData() {
}
int LiteExitOpActor::SetInputShape() {
auto ret = RET_OK;
for (size_t i = 1; i < inputs_data_.size(); ++i) {
auto &output_tensor = kernel_->out_tensors()[i - 1];
if (output_tensor->shape() == inputs_data_[i]->shape()) {
continue;
}
if (output_tensor->data_type() == kObjectTypeTensorType) {
SetTensorListShape(output_tensor, inputs_data_[i]);
} else {
SetTensorShape(output_tensor, inputs_data_[i]);
}
ret = SetTensorShape(output_tensor, inputs_data_[i]);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "set input shape failed.");
}
return RET_OK;
}

View File

@ -22,20 +22,24 @@
namespace mindspore::kernel {
int IdentityKernel::Run() {
auto ret = lite::RET_OK;
for (size_t i = 0; i < in_tensors().size(); ++i) {
auto src_tensor = in_tensors()[i];
auto dst_tensor = out_tensors()[i];
if (NeedCastData(dst_tensor, src_tensor)) {
CastTensorData(dst_tensor, src_tensor, support_fp16_);
ret = CastTensorData(dst_tensor, src_tensor, support_fp16_);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity cast failed.");
continue;
}
if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
SetTensorData(dst_tensor, src_tensor);
ret = SetTensorData(dst_tensor, src_tensor);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity set tensor data failed.");
} else {
MoveTensorData(dst_tensor, src_tensor);
ret = MoveTensorData(dst_tensor, src_tensor);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "identity move tensor data failed.");
}
}
return lite::RET_OK;
return ret;
}
int IdentityKernel::PreProcess() {
@ -43,16 +47,22 @@ int IdentityKernel::PreProcess() {
MS_LOG(ERROR) << "output kernel in_tensors size is not same as out_tensors size.";
return lite::RET_ERROR;
}
auto ret = lite::RET_OK;
for (size_t i = 0; i < in_tensors().size(); ++i) {
auto src_tensor = in_tensors()[i];
auto dst_tensor = out_tensors()[i];
if (src_tensor->data_type() == kObjectTypeTensorType) {
SetTensorListShape(dst_tensor, src_tensor);
} else {
SetTensorShape(dst_tensor, src_tensor);
bool need_resize = false;
if (!IsSameShape(src_tensor, dst_tensor)) {
need_resize = true;
}
ret = SetTensorShape(dst_tensor, src_tensor);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "set input shape failed.");
if (need_resize) {
ret = lite::MallocTensorData(dst_tensor);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "malloc dst tensor data failed.");
}
}
return lite::RET_OK;
return ret;
}
int IdentityKernel::PostProcess() { return lite::RET_OK; }

View File

@ -109,8 +109,9 @@ int TensorListSetItemCPUKernel::Run() {
auto dst = output0_->GetTensor(i);
if (dst == nullptr) {
dst = lite::Tensor::CopyTensor(*input2_, true, ms_context_->allocator);
auto &tensors = output0_->tensors();
auto tensors = output0_->tensors();
tensors.emplace_back(dst);
output0_->set_tensors(tensors);
} else {
dst->set_data_type(input2_->data_type());
dst->set_shape(input2_->shape());
@ -133,8 +134,9 @@ int TensorListSetItemCPUKernel::Run() {
// merge move data will delete tensors
if (dst == nullptr) {
dst = lite::Tensor::CopyTensor(*src, src->data() != nullptr, ms_context_->allocator);
auto &tensors = output0_->tensors();
auto tensors = output0_->tensors();
tensors.emplace_back(dst);
output0_->set_tensors(tensors);
continue;
}

View File

@ -309,17 +309,14 @@ int LiteOpActor::CompileArrowThroughOutputTensors(
}
int LiteOpActor::SetInputShape() {
auto ret = RET_OK;
for (size_t i = 0; i < inputs_data_.size(); ++i) {
auto &input_tensor = kernel_->in_tensors()[i];
if (input_tensor->shape() == inputs_data_[i]->shape()) {
continue;
}
if (input_tensor->data_type() == kObjectTypeTensorType) {
SetTensorListShape(input_tensor, inputs_data_[i]);
} else {
SetTensorShape(input_tensor, inputs_data_[i]);
}
ret = SetTensorShape(input_tensor, inputs_data_[i]);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "set input shape failed.");
}
return RET_OK;
}
@ -335,31 +332,53 @@ int LiteOpActor::AssignInputData() {
}
if (NeedCastData(dst_tensor, src_tensor)) {
ret = CastTensorData(dst_tensor, src_tensor, support_fp16_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "cast tensor data failed.";
return ret;
}
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "CastTensorData failed.");
continue;
}
/* same data-type */
if (src_tensor->allocator() == nullptr || src_tensor->IsGraphInput()) {
// delegate graph kernel output tensor
(void)SetTensorData(dst_tensor, src_tensor);
ret = SetTensorData(dst_tensor, src_tensor);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "SetTensorData failed.");
} else {
(void)MoveTensorData(dst_tensor, src_tensor);
ret = MoveTensorData(dst_tensor, src_tensor);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "MoveTensorData failed.");
}
}
return ret;
}
int LiteOpActor::InitInputData() {
auto ret = SetInputShape();
if (ret != RET_OK) {
MS_LOG(ERROR) << "set input shape failed.";
return ret;
bool LiteOpActor::NeedResize() {
for (size_t i = 0; i < inputs_data_.size(); ++i) {
auto &subgraph_input = kernel_->in_tensors()[i];
auto &cur_input = inputs_data_[i];
if (!IsSameShape(subgraph_input, cur_input)) {
return true;
}
}
return false;
}
return AssignInputData();
int LiteOpActor::InitInputData() {
bool need_resize = NeedResize();
auto ret = SetInputShape();
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "Set input shape failed.");
if (need_resize) {
auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
MS_CHECK_FALSE_MSG(subgraph_kernel == nullptr, RET_ERROR, "Lite actor, cast kernel to subgraph kernel failed.");
ret = subgraph_kernel->MallocSubgraphInputs();
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "Subgraph kernel MallocSubgraphInputs failed.");
}
ret = AssignInputData();
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "Subgraph kernel AssignInputData failed.");
if (need_resize) {
auto subgraph_kernel = reinterpret_cast<kernel::SubGraphKernel *>(kernel_);
ret = subgraph_kernel->ReSize();
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "Subgraph kernel Resize failed.");
subgraph_kernel->MallocNodesOutputSpace();
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "Subgraph kernel MallocSubgraphInputs failed.");
}
return RET_OK;
}
void LiteOpActor::AsyncOutput(OpContext<Tensor> *context) {

View File

@ -76,6 +76,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
}
protected:
virtual bool NeedResize();
virtual int SetInputShape();
virtual int InitInputData();
virtual int AssignInputData();

View File

@ -81,14 +81,9 @@ int SubGraphKernel::Execute(const KernelCallBack &before, const KernelCallBack &
int SubGraphKernel::ReSize() {
for (auto kernel : nodes_) {
if (kernel == nullptr) {
MS_LOG(ERROR) << "input kernel is nullptr!";
return RET_ERROR;
}
if (kernel->subgraph_type() != kernel::kNotSubGraph) {
MS_LOG(ERROR) << "all nodes in should be kernel";
return RET_ERROR;
}
MS_CHECK_FALSE_MSG(kernel == nullptr, RET_ERROR, "input kernel is nullptr.");
MS_CHECK_FALSE_MSG(kernel->subgraph_type() != kernel::kNotSubGraph, RET_ERROR,
"all nodes in should be kernel in subgraph kernels");
std::vector<lite::Tensor *> inputs = kernel->in_tensors();
std::vector<lite::Tensor *> outputs = kernel->out_tensors();
for (auto &output : outputs) {
@ -132,6 +127,33 @@ int SubGraphKernel::ReSize() {
}
return RET_OK;
}
int SubGraphKernel::MallocNodesOutputSpace() {
for (auto node : nodes_) {
MS_CHECK_FALSE_MSG(node == nullptr, RET_ERROR, "input kernel is nullptr.");
MS_CHECK_FALSE_MSG(node->subgraph_type() != kernel::kNotSubGraph, RET_ERROR,
"all nodes in should be kernel in subgraph kernels");
std::vector<lite::Tensor *> outputs = node->out_tensors();
for (auto &output : outputs) {
auto ret = lite::MallocTensorData(output);
if (ret != RET_OK) {
return ret;
}
}
}
return RET_OK;
}
int SubGraphKernel::MallocSubgraphInputs() {
for (auto input : in_tensors()) {
auto ret = lite::MallocTensorData(input);
if (ret != RET_OK) {
return ret;
}
}
return RET_OK;
}
void SubGraphKernel::InitInputTensorInitRefCount() {
for (auto &input : this->in_tensors()) {
int input_init_refcount = input->init_ref_count();

View File

@ -88,9 +88,12 @@ class SubGraphKernel : public KernelExec {
int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
// called after Run
int ReSize() override;
virtual int MallocNodesOutputSpace();
virtual int MallocSubgraphInputs();
void InitOutTensorInitRefCount(const std::vector<KernelExec *> *mask_kernels) override;
void InitInputTensorInitRefCount();

View File

@ -73,6 +73,10 @@ int TensorList::MallocTensorListData(TypeId dtype, const std::vector<std::vector
return RET_ERROR;
}
}
if (this->shape().size() == 0) {
MS_LOG(INFO) << "tensorlist has no elements, no need malloc data.";
return RET_OK;
}
if (this->shape().size() != 1) {
MS_LOG(ERROR) << "tensorlist shape:" << this->shape().size() << " must be one-dimensional";
return RET_ERROR;

View File

@ -69,7 +69,7 @@ class TensorList : public Tensor {
void set_element_shape(const std::vector<int> &shape) { element_shape_ = shape; }
std::vector<int> &element_shape() { return element_shape_; }
std::vector<int> element_shape() const { return element_shape_; }
void set_max_elements_num(int ele_num) { max_elements_num_ = ele_num; }
@ -93,7 +93,7 @@ class TensorList : public Tensor {
TypeId tensors_data_type() const { return tensors_data_type_; }
std::vector<Tensor *> &tensors() { return tensors_; }
std::vector<Tensor *> tensors() { return tensors_; }
void set_tensors(const std::vector<Tensor *> &tensors) { this->tensors_ = tensors; }