forked from mindspore-Ecosystem/mindspore
add support for stop_gradient primitive. And defer BackPropagation to End API
This commit is contained in:
parent
76bc0734d5
commit
b40816e66a
|
@ -64,7 +64,12 @@ AnfNodePtr Adjoint::primal() { return primal_; }
|
|||
|
||||
AnfNodePtr Adjoint::dout() { return dout_hole_; }
|
||||
|
||||
AnfNodePtr Adjoint::RealDout() { return dout_; }
|
||||
AnfNodePtr Adjoint::RealDout() {
|
||||
if (dout_ != nullptr) {
|
||||
return dout_;
|
||||
}
|
||||
return dout_hole_;
|
||||
}
|
||||
|
||||
void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) {
|
||||
dout_user_.emplace_back(std::make_pair(user, index));
|
||||
|
|
|
@ -25,9 +25,6 @@
|
|||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "frontend/optimizer/ad/kpynative.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/info.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/trace.h"
|
||||
|
@ -36,6 +33,32 @@ namespace mindspore {
|
|||
namespace ad {
|
||||
extern KPrim g_k_prims;
|
||||
|
||||
class PynativeAdjoint {
|
||||
public:
|
||||
PynativeAdjoint(const AdjointPtr &adjoint, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg)
|
||||
: adjoint_(adjoint), op_args_(op_args), out_(out), bprop_fg_(bprop_fg) {}
|
||||
|
||||
AnfNodePtrList &users() { return users_; }
|
||||
AdjointPtr &adjoint() { return adjoint_; }
|
||||
const ValuePtrList &op_args() { return op_args_; }
|
||||
const ValuePtr &out() { return out_; }
|
||||
const FuncGraphPtr &bprop_fg() { return bprop_fg_; }
|
||||
void ReplaceDoutHole() { adjoint_->CallDoutHole(); }
|
||||
AnfNodePtr RealDout() { return adjoint_->RealDout(); }
|
||||
void AccumulateDout(const AnfNodePtr &dout_factor) { adjoint_->AccumulateDout(dout_factor); }
|
||||
|
||||
private:
|
||||
AnfNodePtrList users_;
|
||||
AdjointPtr adjoint_;
|
||||
// cache these arguments from ad caller.
|
||||
const ValuePtrList op_args_;
|
||||
const ValuePtr out_;
|
||||
// bprop_fg passed from ad caller, it may be user defined back propagate funcgragh.
|
||||
const FuncGraphPtr bprop_fg_;
|
||||
};
|
||||
using PynativeAdjointPtr = std::shared_ptr<PynativeAdjoint>;
|
||||
|
||||
class KPynativeCellImpl : public KPynativeCell {
|
||||
public:
|
||||
explicit KPynativeCellImpl(const AnfNodePtrList &cell_inputs) : cell_inputs_(cell_inputs) {
|
||||
|
@ -45,21 +68,26 @@ class KPynativeCellImpl : public KPynativeCell {
|
|||
}
|
||||
}
|
||||
~KPynativeCellImpl() override = default;
|
||||
bool KPynativeOp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out);
|
||||
bool KPynativeWithBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
|
||||
bool KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg);
|
||||
FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
|
||||
|
||||
private:
|
||||
FuncGraphPtr tape_;
|
||||
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
|
||||
OrderedMap<AnfNodePtr, PynativeAdjointPtr> anfnode_to_adjoin_;
|
||||
AnfNodePtrList cell_inputs_;
|
||||
// Last cnode of this Cell, may be a primitve op or cell with user defined bprop.
|
||||
AnfNodePtr last_node_;
|
||||
AnfNodePtr last_node_{nullptr};
|
||||
bool need_propagate_stop_gradient_{false};
|
||||
|
||||
bool BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg);
|
||||
void PropagateStopGradient();
|
||||
bool AllReferencesStopped(const CNodePtr &curr_cnode);
|
||||
// Back propagate for all node;
|
||||
bool BackPropagate();
|
||||
bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app);
|
||||
bool BuildBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg);
|
||||
};
|
||||
using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
|
||||
|
||||
|
@ -74,6 +102,9 @@ FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePt
|
|||
}
|
||||
|
||||
FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) {
|
||||
// propagate stop_gradient flag to cnode before back propagate;
|
||||
PropagateStopGradient();
|
||||
|
||||
for (size_t i = 0; i < weights.size(); ++i) {
|
||||
tape_->add_parameter();
|
||||
}
|
||||
|
@ -86,10 +117,8 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
|
|||
// Set dout of last node to sens;
|
||||
last_node_adjoint_iter->second->AccumulateDout(sens_param);
|
||||
|
||||
// Replace dout hole of all adjoint.
|
||||
for (auto &adjoint_iter : anfnode_to_adjoin_) {
|
||||
adjoint_iter.second->CallDoutHole();
|
||||
}
|
||||
// BackPropagate sensitivity;
|
||||
BackPropagate();
|
||||
|
||||
// Return the gradient;
|
||||
AnfNodePtrList node_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
|
@ -133,45 +162,82 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
|
|||
return tape_;
|
||||
}
|
||||
|
||||
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out) {
|
||||
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
|
||||
return k_cell_impl->KPynativeOp(c_node, op_args, out);
|
||||
return k_cell_impl->KPynativeOp(cnode, op_args, out);
|
||||
}
|
||||
|
||||
bool KPynativeCellImpl::KPynativeOp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out) {
|
||||
MS_EXCEPTION_IF_NULL(c_node);
|
||||
auto prim = GetCNodePrimitive(c_node);
|
||||
bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "should be primitive, but: " << c_node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "should be primitive, but: " << cnode->DebugString();
|
||||
}
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
|
||||
need_propagate_stop_gradient_ = true;
|
||||
}
|
||||
|
||||
auto bprop_fg = g_k_prims.GetBprop(prim);
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
BuildBProp(c_node, op_args, out, bprop_fg);
|
||||
BuildAdjoint(cnode, op_args, out, bprop_fg);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const FuncGraphPtr &bprop_fg) {
|
||||
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
|
||||
return k_cell_impl->KPynativeWithBProp(c_node, op_args, out, bprop_fg);
|
||||
return k_cell_impl->KPynativeWithBProp(cnode, op_args, out, bprop_fg);
|
||||
}
|
||||
|
||||
bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg) {
|
||||
MS_EXCEPTION_IF_NULL(c_node);
|
||||
auto primal_fg = GetCNodeFuncGraph(c_node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto primal_fg = GetCNodeFuncGraph(cnode);
|
||||
if (primal_fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "should be func graph, but: " << c_node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "should be func graph, but: " << cnode->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
BuildBProp(c_node, op_args, out, bprop_fg);
|
||||
BuildAdjoint(cnode, op_args, out, bprop_fg);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg) {
|
||||
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
|
||||
if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(EXCEPTION) << "CNode should be unique, but: " << cnode->DebugString();
|
||||
}
|
||||
// Book-keeping last cnode, as dout of this node will be given from outside;
|
||||
last_node_ = cnode;
|
||||
auto cnode_adjoint = std::make_shared<Adjoint>(cnode, NewValueNode(out), tape_);
|
||||
auto cnode_pynative_adjoint = std::make_shared<PynativeAdjoint>(cnode_adjoint, op_args, out, bprop_fg);
|
||||
anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_pynative_adjoint));
|
||||
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto inp_i = cnode->input(i);
|
||||
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(inp_i);
|
||||
if (anfnode_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
if (inp_i->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "cannot find adjoint for anfnode: " << inp_i->DebugString();
|
||||
} else {
|
||||
auto inp_i_adjoint = std::make_shared<Adjoint>(inp_i, NewValueNode(op_args[i - 1]), tape_);
|
||||
auto inp_i_pynative_adjoint =
|
||||
std::make_shared<PynativeAdjoint>(inp_i_adjoint, ValuePtrList{}, nullptr, nullptr);
|
||||
anfnode_to_adjoin_.insert(std::make_pair(inp_i, inp_i_pynative_adjoint));
|
||||
inp_i_pynative_adjoint->users().push_back(cnode);
|
||||
}
|
||||
} else {
|
||||
anfnode_adjoint_iter->second->users().push_back(cnode);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out) {
|
||||
auto optimized_bprop_fg =
|
||||
pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, c_node, op_args, out);
|
||||
|
@ -192,42 +258,78 @@ bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodeP
|
|||
return true;
|
||||
}
|
||||
|
||||
bool KPynativeCellImpl::BuildBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg) {
|
||||
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(c_node);
|
||||
if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(EXCEPTION) << "CNode should be unique, but: " << c_node->DebugString();
|
||||
bool KPynativeCellImpl::BackPropagate() {
|
||||
for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
|
||||
if (!iter->first->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = iter->first->cast<CNodePtr>();
|
||||
if (cnode->stop_gradient()) {
|
||||
MS_LOG(DEBUG) << "Bypass backpropagate for cnode with stop_gradient flag: " << cnode->ToString();
|
||||
continue;
|
||||
}
|
||||
auto bprop_fg = iter->second->bprop_fg();
|
||||
if (bprop_fg == nullptr) {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "should be primitive, but: " << cnode->DebugString();
|
||||
}
|
||||
bprop_fg = g_k_prims.GetBprop(prim);
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
}
|
||||
// Optimize the bprop_fg based on value.
|
||||
auto optimized_bprop_fg = OptimizeBPropFuncGraph(bprop_fg, cnode, iter->second->op_args(), iter->second->out());
|
||||
AnfNodePtrList node_list{NewValueNode(optimized_bprop_fg)};
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto inp_i = cnode->input(i);
|
||||
node_list.push_back(inp_i);
|
||||
}
|
||||
node_list.push_back(NewValueNode(iter->second->out()));
|
||||
node_list.push_back(iter->second->RealDout());
|
||||
|
||||
auto bprop_app = tape_->NewCNode(node_list);
|
||||
BackPropagate(cnode, bprop_app);
|
||||
}
|
||||
// Book-keeping last cnode, as dout of this node will be given from outside;
|
||||
last_node_ = c_node;
|
||||
auto cnode_adjoint = std::make_shared<Adjoint>(c_node, NewValueNode(out), tape_);
|
||||
anfnode_to_adjoin_.emplace(c_node, cnode_adjoint);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Optimize the bprop_fg based on value.
|
||||
auto optimized_bprop_fg = OptimizeBPropFuncGraph(bprop_fg, c_node, op_args, out);
|
||||
AnfNodePtrList node_list{NewValueNode(optimized_bprop_fg)};
|
||||
bool KPynativeCellImpl::AllReferencesStopped(const CNodePtr &curr_cnode) {
|
||||
// If all CNode use curr_cnode has stop_gradient_ flag, then curr_cnode also can set that flag.
|
||||
auto iter = anfnode_to_adjoin_.find(curr_cnode);
|
||||
if (iter == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot adjoint for cnode: " << curr_cnode->DebugString();
|
||||
}
|
||||
auto users = iter->second->users();
|
||||
if (users.empty()) {
|
||||
return false;
|
||||
}
|
||||
auto all_users_have_stopped = std::all_of(users.cbegin(), users.cend(), [](const AnfNodePtr &user) {
|
||||
if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return all_users_have_stopped;
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
|
||||
auto inp_i = c_node->input(i);
|
||||
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(inp_i);
|
||||
if (anfnode_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
if (inp_i->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "cannot find adjoint for anfnode: " << inp_i->DebugString();
|
||||
} else {
|
||||
auto inp_i_adjoint = std::make_shared<Adjoint>(inp_i, NewValueNode(op_args[i - 1]), tape_);
|
||||
anfnode_to_adjoin_.emplace(inp_i, inp_i_adjoint);
|
||||
void KPynativeCellImpl::PropagateStopGradient() {
|
||||
// propagate need_stop_gradient_ to cnode before back propagate;
|
||||
if (need_propagate_stop_gradient_) {
|
||||
for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
|
||||
const auto &node = iter->first;
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!cnode->stop_gradient()) {
|
||||
// Cut off the cnode only when it's not referred any more
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
|
||||
AllReferencesStopped(cnode)) {
|
||||
MS_LOG(DEBUG) << "Set stop_gradient flag for " << cnode->ToString();
|
||||
cnode->set_stop_gradient(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node_list.push_back(inp_i);
|
||||
}
|
||||
node_list.push_back(NewValueNode(out));
|
||||
node_list.push_back(cnode_adjoint->dout());
|
||||
|
||||
auto bprop_app = tape_->NewCNode(node_list);
|
||||
cnode_adjoint->RegisterDoutUser(bprop_app, node_list.size() - 1);
|
||||
BackPropagate(c_node, bprop_app);
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,28 @@ class TestKPynative : public UT::Common {
|
|||
return g;
|
||||
}
|
||||
|
||||
// a = x * y
|
||||
// b = stop_gradient(a)
|
||||
// c = b * y
|
||||
// return c
|
||||
FuncGraphPtr BuildStopGradient(const std::string &testCase) {
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
auto x = g->add_parameter();
|
||||
auto y = g->add_parameter();
|
||||
x->set_abstract(BuildArg());
|
||||
y->set_abstract(BuildArg());
|
||||
auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
|
||||
a_node->set_abstract(BuildArg());
|
||||
auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
|
||||
b_node->set_abstract(BuildArg());
|
||||
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
|
||||
c_node->set_abstract(BuildArg());
|
||||
auto d_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
|
||||
d_node->set_abstract(BuildArg());
|
||||
g->set_output(d_node);
|
||||
return g;
|
||||
}
|
||||
|
||||
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
|
||||
auto k_pynative_cell = GradPynativeCellBegin(primal_fg->parameters());
|
||||
auto node_list = TopoSort(primal_fg->output());
|
||||
|
@ -74,7 +96,6 @@ class TestKPynative : public UT::Common {
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
TEST_F(TestKPynative, test_simple_add) {
|
||||
auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
|
||||
resource->manager()->KeepRoots({primal_fg});
|
||||
|
@ -85,5 +106,16 @@ TEST_F(TestKPynative, test_simple_add) {
|
|||
|
||||
ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg);
|
||||
}
|
||||
|
||||
TEST_F(TestKPynative, test_stop_gradient) {
|
||||
auto primal_fg = BuildStopGradient("test_stop_gradient");
|
||||
resource->manager()->KeepRoots({primal_fg});
|
||||
ExportIR(primal_fg->ToString() + ".dat", "", primal_fg);
|
||||
|
||||
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
|
||||
resource->manager()->KeepRoots({bprop_fg});
|
||||
|
||||
ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg);
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue