From bcfa1f72b1a4408e2561605819cbe41ef221b5da Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Mon, 16 Nov 2020 11:37:52 +0800 Subject: [PATCH] Add debug and trace info for grad operation. --- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 35 ++++++++++++------- .../ccsrc/frontend/optimizer/ad/dfunctor.h | 2 -- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 8b16556ef97..7356e1e4043 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -227,17 +227,25 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { std::vector param_adjoints; for (size_t i = 0; i < cnode_morph->size(); i++) { auto node = cnode_morph->input(i); - auto node_adjoint_iter = anfnode_to_adjoin_.find(node); AdjointPtr node_adjoint = nullptr; AnfNodePtr k = nullptr; - if (node_adjoint_iter != anfnode_to_adjoin_.end()) { - node_adjoint = node_adjoint_iter->second; + if (IsValueNode(node)) { + TraceManager::DebugTrace(std::make_shared(cnode_morph->debug_info())); + k = MapToK(node); + TraceManager::EndTrace(); + node_adjoint = std::make_shared(node, k, tape_); + anfnode_to_adjoin_[node] = node_adjoint; } else { - // Input might be a CNode that needs to be handled before hand. - node_adjoint = MapMorphism(node); + auto node_adjoint_iter = anfnode_to_adjoin_.find(node); + if (node_adjoint_iter != anfnode_to_adjoin_.end()) { + node_adjoint = node_adjoint_iter->second; + } else { + // Input might be a CNode that needs to be handled previously. + node_adjoint = MapMorphism(node); + } + MS_EXCEPTION_IF_NULL(node_adjoint); + k = node_adjoint->k(); } - MS_EXCEPTION_IF_NULL(node_adjoint); - k = node_adjoint->k(); if (k == nullptr) { MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; } @@ -270,6 +278,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << "."; return node_adjoint; } + void TensorSetAddress(const ValuePtr &value, std::map *tuple_tensors) { MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa(); if (value->isa()) { @@ -560,7 +569,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { return nullptr; } -// MapToK(func) +// Map func graph to K AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { auto f = func_graph_to_functor_.find(primal); if (f != func_graph_to_functor_.end()) { @@ -586,7 +595,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { // Construct representation graph for given node. AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { ScopeGuard scope_guard(primal->scope()); - // MapToK(prim) + // Map primitive to K if (IsValueNode(primal)) { auto value_node = primal->cast(); auto prim = GetValueNode(value_node); @@ -605,7 +614,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { } } - // MapToK(func) + // Map func graph to K if (IsValueNode(primal)) { auto func_graph = GetValueNode(primal); auto k_func = MapToK(func_graph); @@ -681,7 +690,7 @@ void DFunctor::MapValueObject() { anfnode_to_adjoin_[node] = adjoint; continue; } - // Skip Return. + // Skip Primitive. if (IsValueNode(node) && GetValueNode(node) == prim::kPrimReturn) { continue; } @@ -796,12 +805,14 @@ void DFunctor::EliminatePrimalGraph() { auto index = it.first->second; auto vnode = cnode->inputs()[index]; if (index != 0) { - MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}"; + MS_LOG(DEBUG) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}"; continue; } cnode->set_input(0, k_vnode); // Replace primal graph with k graph auto construct_wrapper = cnode->func_graph(); + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0}); + TraceManager::EndTrace(); manager->Replace(cnode, getitem0); } } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 100da3a29c8..f24459edfb7 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -194,10 +194,8 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { std::vector transf_args; TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); - TraceManager::DebugTrace(std::make_shared(dout->debug_info())); (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); auto out_value = outer->NewCNode(transf_args); - TraceManager::EndTrace(); (void)mng->Replace(out_param, out_value);