refine ps mode consistence check
This commit is contained in:
parent
5e29413428
commit
cd12f98c29
|
@ -512,6 +512,47 @@ void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
// Get all users of this node
|
||||
void GetNodeUsedList(const FuncGraphPtr &kernel_graph, const AnfNodePtr &node,
|
||||
std::vector<AnfNodePtr> *node_users_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto manager = kernel_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
auto new_manager = MakeManager({kernel_graph});
|
||||
MS_EXCEPTION_IF_NULL(new_manager);
|
||||
new_manager->AddFuncGraph(kernel_graph);
|
||||
kernel_graph->set_manager(new_manager);
|
||||
manager = new_manager;
|
||||
}
|
||||
|
||||
auto iter = manager->node_users().find(node);
|
||||
if (iter == manager->node_users().end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto node_users = iter->second;
|
||||
for (const auto &node_user : node_users) {
|
||||
if (AnfAlgo::GetCNodeName(node_user.first) == prim::kPrimLoad->name()) {
|
||||
GetNodeUsedList(kernel_graph, node_user.first, node_users_list);
|
||||
} else {
|
||||
node_users_list->push_back(node_user.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check whether the Parameter initialized in server is used by the operator executed on the device side.
|
||||
bool UseParamInitInServer(const FuncGraphPtr &kernel_graph, const AnfNodePtr ¶m_node) {
|
||||
std::vector<AnfNodePtr> node_users_list;
|
||||
GetNodeUsedList(kernel_graph, param_node, &node_users_list);
|
||||
|
||||
// Check if there is real CNode among all users of the node.
|
||||
return std::any_of(node_users_list.begin(), node_users_list.end(),
|
||||
[](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); });
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
GraphId SessionBasic::graph_sum_ = 0;
|
||||
|
@ -2804,7 +2845,10 @@ void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
|||
if (!ps::PSContext::instance()->is_worker()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check whether the Parameter initialized in server is used by the operator executed on the device side.
|
||||
CheckPSModeConsistence(kernel_graph);
|
||||
|
||||
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
||||
if (!ps::ps_cache_instance.initialized_ps_cache()) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
@ -2851,8 +2895,11 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) co
|
|||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto param_info_ptr = pk_node->param_info();
|
||||
const std::string ¶m_name = pk_node->fullname_with_scope();
|
||||
|
||||
// If the Parameter is initialized on the server, and the user of the Parameter contains real CNode which executes
|
||||
// in device, an error message will be reported, and it is allowed to be used only by the side effect operator.
|
||||
if (param_info_ptr != nullptr && param_info_ptr->init_in_server() &&
|
||||
!ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
UseParamInitInServer(kernel_graph, input_node) && !ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name
|
||||
<< "] in server, this parameter is used by kernel which executes in device";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue