!1416 fix summary nodes in child graph

Merge pull request !1416 from Margaret_wangrui/fix-summary-nodes-in-child-graph
This commit is contained in:
mindspore-ci-bot 2020-05-27 14:27:56 +08:00 committed by Gitee
commit 20a0f6ef7c
4 changed files with 56 additions and 44 deletions

View File

@ -736,6 +736,27 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
return final_graph_id_; return final_graph_id_;
} }
void AscendSession::GetSummaryNodes(const KernelGraph *graph,
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
summary->clear();
// if final graph have no child graph
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
if (graph_order_iter == graph_execute_orders_.end()) {
SessionBasic::GetSummaryNodes(graph, summary);
return;
}
// for every child graph, find summary nodes
auto graph_order = GetGraphOrder(graph->graph_id());
for (size_t i = 0; i < graph_order.size(); i++) {
auto child_graph = GetGraph(graph_order[i]);
SessionBasic::GetSummaryNodes(child_graph.get(), summary);
}
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
}
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
auto fake_graph = GetGraph(fake_graph_id); auto fake_graph = GetGraph(fake_graph_id);
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);

View File

@ -67,6 +67,8 @@ class AscendSession : public SessionBasic {
void SetActive(GraphId, GraphId) override; void SetActive(GraphId, GraphId) override;
// compile child graph when session have multiple child graphs // compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph); void CompileChildGraph(const KernelGraphPtr &child_graph);
void GetSummaryNodes(const KernelGraph *graph,
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) override;
private: private:
void InitRuntimeResource(); void InitRuntimeResource();

View File

@ -54,46 +54,6 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
return py_param.ptr(); return py_param.ptr();
} }
void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
summary->clear();
auto apply_list = TopoSort(graph->get_return());
for (auto &n : apply_list) {
MS_EXCEPTION_IF_NULL(n);
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
auto cnode = n->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() <= kSummaryGetItem) {
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
}
auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
}
(*summary)[n->fullname_with_scope()] = item_with_index;
}
}
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
}
bool ExistSummaryNode(const KernelGraph *graph) {
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto all_nodes = DeepLinkedGraphSearch(ret);
for (auto &n : all_nodes) {
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
return true;
}
}
return false;
}
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) { const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -742,17 +702,44 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
} }
void SessionBasic::GetSummaryNodes(const KernelGraph *graph,
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
auto apply_list = TopoSort(graph->get_return());
for (auto &n : apply_list) {
MS_EXCEPTION_IF_NULL(n);
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
auto cnode = n->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() <= kSummaryGetItem) {
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
}
auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
}
(*summary)[n->fullname_with_scope()] = item_with_index;
}
}
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
}
void SessionBasic::Summary(KernelGraph *graph) { void SessionBasic::Summary(KernelGraph *graph) {
if (summary_callback_ == nullptr) { if (summary_callback_ == nullptr) {
return; return;
} }
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
bool exist_summary = ExistSummaryNode(graph);
if (!exist_summary) {
return;
}
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs; std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
GetSummaryNodes(graph, &summary_outputs); GetSummaryNodes(graph, &summary_outputs);
// do not exist summary node
if (summary_outputs.empty()) {
return;
}
std::map<std::string, tensor::TensorPtr> params_list; std::map<std::string, tensor::TensorPtr> params_list;
// fetch outputs apply kernel in session & run callback functions // fetch outputs apply kernel in session & run callback functions
for (auto &output_item : summary_outputs) { for (auto &output_item : summary_outputs) {

View File

@ -92,6 +92,8 @@ class SessionBasic {
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
virtual void SetActive(GraphId, GraphId) {} virtual void SetActive(GraphId, GraphId) {}
virtual void GetSummaryNodes(const KernelGraph *graph,
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary);
protected: protected:
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,