add support for stop_gradient primitive. And defer BackPropagation to End API

This commit is contained in:
zhousiyi 2021-02-25 02:24:13 +00:00 committed by chujinjin
parent 76bc0734d5
commit b40816e66a
3 changed files with 199 additions and 60 deletions

View File

@ -64,7 +64,12 @@ AnfNodePtr Adjoint::primal() { return primal_; }
AnfNodePtr Adjoint::dout() { return dout_hole_; }
AnfNodePtr Adjoint::RealDout() { return dout_; }
AnfNodePtr Adjoint::RealDout() {
if (dout_ != nullptr) {
return dout_;
}
return dout_hole_;
}
void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) {
dout_user_.emplace_back(std::make_pair(user, index));

View File

@ -25,9 +25,6 @@
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/optimizer/ad/kpynative.h"
#include "frontend/operator/ops.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "utils/ms_context.h"
#include "utils/info.h"
#include "debug/anf_ir_dump.h"
#include "debug/trace.h"
@ -36,6 +33,32 @@ namespace mindspore {
namespace ad {
extern KPrim g_k_prims;
class PynativeAdjoint {
public:
PynativeAdjoint(const AdjointPtr &adjoint, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg)
: adjoint_(adjoint), op_args_(op_args), out_(out), bprop_fg_(bprop_fg) {}
AnfNodePtrList &users() { return users_; }
AdjointPtr &adjoint() { return adjoint_; }
const ValuePtrList &op_args() { return op_args_; }
const ValuePtr &out() { return out_; }
const FuncGraphPtr &bprop_fg() { return bprop_fg_; }
void ReplaceDoutHole() { adjoint_->CallDoutHole(); }
AnfNodePtr RealDout() { return adjoint_->RealDout(); }
void AccumulateDout(const AnfNodePtr &dout_factor) { adjoint_->AccumulateDout(dout_factor); }
private:
AnfNodePtrList users_;
AdjointPtr adjoint_;
// cache these arguments from ad caller.
const ValuePtrList op_args_;
const ValuePtr out_;
// bprop_fg passed from ad caller, it may be user defined back propagate funcgragh.
const FuncGraphPtr bprop_fg_;
};
using PynativeAdjointPtr = std::shared_ptr<PynativeAdjoint>;
class KPynativeCellImpl : public KPynativeCell {
public:
explicit KPynativeCellImpl(const AnfNodePtrList &cell_inputs) : cell_inputs_(cell_inputs) {
@ -45,21 +68,26 @@ class KPynativeCellImpl : public KPynativeCell {
}
}
~KPynativeCellImpl() override = default;
bool KPynativeOp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out);
bool KPynativeWithBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
bool KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg);
FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
private:
FuncGraphPtr tape_;
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
OrderedMap<AnfNodePtr, PynativeAdjointPtr> anfnode_to_adjoin_;
AnfNodePtrList cell_inputs_;
// Last cnode of this Cell, may be a primitve op or cell with user defined bprop.
AnfNodePtr last_node_;
AnfNodePtr last_node_{nullptr};
bool need_propagate_stop_gradient_{false};
bool BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg);
void PropagateStopGradient();
bool AllReferencesStopped(const CNodePtr &curr_cnode);
// Back propagate for all node;
bool BackPropagate();
bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app);
bool BuildBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg);
};
using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
@ -74,6 +102,9 @@ FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePt
}
FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) {
// propagate stop_gradient flag to cnode before back propagate;
PropagateStopGradient();
for (size_t i = 0; i < weights.size(); ++i) {
tape_->add_parameter();
}
@ -86,10 +117,8 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
// Set dout of last node to sens;
last_node_adjoint_iter->second->AccumulateDout(sens_param);
// Replace dout hole of all adjoint.
for (auto &adjoint_iter : anfnode_to_adjoin_) {
adjoint_iter.second->CallDoutHole();
}
// BackPropagate sensitivity;
BackPropagate();
// Return the gradient;
AnfNodePtrList node_list{NewValueNode(prim::kPrimMakeTuple)};
@ -133,45 +162,82 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
return tape_;
}
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out) {
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
return k_cell_impl->KPynativeOp(c_node, op_args, out);
return k_cell_impl->KPynativeOp(cnode, op_args, out);
}
bool KPynativeCellImpl::KPynativeOp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out) {
MS_EXCEPTION_IF_NULL(c_node);
auto prim = GetCNodePrimitive(c_node);
bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out) {
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "should be primitive, but: " << c_node->DebugString();
MS_LOG(EXCEPTION) << "should be primitive, but: " << cnode->DebugString();
}
if (IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
need_propagate_stop_gradient_ = true;
}
auto bprop_fg = g_k_prims.GetBprop(prim);
MS_EXCEPTION_IF_NULL(bprop_fg);
BuildBProp(c_node, op_args, out, bprop_fg);
BuildAdjoint(cnode, op_args, out, bprop_fg);
return true;
}
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out, const FuncGraphPtr &bprop_fg) {
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
return k_cell_impl->KPynativeWithBProp(c_node, op_args, out, bprop_fg);
return k_cell_impl->KPynativeWithBProp(cnode, op_args, out, bprop_fg);
}
bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg) {
MS_EXCEPTION_IF_NULL(c_node);
auto primal_fg = GetCNodeFuncGraph(c_node);
MS_EXCEPTION_IF_NULL(cnode);
auto primal_fg = GetCNodeFuncGraph(cnode);
if (primal_fg == nullptr) {
MS_LOG(EXCEPTION) << "should be func graph, but: " << c_node->DebugString();
MS_LOG(EXCEPTION) << "should be func graph, but: " << cnode->DebugString();
}
MS_EXCEPTION_IF_NULL(bprop_fg);
BuildBProp(c_node, op_args, out, bprop_fg);
BuildAdjoint(cnode, op_args, out, bprop_fg);
return true;
}
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg) {
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "CNode should be unique, but: " << cnode->DebugString();
}
// Book-keeping last cnode, as dout of this node will be given from outside;
last_node_ = cnode;
auto cnode_adjoint = std::make_shared<Adjoint>(cnode, NewValueNode(out), tape_);
auto cnode_pynative_adjoint = std::make_shared<PynativeAdjoint>(cnode_adjoint, op_args, out, bprop_fg);
anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_pynative_adjoint));
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto inp_i = cnode->input(i);
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(inp_i);
if (anfnode_adjoint_iter == anfnode_to_adjoin_.end()) {
if (inp_i->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot find adjoint for anfnode: " << inp_i->DebugString();
} else {
auto inp_i_adjoint = std::make_shared<Adjoint>(inp_i, NewValueNode(op_args[i - 1]), tape_);
auto inp_i_pynative_adjoint =
std::make_shared<PynativeAdjoint>(inp_i_adjoint, ValuePtrList{}, nullptr, nullptr);
anfnode_to_adjoin_.insert(std::make_pair(inp_i, inp_i_pynative_adjoint));
inp_i_pynative_adjoint->users().push_back(cnode);
}
} else {
anfnode_adjoint_iter->second->users().push_back(cnode);
}
}
return true;
}
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out) {
auto optimized_bprop_fg =
pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, c_node, op_args, out);
@ -192,42 +258,78 @@ bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodeP
return true;
}
bool KPynativeCellImpl::BuildBProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg) {
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(c_node);
if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "CNode should be unique, but: " << c_node->DebugString();
bool KPynativeCellImpl::BackPropagate() {
for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
if (!iter->first->isa<CNode>()) {
continue;
}
auto cnode = iter->first->cast<CNodePtr>();
if (cnode->stop_gradient()) {
MS_LOG(DEBUG) << "Bypass backpropagate for cnode with stop_gradient flag: " << cnode->ToString();
continue;
}
auto bprop_fg = iter->second->bprop_fg();
if (bprop_fg == nullptr) {
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "should be primitive, but: " << cnode->DebugString();
}
bprop_fg = g_k_prims.GetBprop(prim);
MS_EXCEPTION_IF_NULL(bprop_fg);
}
// Optimize the bprop_fg based on value.
auto optimized_bprop_fg = OptimizeBPropFuncGraph(bprop_fg, cnode, iter->second->op_args(), iter->second->out());
AnfNodePtrList node_list{NewValueNode(optimized_bprop_fg)};
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto inp_i = cnode->input(i);
node_list.push_back(inp_i);
}
node_list.push_back(NewValueNode(iter->second->out()));
node_list.push_back(iter->second->RealDout());
auto bprop_app = tape_->NewCNode(node_list);
BackPropagate(cnode, bprop_app);
}
// Book-keeping last cnode, as dout of this node will be given from outside;
last_node_ = c_node;
auto cnode_adjoint = std::make_shared<Adjoint>(c_node, NewValueNode(out), tape_);
anfnode_to_adjoin_.emplace(c_node, cnode_adjoint);
return true;
}
// Optimize the bprop_fg based on value.
auto optimized_bprop_fg = OptimizeBPropFuncGraph(bprop_fg, c_node, op_args, out);
AnfNodePtrList node_list{NewValueNode(optimized_bprop_fg)};
bool KPynativeCellImpl::AllReferencesStopped(const CNodePtr &curr_cnode) {
// If all CNode use curr_cnode has stop_gradient_ flag, then curr_cnode also can set that flag.
auto iter = anfnode_to_adjoin_.find(curr_cnode);
if (iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "Cannot adjoint for cnode: " << curr_cnode->DebugString();
}
auto users = iter->second->users();
if (users.empty()) {
return false;
}
auto all_users_have_stopped = std::all_of(users.cbegin(), users.cend(), [](const AnfNodePtr &user) {
if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
return false;
}
return true;
});
return all_users_have_stopped;
}
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
auto inp_i = c_node->input(i);
auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(inp_i);
if (anfnode_adjoint_iter == anfnode_to_adjoin_.end()) {
if (inp_i->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot find adjoint for anfnode: " << inp_i->DebugString();
} else {
auto inp_i_adjoint = std::make_shared<Adjoint>(inp_i, NewValueNode(op_args[i - 1]), tape_);
anfnode_to_adjoin_.emplace(inp_i, inp_i_adjoint);
void KPynativeCellImpl::PropagateStopGradient() {
// propagate need_stop_gradient_ to cnode before back propagate;
if (need_propagate_stop_gradient_) {
for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
const auto &node = iter->first;
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!cnode->stop_gradient()) {
// Cut off the cnode only when it's not referred any more
if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
AllReferencesStopped(cnode)) {
MS_LOG(DEBUG) << "Set stop_gradient flag for " << cnode->ToString();
cnode->set_stop_gradient(true);
}
}
}
}
node_list.push_back(inp_i);
}
node_list.push_back(NewValueNode(out));
node_list.push_back(cnode_adjoint->dout());
auto bprop_app = tape_->NewCNode(node_list);
cnode_adjoint->RegisterDoutUser(bprop_app, node_list.size() - 1);
BackPropagate(c_node, bprop_app);
return true;
}
} // namespace ad
} // namespace mindspore

View File

@ -55,6 +55,28 @@ class TestKPynative : public UT::Common {
return g;
}
// a = x * y
// b = stop_gradient(a)
// c = b * y
// return c
FuncGraphPtr BuildStopGradient(const std::string &testCase) {
auto g = std::make_shared<FuncGraph>();
auto x = g->add_parameter();
auto y = g->add_parameter();
x->set_abstract(BuildArg());
y->set_abstract(BuildArg());
auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
a_node->set_abstract(BuildArg());
auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
b_node->set_abstract(BuildArg());
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
c_node->set_abstract(BuildArg());
auto d_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
d_node->set_abstract(BuildArg());
g->set_output(d_node);
return g;
}
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
auto k_pynative_cell = GradPynativeCellBegin(primal_fg->parameters());
auto node_list = TopoSort(primal_fg->output());
@ -74,7 +96,6 @@ class TestKPynative : public UT::Common {
}
};
TEST_F(TestKPynative, test_simple_add) {
auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
resource->manager()->KeepRoots({primal_fg});
@ -85,5 +106,16 @@ TEST_F(TestKPynative, test_simple_add) {
ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg);
}
TEST_F(TestKPynative, test_stop_gradient) {
auto primal_fg = BuildStopGradient("test_stop_gradient");
resource->manager()->KeepRoots({primal_fg});
ExportIR(primal_fg->ToString() + ".dat", "", primal_fg);
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
resource->manager()->KeepRoots({bprop_fg});
ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg);
}
} // namespace ad
} // namespace mindspore