forked from mindspore-Ecosystem/mindspore
!21629 Modifying Tensor Device Address Setting Conditions
Merge pull request !21629 from hwjiaorui/tensor-async
This commit is contained in:
commit
52ff091534
|
@ -381,7 +381,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
|
||||
AnfAlgo::IsParameterWeight(input_param)) {
|
||||
AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
if (kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
|
|
|
@ -1362,10 +1362,8 @@ void KernelGraph::SetOptimizerFlag() {
|
|||
continue;
|
||||
}
|
||||
auto param = real_node->cast<ParameterPtr>();
|
||||
if (AnfAlgo::IsParameterWeight(param)) {
|
||||
has_optimizer_ = true;
|
||||
(void)updated_parameters_.insert(param);
|
||||
}
|
||||
has_optimizer_ = true;
|
||||
(void)updated_parameters_.insert(param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -446,6 +446,26 @@ void UpdateGraphAquireGilAttr(const NotNull<KernelGraphPtr> &root_graph) {
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
bool NoPartialInPartialGraph(const AnfNodePtr &partial_node) {
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
auto partial_cnode = partial_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||
auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
|
||||
MS_EXCEPTION_IF_NULL(partial_graph);
|
||||
auto graph_nodes = TopoSort(partial_graph->get_return());
|
||||
for (auto &node : graph_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
|
||||
return false;
|
||||
}
|
||||
if (node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(kAnfPrimitiveIndex))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GraphId SessionBasic::graph_sum_ = 0;
|
||||
|
@ -1958,7 +1978,8 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
|
|||
if (internal_output) {
|
||||
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
|
||||
for (auto &user : users) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
|
||||
NoPartialInPartialGraph(user)) {
|
||||
auto partial_target = AddPartialParametersMap(user);
|
||||
if (partial_target != kNoTarget && partial_target != kernel_target) {
|
||||
unique_target = false;
|
||||
|
|
Loading…
Reference in New Issue