!18760 Remove useless SyncHostToDevice for PyNative mode

Merge pull request !18760 from zyli2020/mindrt_debug
This commit is contained in:
i-robot 2021-06-24 06:26:20 +00:00 committed by Gitee
commit b40bbb49af
6 changed files with 39 additions and 13 deletions

View File

@ -247,8 +247,11 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
auto &device_tensor = device_tensors[i];
MS_EXCEPTION_IF_NULL(host_tensor);
MS_EXCEPTION_IF_NULL(device_tensor);
if (std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address()) != nullptr) {
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(host_tensor->device_address());
if (tensor_device_address != nullptr) {
if (tensor_device_address.get() != device_tensor) {
MS_LOG(EXCEPTION) << "The device tensor of host queue node should be equal to device address of input tensor";
}
continue;
}

View File

@ -285,16 +285,39 @@ TensorPtr FetchInputTensor(const GraphCompilerInfo &graph_compiler_info, size_t
void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
const AnfNodePtr &node, const TensorPtr &tensor,
std::vector<TensorPtr> *host_tensors) {
std::vector<TensorPtr> *host_tensors,
const DeviceContext *device_context = nullptr,
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) {
MS_EXCEPTION_IF_NULL(tensor);
// Fill the host tensors for non weighted parameters.
const auto &iter = data_node_position_map.find(node);
if (iter != data_node_position_map.end()) {
(*host_tensors)[iter->second] = tensor;
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
if (device_address != nullptr) {
AnfAlgo::SetOutputAddr(device_address, 0, node.get());
if (iter == data_node_position_map.end()) {
return;
}
(*host_tensors)[iter->second] = tensor;
auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
if (device_address != nullptr) {
AnfAlgo::SetOutputAddr(device_address, 0, node.get());
return;
}
if (strategy == GraphExecutionStrategy::kStep) {
auto node_device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
MS_EXCEPTION_IF_NULL(node_device_address);
tensor->set_device_address(node_device_address);
UpdateRefCount(node_device_address.get(), true);
MS_EXCEPTION_IF_NULL(device_context);
if (!device_context->AllocateMemory(node_device_address.get(), node_device_address->GetSize())) {
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope();
}
if (!node_device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c(), tensor->device_info().host_format_)) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
}
@ -492,7 +515,7 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
strategy)) {
MS_EXCEPTION_IF_NULL(host_data_source_actor);
PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
&host_tensors);
&host_tensors, device_context, strategy);
}
}
}

View File

@ -54,7 +54,7 @@ DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContex
return device_context.get();
}
void DeviceContextManager::UpdataDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) {
void DeviceContextManager::UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) {
std::string old_key_str = old_key.ToString();
std::string new_key_str = new_key.ToString();

View File

@ -37,7 +37,7 @@ class DeviceContextManager {
}
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key);
void UpdataDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key);
void UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key);
void ClearDeviceContexts();
private:

View File

@ -65,7 +65,7 @@ bool GPUDeviceContext::Initialize() {
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)());
DeviceContextManager::GetInstance().UpdataDeviceContextKey(old_key, device_context_key_);
DeviceContextManager::GetInstance().UpdateDeviceContextKey(old_key, device_context_key_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);

View File

@ -253,7 +253,7 @@ void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
#endif
MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
: Backend(backend_name), device_name_(device_name), device_id_(device_id) {
: Backend(backend_name), device_name_(device_name) {
root_graph_ = nullptr;
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);