forked from mindspore-Ecosystem/mindspore
!20075 [MSLITE] isolate input tensor set replace copy
Merge pull request !20075 from ling/r1.3
This commit is contained in:
commit
1dcfee0ef0
|
@ -35,7 +35,7 @@ void LiteOpActor::RunOpData(OpData<lite::Tensor> *inputs, OpContext<lite::Tensor
|
|||
return;
|
||||
}
|
||||
|
||||
auto ret = SetInputData();
|
||||
auto ret = InitInputData();
|
||||
if (ret != RET_OK) {
|
||||
input_op_datas_.erase(op_uuid);
|
||||
context->SetFailed(ret);
|
||||
|
@ -249,18 +249,15 @@ void LiteOpActor::MoveTensorInputData(Tensor *dst_tensor, Tensor *src_tensor) {
|
|||
dst_tensor->FreeData();
|
||||
dst_tensor->ResetRefCount();
|
||||
dst_tensor->set_allocator(src_tensor->allocator());
|
||||
if (src_tensor->allocator() != nullptr) {
|
||||
|
||||
src_tensor->allocator()->IncRefCount(src_tensor->data(), dst_tensor->ref_count());
|
||||
}
|
||||
|
||||
if (src_tensor->data_c() != nullptr) {
|
||||
dst_tensor->set_data(src_tensor->MutableData()); /* using MutableData to sync GPU data */
|
||||
}
|
||||
|
||||
dst_tensor->set_own_data(src_tensor->own_data());
|
||||
if (src_tensor->IsConst() || src_tensor->IsGraphInput()) {
|
||||
dst_tensor->set_own_data(false);
|
||||
} else {
|
||||
src_tensor->DecRefCount();
|
||||
}
|
||||
}
|
||||
|
||||
void LiteOpActor::MoveTensorListInputData(TensorList *dst_tensorlist, TensorList *src_tensorlist) {
|
||||
|
@ -307,6 +304,7 @@ void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) {
|
|||
MS_LOG(INFO) << "no need to move.";
|
||||
return;
|
||||
}
|
||||
MS_ASSERT(src_tensor->allocator() != nullptr);
|
||||
|
||||
if (src_tensor->data_type() == kObjectTypeTensorType) {
|
||||
MoveTensorListInputData(reinterpret_cast<TensorList *>(dst_tensor), reinterpret_cast<TensorList *>(src_tensor));
|
||||
|
@ -316,10 +314,9 @@ void LiteOpActor::MoveInputData(Tensor *dst_tensor, Tensor *src_tensor) {
|
|||
return;
|
||||
}
|
||||
|
||||
void LiteOpActor::CopyInputData(Tensor *dst_tensor, Tensor *src_tensor) {
|
||||
dst_tensor->ResetRefCount();
|
||||
dst_tensor->MallocData();
|
||||
memcpy(dst_tensor->data(), src_tensor->data(), src_tensor->Size());
|
||||
void LiteOpActor::SetInputData(Tensor *dst_tensor, Tensor *src_tensor) {
|
||||
dst_tensor->set_data(src_tensor->data());
|
||||
dst_tensor->set_own_data(false);
|
||||
}
|
||||
|
||||
int LiteOpActor::CastInputData(Tensor *dst, Tensor *src) {
|
||||
|
@ -427,7 +424,7 @@ void LiteOpActor::SetInputShape() {
|
|||
}
|
||||
}
|
||||
|
||||
int LiteOpActor::SetInputData() {
|
||||
int LiteOpActor::InitInputData() {
|
||||
SetInputShape();
|
||||
|
||||
for (size_t i = 0; i < inputs_data_.size(); ++i) {
|
||||
|
@ -440,9 +437,13 @@ int LiteOpActor::SetInputData() {
|
|||
|
||||
if (NeedCastData(dst_tensor, src_tensor)) {
|
||||
CastInputData(dst_tensor, src_tensor);
|
||||
} else if (src_tensor->allocator() == nullptr && !(src_tensor->IsConst()) && !(src_tensor->IsGraphInput())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
/* same data-type */
|
||||
if (src_tensor->allocator() == nullptr) {
|
||||
// delegate graph kernel output tensor
|
||||
CopyInputData(dst_tensor, src_tensor);
|
||||
SetInputData(dst_tensor, src_tensor);
|
||||
} else {
|
||||
MoveInputData(dst_tensor, src_tensor);
|
||||
}
|
||||
|
@ -684,7 +685,7 @@ void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *con
|
|||
return;
|
||||
}
|
||||
|
||||
int ret = SetInputData();
|
||||
int ret = InitInputData();
|
||||
if (ret != RET_OK) {
|
||||
input_op_datas_.erase(op_uuid);
|
||||
context->SetFailed(ret);
|
||||
|
|
|
@ -78,7 +78,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
|
|||
|
||||
protected:
|
||||
void SetInputShape();
|
||||
int SetInputData();
|
||||
int InitInputData();
|
||||
void SetOutputData(OpContext<Tensor> *context);
|
||||
void AsyncOutput(OpContext<Tensor> *context);
|
||||
int CompileArrowThroughPartialCall();
|
||||
|
@ -97,7 +97,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
|
|||
void MoveTensorInputData(Tensor *dst_tensor, Tensor *src_tensor);
|
||||
void MoveTensorListInputData(TensorList *dst_tensor, TensorList *src_tensor);
|
||||
void MoveInputData(Tensor *dst_tensor, Tensor *src_tensor);
|
||||
void CopyInputData(Tensor *dst_tensor, Tensor *src_tensor);
|
||||
void SetInputData(Tensor *dst_tensor, Tensor *src_tensor);
|
||||
int CastInputData(Tensor *dst_tensor, Tensor *src_tensor);
|
||||
bool NeedCastData(Tensor *dst_tensor, Tensor *src_tensor);
|
||||
int CastTensorInputData(Tensor *dst_tensor, Tensor *src_tensor);
|
||||
|
|
Loading…
Reference in New Issue