refine ps mode consistence check

This commit is contained in:
lizhenyu 2022-02-10 20:01:23 +08:00
parent 5e29413428
commit cd12f98c29
1 changed files with 48 additions and 1 deletions

View File

@ -512,6 +512,47 @@ void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
}
}
}
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
// Get all users of this node
void GetNodeUsedList(const FuncGraphPtr &kernel_graph, const AnfNodePtr &node,
std::vector<AnfNodePtr> *node_users_list) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(node);
auto manager = kernel_graph->manager();
if (manager == nullptr) {
auto new_manager = MakeManager({kernel_graph});
MS_EXCEPTION_IF_NULL(new_manager);
new_manager->AddFuncGraph(kernel_graph);
kernel_graph->set_manager(new_manager);
manager = new_manager;
}
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
return;
}
auto node_users = iter->second;
for (const auto &node_user : node_users) {
if (AnfAlgo::GetCNodeName(node_user.first) == prim::kPrimLoad->name()) {
GetNodeUsedList(kernel_graph, node_user.first, node_users_list);
} else {
node_users_list->push_back(node_user.first);
}
}
}
// Check whether the Parameter initialized in server is used by the operator executed on the device side.
bool UseParamInitInServer(const FuncGraphPtr &kernel_graph, const AnfNodePtr &param_node) {
std::vector<AnfNodePtr> node_users_list;
GetNodeUsedList(kernel_graph, param_node, &node_users_list);
// Check if there is real CNode among all users of the node.
return std::any_of(node_users_list.begin(), node_users_list.end(),
[](const AnfNodePtr &node) { return AnfUtils::IsRealKernel(node); });
}
#endif
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
@ -2804,7 +2845,10 @@ void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::PSContext::instance()->is_worker()) {
return;
}
// Check whether the Parameter initialized in server is used by the operator executed on the device side.
CheckPSModeConsistence(kernel_graph);
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
if (!ps::ps_cache_instance.initialized_ps_cache()) {
auto context_ptr = MsContext::GetInstance();
@ -2851,8 +2895,11 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) co
MS_EXCEPTION_IF_NULL(pk_node);
auto param_info_ptr = pk_node->param_info();
const std::string &param_name = pk_node->fullname_with_scope();
// If the Parameter is initialized on the server, and the user of the Parameter contains real CNode which executes
// in device, an error message will be reported, and it is allowed to be used only by the side effect operator.
if (param_info_ptr != nullptr && param_info_ptr->init_in_server() &&
!ps::ps_cache_instance.IsHashTable(param_name)) {
UseParamInitInServer(kernel_graph, input_node) && !ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name
<< "] in server, this parameter is used by kernel which executes in device";
}