forked from mindspore-Ecosystem/mindspore
!8642 Keep debug info. and trace info. after Grad Operation.
From: @zh_qh Reviewed-by: Signed-off-by:
This commit is contained in:
commit
610f06b92d
|
@ -227,17 +227,25 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
std::vector<AdjointPtr> param_adjoints;
|
||||
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
||||
auto node = cnode_morph->input(i);
|
||||
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
|
||||
AdjointPtr node_adjoint = nullptr;
|
||||
AnfNodePtr k = nullptr;
|
||||
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
|
||||
node_adjoint = node_adjoint_iter->second;
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
TraceManager::DebugTrace(std::make_shared<TraceEquiv>(cnode_morph->debug_info()));
|
||||
k = MapToK(node);
|
||||
TraceManager::EndTrace();
|
||||
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
|
||||
anfnode_to_adjoin_[node] = node_adjoint;
|
||||
} else {
|
||||
// Input might be a CNode that needs to be handled before hand.
|
||||
node_adjoint = MapMorphism(node);
|
||||
auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
|
||||
if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
|
||||
node_adjoint = node_adjoint_iter->second;
|
||||
} else {
|
||||
// Input might be a CNode that needs to be handled previously.
|
||||
node_adjoint = MapMorphism(node);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(node_adjoint);
|
||||
k = node_adjoint->k();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(node_adjoint);
|
||||
k = node_adjoint->k();
|
||||
if (k == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
|
||||
}
|
||||
|
@ -270,6 +278,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
|
||||
return node_adjoint;
|
||||
}
|
||||
|
||||
void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) {
|
||||
MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>();
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
|
@ -560,7 +569,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// MapToK(func)
|
||||
// Map func graph to K
|
||||
AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
|
||||
auto f = func_graph_to_functor_.find(primal);
|
||||
if (f != func_graph_to_functor_.end()) {
|
||||
|
@ -586,7 +595,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
|
|||
// Construct representation graph for given node.
|
||||
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
||||
ScopeGuard scope_guard(primal->scope());
|
||||
// MapToK(prim)
|
||||
// Map primitive to K
|
||||
if (IsValueNode<Primitive>(primal)) {
|
||||
auto value_node = primal->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
|
@ -605,7 +614,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
|||
}
|
||||
}
|
||||
|
||||
// MapToK(func)
|
||||
// Map func graph to K
|
||||
if (IsValueNode<FuncGraph>(primal)) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(primal);
|
||||
auto k_func = MapToK(func_graph);
|
||||
|
@ -681,7 +690,7 @@ void DFunctor::MapValueObject() {
|
|||
anfnode_to_adjoin_[node] = adjoint;
|
||||
continue;
|
||||
}
|
||||
// Skip Return.
|
||||
// Skip Primitive.
|
||||
if (IsValueNode<Primitive>(node) && GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) {
|
||||
continue;
|
||||
}
|
||||
|
@ -796,12 +805,14 @@ void DFunctor::EliminatePrimalGraph() {
|
|||
auto index = it.first->second;
|
||||
auto vnode = cnode->inputs()[index];
|
||||
if (index != 0) {
|
||||
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
|
||||
MS_LOG(DEBUG) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
|
||||
continue;
|
||||
}
|
||||
cnode->set_input(0, k_vnode); // Replace primal graph with k graph
|
||||
auto construct_wrapper = cnode->func_graph();
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode->debug_info()));
|
||||
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
|
||||
TraceManager::EndTrace();
|
||||
manager->Replace(cnode, getitem0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -194,10 +194,8 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
|||
std::vector<AnfNodePtr> transf_args;
|
||||
TransformArgs(mng, cloned_bprop_fg, outer, &transf_args);
|
||||
|
||||
TraceManager::DebugTrace(std::make_shared<TraceEquiv>(dout->debug_info()));
|
||||
(void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
|
||||
auto out_value = outer->NewCNode(transf_args);
|
||||
TraceManager::EndTrace();
|
||||
|
||||
(void)mng->Replace(out_param, out_value);
|
||||
|
||||
|
|
Loading…
Reference in New Issue