Set trace info of Primitive CNode only for equiv out node, not the whole fprop function.

This commit is contained in:
Zhang Qinghua 2020-11-30 20:13:04 +08:00
parent 0a9899e3a1
commit 8310236ff1
5 changed files with 48 additions and 16 deletions

View File

@ -234,8 +234,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
AdjointPtr node_adjoint = nullptr; AdjointPtr node_adjoint = nullptr;
AnfNodePtr k = nullptr; AnfNodePtr k = nullptr;
if (IsValueNode<Primitive>(node)) { if (IsValueNode<Primitive>(node)) {
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode_morph->debug_info())); k = MapToK(cnode_morph, i);
k = MapToK(node);
node_adjoint = std::make_shared<Adjoint>(node, k, tape_); node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
anfnode_to_adjoin_[node] = node_adjoint; anfnode_to_adjoin_[node] = node_adjoint;
} else { } else {
@ -597,6 +596,31 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
return NewValueNode(functor->k_graph_); return NewValueNode(functor->k_graph_);
} }
// Construct representation graph for primitive CNode.
AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) {
auto primal = primal_user->input(index);
ScopeGuard scope_guard(primal->scope());
// Map primitive to K
if (IsValueNode<Primitive>(primal)) {
auto value_node = primal->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true;
}
auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_);
if (k_prim != nullptr) {
return NewValueNode(k_prim);
}
// When failed to find k_prim, try k_meta.
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
if (k_meta != nullptr) {
return NewValueNode(k_meta);
}
}
return MapToK(primal);
}
// Construct representation graph for given node. // Construct representation graph for given node.
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
ScopeGuard scope_guard(primal->scope()); ScopeGuard scope_guard(primal->scope());
@ -608,7 +632,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
need_cut_ = true; need_cut_ = true;
} }
auto k_prim = g_k_prims.KPrimitive(value_node, resources_); auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_);
if (k_prim != nullptr) { if (k_prim != nullptr) {
return NewValueNode(k_prim); return NewValueNode(k_prim);
} }

View File

@ -81,8 +81,10 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
// Map Anfnode object from D category to K category. // Map AnfNode object from D category to K category.
AnfNodePtr MapToK(const AnfNodePtr &primal); AnfNodePtr MapToK(const AnfNodePtr &primal);
// Map CNode object from D category to K category.
AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index);
// Map FuncGraph object from D category to K category. // Map FuncGraph object from D category to K category.
AnfNodePtr MapToK(const FuncGraphPtr &primal); AnfNodePtr MapToK(const FuncGraphPtr &primal);
// MapObject impls. // MapObject impls.
@ -129,7 +131,8 @@ class KPrim {
KPrim() = default; KPrim() = default;
~KPrim() = default; ~KPrim() = default;
FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node,
const pipeline::ResourceBasePtr &resources);
MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop);
@ -145,7 +148,7 @@ class KPrim {
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
// Given a bprop rule, do the K mapping. // Given a bprop rule, do the K mapping.
template <typename T> template <typename T>
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const CNodePtr &cnode);
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg);
void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
std::vector<AnfNodePtr> *const transf_args); std::vector<AnfNodePtr> *const transf_args);
@ -156,7 +159,7 @@ class KPrim {
}; };
template <typename T> template <typename T>
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(primal); MS_EXCEPTION_IF_NULL(primal);
MS_EXCEPTION_IF_NULL(bprop_fg); MS_EXCEPTION_IF_NULL(bprop_fg);
CheckBprop(bprop_fg, primal->ToString()); CheckBprop(bprop_fg, primal->ToString());
@ -197,8 +200,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); TransformArgs(mng, cloned_bprop_fg, outer, &transf_args);
(void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); (void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
auto out_value = outer->NewCNode(transf_args); CNodePtr out_value = nullptr;
if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out.
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
out_value = outer->NewCNode(transf_args);
} else {
out_value = outer->NewCNode(transf_args);
}
(void)mng->Replace(out_param, out_value); (void)mng->Replace(out_param, out_value);
TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info())); TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
@ -207,7 +215,6 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
// We remove all parameters except new_dout. // We remove all parameters except new_dout.
std::vector<AnfNodePtr> newBpropParams = {new_dout}; std::vector<AnfNodePtr> newBpropParams = {new_dout};
cloned_bprop_fg->set_parameters(newBpropParams); cloned_bprop_fg->set_parameters(newBpropParams);
outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
return BasicClone(outer); return BasicClone(outer);
} }

View File

@ -64,7 +64,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
} }
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
auto fg = g_k_prims.KPrimitive(value_node, resources); auto fg = g_k_prims.KPrimitive(nullptr, value_node, resources);
if (fg == nullptr) { if (fg == nullptr) {
return nullptr; return nullptr;
} }

View File

@ -102,7 +102,8 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
} }
FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
const pipeline::ResourceBasePtr &resources) {
if (!IsValueNode<Primitive>(value_node)) { if (!IsValueNode<Primitive>(value_node)) {
MS_LOG(EXCEPTION) << "Primitive node is not valid."; MS_LOG(EXCEPTION) << "Primitive node is not valid.";
} }
@ -141,7 +142,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
} }
} }
auto expanded_fg = BpropToK(prim, bprop_fg); auto expanded_fg = BpropToK(prim, bprop_fg, cnode);
if (expanded_fg == nullptr) { if (expanded_fg == nullptr) {
MS_LOG(EXCEPTION) << "Failed convert " << prim->name() MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
<< " prim bprop function to J expanded func graph. NodeInfo: " << " prim bprop function to J expanded func graph. NodeInfo: "
@ -220,7 +221,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
MS_EXCEPTION_IF_NULL(bprop_fg); MS_EXCEPTION_IF_NULL(bprop_fg);
auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph();
auto expanded_fg = BpropToK(fprop_fg, bprop_fg); auto expanded_fg = BpropToK(fprop_fg, bprop_fg, nullptr);
if (expanded_fg == nullptr) { if (expanded_fg == nullptr) {
MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString()
<< " Cell bprop function to K expanded func graph. NodeInfo: " << " Cell bprop function to K expanded func graph. NodeInfo: "

View File

@ -33,11 +33,11 @@ class TestGradImplementations : public UT::Common {
}; };
TEST_F(TestGradImplementations, TestGetAugmentedGraph) { TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
FuncGraphPtr fg = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); FuncGraphPtr fg = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr);
ASSERT_TRUE(fg != nullptr); ASSERT_TRUE(fg != nullptr);
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg); draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); auto fg1 = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr);
FuncGraphPairMapEquiv equiv_graph; FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node; NodeMapEquiv equiv_node;