forked from mindspore-Ecosystem/mindspore
!1416 fix summary nodes in child graph
Merge pull request !1416 from Margaret_wangrui/fix-summary-nodes-in-child-graph
This commit is contained in:
commit
20a0f6ef7c
|
@ -736,6 +736,27 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
|||
return final_graph_id_;
|
||||
}
|
||||
|
||||
void AscendSession::GetSummaryNodes(const KernelGraph *graph,
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
||||
MS_LOG(DEBUG) << "Update summary Start";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(summary);
|
||||
summary->clear();
|
||||
// if final graph have no child graph
|
||||
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
|
||||
if (graph_order_iter == graph_execute_orders_.end()) {
|
||||
SessionBasic::GetSummaryNodes(graph, summary);
|
||||
return;
|
||||
}
|
||||
// for every child graph, find summary nodes
|
||||
auto graph_order = GetGraphOrder(graph->graph_id());
|
||||
for (size_t i = 0; i < graph_order.size(); i++) {
|
||||
auto child_graph = GetGraph(graph_order[i]);
|
||||
SessionBasic::GetSummaryNodes(child_graph.get(), summary);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
|
||||
}
|
||||
|
||||
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
|
||||
auto fake_graph = GetGraph(fake_graph_id);
|
||||
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
|
||||
|
|
|
@ -67,6 +67,8 @@ class AscendSession : public SessionBasic {
|
|||
void SetActive(GraphId, GraphId) override;
|
||||
// compile child graph when session have multiple child graphs
|
||||
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
||||
void GetSummaryNodes(const KernelGraph *graph,
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) override;
|
||||
|
||||
private:
|
||||
void InitRuntimeResource();
|
||||
|
|
|
@ -54,46 +54,6 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
|
|||
return py_param.ptr();
|
||||
}
|
||||
|
||||
void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
||||
MS_LOG(DEBUG) << "Update summary Start";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(summary);
|
||||
summary->clear();
|
||||
auto apply_list = TopoSort(graph->get_return());
|
||||
for (auto &n : apply_list) {
|
||||
MS_EXCEPTION_IF_NULL(n);
|
||||
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
|
||||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
|
||||
auto cnode = n->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() <= kSummaryGetItem) {
|
||||
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
|
||||
}
|
||||
auto node = cnode->input(kSummaryGetItem);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
|
||||
}
|
||||
(*summary)[n->fullname_with_scope()] = item_with_index;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
|
||||
}
|
||||
|
||||
bool ExistSummaryNode(const KernelGraph *graph) {
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto all_nodes = DeepLinkedGraphSearch(ret);
|
||||
for (auto &n : all_nodes) {
|
||||
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
|
||||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -742,17 +702,44 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
|
|||
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
||||
}
|
||||
|
||||
void SessionBasic::GetSummaryNodes(const KernelGraph *graph,
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
||||
MS_LOG(DEBUG) << "Update summary Start";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(summary);
|
||||
auto apply_list = TopoSort(graph->get_return());
|
||||
for (auto &n : apply_list) {
|
||||
MS_EXCEPTION_IF_NULL(n);
|
||||
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
|
||||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
|
||||
auto cnode = n->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() <= kSummaryGetItem) {
|
||||
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
|
||||
}
|
||||
auto node = cnode->input(kSummaryGetItem);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
|
||||
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
|
||||
}
|
||||
(*summary)[n->fullname_with_scope()] = item_with_index;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
|
||||
}
|
||||
|
||||
void SessionBasic::Summary(KernelGraph *graph) {
|
||||
if (summary_callback_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool exist_summary = ExistSummaryNode(graph);
|
||||
if (!exist_summary) {
|
||||
return;
|
||||
}
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
|
||||
GetSummaryNodes(graph, &summary_outputs);
|
||||
// do not exist summary node
|
||||
if (summary_outputs.empty()) {
|
||||
return;
|
||||
}
|
||||
std::map<std::string, tensor::TensorPtr> params_list;
|
||||
// fetch outputs apply kernel in session & run callback functions
|
||||
for (auto &output_item : summary_outputs) {
|
||||
|
|
|
@ -92,6 +92,8 @@ class SessionBasic {
|
|||
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
|
||||
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
|
||||
virtual void SetActive(GraphId, GraphId) {}
|
||||
virtual void GetSummaryNodes(const KernelGraph *graph,
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
||||
|
||||
protected:
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
|
|
Loading…
Reference in New Issue