forked from mindspore-Ecosystem/mindspore
fix cpu output tensor sync
This commit is contained in:
parent
7092868a6c
commit
570ce2b811
|
@ -64,7 +64,7 @@ HcclKernelFactory &HcclKernelFactory::Get() {
|
|||
return _this;
|
||||
}
|
||||
|
||||
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0) {}
|
||||
HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REDUCE_SUM), root_id_(0), receive_type_(0) {}
|
||||
|
||||
HcclKernel::~HcclKernel() {
|
||||
hccl_kernel_input_shape_list_.clear();
|
||||
|
|
|
@ -196,20 +196,18 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
|
|||
}
|
||||
}
|
||||
tensor->set_device_address(address);
|
||||
if (bound_addresses_.find(address) != bound_addresses_.end()) {
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
} else {
|
||||
if (bound_addresses_.find(address) == bound_addresses_.end()) {
|
||||
if (infer_type_id != device_type_id) {
|
||||
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
|
||||
ShapeVector data_shape = tensor->shape();
|
||||
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
|
||||
address->ptr_ = static_cast<CPUMemoryManager *>(mem_manager_.get())->StaticMemMalloc(tensor_size);
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
} else {
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
address->ptr_ = nullptr;
|
||||
}
|
||||
(void)bound_addresses_.insert(address);
|
||||
}
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
session::KernelWithIndex node_index(node, index);
|
||||
tensor->SetNeedWait(true);
|
||||
tensor->SetIsGraphOutput();
|
||||
|
@ -329,7 +327,7 @@ void CPUKernelRuntime::BindOutputTensorAddressPtr(const VectorRef *outputs) {
|
|||
continue;
|
||||
}
|
||||
auto address_ptr = std::dynamic_pointer_cast<device::DeviceAddress>(address);
|
||||
if (tensor->sync_status() == kNoNeedSync) {
|
||||
if (address_ptr->ptr_ == nullptr || tensor->sync_status() == kNoNeedSync) {
|
||||
address_ptr->ptr_ = tensor->data_c();
|
||||
}
|
||||
address_ptr->ref_count_ = INIT_NODE_REF;
|
||||
|
|
Loading…
Reference in New Issue