forked from mindspore-Ecosystem/mindspore
!16708 Handle GraphKernel node in PSEmbeddingCache
From: @dayschan Reviewed-by: @limingqi107,@gaoxiong1 Signed-off-by: @gaoxiong1
This commit is contained in:
commit
41eaf3f58d
|
@ -1160,6 +1160,18 @@ bool AnfRuntimeAlgorithm::IsNodeInGraphKernel(const AnfNodePtr &node) {
|
|||
return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
}
|
||||
|
||||
AnfNodePtr AnfRuntimeAlgorithm::GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index) {
|
||||
auto func_graph = GetCNodeFuncGraph(kernel_with_index.first);
|
||||
if (func_graph == nullptr) {
|
||||
return kernel_with_index.first;
|
||||
}
|
||||
auto output = func_graph->output();
|
||||
if (CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
|
||||
return output->cast<CNodePtr>()->input(kernel_with_index.second + 1);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return node->has_default();
|
||||
|
|
|
@ -214,6 +214,8 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsGraphKernel(const AnfNodePtr &node);
|
||||
// checkout whether the anf node is an inner node of graph kernel.
|
||||
static bool IsNodeInGraphKernel(const AnfNodePtr &node);
|
||||
// get the real output of GraphKernel.
|
||||
static AnfNodePtr GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index);
|
||||
// check parameter is weight or data
|
||||
static bool IsParameterWeight(const ParameterPtr &node);
|
||||
// checkout whether the anf node is include the label_index.
|
||||
|
|
|
@ -1114,21 +1114,27 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
|||
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(input_index.first);
|
||||
}
|
||||
auto input_index_node_name = AnfAlgo::GetCNodeName(input_index.first);
|
||||
if (input_index.first->isa<CNode>() && (input_index_node_name != kGetNextOpName)) {
|
||||
auto cnode =
|
||||
AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
|
||||
if (!cnode->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index should be a CNode but got "
|
||||
<< cnode->fullname_with_scope();
|
||||
}
|
||||
auto input_index_node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (input_index_node_name != kGetNextOpName) {
|
||||
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
|
||||
if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
|
||||
(full_batch && (input_index_node_name != kMinimumOpName))) {
|
||||
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
|
||||
<< ") cache is from " << input_index.first->fullname_with_scope();
|
||||
<< ") cache is from " << cnode->fullname_with_scope();
|
||||
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
|
||||
"parameter server training mode.";
|
||||
}
|
||||
}
|
||||
*first_cache_input_index = input_index.first;
|
||||
*first_cache_input_index = cnode;
|
||||
*first_cache_size = size;
|
||||
MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from "
|
||||
<< input_index.first->fullname_with_scope() << ", the cache size is " << size;
|
||||
MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from " << cnode->fullname_with_scope()
|
||||
<< ", the cache size is " << size;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -1182,7 +1188,9 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
|
|||
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(input_index.first);
|
||||
}
|
||||
if (input_index.first == first_cache_input_index) {
|
||||
auto cnode =
|
||||
AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
|
||||
if (cnode == first_cache_input_index) {
|
||||
if (!ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
|
||||
MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
|
||||
|
@ -1196,10 +1204,10 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
|
|||
}
|
||||
} else if (ps::ps_cache_instance.IsHashTable(param_name)) {
|
||||
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
|
||||
<< input_index.first->fullname_with_scope();
|
||||
<< cnode->fullname_with_scope();
|
||||
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
|
||||
"parameter server training mode.";
|
||||
} else if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kGetNextOpName)) {
|
||||
} else if (cnode->isa<CNode>() && (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName)) {
|
||||
MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
|
||||
MS_LOG(EXCEPTION) << "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
|
||||
"the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
|
||||
|
|
Loading…
Reference in New Issue