forked from mindspore-Ecosystem/mindspore
Add high order for pynative
Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
parent
c189903d04
commit
9375c56259
|
@ -278,24 +278,30 @@ AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std:
|
|||
}
|
||||
}
|
||||
|
||||
bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
|
||||
MS_LOG(DEBUG) << "Forward cnode: " << cnode->DebugString();
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
MS_LOG(DEBUG) << "Forward cnode: " << grad_param->cnode->DebugString();
|
||||
auto prim = GetCNodePrimitive(grad_param->cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << grad_param->cnode->DebugString();
|
||||
}
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) ||
|
||||
IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
|
||||
MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad";
|
||||
return true;
|
||||
}
|
||||
// anfnode_to_variable_adjoint_ hold out value, to avoid device not release, clear its device_address
|
||||
auto cloned_value = ShallowCopyTensorValue(out);
|
||||
auto cloned_value = ShallowCopyTensorValue(grad_param->out);
|
||||
ClearDeviceAddress(cloned_value);
|
||||
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
|
||||
CNodePtr input_node = ConstructBpropGraphInput(cnode, op_args, out, dout);
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
|
||||
if (!grad_param->grad_by_value) {
|
||||
BuildKNode(grad_param, variable_adjoint);
|
||||
need_do_manager_replace_ = true;
|
||||
}
|
||||
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout);
|
||||
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
|
||||
std::vector<CNodePtr> outputs;
|
||||
#ifndef ENABLE_TEST
|
||||
|
@ -304,7 +310,7 @@ bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op
|
|||
} else {
|
||||
mindspore::BuildBprop(input_node, &outputs, &users_);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "The bprop output should not be empty" << cnode->DebugString();
|
||||
MS_LOG(DEBUG) << "The bprop output should not be empty" << grad_param->cnode->DebugString();
|
||||
BuildCustomBpropCNode(input_node, &outputs);
|
||||
}
|
||||
}
|
||||
|
@ -316,15 +322,12 @@ bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op
|
|||
}
|
||||
#endif
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "the bprop output should not be empty" << cnode->DebugString();
|
||||
MS_LOG(EXCEPTION) << "The bprop output should not be empty" << grad_param->cnode->DebugString();
|
||||
}
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
|
||||
UpdateNextEdges(fn, cnode, outputs, op_args);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(cnode, variable_adjoint));
|
||||
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
// record last_node for brackpropagate
|
||||
last_node_ = cnode;
|
||||
last_node_ = grad_param->cnode;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -350,15 +353,8 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
}
|
||||
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
|
||||
} else {
|
||||
// Set current knode for cnode
|
||||
k_node = BuildKNode(grad_param);
|
||||
const auto it = anfnode_to_variable_adjoint_.find(grad_param->cnode);
|
||||
if (it == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find cnode " << grad_param->cnode->DebugString();
|
||||
}
|
||||
// Get current cnode all inputs knode
|
||||
const auto &k_node_list = BuildKNodeListFromPrimalCNode(grad_param->cnode, it->second);
|
||||
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, k_node_list, grad_param->out, &dout);
|
||||
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
|
||||
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> outputs;
|
||||
|
@ -373,50 +369,10 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
variable_adjoint->set_k_node(k_node);
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
has_fbprop_ = true;
|
||||
need_do_manager_replace_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param) {
|
||||
AnfNodePtrList node_list;
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
|
||||
(void)node_list.emplace_back(BuildKNodeForCNodeInput(grad_param->op_args, grad_param->cnode->input(i), i));
|
||||
}
|
||||
auto k_node = tape_->NewCNode(node_list);
|
||||
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
|
||||
return k_node;
|
||||
}
|
||||
|
||||
AnfNodePtrList AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(adjoint);
|
||||
AnfNodePtrList node_list;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
|
||||
(void)node_list.emplace_back(input_adjoint_iter->second->k_node());
|
||||
}
|
||||
return node_list;
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node,
|
||||
size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<CNode>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
|
||||
}
|
||||
return input_adjoint_iter->second->k_node();
|
||||
} else {
|
||||
return input_node;
|
||||
}
|
||||
}
|
||||
|
||||
CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args,
|
||||
const ValuePtr &out, AnfNodePtr *const tape_dout) {
|
||||
// Wrap tuple_getitem(fprop_app, 1) in a FuncGraph and optimize it;
|
||||
|
@ -464,7 +420,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
|
|||
}
|
||||
|
||||
FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
const GradAttr &grad_attr, bool build_formal_param) {
|
||||
const GradAttr &grad_attr) {
|
||||
// Set sens node and weights node
|
||||
SetSensAndWeights(weights, grad_attr.has_sens);
|
||||
|
||||
|
@ -483,41 +439,84 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
|
|||
return tape_;
|
||||
}
|
||||
|
||||
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const AnfNodePtr &dout) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
MS_EXCEPTION_IF_NULL(dout);
|
||||
|
||||
if (cnode->size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "cnode do not have inputs";
|
||||
}
|
||||
std::vector<AnfNodePtr> node_lists;
|
||||
(void)node_lists.emplace_back(cnode->input(0));
|
||||
for (size_t i = 0; i < op_args.size(); ++i) {
|
||||
auto v = op_args[i];
|
||||
auto node = cnode->input(i + 1);
|
||||
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
std::vector<AnfNodePtr> node_list;
|
||||
(void)node_list.emplace_back(grad_param->cnode->input(0));
|
||||
if (grad_param->grad_by_value) {
|
||||
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
||||
const auto &v = grad_param->op_args[i];
|
||||
auto node = grad_param->cnode->input(i + 1);
|
||||
if (node->isa<Parameter>()) {
|
||||
node_lists.emplace_back(node);
|
||||
node_list.emplace_back(node);
|
||||
node->set_abstract(v->ToAbstract());
|
||||
continue;
|
||||
}
|
||||
auto v_node = NewValueNode(op_args[i]);
|
||||
v_node->set_abstract(op_args[i]->ToAbstract());
|
||||
node_lists.emplace_back(v_node);
|
||||
auto v_node = NewValueNode(grad_param->op_args[i]);
|
||||
v_node->set_abstract(grad_param->op_args[i]->ToAbstract());
|
||||
node_list.emplace_back(v_node);
|
||||
}
|
||||
auto out_node = NewValueNode(out);
|
||||
out_node->set_abstract(out->ToAbstract());
|
||||
node_lists.emplace_back(out_node);
|
||||
node_lists.emplace_back(dout);
|
||||
CNodePtr input_node = tape_->NewCNode(node_lists);
|
||||
input_node->set_abstract(out->ToAbstract()->Broaden());
|
||||
} else {
|
||||
// Input is a Parameter or cnode, not a value node
|
||||
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list);
|
||||
}
|
||||
auto out_node = NewValueNode(grad_param->out);
|
||||
auto out_abs = grad_param->out->ToAbstract()->Broaden();
|
||||
out_node->set_abstract(out_abs);
|
||||
// set out
|
||||
node_list.emplace_back(out_node);
|
||||
// set dout
|
||||
node_list.emplace_back(dout);
|
||||
auto input_node = tape_->NewCNode(node_list);
|
||||
input_node->set_abstract(out_abs);
|
||||
return input_node;
|
||||
}
|
||||
|
||||
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out) {
|
||||
return k_cell->KPynativeOp(cnode, op_args, out);
|
||||
void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
std::vector<AnfNodePtr> *const node_list) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
MS_LOG(DEBUG) << "Find input knode of node " << cnode->input(i)->DebugString();
|
||||
if (cnode->input(i)->isa<CNode>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
|
||||
(void)node_list->emplace_back(input_adjoint_iter->second->k_node());
|
||||
} else {
|
||||
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract());
|
||||
(void)node_list->emplace_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
AnfNodePtrList node_list;
|
||||
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
|
||||
(void)node_list.emplace_back(BuildKNodeForCNodeInput(grad_param->cnode->input(i)));
|
||||
}
|
||||
auto k_node = tape_->NewCNode(node_list);
|
||||
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
|
||||
VariableNode->set_k_node(k_node);
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<CNode>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
|
||||
}
|
||||
return input_adjoint_iter->second->k_node();
|
||||
} else {
|
||||
return input_node;
|
||||
}
|
||||
}
|
||||
|
||||
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param) {
|
||||
return k_cell->KPynativeOp(grad_param);
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode,
|
||||
|
@ -992,7 +991,7 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new
|
|||
return;
|
||||
}
|
||||
auto &old_node_users = users_[old_node];
|
||||
for (auto pair_node : old_node_users) {
|
||||
for (const auto &pair_node : old_node_users) {
|
||||
auto cnode = pair_node.first;
|
||||
size_t index = pair_node.second;
|
||||
if (index >= cnode->size()) {
|
||||
|
@ -1034,7 +1033,7 @@ void AutoGradCellImpl::ClearDeviceAddress(const ValuePtr &out) {
|
|||
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||
const auto ¶meters = tape_->parameters();
|
||||
auto cell_inputs_size = cell_inputs_.size();
|
||||
if (has_fbprop_) {
|
||||
if (need_do_manager_replace_) {
|
||||
MS_LOG(DEBUG) << "Do parameter replace by manager";
|
||||
auto mng = MakeManager({tape_}, false);
|
||||
auto tr = mng->Transact();
|
||||
|
@ -1051,7 +1050,7 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo
|
|||
(void)tr.Replace(weights[i], parameters[weight_offset + i]);
|
||||
}
|
||||
tr.Commit();
|
||||
has_fbprop_ = false;
|
||||
need_do_manager_replace_ = false;
|
||||
} else {
|
||||
for (size_t i = 0; i < cell_inputs_size; ++i) {
|
||||
Replace(cell_inputs_[i], parameters[i]);
|
||||
|
@ -1093,9 +1092,8 @@ AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
|||
}
|
||||
|
||||
FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &auto_grad_cell, const AnfNodePtrList &weights,
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
|
||||
bool build_formal_param) {
|
||||
return auto_grad_cell->Finish(weights, grad_position, grad_attr, build_formal_param);
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr) {
|
||||
return auto_grad_cell->Finish(weights, grad_position, grad_attr);
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,8 +42,9 @@ struct GradAttr {
|
|||
};
|
||||
|
||||
struct GradParam {
|
||||
GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg = nullptr)
|
||||
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {}
|
||||
GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg,
|
||||
bool grad_by_value)
|
||||
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)), grad_by_value(grad_by_value) {}
|
||||
|
||||
// Primal CNode create by op forward process
|
||||
const CNodePtr cnode;
|
||||
|
@ -111,7 +112,7 @@ class AutoGradCellImpl {
|
|||
AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values);
|
||||
~AutoGradCellImpl() = default;
|
||||
// Reverse connect bprop of op
|
||||
bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
|
||||
bool KPynativeOp(const GradParamPtr &grad_param);
|
||||
// Reverse connect ms_function or higher order sub bprop funcgraph
|
||||
bool KPynativeWithFProp(const GradParamPtr &grad_param);
|
||||
CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args, const ValuePtr &out,
|
||||
|
@ -121,7 +122,7 @@ class AutoGradCellImpl {
|
|||
// Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
|
||||
// can be grad again.
|
||||
FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
const GradAttr &grad_attr, bool build_formal_param);
|
||||
const GradAttr &grad_attr);
|
||||
|
||||
private:
|
||||
// Last cnode of this Cell, may be a primitive op or cell with user defined bprop.
|
||||
|
@ -139,14 +140,13 @@ class AutoGradCellImpl {
|
|||
// Record cnode's input map for tape_
|
||||
UserType users_;
|
||||
// Flag for ms_funtcion and high order
|
||||
bool has_fbprop_{false};
|
||||
bool need_do_manager_replace_{false};
|
||||
|
||||
bool IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const;
|
||||
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
|
||||
|
||||
// construct input as cnode for expander
|
||||
CNodePtr ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const AnfNodePtr &dout);
|
||||
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout);
|
||||
// Back propagate for one node;
|
||||
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
|
||||
const ValuePtrList &op_args);
|
||||
|
@ -182,9 +182,10 @@ class AutoGradCellImpl {
|
|||
void ClearDeviceAddress(const ValuePtr &out);
|
||||
|
||||
// Fbprop
|
||||
AnfNodePtr BuildKNode(const GradParamPtr &grad_param);
|
||||
AnfNodePtrList BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint);
|
||||
AnfNodePtr BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node, size_t input_index);
|
||||
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);
|
||||
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
std::vector<AnfNodePtr> *const node_list);
|
||||
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);
|
||||
};
|
||||
using AutoGradCellImplPtr = std::shared_ptr<AutoGradCellImpl>;
|
||||
|
||||
|
@ -209,15 +210,13 @@ AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
|||
// else:
|
||||
// each cnode in primal funcgraph is replaced by value node
|
||||
FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &k_cell, const AnfNodePtrList &weights,
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
|
||||
bool build_formal_param = false);
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr);
|
||||
|
||||
// Grad for each operation.
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
// op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param);
|
||||
|
||||
// adjoint bprop form ms_function and high grad
|
||||
void GradPynativeFBprop(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
|
|
|
@ -33,14 +33,17 @@
|
|||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
||||
static const std::map<std::string, std::vector<std::string>> op2attrs = {
|
||||
{prim::kPrimBroadcastTo->name(), {kAttrShape}},
|
||||
namespace {
|
||||
const std::map<std::string, std::vector<std::string>> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}},
|
||||
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
|
||||
}
|
||||
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto todos = TopoSort(graph->get_return());
|
||||
auto mng = Manage({graph}, false);
|
||||
for (const auto &node : todos) {
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
|
||||
continue;
|
||||
|
@ -66,7 +69,8 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
|||
AnfNodePtrList inputs = {NewValueNode(new_prim)};
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
cnode->set_inputs(inputs);
|
||||
auto new_cnode = graph->NewCNodeInOrder(inputs);
|
||||
(void)mng->Replace(node, new_cnode);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ namespace opt {
|
|||
* Try Expand cnode for front end graph.
|
||||
*/
|
||||
AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node);
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H
|
||||
|
|
|
@ -33,9 +33,11 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
|
|||
// almost all ops are defined in two main paths
|
||||
const std::string ops_module = OP_PATH;
|
||||
const std::string inner_ops_module = INNER_OP_PATH;
|
||||
const std::string grad_ops_module = GRAD_OP_PATH;
|
||||
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
|
||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||
py::module grad_mod = py::module::import(common::SafeCStr(grad_ops_module));
|
||||
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
|
||||
|
||||
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
|
@ -44,9 +46,12 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
|
|||
if (py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
return ops_module;
|
||||
}
|
||||
if (py::hasattr(grad_mod, common::SafeCStr(op_name))) {
|
||||
return grad_ops_module;
|
||||
}
|
||||
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module
|
||||
<< " don't have op:" << op_name;
|
||||
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << grad_ops_module << " and "
|
||||
<< functional_op_module << " don't have op:" << op_name;
|
||||
}
|
||||
return functional_op_module;
|
||||
}
|
||||
|
|
|
@ -126,6 +126,7 @@ constexpr char REDUCE_OP_ALL[] = "prod";
|
|||
constexpr char REDUCE_OP_PROD[] = "prod";
|
||||
constexpr char OP_PATH[] = "mindspore.ops.operations";
|
||||
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
|
||||
constexpr char GRAD_OP_PATH[] = "mindspore.ops.operations._grad_ops";
|
||||
constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional";
|
||||
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
|
||||
constexpr char GET_OP_FUNCTION[] = "_get_python_op";
|
||||
|
|
|
@ -200,6 +200,16 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
|
|||
irpass.depend_value_elim_,
|
||||
});
|
||||
OptPassGroupMap map({{"ad_final_opt", bg_final_opt}});
|
||||
if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
|
||||
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
|
||||
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
|
||||
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
|
||||
opt::OptPassConfig environ_eliminate = opt::OptPassConfig({
|
||||
irpass.incorporate_call_,
|
||||
irpass.incorporate_call_switch_,
|
||||
});
|
||||
(void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate));
|
||||
}
|
||||
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", resource, map);
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto func_graph = resource->func_graph();
|
||||
|
|
|
@ -22,10 +22,10 @@
|
|||
#include "ir/cell.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "frontend/optimizer/ad/grad.h"
|
||||
#include "frontend/optimizer/expander.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -136,10 +136,159 @@ ValuePtr ConvertOutputValueToTensor(const ValuePtr &v) {
|
|||
int64_t input = v->cast<Int64ImmPtr>()->value();
|
||||
return std::make_shared<tensor::Tensor>(input, kInt64);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Output is " << v->ToString() << ", abstract " << v->ToAbstract()->Broaden();
|
||||
MS_LOG(DEBUG) << "Output is " << v->ToString() << ", abstract " << v->ToAbstract()->Broaden();
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsAbsDifferent(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) {
|
||||
if (old_abs == new_abs) {
|
||||
return false;
|
||||
}
|
||||
if (old_abs == nullptr || new_abs == nullptr) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs";
|
||||
return true;
|
||||
}
|
||||
if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) ||
|
||||
!common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs, old abs: " << old_abs->ToString()
|
||||
<< " new abs: " << new_abs->ToString();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) {
|
||||
if (v1 == v2) {
|
||||
return true;
|
||||
}
|
||||
if (v1 == nullptr || v2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (v1->isa<tensor::Tensor>() && v2->isa<tensor::Tensor>()) {
|
||||
return v1->cast<tensor::TensorPtr>()->ValueEqual(*(v2->cast<tensor::TensorPtr>()));
|
||||
}
|
||||
return *v1 == *v2;
|
||||
}
|
||||
|
||||
bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) {
|
||||
if (p1 == p2) {
|
||||
return true;
|
||||
}
|
||||
if (p1 == nullptr || p2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return p1->key() == p2->key();
|
||||
}
|
||||
|
||||
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const std::vector<AnfNodePtr> &new_anf_inputs,
|
||||
const TopCellInfoPtr &top_cell) {
|
||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
|
||||
old_node_info->input_param_infos.size();
|
||||
if (old_input_size != new_anf_inputs.size() - 1) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old input size: " << old_input_size
|
||||
<< " new input_infos: " << (new_anf_inputs.size() - 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < new_anf_inputs.size(); i++) {
|
||||
const auto &new_anf_input = new_anf_inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(new_anf_input);
|
||||
if (new_anf_input->isa<ValueNode>()) {
|
||||
const auto &value_iter = old_node_info->input_values.find(i);
|
||||
if (value_iter == old_node_info->input_values.end()) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, value is different.";
|
||||
return true;
|
||||
}
|
||||
} else if (new_anf_input->isa<CNode>()) {
|
||||
// Compare cnode abstract.
|
||||
const auto &node_iter = old_node_info->input_cnode_info.find(i);
|
||||
if (node_iter == old_node_info->input_cnode_info.end()) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode.";
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t old_op_index = 0;
|
||||
AbstractBasePtr old_abs = nullptr;
|
||||
std::tie(old_op_index, old_abs) = node_iter->second;
|
||||
if (IsAbsDifferent(old_abs, new_anf_input->abstract())) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, abs is different.";
|
||||
return true;
|
||||
}
|
||||
|
||||
// Compare cnode edge.
|
||||
MS_EXCEPTION_IF_NULL(top_cell);
|
||||
if (old_op_index != top_cell->get_op_index_by_cnode_hash(new_anf_input->hash())) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index: " << old_op_index
|
||||
<< " new op_index: " << top_cell->get_op_index_by_cnode_hash(new_anf_input->hash());
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// Compare parameter.
|
||||
if (!new_anf_input->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "new_anf_input: " << new_anf_input->fullname_with_scope()
|
||||
<< " is none of value node, cnode and parameter.";
|
||||
}
|
||||
|
||||
const auto &node_iter = old_node_info->input_param_infos.find(i);
|
||||
if (node_iter == old_node_info->input_param_infos.end()) {
|
||||
MS_LOG(DEBUG) << "The " << i
|
||||
<< "th input is different, cur input is a parameter, old input is not a parameter.";
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto ¶m = new_anf_input->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (!IsParamInfoEqual(node_iter->second, param->param_info())) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, param info is different.";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
|
||||
bool is_ms_function_node, const std::string &graph_phase,
|
||||
const TopCellInfoPtr &top_cell) {
|
||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||
// 1.Detect ms_function phase
|
||||
if (is_ms_function_node) {
|
||||
if (!old_node_info->is_graph_node || graph_phase != old_node_info->graph_phase) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node
|
||||
<< " new is_graph_node: " << is_ms_function_node << " old graph_phase "
|
||||
<< old_node_info->graph_phase << " new graph_phase: " << graph_phase;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2.Detect cnode prim
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
auto new_prim = GetCNodePrimitive(new_cnode);
|
||||
if (!common::IsEqual(new_prim, old_node_info->prim)) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, old prim: "
|
||||
<< (old_node_info->prim == nullptr ? "nullptr" : old_node_info->prim->name())
|
||||
<< " new prim: " << (new_prim == nullptr ? "nullptr" : new_prim->name());
|
||||
return true;
|
||||
}
|
||||
|
||||
// 3.Detect output abs
|
||||
if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, output_abs is different";
|
||||
return true;
|
||||
}
|
||||
|
||||
// 4.Detect inputs
|
||||
return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs(), top_cell);
|
||||
}
|
||||
|
||||
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph) {
|
||||
auto resource = std::make_shared<pipeline::Resource>();
|
||||
resource->set_func_graph(bprop_graph);
|
||||
|
@ -243,10 +392,10 @@ void GradExecutor::HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_
|
|||
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||
if (is_bprop_top) {
|
||||
// Convert input args to parameters for top cell graph in bprop.
|
||||
for (size_t i = 0; i < input_args_info->input_arg_id_vec.size(); ++i) {
|
||||
for (auto &id : input_args_info->input_arg_id_vec) {
|
||||
auto new_param = curr_g()->add_parameter();
|
||||
MS_LOG(DEBUG) << "Top bprop graph set input parameter " << input_args_info->input_arg_id_vec[i];
|
||||
top_cell()->SetParamNodeMapInGraphInfoMap(input_args_info->input_arg_id_vec[i], new_param);
|
||||
MS_LOG(DEBUG) << "Top bprop graph set input parameter " << id;
|
||||
top_cell()->SetParamNodeMapInGraphInfoMap(id, new_param);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -369,8 +518,9 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
|
|||
fg->debug_info()->set_name("pynative_forward_graph");
|
||||
auto resource = std::make_shared<pipeline::Resource>();
|
||||
const auto &already_run_cell_id = GetAlreadyRunCellId(input_args_info->cell_id);
|
||||
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->obj_id,
|
||||
input_args_info->cell_id, already_run_cell_id, resource, fg);
|
||||
top_cell_ =
|
||||
std::make_shared<TopCellInfo>(input_args_info->is_high_order_top_cell, input_args_info->grad_order,
|
||||
input_args_info->obj_id, input_args_info->cell_id, already_run_cell_id, resource, fg);
|
||||
top_cell_->set_forward_already_run(true);
|
||||
top_cell_->set_is_run_cell(input_args_info->is_run_cell);
|
||||
top_cell_->set_input_args_id(input_args_info->input_args_id);
|
||||
|
@ -402,8 +552,7 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
|
|||
// Set last output abstract and will be used for sens
|
||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
auto sens_v = ConvertOutputValueToTensor(v);
|
||||
auto cloned_value = ShallowCopyTensorValue(sens_v);
|
||||
auto cloned_value = ShallowCopyTensorValue(v);
|
||||
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||
AsyncUpdateOutputNodeOfTopCell(output_node, cloned_value);
|
||||
} else {
|
||||
|
@ -435,9 +584,11 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
const auto &cell_id = input_args_info->cell_id;
|
||||
MS_LOG(DEBUG) << "EndGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
|
||||
<< ", input args info ptr " << input_args_info.get();
|
||||
const auto &out_value = input_args_info->out_value;
|
||||
MS_EXCEPTION_IF_NULL(out_value);
|
||||
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||
bool is_top_cell_end = (cell_id == top_cell()->cell_id());
|
||||
if (is_top_cell_end) {
|
||||
input_args_info->out_value = ConvertOutputValueToTensor(input_args_info->out_value);
|
||||
}
|
||||
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(input_args_info->out_value);
|
||||
DoGradForCustomBprop(input_args_info, out_id);
|
||||
// Update bprop grad stack
|
||||
if (grad_is_running_ && !bprop_grad_stack_.empty()) {
|
||||
|
@ -450,16 +601,15 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
}
|
||||
}
|
||||
// Just only dump the last forward graph
|
||||
bool is_top_cell_end = cell_id == top_cell()->cell_id();
|
||||
if (is_top_cell_end && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
curr_g()->set_output(GetInput(out_value, out_id));
|
||||
curr_g()->set_output(GetInput(input_args_info->out_value, out_id));
|
||||
PyNativeAlgo::Common::DumpGraphIR("fg.ir", curr_g());
|
||||
}
|
||||
// Reset grad flag and update output node of the outermost cell
|
||||
if (input_args_info->is_grad_topest_cell && is_top_cell_end) {
|
||||
MS_LOG(DEBUG) << "Cur top last cell " << cell_id;
|
||||
(void)PopHighOrderGraphStack();
|
||||
SetForwardLastNodeInfo(out_value, out_id);
|
||||
SetForwardLastNodeInfo(input_args_info->out_value, out_id);
|
||||
top_cell()->ClearCellHookOp();
|
||||
cell_order_ = 0;
|
||||
// set_grad_flag(false);
|
||||
|
@ -468,7 +618,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
if (is_top_cell_end) {
|
||||
// In high grad cases, the output of the internal graph may be a tuple, and node needs to be created in the getobj
|
||||
if (!input_args_info->is_grad_topest_cell) {
|
||||
SetForwardLastNodeInfo(out_value, out_id);
|
||||
SetForwardLastNodeInfo(input_args_info->out_value, out_id);
|
||||
}
|
||||
top_cell()->CheckSubCellHookChanged();
|
||||
CheckNeedCompileGraph(input_args_info);
|
||||
|
@ -476,7 +626,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
}
|
||||
}
|
||||
|
||||
void GradExecutor::AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info) {
|
||||
void GradExecutor::AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||
const auto fn = [this, input_args_info]() { this->EndGraphImpl(input_args_info); };
|
||||
auto task = std::make_shared<BpropTask>(fn);
|
||||
async_executor_->Push(task);
|
||||
|
@ -618,13 +768,12 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
|||
py::gil_scoped_release gil_release;
|
||||
async_executor_->Wait();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||
MS_LOG(DEBUG) << "GradNetInner start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
|
||||
<< ", input args info ptr " << top_input_args_info_.get();
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
if (grad->sens_param()) {
|
||||
MS_LOG(DEBUG) << "Get sens param";
|
||||
top_input_args_info_->has_sens = true;
|
||||
size_t forward_args_size = args.size() - 1;
|
||||
auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
|
||||
const auto &sens_tensor = ConvertOutputValueToTensor(sens_v);
|
||||
|
@ -632,17 +781,16 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
|||
if (top_input_args_info_->input_arg_value_vec.size() == args.size()) {
|
||||
top_input_args_info_->input_arg_value_vec.pop_back();
|
||||
}
|
||||
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_v));
|
||||
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_tensor));
|
||||
top_input_args_info_->has_sens = true;
|
||||
}
|
||||
// For async, top can not be change when run SetForwardLastNodeInfo
|
||||
if (pre_top_cell_ != nullptr) {
|
||||
set_top_cell(pre_top_cell_);
|
||||
}
|
||||
|
||||
if (!top_cell()->need_compile_graph()) {
|
||||
MS_LOG(DEBUG) << "No need compile graph";
|
||||
top_cell_list_.pop_back();
|
||||
|
||||
UpdateTopCellInfo(false, false);
|
||||
return;
|
||||
}
|
||||
|
@ -671,7 +819,7 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
|
|||
auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||
bool use_dynamic_shape_process = (forward()->device_target() == kAscendDevice ? false : use_dynamic_shape_process_);
|
||||
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_;
|
||||
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
|
||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
||||
|
@ -776,12 +924,14 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
|
|||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
std::vector<ValuePtr> tensor_args;
|
||||
size_t input_size = has_sens ? input_args.size() - 1 : input_args.size();
|
||||
// Sens may be a value tuple not a single tensor
|
||||
// Sens may be a value tuple not a single tensor; bprop gradph have only one ses params, so tuple sens can not be
|
||||
// flatten
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
if (PyNativeAlgo::Common::IsTensor(input_args[i])) {
|
||||
(void)tensor_args.emplace_back(input_args[i]);
|
||||
}
|
||||
}
|
||||
// No flatten
|
||||
if (has_sens) {
|
||||
(void)tensor_args.emplace_back(input_args[input_size]);
|
||||
}
|
||||
|
@ -819,15 +969,9 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
|
|||
FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const vector<AnfNodePtr> &w_args,
|
||||
const vector<size_t> &p_args) {
|
||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||
bool build_formal_param = false;
|
||||
if (!top_input_args_info_->has_custom_bprop && !top_input_args_info_->is_grad_topest_cell && IsNestedGrad()) {
|
||||
build_formal_param = true;
|
||||
need_renormalize_ = true;
|
||||
}
|
||||
|
||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr, build_formal_param);
|
||||
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr);
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
|
||||
MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
|
||||
|
@ -836,7 +980,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
|
|||
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop_graph->debug_info()->set_name(ss.str());
|
||||
UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens);
|
||||
if (top_cell()->ms_function_flag()) {
|
||||
if (top_cell()->need_do_final_opt()) {
|
||||
bprop_graph = BpropGraphFinalOpt(bprop_graph);
|
||||
}
|
||||
if (top_input_args_info_->is_grad_topest_cell) {
|
||||
|
@ -937,14 +1081,15 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
|||
MS_LOG(DEBUG) << "Bprop nested";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(first_grad_fg);
|
||||
PyNativeAlgo::Common::DumpGraphIR("first_grad_fg.ir", first_grad_fg);
|
||||
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
|
||||
ValuePtrList weights_args;
|
||||
DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
|
||||
|
||||
pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
|
||||
r->manager()->AddFuncGraph(first_grad_fg);
|
||||
if (!opt::ConvertPrimToPrimPy(first_grad_fg)) {
|
||||
MS_LOG(EXCEPTION) << "Convert PrimitiveC to PrimitivePy failed";
|
||||
}
|
||||
|
||||
auto r = std::make_shared<pipeline::Resource>();
|
||||
set_eliminate_forward(false);
|
||||
(void)first_grad_fg->transforms().erase(kGrad);
|
||||
// Do high order
|
||||
|
@ -968,10 +1113,12 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
|||
std::vector<ValuePtr> out_v{out_value};
|
||||
out_value = std::make_shared<ValueTuple>(out_v);
|
||||
}
|
||||
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg);
|
||||
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg,
|
||||
!top_cell()->is_high_order_top_cell());
|
||||
if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
|
||||
}
|
||||
top_cell()->set_need_do_final_opt(true);
|
||||
need_renormalize_ = true;
|
||||
}
|
||||
|
||||
|
@ -986,8 +1133,8 @@ void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const s
|
|||
|
||||
// Replace inputs param
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
for (size_t i = 0; i < forward_args.size(); ++i) {
|
||||
const auto &id = PyNativeAlgo::Common::GetIdByValue(forward_args[i]);
|
||||
for (const auto &forward_arg : forward_args) {
|
||||
const auto &id = PyNativeAlgo::Common::GetIdByValue(forward_arg);
|
||||
const auto it = outer_graph_info->input_params.find(id);
|
||||
if (it != outer_graph_info->input_params.end()) {
|
||||
// Can find in outer graph
|
||||
|
@ -997,7 +1144,7 @@ void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const s
|
|||
} else {
|
||||
MS_LOG(DEBUG) << "Can't find input param id " << id;
|
||||
// Inner graph input param not find in outer graph, need add to outer graph
|
||||
(void)inputs->emplace_back(GetInput(forward_args[i], id));
|
||||
(void)inputs->emplace_back(GetInput(forward_arg, id));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1055,12 +1202,10 @@ void GradExecutor::ClearGradRes() {
|
|||
if (top_cell_ != nullptr) {
|
||||
top_cell_->ClearDeviceMemory();
|
||||
}
|
||||
|
||||
if (use_dynamic_shape_process_ ||
|
||||
already_run_top_cell_.find(top_cell_->already_run_cell_id()) != already_run_top_cell_.end()) {
|
||||
top_cell_ = nullptr;
|
||||
}
|
||||
|
||||
DecreaseGradOrder();
|
||||
ClearGlobalRes();
|
||||
}
|
||||
|
@ -1320,29 +1465,26 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
|
|||
std::back_inserter(cloned_op_args),
|
||||
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
|
||||
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
|
||||
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name;
|
||||
}
|
||||
auto grad_param =
|
||||
std::make_shared<ad::GradParam>(cnode, cloned_op_args, cloned_out, nullptr, !top_cell()->is_high_order_top_cell());
|
||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||
AsyncGradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
|
||||
AsyncGradPynativeOp(auto_grad_cell_ptr, grad_param);
|
||||
} else {
|
||||
GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
|
||||
GradPynativeOp(auto_grad_cell_ptr, grad_param);
|
||||
}
|
||||
}
|
||||
|
||||
void GradExecutor::GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const {
|
||||
if (!ad::GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out)) {
|
||||
void GradExecutor::GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
|
||||
const ad::GradParamPtr &grad_param) const {
|
||||
if (!ad::GradPynativeOp(auto_grad_cell_ptr, grad_param)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to run ad grad for op ";
|
||||
}
|
||||
}
|
||||
|
||||
void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const {
|
||||
const auto fn = [this, auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out]() {
|
||||
this->GradPynativeOp(auto_grad_cell_ptr, cnode, cloned_op_args, cloned_out);
|
||||
};
|
||||
void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
|
||||
const ad::GradParamPtr &grad_param) const {
|
||||
const auto fn = [this, auto_grad_cell_ptr, grad_param]() { this->GradPynativeOp(auto_grad_cell_ptr, grad_param); };
|
||||
auto task = std::make_shared<BpropTask>(fn);
|
||||
async_executor_->Push(task);
|
||||
}
|
||||
|
@ -1608,7 +1750,7 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c
|
|||
node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs);
|
||||
} else {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "input_node:" << input_node->fullname_with_scope()
|
||||
MS_LOG(EXCEPTION) << "input_node: " << input_node->fullname_with_scope()
|
||||
<< " is none of value node, cnode and parameter.";
|
||||
}
|
||||
const auto ¶m = input_node->cast<ParameterPtr>();
|
||||
|
@ -1626,157 +1768,6 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c
|
|||
(void)cell_id_with_dynamic_detect_nodes_[cell_id].emplace_back(node_info);
|
||||
}
|
||||
|
||||
bool IsAbsDifferent(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) {
|
||||
if (old_abs == new_abs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (old_abs == nullptr || new_abs == nullptr) {
|
||||
MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) ||
|
||||
!common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) {
|
||||
MS_LOG(DEBUG) << "graph is dynamic, old_abs is different with new_abs, old abs:" << old_abs->ToString()
|
||||
<< " new abs:" << new_abs->ToString();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) {
|
||||
if (v1 == v2) {
|
||||
return true;
|
||||
}
|
||||
if (v1 == nullptr || v2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (v1->isa<tensor::Tensor>() && v2->isa<tensor::Tensor>()) {
|
||||
return v1->cast<tensor::TensorPtr>()->ValueEqual(*(v2->cast<tensor::TensorPtr>()));
|
||||
}
|
||||
return *v1 == *v2;
|
||||
}
|
||||
|
||||
bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) {
|
||||
if (p1 == p2) {
|
||||
return true;
|
||||
}
|
||||
if (p1 == nullptr || p2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return p1->key() == p2->key();
|
||||
}
|
||||
|
||||
bool GradExecutor::IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
|
||||
const std::vector<AnfNodePtr> &new_anf_inputs) const {
|
||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||
|
||||
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
|
||||
old_node_info->input_param_infos.size();
|
||||
if (old_input_size != new_anf_inputs.size() - 1) {
|
||||
MS_LOG(DEBUG) << "graph is dynamic, old input size:" << old_input_size
|
||||
<< " new input_infos:" << (new_anf_inputs.size() - 1);
|
||||
return true;
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < new_anf_inputs.size(); i++) {
|
||||
const auto &new_anf_input = new_anf_inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(new_anf_input);
|
||||
if (new_anf_input->isa<ValueNode>()) {
|
||||
const auto &value_iter = old_node_info->input_values.find(i);
|
||||
if (value_iter == old_node_info->input_values.end()) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a value, old input is not a value.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!IsValuePtrEqual(value_iter->second, GetValueNode(new_anf_input))) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, value is different.";
|
||||
return true;
|
||||
}
|
||||
} else if (new_anf_input->isa<CNode>()) {
|
||||
// Compare cnode abstract.
|
||||
const auto &node_iter = old_node_info->input_cnode_info.find(i);
|
||||
if (node_iter == old_node_info->input_cnode_info.end()) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input is different, cur input is a cnode, old input is not a cnode.";
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t old_op_index = 0;
|
||||
AbstractBasePtr old_abs = nullptr;
|
||||
std::tie(old_op_index, old_abs) = node_iter->second;
|
||||
if (IsAbsDifferent(old_abs, new_anf_input->abstract())) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, abs is different.";
|
||||
return true;
|
||||
}
|
||||
|
||||
// Compare cnode edge.
|
||||
if (old_op_index != top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash())) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index:" << old_op_index
|
||||
<< " new op_index:" << top_cell()->get_op_index_by_cnode_hash(new_anf_input->hash());
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
// Compare parameter.
|
||||
if (!new_anf_input->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "new_anf_input:" << new_anf_input->fullname_with_scope()
|
||||
<< " is none of value node, cnode and parameter.";
|
||||
}
|
||||
|
||||
const auto &node_iter = old_node_info->input_param_infos.find(i);
|
||||
if (node_iter == old_node_info->input_param_infos.end()) {
|
||||
MS_LOG(DEBUG) << "The " << i
|
||||
<< "th input is different, cur input is a parameter, old input is not a parameter.";
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto ¶m = new_anf_input->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
if (!IsParamInfoEqual(node_iter->second, param->param_info())) {
|
||||
MS_LOG(DEBUG) << "The " << i << "th input, param info is different.";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GradExecutor::IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info,
|
||||
const CNodePtr &new_cnode, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const {
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||
|
||||
// 1.Detect ms_function phase
|
||||
if (is_ms_function_node != old_node_info->is_graph_node ||
|
||||
(is_ms_function_node && graph_phase != old_node_info->graph_phase)) {
|
||||
MS_LOG(DEBUG) << "graph is dynamic, old is_graph_node:" << old_node_info->is_graph_node
|
||||
<< " new is_graph_node:" << is_ms_function_node << " old graph_phase" << old_node_info->graph_phase
|
||||
<< " new graph_phase:" << graph_phase;
|
||||
return true;
|
||||
}
|
||||
|
||||
// 2.Detect cnode prim
|
||||
auto new_prim = GetCNodePrimitive(new_cnode);
|
||||
if (!common::IsEqual(new_prim, old_node_info->prim)) {
|
||||
MS_LOG(DEBUG) << "graph is dynamic, old prim:" << (old_node_info->prim == nullptr ? 0 : old_node_info->prim->name())
|
||||
<< " new prim:" << (new_prim == nullptr ? 0 : new_prim->name());
|
||||
return true;
|
||||
}
|
||||
|
||||
// 3.Detect output abs
|
||||
if (IsAbsDifferent(old_node_info->output_abs, new_cnode->abstract())) {
|
||||
MS_LOG(DEBUG) << "graph is dynamic, output_abs is different";
|
||||
return true;
|
||||
}
|
||||
|
||||
// 4.Detect inputs
|
||||
return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs());
|
||||
}
|
||||
|
||||
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -1789,18 +1780,17 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx,
|
|||
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
|
||||
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[cell_id];
|
||||
if (node_idx >= dynamic_nodes.size()) {
|
||||
MS_LOG(DEBUG) << "old dynamic_nodes size:" << dynamic_nodes.size() << " cur node_idx is:" << node_idx
|
||||
MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << " cur node_idx is: " << node_idx
|
||||
<< ", graph is dynamic.";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase)) {
|
||||
MS_LOG(DEBUG) << "graph is dynamic, node_idx:" << node_idx
|
||||
<< " is different, cnode:" << cnode->fullname_with_scope();
|
||||
if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase, top_cell())) {
|
||||
MS_LOG(DEBUG) << "Graph is dynamic, node_idx: " << node_idx
|
||||
<< " is different, cnode: " << cnode->fullname_with_scope();
|
||||
return true;
|
||||
}
|
||||
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1812,9 +1802,9 @@ void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_i
|
|||
|
||||
use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase);
|
||||
if (use_dynamic_shape_process_) {
|
||||
MS_LOG(DEBUG) << "cnode:" << cnode->fullname_with_scope() << ",node_idx:" << node_idx
|
||||
<< ",is_ms_function_node:" << is_ms_function_node << ",graph_phase:" << graph_phase
|
||||
<< ",use_dynamic_shape_process_:" << use_dynamic_shape_process_;
|
||||
MS_LOG(DEBUG) << "Cnode: " << cnode->fullname_with_scope() << ", node_idx: " << node_idx
|
||||
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase
|
||||
<< ", use_dynamic_shape_process_: " << use_dynamic_shape_process_;
|
||||
cell_id_with_dynamic_detect_nodes_.clear();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -103,17 +103,13 @@ class GradExecutor {
|
|||
TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id);
|
||||
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
||||
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
||||
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
|
||||
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
|
||||
void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
|
||||
const std::vector<tensor::TensorPtr> &pre_tensors) const;
|
||||
void ClearRes();
|
||||
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
||||
|
||||
void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false,
|
||||
const std::string &graph_phase = "") const;
|
||||
|
||||
|
@ -126,10 +122,8 @@ class GradExecutor {
|
|||
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
|
||||
const CNodePtr &cnode) const;
|
||||
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
|
||||
void GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const;
|
||||
void AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const CNodePtr &cnode,
|
||||
const ValuePtrList &cloned_op_args, const ValuePtr &cloned_out) const;
|
||||
void GradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const ad::GradParamPtr &grad_param) const;
|
||||
void AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr, const ad::GradParamPtr &grad_param) const;
|
||||
void AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const;
|
||||
AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const;
|
||||
void SetBpropGraphJitLevel(const py::object &obj) const;
|
||||
|
@ -168,6 +162,9 @@ class GradExecutor {
|
|||
InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args);
|
||||
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
||||
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
|
||||
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
|
||||
const InputArgsInfoPtr &input_args_info);
|
||||
|
@ -195,10 +192,7 @@ class GradExecutor {
|
|||
const std::string &graph_phase) const;
|
||||
bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const;
|
||||
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
|
||||
const std::vector<AnfNodePtr> &new_anf_inputs) const;
|
||||
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
|
||||
bool is_ms_function_node, const std::string &graph_phase) const;
|
||||
|
||||
bool grad_flag_{false};
|
||||
bool grad_is_running_{false};
|
||||
bool need_renormalize_{false};
|
||||
|
@ -209,7 +203,7 @@ class GradExecutor {
|
|||
|
||||
// Used in sub thread
|
||||
size_t cell_order_{0};
|
||||
std::string cur_cell_id_{""};
|
||||
std::string cur_cell_id_;
|
||||
|
||||
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
|
||||
size_t grad_order_{0};
|
||||
|
|
|
@ -272,9 +272,8 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
|
|||
|
||||
// Connect grad graph of ms_function to context.
|
||||
auto auto_grad_cell_ptr = top_cell->auto_grad_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
auto grad_param =
|
||||
std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph);
|
||||
auto grad_param = std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value,
|
||||
grad_graph, !top_cell->is_high_order_top_cell());
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
grad_executor->async_executor()->Wait();
|
||||
|
@ -283,7 +282,7 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
|
|||
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
|
||||
<< ms_function_cnode->DebugString();
|
||||
}
|
||||
top_cell->set_ms_function_flag(true);
|
||||
top_cell->set_need_do_final_opt(true);
|
||||
return ms_function_cnode;
|
||||
}
|
||||
|
||||
|
@ -291,8 +290,7 @@ void MsFunction::AsyncKPynativeWithFProp(const GradExecutor *grad_executor,
|
|||
const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
|
||||
const ad::GradParamPtr &grad_param) const {
|
||||
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||
|
||||
const auto fn = [this, grad_param, auto_grad_cell_ptr]() {
|
||||
const auto fn = [grad_param, auto_grad_cell_ptr]() {
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode";
|
||||
|
|
|
@ -68,8 +68,7 @@ void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
// else:
|
||||
// x = x + self.p
|
||||
// return x
|
||||
for (size_t i = 0; i < op_run_info->base_op_run_info.input_tensor.size(); i++) {
|
||||
const auto &t = op_run_info->base_op_run_info.input_tensor[i];
|
||||
for (auto &t : op_run_info->base_op_run_info.input_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(t);
|
||||
if (t->is_parameter() && t->param_info() != nullptr && t->param_info()->requires_grad()) {
|
||||
input_args_info += "w";
|
||||
|
@ -127,6 +126,22 @@ void TopCellInfo::ClearDeviceMemory() const {
|
|||
}
|
||||
}
|
||||
|
||||
void TopCellInfo::Clear() {
|
||||
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
|
||||
hook_changed_ = false;
|
||||
is_init_kpynative_ = false;
|
||||
need_compile_graph_ = false;
|
||||
forward_already_run_ = false;
|
||||
op_index_ = 0;
|
||||
resource_ = nullptr;
|
||||
fg_ = nullptr;
|
||||
graph_info_map_.clear();
|
||||
op_info_with_tensor_id_.clear();
|
||||
tensor_id_with_tensor_object_.clear();
|
||||
op_info_with_ms_func_forward_tensors_.clear();
|
||||
cnode_hash_with_op_index_.clear();
|
||||
}
|
||||
|
||||
void TopCellInfo::DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) {
|
||||
auto &graph_info = graph_info_map().at(g);
|
||||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
|
@ -145,12 +160,12 @@ void TopCellInfo::SetParamNodeMapInGraphInfoMap(const std::string &id, const Par
|
|||
}
|
||||
|
||||
void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index,
|
||||
bool save_flag) const {
|
||||
bool need_save_sub_id) const {
|
||||
auto &graph_info = graph_info_map().at(fg());
|
||||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
// For example, set id of ((A,B),C) = {CNode, -1}
|
||||
if (save_flag) {
|
||||
if (need_save_sub_id) {
|
||||
SetMultipleOutputToGraphInfoMap(id, node);
|
||||
}
|
||||
}
|
||||
|
@ -188,23 +203,6 @@ void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const
|
|||
}
|
||||
}
|
||||
|
||||
void TopCellInfo::Clear() {
|
||||
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
|
||||
hook_changed_ = false;
|
||||
ms_function_flag_ = false;
|
||||
is_init_kpynative_ = false;
|
||||
need_compile_graph_ = false;
|
||||
forward_already_run_ = false;
|
||||
op_index_ = 0;
|
||||
resource_ = nullptr;
|
||||
fg_ = nullptr;
|
||||
graph_info_map_.clear();
|
||||
op_info_with_tensor_id_.clear();
|
||||
tensor_id_with_tensor_object_.clear();
|
||||
op_info_with_ms_func_forward_tensors_.clear();
|
||||
cnode_hash_with_op_index_.clear();
|
||||
}
|
||||
|
||||
void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index) const {
|
||||
auto &graph_info = graph_info_map().at(fg());
|
||||
|
|
|
@ -58,9 +58,10 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
|||
class TopCellInfo {
|
||||
public:
|
||||
~TopCellInfo() = default;
|
||||
TopCellInfo(size_t grad_order, std::string c_cell_id, std::string cellid, std::string already_run_cell_id,
|
||||
pipeline::ResourcePtr r, FuncGraphPtr fg)
|
||||
: grad_order_(grad_order),
|
||||
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string c_cell_id, std::string cellid,
|
||||
std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg)
|
||||
: is_high_order_top_cell_(is_high_order_top_cell),
|
||||
grad_order_(grad_order),
|
||||
c_cell_id_(std::move(c_cell_id)),
|
||||
cell_id_(std::move(cellid)),
|
||||
already_run_cell_id_(std::move(already_run_cell_id)),
|
||||
|
@ -76,12 +77,13 @@ class TopCellInfo {
|
|||
void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op);
|
||||
void GetOpInfo(const FrontendOpRunInfoPtr &op_run_info);
|
||||
inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
|
||||
inline bool ms_function_flag() const { return ms_function_flag_; }
|
||||
inline void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
|
||||
inline bool forward_already_run() const { return forward_already_run_; }
|
||||
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||
inline bool need_compile_graph() const { return need_compile_graph_; }
|
||||
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
|
||||
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
|
||||
inline bool need_do_final_opt() const { return need_do_final_opt_; }
|
||||
inline pipeline::ResourcePtr resource() const { return resource_; }
|
||||
inline FuncGraphPtr fg() const {
|
||||
MS_EXCEPTION_IF_NULL(fg_);
|
||||
|
@ -100,7 +102,7 @@ class TopCellInfo {
|
|||
graph_info_map_[fg] = graph_info;
|
||||
}
|
||||
inline void set_is_run_cell(bool is_run_cell) { is_run_cell_ = is_run_cell; }
|
||||
inline bool is_run_cell() { return is_run_cell_; }
|
||||
inline bool is_run_cell() const { return is_run_cell_; }
|
||||
inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; }
|
||||
inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const {
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_);
|
||||
|
@ -119,26 +121,25 @@ class TopCellInfo {
|
|||
return op_info_with_ms_func_forward_tensors_;
|
||||
}
|
||||
inline size_t op_index() const { return op_index_; }
|
||||
inline void IncreaseOpIndex() { op_index_++; }
|
||||
inline void IncreaseOpIndex() { ++op_index_; }
|
||||
|
||||
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
|
||||
cnode_hash_with_op_index_[node_hash] = op_index;
|
||||
}
|
||||
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) {
|
||||
auto iter = cnode_hash_with_op_index_.find(node_hash);
|
||||
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) const {
|
||||
const auto iter = cnode_hash_with_op_index_.find(node_hash);
|
||||
if (iter == cnode_hash_with_op_index_.end()) {
|
||||
MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void Clear();
|
||||
|
||||
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
|
||||
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const;
|
||||
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
|
||||
bool save_flag = true) const;
|
||||
bool need_save_sub_id = true) const;
|
||||
void ClearDeviceMemory() const;
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
void SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const;
|
||||
|
@ -148,12 +149,13 @@ class TopCellInfo {
|
|||
const std::vector<int64_t> &index) const;
|
||||
|
||||
bool hook_changed_{false};
|
||||
bool ms_function_flag_{false};
|
||||
bool is_init_kpynative_{false};
|
||||
bool forward_already_run_{false};
|
||||
bool need_compile_graph_{false};
|
||||
bool is_run_cell_{false};
|
||||
size_t op_index_{0};
|
||||
bool is_high_order_top_cell_{false};
|
||||
bool need_do_final_opt_{false};
|
||||
size_t grad_order_{0};
|
||||
std::string c_cell_id_;
|
||||
std::string cell_id_;
|
||||
|
|
|
@ -225,3 +225,9 @@ def bprop_scalar_calc(x, y, out, dout):
|
|||
def bprop_scalar_not(x, out, dout):
|
||||
"""Backpropagator for primitive `bool_not` and `string_not`."""
|
||||
return (C.zeros_like(x),)
|
||||
|
||||
|
||||
@bprops.register("TensorMove")
|
||||
def bprop_tensor_move(x, out, dout):
|
||||
"""Backpropagator for primitive `mutable`."""
|
||||
return (dout,)
|
||||
|
|
|
@ -148,7 +148,7 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
|
|||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_iteration_function_pynative():
|
||||
|
|
Loading…
Reference in New Issue