forked from mindspore-Ecosystem/mindspore
fix hete_mix_ctrlflow
This commit is contained in:
parent
005bc7c380
commit
a112376b36
|
@ -222,6 +222,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
|
|||
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
|
||||
}
|
||||
tensor->set_device_address(address);
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
if (bound_addresses_.find(address) == bound_addresses_.end()) {
|
||||
if (infer_type_id != device_type_id) {
|
||||
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
|
||||
|
@ -230,10 +231,11 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
|
|||
address->ptr_ = static_cast<CPUMemoryManager *>(mem_manager_.get())->StaticMemMalloc(tensor_size);
|
||||
address->size_ = tensor_size;
|
||||
address->type_id_ = device_type_id;
|
||||
} else {
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
(void)bound_addresses_.insert(address);
|
||||
}
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
session::KernelWithIndex node_index(node, index);
|
||||
tensor->SetNeedWait(true);
|
||||
tensor->SetIsGraphOutput();
|
||||
|
@ -369,7 +371,7 @@ void CPUKernelRuntime::BindOutputTensorAddressPtr(const VectorRef *outputs) {
|
|||
continue;
|
||||
}
|
||||
auto address_ptr = std::dynamic_pointer_cast<device::DeviceAddress>(address);
|
||||
if (address_ptr->type_id_ == tensor->data_type_c()) {
|
||||
if (address_ptr->type_id_ == tensor->data_type_c() && tensor->sync_status() == kNoNeedSync) {
|
||||
address_ptr->ptr_ = tensor->data_c();
|
||||
}
|
||||
address_ptr->ref_count_ = INIT_NODE_REF;
|
||||
|
|
Loading…
Reference in New Issue