add API to build bprop funcgraph which can be used for higher grad

This commit is contained in:
zhousiyi 2021-04-09 02:18:13 +00:00 committed by chujinjin
parent c759effa94
commit 4ff19b7082
3 changed files with 149 additions and 37 deletions

View File

@ -211,6 +211,9 @@ class PynativeAdjoint {
dout_ = dout_factor;
}
AnfNodePtr k_node() const { return k_node_; }
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
private:
const FuncGraphPtr tape_;
AnfNodePtr dout_{nullptr};
@ -221,6 +224,8 @@ class PynativeAdjoint {
const ValuePtr out_;
// bprop_fg passed from ad caller, it may be user defined back propagate funcgragh.
const FuncGraphPtr bprop_fg_;
// k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
AnfNodePtr k_node_;
};
using PynativeAdjointPtr = std::shared_ptr<PynativeAdjoint>;
@ -243,7 +248,10 @@ class KPynativeCellImpl : public KPynativeCell {
bool KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &bprop_fg);
void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) override;
// Build a back propagate funcgraph with each cnode in primal funcgraph is replaced by value node;
FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights, bool has_sens_arg);
// Build a back propagate funcgraph with formal cnode;
FuncGraphPtr BuildFormalBProp(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
private:
FuncGraphPtr tape_;
@ -263,13 +271,22 @@ class KPynativeCellImpl : public KPynativeCell {
void PropagateStopGradient();
bool AllReferencesStopped(const CNodePtr &curr_cnode);
// Back propagate for all node;
bool BackPropagate();
// if by_value is true, in bprop_app cnode, every input is value node;
// if by_value is false, in bprop_app cnode, input is the k mapped node, so it can be grad again.
bool BackPropagate(bool by_value);
bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app);
FuncGraphPtr BuildBpropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode);
FuncGraphPtr BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode);
// Back propagate for MakeList or MakeTuple is generated from MetaFuncGraph.
FuncGraphPtr BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode);
// Replace input or weights parameter from primal funcgraph to parameters of tape_;
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
// Set return node according to grad flag
void SetOutput(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
// for higher order gradient;
// Build k mapped node owned by tape_ for each cnode in primal funcgraph, so these node can be
// used in tape_ to keep tracking the cnode dependency.
bool BuildKNode();
};
using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
@ -323,26 +340,11 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
}
// BackPropagate sensitivity;
BackPropagate();
BackPropagate(true);
// Return the gradient;
SetOutput(weights, grad_inputs, grad_weights);
// Replace AnfNode with parameter of tape_;
auto mng = MakeManager({tape_}, false);
auto tr = mng->Transact();
const auto &parameters = tape_->parameters();
auto cell_inputs_size = cell_inputs_.size();
for (size_t i = 0; i < cell_inputs_size; ++i) {
tr.Replace(cell_inputs_[i], parameters[i]);
}
// (Inputs, sens, weights) or (Inputs, weights)
size_t weight_offset = cell_inputs_size;
if (has_sens_arg) {
weight_offset = weight_offset + 1;
}
for (size_t i = 0; i < weights.size(); ++i) {
tr.Replace(weights[i], parameters[weight_offset + i]);
}
tr.Commit();
// Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, has_sens_arg);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("before_final_opt.ir", tape_);
@ -350,6 +352,46 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
return tape_;
}
FuncGraphPtr GradPynativeCellBuildFormalBProp(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
bool grad_inputs, bool grad_weights) {
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
return k_cell_impl->BuildFormalBProp(weights, grad_inputs, grad_weights);
}
FuncGraphPtr KPynativeCellImpl::BuildFormalBProp(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) {
// propagate stop_gradient flag to cnode before back propagate;
PropagateStopGradient();
MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->ToString();
}
// Build forward CNode;
BuildKNode();
// Add weights parameter
for (const auto &weight : weights) {
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
auto p = tape_->add_parameter();
auto input_w = weight->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(input_w);
p->set_default_param(input_w->default_param());
}
auto sens_node = BuildOnesLikeValue(tape_, last_node_adjoint_iter->second->out());
last_node_adjoint_iter->second->AccumulateDout(sens_node);
// BackPropagate sensitivity;
BackPropagate(false);
// Return the gradient;
SetOutput(weights, grad_inputs, grad_weights);
// Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, false);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("formal_bprop_before_final_opt.ir", tape_);
}
return tape_;
}
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out) {
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
@ -368,7 +410,7 @@ bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &o
FuncGraphPtr bprop_fg = nullptr;
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward)) {
bprop_fg = BuildBpropCutFuncGraph(prim, cnode);
bprop_fg = BuildBPropCutFuncGraph(prim, cnode);
} else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
bprop_fg = BuildMakeSequenceBprop(prim, cnode);
} else {
@ -638,7 +680,7 @@ bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodeP
return true;
}
bool KPynativeCellImpl::BackPropagate() {
bool KPynativeCellImpl::BackPropagate(bool by_value) {
for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
if (!iter->first->isa<CNode>()) {
continue;
@ -653,18 +695,34 @@ bool KPynativeCellImpl::BackPropagate() {
MS_EXCEPTION_IF_NULL(bprop_fg);
AnfNodePtrList node_list{NewValueNode(bprop_fg)};
std::transform(iter->second->op_args().begin(), iter->second->op_args().end(), std::back_inserter(node_list),
[](const ValuePtr &value) { return NewValueNode(value); });
node_list.push_back(NewValueNode(iter->second->out()));
node_list.push_back(iter->second->RealDout());
// Update abstract info of valuenode with its value
for (size_t i = 1; i < node_list.size() - 1; ++i) {
auto v_node = node_list[i]->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(v_node);
auto value = v_node->value();
if (v_node->abstract() == nullptr && value != nullptr && value->ToAbstract() != nullptr) {
v_node->set_abstract(value->ToAbstract()->Broaden());
}
if (by_value) {
std::transform(iter->second->op_args().begin(), iter->second->op_args().end(), std::back_inserter(node_list),
[](const ValuePtr &value) {
auto v_node = NewValueNode(value);
v_node->set_abstract(value->ToAbstract()->Broaden());
return v_node;
});
auto out_node = NewValueNode(iter->second->out());
out_node->set_abstract(iter->second->out()->ToAbstract()->Broaden());
node_list.push_back(out_node);
node_list.push_back(iter->second->RealDout());
} else {
std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(node_list),
[this](const AnfNodePtr &inp) {
if (inp->isa<CNode>()) {
auto inp_iter = anfnode_to_adjoin_.find(inp);
if (inp_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "cannot find inp in adjoint map, inp: " << inp->DebugString();
}
return inp_iter->second->k_node();
} else {
return inp;
}
});
// out;
node_list.push_back(iter->second->k_node());
// dout;
node_list.push_back(iter->second->RealDout());
}
// Back propagate process
auto bprop_fg_output_abs = bprop_fg->output()->abstract();
@ -722,7 +780,7 @@ void KPynativeCellImpl::PropagateStopGradient() {
}
}
FuncGraphPtr KPynativeCellImpl::BuildBpropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode) {
FuncGraphPtr KPynativeCellImpl::BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode) {
auto inputs_num = cnode->size() - 1;
auto func_graph = std::make_shared<FuncGraph>();
@ -870,5 +928,49 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, bool grad_input
}
tape_->set_output(tape_output);
}
bool KPynativeCellImpl::BuildKNode() {
for (auto iter = anfnode_to_adjoin_.cbegin(); iter != anfnode_to_adjoin_.cend(); ++iter) {
if (!iter->first->isa<CNode>()) {
continue;
}
auto cnode = iter->first->cast<CNodePtr>();
AnfNodePtrList node_list;
// Update abstract info of valuenode with its value
for (const auto &inp : cnode->inputs()) {
if (inp->isa<CNode>()) {
auto inp_iter = anfnode_to_adjoin_.find(inp);
if (inp_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "cannot find inp in adjoint map, inp: " << inp->DebugString();
}
node_list.push_back(inp_iter->second->k_node());
} else {
node_list.push_back(inp);
}
}
auto k_node = tape_->NewCNode(node_list);
iter->second->set_k_node(k_node);
}
return true;
}
void KPynativeCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
auto mng = MakeManager({tape_}, false);
auto tr = mng->Transact();
const auto &parameters = tape_->parameters();
auto cell_inputs_size = cell_inputs_.size();
for (size_t i = 0; i < cell_inputs_size; ++i) {
tr.Replace(cell_inputs_[i], parameters[i]);
}
// (Inputs, sens, weights) or (Inputs, weights)
size_t weight_offset = cell_inputs_size;
if (has_sens_arg) {
weight_offset = weight_offset + 1;
}
for (size_t i = 0; i < weights.size(); ++i) {
tr.Replace(weights[i], parameters[weight_offset + i]);
}
tr.Commit();
}
} // namespace ad
} // namespace mindspore

View File

@ -47,7 +47,7 @@ FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
const std::vector<ValuePtr> &input_param_values);
// Return the back propagate funcgraph for this cell.
// Return the back propagate funcgraph for this cell, each cnode in primal funcgraph is replaced by value node.
// weights: weights parameters used in this cell.
// grad_inputs: return sensitivity for input parameters;
// grad_weights: return sensitivity for weights;
@ -61,6 +61,16 @@ KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
bool grad_weights, bool has_sens_arg = false);
// Return the back propagate funcgraph for this cell, each cnode in primal funcgraph is replaced by formal cnode.
// so the return bprop funcgraph can be grad again.
// weights: weights parameters used in this cell.
// grad_inputs: return sensitivity for input parameters;
// grad_weights: return sensitivity for weights;
// return: the returned funcgraph will have prototype:
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
FuncGraphPtr GradPynativeCellBuildFormalBProp(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
bool grad_inputs, bool grad_weights);
// Grad for each operation.
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
// op_args: the arguments list of each input parameters.

View File

@ -98,7 +98,7 @@ class TestKPynative : public UT::Common {
GradPynativeOp(k_pynative_cell, c_node, args, out);
}
}
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false);
auto bprop_fg = GradPynativeCellBuildFormalBProp(k_pynative_cell, AnfNodePtrList{}, true, false);
return bprop_fg;
}
};