!23592 fix GetAllVisitedCNode bug

Merge pull request !23592 from baihuawei/fixbugs
This commit is contained in:
i-robot 2021-09-17 01:14:02 +00:00 committed by Gitee
commit 3d8dade612
4 changed files with 19 additions and 9 deletions

View File

@ -2123,9 +2123,16 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
return device_shape;
}
void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vector<AnfNodePtr> *used_kernels) {
void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vector<AnfNodePtr> *used_kernels,
std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(used_kernels);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited";
return;
}
visited->insert(anf_node);
auto input_size = anf_node->inputs().size() - 1;
for (size_t i = 0; i < input_size; ++i) {
auto input = AnfAlgo::GetInputNode(anf_node, i);
@ -2134,7 +2141,7 @@ void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vect
}
auto input_cnode = input->cast<CNodePtr>();
if (!IsRealKernelCNode(input_cnode) || opt::IsNopNode(input_cnode)) {
GetAllVisitedCNode(input_cnode, used_kernels);
GetAllVisitedCNode(input_cnode, used_kernels, visited);
} else {
used_kernels->push_back(input);
}

View File

@ -299,7 +299,8 @@ class AnfRuntimeAlgorithm {
// Find real input nodes.
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited);
static void GetAllVisitedCNode(const CNodePtr &cnode, std::vector<AnfNodePtr> *used_kernels);
static void GetAllVisitedCNode(const CNodePtr &cnode, std::vector<AnfNodePtr> *used_kernels,
std::set<AnfNodePtr> *visited);
static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph);
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
// Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.

View File

@ -807,8 +807,9 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph *graph) {
auto wait_stream = stream_id_map_[curr_stream_id];
auto stream_num = stream_id_map_.size();
std::vector<bool> stream_hit(stream_num, false);
std::vector<AnfNodePtr> visited_kernels;
AnfAlgo::GetAllVisitedCNode(kernel, &visited_kernels);
std::vector<AnfNodePtr> used_kernels;
std::set<AnfNodePtr> visited_kernels;
AnfAlgo::GetAllVisitedCNode(kernel, &used_kernels, &visited_kernels);
bool found_depend = false;
for (int k = SizeToInt(i) - 1; k >= 0; --k) {
auto pre_cnode = kernels[k];
@ -817,7 +818,7 @@ void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph *graph) {
found_depend = true;
continue;
}
for (auto &visited : visited_kernels) {
for (auto &visited : used_kernels) {
if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id]) {
stream_hit[pre_cnode_stream_id] = true;
found_depend = true;

View File

@ -213,9 +213,10 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
}
operator_info << ") ";
}
operator_info << "is not support. This error means the current type is not supported, please refer to the MindSpore "
"doc for supported types";
MS_EXCEPTION(TypeError) << operator_info.str() << " Trace: " << trace::DumpSourceLines(kernel_node);
operator_info
<< "is not support. This error means the current input type is not supported, please refer to the MindSpore "
"doc for supported types.\n";
MS_EXCEPTION(TypeError) << operator_info.str() << "Trace: " << trace::DumpSourceLines(kernel_node);
}
void UpdateDynamicKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {