!21629 Modifying Tensor Device Address Setting Conditions

Merge pull request !21629 from hwjiaorui/tensor-async
This commit is contained in:
i-robot 2021-08-12 01:34:45 +00:00 committed by Gitee
commit 52ff091534
3 changed files with 25 additions and 6 deletions

View File

@ -381,7 +381,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_param)) {
AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) {
tensor->set_device_address(device_address);
}
if (kernel_graph->IsUpdatedParameter(input_param)) {

View File

@ -1362,10 +1362,8 @@ void KernelGraph::SetOptimizerFlag() {
continue;
}
auto param = real_node->cast<ParameterPtr>();
if (AnfAlgo::IsParameterWeight(param)) {
has_optimizer_ = true;
(void)updated_parameters_.insert(param);
}
has_optimizer_ = true;
(void)updated_parameters_.insert(param);
}
}
}

View File

@ -446,6 +446,26 @@ void UpdateGraphAquireGilAttr(const NotNull<KernelGraphPtr> &root_graph) {
}
return;
}
bool NoPartialInPartialGraph(const AnfNodePtr &partial_node) {
MS_EXCEPTION_IF_NULL(partial_node);
auto partial_cnode = partial_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
MS_EXCEPTION_IF_NULL(partial_graph);
auto graph_nodes = TopoSort(partial_graph->get_return());
for (auto &node : graph_nodes) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
return false;
}
if (node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(kAnfPrimitiveIndex))) {
return false;
}
}
return true;
}
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
@ -1958,7 +1978,8 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
if (internal_output) {
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
for (auto &user : users) {
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice) {
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
NoPartialInPartialGraph(user)) {
auto partial_target = AddPartialParametersMap(user);
if (partial_target != kNoTarget && partial_target != kernel_target) {
unique_target = false;