!8642 Keep debug info. and trace info. after Grad Operation.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-16 20:29:13 +08:00 committed by Gitee
commit 610f06b92d
2 changed files with 23 additions and 14 deletions

View File

@ -227,17 +227,25 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
std::vector<AdjointPtr> param_adjoints;
for (size_t i = 0; i < cnode_morph->size(); i++) {
auto node = cnode_morph->input(i);
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
AdjointPtr node_adjoint = nullptr;
AnfNodePtr k = nullptr;
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
node_adjoint = node_adjoint_iter->second;
if (IsValueNode<Primitive>(node)) {
TraceManager::DebugTrace(std::make_shared<TraceEquiv>(cnode_morph->debug_info()));
k = MapToK(node);
TraceManager::EndTrace();
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
anfnode_to_adjoin_[node] = node_adjoint;
} else {
// Input might be a CNode that needs to be handled before hand.
node_adjoint = MapMorphism(node);
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
node_adjoint = node_adjoint_iter->second;
} else {
// Input might be a CNode that needs to be handled previously.
node_adjoint = MapMorphism(node);
}
MS_EXCEPTION_IF_NULL(node_adjoint);
k = node_adjoint->k();
}
MS_EXCEPTION_IF_NULL(node_adjoint);
k = node_adjoint->k();
if (k == nullptr) {
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
}
@ -270,6 +278,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
return node_adjoint;
}
void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) {
MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>();
if (value->isa<tensor::Tensor>()) {
@ -560,7 +569,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
return nullptr;
}
// MapToK(func)
// Map func graph to K
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
auto f = func_graph_to_functor_.find(primal);
if (f != func_graph_to_functor_.end()) {
@ -586,7 +595,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
// Construct representation graph for given node.
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
ScopeGuard scope_guard(primal->scope());
// MapToK(prim)
// Map primitive to K
if (IsValueNode<Primitive>(primal)) {
auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
@ -605,7 +614,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
}
}
// MapToK(func)
// Map func graph to K
if (IsValueNode<FuncGraph>(primal)) {
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
auto k_func = MapToK(func_graph);
@ -681,7 +690,7 @@ void DFunctor::MapValueObject() {
anfnode_to_adjoin_[node] = adjoint;
continue;
}
// Skip Return.
// Skip Primitive.
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
continue;
}
@ -796,12 +805,14 @@ void DFunctor::EliminatePrimalGraph() {
auto index = it.first->second;
auto vnode = cnode->inputs()[index];
if (index != 0) {
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
MS_LOG(DEBUG) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
continue;
}
cnode->set_input(0, k_vnode); // Replace primal graph with k graph
auto construct_wrapper = cnode->func_graph();
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode->debug_info()));
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
TraceManager::EndTrace();
manager->Replace(cnode, getitem0);
}
}

View File

@ -194,10 +194,8 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
std::vector<AnfNodePtr> transf_args;
TransformArgs(mng, cloned_bprop_fg, outer, &transf_args);
TraceManager::DebugTrace(std::make_shared<TraceEquiv>(dout->debug_info()));
(void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
auto out_value = outer->NewCNode(transf_args);
TraceManager::EndTrace();
(void)mng->Replace(out_param, out_value);