forked from mindspore-Ecosystem/mindspore
Set trace info of Primitive CNode only for equiv out node, not the whole fprop function.
This commit is contained in:
parent
0a9899e3a1
commit
8310236ff1
|
@ -234,8 +234,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
||||||
AdjointPtr node_adjoint = nullptr;
|
AdjointPtr node_adjoint = nullptr;
|
||||||
AnfNodePtr k = nullptr;
|
AnfNodePtr k = nullptr;
|
||||||
if (IsValueNode<Primitive>(node)) {
|
if (IsValueNode<Primitive>(node)) {
|
||||||
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode_morph->debug_info()));
|
k = MapToK(cnode_morph, i);
|
||||||
k = MapToK(node);
|
|
||||||
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
|
node_adjoint = std::make_shared<Adjoint>(node, k, tape_);
|
||||||
anfnode_to_adjoin_[node] = node_adjoint;
|
anfnode_to_adjoin_[node] = node_adjoint;
|
||||||
} else {
|
} else {
|
||||||
|
@ -597,6 +596,31 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
|
||||||
return NewValueNode(functor->k_graph_);
|
return NewValueNode(functor->k_graph_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Construct representation graph for primitive CNode.
|
||||||
|
AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) {
|
||||||
|
auto primal = primal_user->input(index);
|
||||||
|
ScopeGuard scope_guard(primal->scope());
|
||||||
|
// Map primitive to K
|
||||||
|
if (IsValueNode<Primitive>(primal)) {
|
||||||
|
auto value_node = primal->cast<ValueNodePtr>();
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||||
|
if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) {
|
||||||
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
|
||||||
|
need_cut_ = true;
|
||||||
|
}
|
||||||
|
auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_);
|
||||||
|
if (k_prim != nullptr) {
|
||||||
|
return NewValueNode(k_prim);
|
||||||
|
}
|
||||||
|
// When failed to find k_prim, try k_meta.
|
||||||
|
auto k_meta = g_k_prims.KMetaFuncGraph(prim);
|
||||||
|
if (k_meta != nullptr) {
|
||||||
|
return NewValueNode(k_meta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MapToK(primal);
|
||||||
|
}
|
||||||
|
|
||||||
// Construct representation graph for given node.
|
// Construct representation graph for given node.
|
||||||
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
||||||
ScopeGuard scope_guard(primal->scope());
|
ScopeGuard scope_guard(primal->scope());
|
||||||
|
@ -608,7 +632,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
||||||
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
|
MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << ".";
|
||||||
need_cut_ = true;
|
need_cut_ = true;
|
||||||
}
|
}
|
||||||
auto k_prim = g_k_prims.KPrimitive(value_node, resources_);
|
auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_);
|
||||||
if (k_prim != nullptr) {
|
if (k_prim != nullptr) {
|
||||||
return NewValueNode(k_prim);
|
return NewValueNode(k_prim);
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,8 +81,10 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
||||||
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
|
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
|
||||||
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
|
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
|
||||||
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
|
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
|
||||||
// Map Anfnode object from D category to K category.
|
// Map AnfNode object from D category to K category.
|
||||||
AnfNodePtr MapToK(const AnfNodePtr &primal);
|
AnfNodePtr MapToK(const AnfNodePtr &primal);
|
||||||
|
// Map CNode object from D category to K category.
|
||||||
|
AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index);
|
||||||
// Map FuncGraph object from D category to K category.
|
// Map FuncGraph object from D category to K category.
|
||||||
AnfNodePtr MapToK(const FuncGraphPtr &primal);
|
AnfNodePtr MapToK(const FuncGraphPtr &primal);
|
||||||
// MapObject impls.
|
// MapObject impls.
|
||||||
|
@ -129,7 +131,8 @@ class KPrim {
|
||||||
KPrim() = default;
|
KPrim() = default;
|
||||||
~KPrim() = default;
|
~KPrim() = default;
|
||||||
|
|
||||||
FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node,
|
||||||
|
const pipeline::ResourceBasePtr &resources);
|
||||||
MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
|
MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
|
||||||
FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop);
|
FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop);
|
||||||
|
|
||||||
|
@ -145,7 +148,7 @@ class KPrim {
|
||||||
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||||
// Given a bprop rule, do the K mapping.
|
// Given a bprop rule, do the K mapping.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g);
|
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const CNodePtr &cnode);
|
||||||
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg);
|
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg);
|
||||||
void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
|
void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
|
||||||
std::vector<AnfNodePtr> *const transf_args);
|
std::vector<AnfNodePtr> *const transf_args);
|
||||||
|
@ -156,7 +159,7 @@ class KPrim {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(primal);
|
MS_EXCEPTION_IF_NULL(primal);
|
||||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||||
CheckBprop(bprop_fg, primal->ToString());
|
CheckBprop(bprop_fg, primal->ToString());
|
||||||
|
@ -197,8 +200,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
||||||
TransformArgs(mng, cloned_bprop_fg, outer, &transf_args);
|
TransformArgs(mng, cloned_bprop_fg, outer, &transf_args);
|
||||||
|
|
||||||
(void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
|
(void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
|
||||||
auto out_value = outer->NewCNode(transf_args);
|
CNodePtr out_value = nullptr;
|
||||||
|
if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out.
|
||||||
|
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
|
||||||
|
out_value = outer->NewCNode(transf_args);
|
||||||
|
} else {
|
||||||
|
out_value = outer->NewCNode(transf_args);
|
||||||
|
}
|
||||||
(void)mng->Replace(out_param, out_value);
|
(void)mng->Replace(out_param, out_value);
|
||||||
|
|
||||||
TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
|
TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
|
||||||
|
@ -207,7 +215,6 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
||||||
// We remove all parameters except new_dout.
|
// We remove all parameters except new_dout.
|
||||||
std::vector<AnfNodePtr> newBpropParams = {new_dout};
|
std::vector<AnfNodePtr> newBpropParams = {new_dout};
|
||||||
cloned_bprop_fg->set_parameters(newBpropParams);
|
cloned_bprop_fg->set_parameters(newBpropParams);
|
||||||
|
|
||||||
outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
|
outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
|
||||||
return BasicClone(outer);
|
return BasicClone(outer);
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,7 +64,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
|
FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
|
||||||
auto fg = g_k_prims.KPrimitive(value_node, resources);
|
auto fg = g_k_prims.KPrimitive(nullptr, value_node, resources);
|
||||||
if (fg == nullptr) {
|
if (fg == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,7 +102,8 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
|
||||||
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
|
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
|
FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
|
||||||
|
const pipeline::ResourceBasePtr &resources) {
|
||||||
if (!IsValueNode<Primitive>(value_node)) {
|
if (!IsValueNode<Primitive>(value_node)) {
|
||||||
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
|
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
|
||||||
}
|
}
|
||||||
|
@ -141,7 +142,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto expanded_fg = BpropToK(prim, bprop_fg);
|
auto expanded_fg = BpropToK(prim, bprop_fg, cnode);
|
||||||
if (expanded_fg == nullptr) {
|
if (expanded_fg == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
|
MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
|
||||||
<< " prim bprop function to J expanded func graph. NodeInfo: "
|
<< " prim bprop function to J expanded func graph. NodeInfo: "
|
||||||
|
@ -220,7 +221,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
|
||||||
FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
|
FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
|
||||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||||
auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph();
|
auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph();
|
||||||
auto expanded_fg = BpropToK(fprop_fg, bprop_fg);
|
auto expanded_fg = BpropToK(fprop_fg, bprop_fg, nullptr);
|
||||||
if (expanded_fg == nullptr) {
|
if (expanded_fg == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString()
|
MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString()
|
||||||
<< " Cell bprop function to K expanded func graph. NodeInfo: "
|
<< " Cell bprop function to K expanded func graph. NodeInfo: "
|
||||||
|
|
|
@ -33,11 +33,11 @@ class TestGradImplementations : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
|
TEST_F(TestGradImplementations, TestGetAugmentedGraph) {
|
||||||
FuncGraphPtr fg = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
|
FuncGraphPtr fg = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr);
|
||||||
ASSERT_TRUE(fg != nullptr);
|
ASSERT_TRUE(fg != nullptr);
|
||||||
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
|
draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg);
|
||||||
|
|
||||||
auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr);
|
auto fg1 = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr);
|
||||||
|
|
||||||
FuncGraphPairMapEquiv equiv_graph;
|
FuncGraphPairMapEquiv equiv_graph;
|
||||||
NodeMapEquiv equiv_node;
|
NodeMapEquiv equiv_node;
|
||||||
|
|
Loading…
Reference in New Issue