forked from mindspore-Ecosystem/mindspore
!1416 fix summary nodes in child graph
Merge pull request !1416 from Margaret_wangrui/fix-summary-nodes-in-child-graph
This commit is contained in:
commit
20a0f6ef7c
|
@ -736,6 +736,27 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
||||||
return final_graph_id_;
|
return final_graph_id_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AscendSession::GetSummaryNodes(const KernelGraph *graph,
|
||||||
|
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
||||||
|
MS_LOG(DEBUG) << "Update summary Start";
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(summary);
|
||||||
|
summary->clear();
|
||||||
|
// if final graph have no child graph
|
||||||
|
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
|
||||||
|
if (graph_order_iter == graph_execute_orders_.end()) {
|
||||||
|
SessionBasic::GetSummaryNodes(graph, summary);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// for every child graph, find summary nodes
|
||||||
|
auto graph_order = GetGraphOrder(graph->graph_id());
|
||||||
|
for (size_t i = 0; i < graph_order.size(); i++) {
|
||||||
|
auto child_graph = GetGraph(graph_order[i]);
|
||||||
|
SessionBasic::GetSummaryNodes(child_graph.get(), summary);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
|
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
|
||||||
auto fake_graph = GetGraph(fake_graph_id);
|
auto fake_graph = GetGraph(fake_graph_id);
|
||||||
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
|
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
|
||||||
|
|
|
@ -67,6 +67,8 @@ class AscendSession : public SessionBasic {
|
||||||
void SetActive(GraphId, GraphId) override;
|
void SetActive(GraphId, GraphId) override;
|
||||||
// compile child graph when session have multiple child graphs
|
// compile child graph when session have multiple child graphs
|
||||||
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
||||||
|
void GetSummaryNodes(const KernelGraph *graph,
|
||||||
|
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitRuntimeResource();
|
void InitRuntimeResource();
|
||||||
|
|
|
@ -54,46 +54,6 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
|
||||||
return py_param.ptr();
|
return py_param.ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
|
||||||
MS_LOG(DEBUG) << "Update summary Start";
|
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(summary);
|
|
||||||
summary->clear();
|
|
||||||
auto apply_list = TopoSort(graph->get_return());
|
|
||||||
for (auto &n : apply_list) {
|
|
||||||
MS_EXCEPTION_IF_NULL(n);
|
|
||||||
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
|
|
||||||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
|
|
||||||
auto cnode = n->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
if (cnode->inputs().size() <= kSummaryGetItem) {
|
|
||||||
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
|
|
||||||
}
|
|
||||||
auto node = cnode->input(kSummaryGetItem);
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
|
||||||
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
|
|
||||||
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
|
|
||||||
}
|
|
||||||
(*summary)[n->fullname_with_scope()] = item_with_index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ExistSummaryNode(const KernelGraph *graph) {
|
|
||||||
auto ret = graph->get_return();
|
|
||||||
MS_EXCEPTION_IF_NULL(ret);
|
|
||||||
auto all_nodes = DeepLinkedGraphSearch(ret);
|
|
||||||
for (auto &n : all_nodes) {
|
|
||||||
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
|
|
||||||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
|
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -742,17 +702,44 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
|
||||||
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SessionBasic::GetSummaryNodes(const KernelGraph *graph,
|
||||||
|
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
||||||
|
MS_LOG(DEBUG) << "Update summary Start";
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(summary);
|
||||||
|
auto apply_list = TopoSort(graph->get_return());
|
||||||
|
for (auto &n : apply_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(n);
|
||||||
|
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
|
||||||
|
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
|
||||||
|
auto cnode = n->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (cnode->inputs().size() <= kSummaryGetItem) {
|
||||||
|
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
|
||||||
|
}
|
||||||
|
auto node = cnode->input(kSummaryGetItem);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||||
|
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
|
||||||
|
}
|
||||||
|
(*summary)[n->fullname_with_scope()] = item_with_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
|
||||||
|
}
|
||||||
|
|
||||||
void SessionBasic::Summary(KernelGraph *graph) {
|
void SessionBasic::Summary(KernelGraph *graph) {
|
||||||
if (summary_callback_ == nullptr) {
|
if (summary_callback_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
bool exist_summary = ExistSummaryNode(graph);
|
|
||||||
if (!exist_summary) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
|
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
|
||||||
GetSummaryNodes(graph, &summary_outputs);
|
GetSummaryNodes(graph, &summary_outputs);
|
||||||
|
// do not exist summary node
|
||||||
|
if (summary_outputs.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
std::map<std::string, tensor::TensorPtr> params_list;
|
std::map<std::string, tensor::TensorPtr> params_list;
|
||||||
// fetch outputs apply kernel in session & run callback functions
|
// fetch outputs apply kernel in session & run callback functions
|
||||||
for (auto &output_item : summary_outputs) {
|
for (auto &output_item : summary_outputs) {
|
||||||
|
|
|
@ -92,6 +92,8 @@ class SessionBasic {
|
||||||
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
|
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
|
||||||
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
|
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
|
||||||
virtual void SetActive(GraphId, GraphId) {}
|
virtual void SetActive(GraphId, GraphId) {}
|
||||||
|
virtual void GetSummaryNodes(const KernelGraph *graph,
|
||||||
|
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
|
|
Loading…
Reference in New Issue