!20075 [MSLITE] isolate input tensor set replace copy

Merge pull request !20075 from ling/r1.3
This commit is contained in:
i-robot 2021-07-13 09:36:46 +00:00 committed by Gitee
commit 1dcfee0ef0
2 changed files with 20 additions and 19 deletions

View File

@ -35,7 +35,7 @@ void LiteOpActor::RunOpData(OpData<lite::Tensor> *inputs, OpContext<lite::Tensor
return;
}
auto ret = SetInputData();
auto ret = InitInputData();
if (ret != RET_OK) {
input_op_datas_.erase(op_uuid);
context->SetFailed(ret);
@ -249,19 +249,16 @@ void LiteOpActor::MoveTensorInputData(Tensor *dst_tensor, Tensor *src_tensor) {
dst_tensor->FreeData();
dst_tensor->ResetRefCount();
dst_tensor->set_allocator(src_tensor->allocator());
if (src_tensor->allocator() != nullptr) {
src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
}
if (src_tensor->data_c() != nullptr) {
dst_tensor->set_data(src_tensor->MutableData()); /* using MutableData to sync GPU data */
}
dst_tensor->set_own_data(src_tensor->own_data());
if (src_tensor->IsConst() || src_tensor->IsGraphInput()) {
dst_tensor->set_own_data(false);
} else {
src_tensor->DecRefCount();
}
}
void LiteOpActor::MoveTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
MS_ASSERT(src_tensorlist != nullptr);
@ -307,6 +304,7 @@ void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) {
MS_LOG(INFO) << "no need to move.";
return;
}
MS_ASSERT(src_tensor->allocator() != nullptr);
if (src_tensor->data_type() == kObjectTypeTensorType) {
MoveTensorListInputData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
@ -316,10 +314,9 @@ void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) {
return;
}
void LiteOpActor::CopyInputData(Tensor *dst_tensor, Tensor *src_tensor) {
dst_tensor->ResetRefCount();
dst_tensor->MallocData();
memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size());
void LiteOpActor::SetInputData(Tensor *dst_tensor, Tensor *src_tensor) {
dst_tensor->set_data(src_tensor->data());
dst_tensor->set_own_data(false);
}
int LiteOpActor::CastInputData(Tensor *dst, Tensor *src) {
@ -427,7 +424,7 @@ void LiteOpActor::SetInputShape() {
}
}
int LiteOpActor::SetInputData() {
int LiteOpActor::InitInputData() {
SetInputShape();
for (size_t i = 0; i < inputs_data_.size(); ++i) {
@ -440,9 +437,13 @@ int LiteOpActor::SetInputData() {
if (NeedCastData(dst_tensor, src_tensor)) {
CastInputData(dst_tensor, src_tensor);
} else if (src_tensor->allocator() == nullptr && !(src_tensor->IsConst()) && !(src_tensor->IsGraphInput())) {
continue;
}
/* same data-type */
if (src_tensor->allocator() == nullptr) {
// delegate graph kernel output tensor
CopyInputData(dst_tensor, src_tensor);
SetInputData(dst_tensor, src_tensor);
} else {
MoveInputData(dst_tensor, src_tensor);
}
@ -684,7 +685,7 @@ void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *con
return;
}
int ret = SetInputData();
int ret = InitInputData();
if (ret != RET_OK) {
input_op_datas_.erase(op_uuid);
context->SetFailed(ret);

View File

@ -78,7 +78,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
protected:
void SetInputShape();
int SetInputData();
int InitInputData();
void SetOutputData(OpContext<Tensor> *context);
void AsyncOutput(OpContext<Tensor> *context);
int CompileArrowThroughPartialCall();
@ -97,7 +97,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
void MoveTensorInputData(Tensor *dst_tensor, Tensor *src_tensor);
void MoveTensorListInputData(TensorList *dst_tensor, TensorList *src_tensor);
void MoveInputData(Tensor *dst_tensor, Tensor *src_tensor);
void CopyInputData(Tensor *dst_tensor, Tensor *src_tensor);
void SetInputData(Tensor *dst_tensor, Tensor *src_tensor);
int CastInputData(Tensor *dst_tensor, Tensor *src_tensor);
bool NeedCastData(Tensor *dst_tensor, Tensor *src_tensor);
int CastTensorInputData(Tensor *dst_tensor, Tensor *src_tensor);