!18760 Remove useless SyncHostToDevice for PyNative mode
Merge pull request !18760 from zyli2020/mindrt_debug
This commit is contained in:
commit
b40bbb49af
|
@ -247,8 +247,11 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
|
|||
auto &device_tensor = device_tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(host_tensor);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
|
||||
if (std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address()) != nullptr) {
|
||||
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
|
||||
if (tensor_device_address != nullptr) {
|
||||
if (tensor_device_address.get() != device_tensor) {
|
||||
MS_LOG(EXCEPTION) << "The device tensor of host queue node should be equal to device address of input tensor";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -285,16 +285,39 @@ TensorPtr FetchInputTensor(const GraphCompilerInfo &graph_compiler_info, size_t
|
|||
|
||||
void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
|
||||
const AnfNodePtr &node, const TensorPtr &tensor,
|
||||
std::vector<TensorPtr> *host_tensors) {
|
||||
std::vector<TensorPtr> *host_tensors,
|
||||
const DeviceContext *device_context = nullptr,
|
||||
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
||||
// Fill the host tensors for non weighted parameters.
|
||||
const auto &iter = data_node_position_map.find(node);
|
||||
if (iter != data_node_position_map.end()) {
|
||||
(*host_tensors)[iter->second] = tensor;
|
||||
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
|
||||
if (device_address != nullptr) {
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, node.get());
|
||||
if (iter == data_node_position_map.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
(*host_tensors)[iter->second] = tensor;
|
||||
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
|
||||
if (device_address != nullptr) {
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, node.get());
|
||||
return;
|
||||
}
|
||||
|
||||
if (strategy == GraphExecutionStrategy::kStep) {
|
||||
auto node_device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(node_device_address);
|
||||
tensor->set_device_address(node_device_address);
|
||||
UpdateRefCount(node_device_address.get(), true);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if (!device_context->AllocateMemory(node_device_address.get(), node_device_address->GetSize())) {
|
||||
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope();
|
||||
}
|
||||
|
||||
if (!node_device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(), tensor->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -492,7 +515,7 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
strategy)) {
|
||||
MS_EXCEPTION_IF_NULL(host_data_source_actor);
|
||||
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
|
||||
&host_tensors);
|
||||
&host_tensors, device_context, strategy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContex
|
|||
return device_context.get();
|
||||
}
|
||||
|
||||
void DeviceContextManager::UpdataDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) {
|
||||
void DeviceContextManager::UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) {
|
||||
std::string old_key_str = old_key.ToString();
|
||||
std::string new_key_str = new_key.ToString();
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class DeviceContextManager {
|
|||
}
|
||||
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
|
||||
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key);
|
||||
void UpdataDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key);
|
||||
void UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key);
|
||||
void ClearDeviceContexts();
|
||||
|
||||
private:
|
||||
|
|
|
@ -65,7 +65,7 @@ bool GPUDeviceContext::Initialize() {
|
|||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)());
|
||||
|
||||
DeviceContextManager::GetInstance().UpdataDeviceContextKey(old_key, device_context_key_);
|
||||
DeviceContextManager::GetInstance().UpdateDeviceContextKey(old_key, device_context_key_);
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
|
|
@ -253,7 +253,7 @@ void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
|
|||
#endif
|
||||
|
||||
MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
|
||||
: Backend(backend_name), device_name_(device_name), device_id_(device_id) {
|
||||
: Backend(backend_name), device_name_(device_name) {
|
||||
root_graph_ = nullptr;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
|
||||
|
|
Loading…
Reference in New Issue