fix hete_mix_ctrlflow

This commit is contained in:
baihuawei 2021-11-10 16:20:03 +08:00
parent 005bc7c380
commit a112376b36
1 changed files with 4 additions and 2 deletions

View File

@ -222,6 +222,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
}
tensor->set_device_address(address);
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
if (bound_addresses_.find(address) == bound_addresses_.end()) {
if (infer_type_id != device_type_id) {
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
@ -230,10 +231,11 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
address->ptr_ = static_cast<CPUMemoryManager *>(mem_manager_.get())->StaticMemMalloc(tensor_size);
address->size_ = tensor_size;
address->type_id_ = device_type_id;
} else {
tensor->set_sync_status(kNoNeedSync);
}
(void)bound_addresses_.insert(address);
}
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
session::KernelWithIndex node_index(node, index);
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
@ -369,7 +371,7 @@ void CPUKernelRuntime::BindOutputTensorAddressPtr(const VectorRef *outputs) {
continue;
}
auto address_ptr = std::dynamic_pointer_cast<device::DeviceAddress>(address);
if (address_ptr->type_id_ == tensor->data_type_c()) {
if (address_ptr->type_id_ == tensor->data_type_c() && tensor->sync_status() == kNoNeedSync) {
address_ptr->ptr_ = tensor->data_c();
}
address_ptr->ref_count_ = INIT_NODE_REF;