forked from mindspore-Ecosystem/mindspore
add API to build bprop funcgraph which can be used for higher grad
This commit is contained in:
parent
c759effa94
commit
4ff19b7082
|
@ -211,6 +211,9 @@ class PynativeAdjoint {
|
|||
dout_ = dout_factor;
|
||||
}
|
||||
|
||||
AnfNodePtr k_node() const { return k_node_; }
|
||||
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
|
||||
|
||||
private:
|
||||
const FuncGraphPtr tape_;
|
||||
AnfNodePtr dout_{nullptr};
|
||||
|
@ -221,6 +224,8 @@ class PynativeAdjoint {
|
|||
const ValuePtr out_;
|
||||
// bprop_fg passed from ad caller, it may be user defined back propagate funcgragh.
|
||||
const FuncGraphPtr bprop_fg_;
|
||||
// k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
||||
AnfNodePtr k_node_;
|
||||
};
|
||||
using PynativeAdjointPtr = std::shared_ptr<PynativeAdjoint>;
|
||||
|
||||
|
@ -243,7 +248,10 @@ class KPynativeCellImpl : public KPynativeCell {
|
|||
bool KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &bprop_fg);
|
||||
void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) override;
|
||||
// Build a back propagate funcgraph with each cnode in primal funcgraph is replaced by value node;
|
||||
FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights, bool has_sens_arg);
|
||||
// Build a back propagate funcgraph with formal cnode;
|
||||
FuncGraphPtr BuildFormalBProp(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
|
||||
|
||||
private:
|
||||
FuncGraphPtr tape_;
|
||||
|
@ -263,13 +271,22 @@ class KPynativeCellImpl : public KPynativeCell {
|
|||
void PropagateStopGradient();
|
||||
bool AllReferencesStopped(const CNodePtr &curr_cnode);
|
||||
// Back propagate for all node;
|
||||
bool BackPropagate();
|
||||
// if by_value is true, in bprop_app cnode, every input is value node;
|
||||
// if by_value is false, in bprop_app cnode, input is the k mapped node, so it can be grad again.
|
||||
bool BackPropagate(bool by_value);
|
||||
bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app);
|
||||
FuncGraphPtr BuildBpropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode);
|
||||
FuncGraphPtr BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode);
|
||||
// Back propagate for MakeList or MakeTuple is generated from MetaFuncGraph.
|
||||
FuncGraphPtr BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode);
|
||||
// Replace input or weights parameter from primal funcgraph to parameters of tape_;
|
||||
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||
// Set return node according to grad flag
|
||||
void SetOutput(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
|
||||
|
||||
// for higher order gradient;
|
||||
// Build k mapped node owned by tape_ for each cnode in primal funcgraph, so these node can be
|
||||
// used in tape_ to keep tracking the cnode dependency.
|
||||
bool BuildKNode();
|
||||
};
|
||||
using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
|
||||
|
||||
|
@ -323,26 +340,11 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
|
|||
}
|
||||
|
||||
// BackPropagate sensitivity;
|
||||
BackPropagate();
|
||||
BackPropagate(true);
|
||||
// Return the gradient;
|
||||
SetOutput(weights, grad_inputs, grad_weights);
|
||||
// Replace AnfNode with parameter of tape_;
|
||||
auto mng = MakeManager({tape_}, false);
|
||||
auto tr = mng->Transact();
|
||||
const auto ¶meters = tape_->parameters();
|
||||
auto cell_inputs_size = cell_inputs_.size();
|
||||
for (size_t i = 0; i < cell_inputs_size; ++i) {
|
||||
tr.Replace(cell_inputs_[i], parameters[i]);
|
||||
}
|
||||
// (Inputs, sens, weights) or (Inputs, weights)
|
||||
size_t weight_offset = cell_inputs_size;
|
||||
if (has_sens_arg) {
|
||||
weight_offset = weight_offset + 1;
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i) {
|
||||
tr.Replace(weights[i], parameters[weight_offset + i]);
|
||||
}
|
||||
tr.Commit();
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
ReplacePrimalParameter(weights, has_sens_arg);
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("before_final_opt.ir", tape_);
|
||||
|
@ -350,6 +352,46 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
|
|||
return tape_;
|
||||
}
|
||||
|
||||
FuncGraphPtr GradPynativeCellBuildFormalBProp(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
|
||||
bool grad_inputs, bool grad_weights) {
|
||||
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
|
||||
return k_cell_impl->BuildFormalBProp(weights, grad_inputs, grad_weights);
|
||||
}
|
||||
|
||||
FuncGraphPtr KPynativeCellImpl::BuildFormalBProp(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) {
|
||||
// propagate stop_gradient flag to cnode before back propagate;
|
||||
PropagateStopGradient();
|
||||
MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
|
||||
auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
|
||||
if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->ToString();
|
||||
}
|
||||
// Build forward CNode;
|
||||
BuildKNode();
|
||||
// Add weights parameter
|
||||
for (const auto &weight : weights) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
|
||||
auto p = tape_->add_parameter();
|
||||
auto input_w = weight->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_w);
|
||||
p->set_default_param(input_w->default_param());
|
||||
}
|
||||
auto sens_node = BuildOnesLikeValue(tape_, last_node_adjoint_iter->second->out());
|
||||
last_node_adjoint_iter->second->AccumulateDout(sens_node);
|
||||
|
||||
// BackPropagate sensitivity;
|
||||
BackPropagate(false);
|
||||
// Return the gradient;
|
||||
SetOutput(weights, grad_inputs, grad_weights);
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
ReplacePrimalParameter(weights, false);
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("formal_bprop_before_final_opt.ir", tape_);
|
||||
}
|
||||
return tape_;
|
||||
}
|
||||
|
||||
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out) {
|
||||
auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
|
||||
|
@ -368,7 +410,7 @@ bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &o
|
|||
|
||||
FuncGraphPtr bprop_fg = nullptr;
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward)) {
|
||||
bprop_fg = BuildBpropCutFuncGraph(prim, cnode);
|
||||
bprop_fg = BuildBPropCutFuncGraph(prim, cnode);
|
||||
} else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
|
||||
bprop_fg = BuildMakeSequenceBprop(prim, cnode);
|
||||
} else {
|
||||
|
@ -638,7 +680,7 @@ bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodeP
|
|||
return true;
|
||||
}
|
||||
|
||||
bool KPynativeCellImpl::BackPropagate() {
|
||||
bool KPynativeCellImpl::BackPropagate(bool by_value) {
|
||||
for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
|
||||
if (!iter->first->isa<CNode>()) {
|
||||
continue;
|
||||
|
@ -653,18 +695,34 @@ bool KPynativeCellImpl::BackPropagate() {
|
|||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
|
||||
AnfNodePtrList node_list{NewValueNode(bprop_fg)};
|
||||
std::transform(iter->second->op_args().begin(), iter->second->op_args().end(), std::back_inserter(node_list),
|
||||
[](const ValuePtr &value) { return NewValueNode(value); });
|
||||
node_list.push_back(NewValueNode(iter->second->out()));
|
||||
node_list.push_back(iter->second->RealDout());
|
||||
// Update abstract info of valuenode with its value
|
||||
for (size_t i = 1; i < node_list.size() - 1; ++i) {
|
||||
auto v_node = node_list[i]->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(v_node);
|
||||
auto value = v_node->value();
|
||||
if (v_node->abstract() == nullptr && value != nullptr && value->ToAbstract() != nullptr) {
|
||||
v_node->set_abstract(value->ToAbstract()->Broaden());
|
||||
}
|
||||
if (by_value) {
|
||||
std::transform(iter->second->op_args().begin(), iter->second->op_args().end(), std::back_inserter(node_list),
|
||||
[](const ValuePtr &value) {
|
||||
auto v_node = NewValueNode(value);
|
||||
v_node->set_abstract(value->ToAbstract()->Broaden());
|
||||
return v_node;
|
||||
});
|
||||
auto out_node = NewValueNode(iter->second->out());
|
||||
out_node->set_abstract(iter->second->out()->ToAbstract()->Broaden());
|
||||
node_list.push_back(out_node);
|
||||
node_list.push_back(iter->second->RealDout());
|
||||
} else {
|
||||
std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(node_list),
|
||||
[this](const AnfNodePtr &inp) {
|
||||
if (inp->isa<CNode>()) {
|
||||
auto inp_iter = anfnode_to_adjoin_.find(inp);
|
||||
if (inp_iter == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(EXCEPTION) << "cannot find inp in adjoint map, inp: " << inp->DebugString();
|
||||
}
|
||||
return inp_iter->second->k_node();
|
||||
} else {
|
||||
return inp;
|
||||
}
|
||||
});
|
||||
// out;
|
||||
node_list.push_back(iter->second->k_node());
|
||||
// dout;
|
||||
node_list.push_back(iter->second->RealDout());
|
||||
}
|
||||
// Back propagate process
|
||||
auto bprop_fg_output_abs = bprop_fg->output()->abstract();
|
||||
|
@ -722,7 +780,7 @@ void KPynativeCellImpl::PropagateStopGradient() {
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr KPynativeCellImpl::BuildBpropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode) {
|
||||
FuncGraphPtr KPynativeCellImpl::BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode) {
|
||||
auto inputs_num = cnode->size() - 1;
|
||||
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
|
@ -870,5 +928,49 @@ void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, bool grad_input
|
|||
}
|
||||
tape_->set_output(tape_output);
|
||||
}
|
||||
|
||||
bool KPynativeCellImpl::BuildKNode() {
|
||||
for (auto iter = anfnode_to_adjoin_.cbegin(); iter != anfnode_to_adjoin_.cend(); ++iter) {
|
||||
if (!iter->first->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = iter->first->cast<CNodePtr>();
|
||||
AnfNodePtrList node_list;
|
||||
// Update abstract info of valuenode with its value
|
||||
for (const auto &inp : cnode->inputs()) {
|
||||
if (inp->isa<CNode>()) {
|
||||
auto inp_iter = anfnode_to_adjoin_.find(inp);
|
||||
if (inp_iter == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(EXCEPTION) << "cannot find inp in adjoint map, inp: " << inp->DebugString();
|
||||
}
|
||||
node_list.push_back(inp_iter->second->k_node());
|
||||
} else {
|
||||
node_list.push_back(inp);
|
||||
}
|
||||
}
|
||||
auto k_node = tape_->NewCNode(node_list);
|
||||
iter->second->set_k_node(k_node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void KPynativeCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||
auto mng = MakeManager({tape_}, false);
|
||||
auto tr = mng->Transact();
|
||||
const auto ¶meters = tape_->parameters();
|
||||
auto cell_inputs_size = cell_inputs_.size();
|
||||
for (size_t i = 0; i < cell_inputs_size; ++i) {
|
||||
tr.Replace(cell_inputs_[i], parameters[i]);
|
||||
}
|
||||
// (Inputs, sens, weights) or (Inputs, weights)
|
||||
size_t weight_offset = cell_inputs_size;
|
||||
if (has_sens_arg) {
|
||||
weight_offset = weight_offset + 1;
|
||||
}
|
||||
for (size_t i = 0; i < weights.size(); ++i) {
|
||||
tr.Replace(weights[i], parameters[weight_offset + i]);
|
||||
}
|
||||
tr.Commit();
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,7 +47,7 @@ FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr
|
|||
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
||||
const std::vector<ValuePtr> &input_param_values);
|
||||
|
||||
// Return the back propagate funcgraph for this cell.
|
||||
// Return the back propagate funcgraph for this cell, each cnode in primal funcgraph is replaced by value node.
|
||||
// weights: weights parameters used in this cell.
|
||||
// grad_inputs: return sensitivity for input parameters;
|
||||
// grad_weights: return sensitivity for weights;
|
||||
|
@ -61,6 +61,16 @@ KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
|||
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
|
||||
bool grad_weights, bool has_sens_arg = false);
|
||||
|
||||
// Return the back propagate funcgraph for this cell, each cnode in primal funcgraph is replaced by formal cnode.
|
||||
// so the return bprop funcgraph can be grad again.
|
||||
// weights: weights parameters used in this cell.
|
||||
// grad_inputs: return sensitivity for input parameters;
|
||||
// grad_weights: return sensitivity for weights;
|
||||
// return: the returned funcgraph will have prototype:
|
||||
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
|
||||
FuncGraphPtr GradPynativeCellBuildFormalBProp(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
|
||||
bool grad_inputs, bool grad_weights);
|
||||
|
||||
// Grad for each operation.
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
// op_args: the arguments list of each input parameters.
|
||||
|
|
|
@ -98,7 +98,7 @@ class TestKPynative : public UT::Common {
|
|||
GradPynativeOp(k_pynative_cell, c_node, args, out);
|
||||
}
|
||||
}
|
||||
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false);
|
||||
auto bprop_fg = GradPynativeCellBuildFormalBProp(k_pynative_cell, AnfNodePtrList{}, true, false);
|
||||
return bprop_fg;
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue