fix cpu output tensor sync

This commit is contained in:
baihuawei 2021-04-12 09:14:18 +08:00
parent 7092868a6c
commit 570ce2b811
2 changed files with 5 additions and 7 deletions

View File

@ -64,7 +64,7 @@ HcclKernelFactory &HcclKernelFactory::Get() {
return _this;
}
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0) {}
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), receive_type_(0) {}
HcclKernel::~HcclKernel() {
hccl_kernel_input_shape_list_.clear();

View File

@ -196,20 +196,18 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
}
}
tensor->set_device_address(address);
if (bound_addresses_.find(address) != bound_addresses_.end()) {
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
if (bound_addresses_.find(address) == bound_addresses_.end()) {
if (infer_type_id != device_type_id) {
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
ShapeVector data_shape = tensor->shape();
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
address->ptr_ = static_cast<CPUMemoryManager *>(mem_manager_.get())->StaticMemMalloc(tensor_size);
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
} else {
tensor->set_sync_status(kNoNeedSync);
address->ptr_ = nullptr;
}
(void)bound_addresses_.insert(address);
}
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
session::KernelWithIndex node_index(node, index);
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
@ -329,7 +327,7 @@ void CPUKernelRuntime::BindOutputTensorAddressPtr(const VectorRef *outputs) {
continue;
}
auto address_ptr = std::dynamic_pointer_cast<device::DeviceAddress>(address);
if (tensor->sync_status() == kNoNeedSync) {
if (address_ptr->ptr_ == nullptr || tensor->sync_status() == kNoNeedSync) {
address_ptr->ptr_ = tensor->data_c();
}
address_ptr->ref_count_ = INIT_NODE_REF;