!16708 Handle GraphKernel node in PSEmbeddingCache

From: @dayschan
Reviewed-by: @limingqi107,@gaoxiong1
Signed-off-by: @gaoxiong1
This commit is contained in:
mindspore-ci-bot 2021-05-24 09:04:22 +08:00 committed by Gitee
commit 41eaf3f58d
3 changed files with 31 additions and 9 deletions

View File

@ -1160,6 +1160,18 @@ bool AnfRuntimeAlgorithm::IsNodeInGraphKernel(const AnfNodePtr &node) {
return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
AnfNodePtr AnfRuntimeAlgorithm::GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index) {
auto func_graph = GetCNodeFuncGraph(kernel_with_index.first);
if (func_graph == nullptr) {
return kernel_with_index.first;
}
auto output = func_graph->output();
if (CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
return output->cast<CNodePtr>()->input(kernel_with_index.second + 1);
}
return output;
}
bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->has_default();

View File

@ -214,6 +214,8 @@ class AnfRuntimeAlgorithm {
static bool IsGraphKernel(const AnfNodePtr &node);
// checkout whether the anf node is an inner node of graph kernel.
static bool IsNodeInGraphKernel(const AnfNodePtr &node);
// get the real output of GraphKernel.
static AnfNodePtr GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index);
// check parameter is weight or data
static bool IsParameterWeight(const ParameterPtr &node);
// checkout whether the anf node is include the label_index.

View File

@ -1114,21 +1114,27 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
auto input_index_node_name = AnfAlgo::GetCNodeName(input_index.first);
if (input_index.first->isa<CNode>() && (input_index_node_name != kGetNextOpName)) {
auto cnode =
AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
if (!cnode->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index should be a CNode but got "
<< cnode->fullname_with_scope();
}
auto input_index_node_name = AnfAlgo::GetCNodeName(cnode);
if (input_index_node_name != kGetNextOpName) {
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
(full_batch && (input_index_node_name != kMinimumOpName))) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
<< ") cache is from " << input_index.first->fullname_with_scope();
<< ") cache is from " << cnode->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
}
}
*first_cache_input_index = input_index.first;
*first_cache_input_index = cnode;
*first_cache_size = size;
MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from "
<< input_index.first->fullname_with_scope() << ", the cache size is " << size;
MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from " << cnode->fullname_with_scope()
<< ", the cache size is " << size;
return;
}
}
@ -1182,7 +1188,9 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if (input_index.first == first_cache_input_index) {
auto cnode =
AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
if (cnode == first_cache_input_index) {
if (!ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
@ -1196,10 +1204,10 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
}
} else if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
<< input_index.first->fullname_with_scope();
<< cnode->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
} else if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kGetNextOpName)) {
} else if (cnode->isa<CNode>() && (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName)) {
MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
MS_LOG(EXCEPTION) << "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
"the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "