Add high order for pynative

Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
zjun 2022-11-26 09:53:23 +08:00
parent dbf5f7ad6d
commit 2db79d9c1e
13 changed files with 189 additions and 170 deletions

View File

@ -276,23 +276,29 @@ AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std:
}
}
bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(out);
bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
MS_EXCEPTION_IF_NULL(grad_param);
auto prim = GetCNodePrimitive(cnode);
auto prim = GetCNodePrimitive(grad_param->cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
MS_LOG(EXCEPTION) << "Should be primitive, but: " << grad_param->cnode->DebugString();
}
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) ||
IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad";
return true;
}
// anfnode_to_variable_adjoint_ hold out value, to avoid device not release, clear its device_address
auto cloned_value = ShallowCopyTensorValue(out);
auto cloned_value = ShallowCopyTensorValue(grad_param->out);
ClearDeviceAddress(cloned_value);
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
CNodePtr input_node = ConstructBpropGraphInput(cnode, op_args, out, dout);
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
if (!grad_param->grad_by_value) {
BuildKNode(grad_param, variable_adjoint);
need_do_manager_replace_ = true;
}
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout);
std::vector<CNodePtr> outputs;
#ifndef ENABLE_TEST
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
@ -300,7 +306,7 @@ bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op
} else {
mindspore::BuildBprop(input_node, &outputs, &users_);
if (outputs.empty()) {
MS_LOG(ERROR) << "the bprop output should not be empty" << cnode->DebugString();
MS_LOG(ERROR) << "the bprop output should not be empty" << grad_param->cnode->DebugString();
BuildCustomBpropCNode(input_node, &outputs);
}
}
@ -312,15 +318,12 @@ bool AutoGradCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op
}
#endif
if (outputs.empty()) {
MS_LOG(EXCEPTION) << "the bprop output should not be empty" << cnode->DebugString();
MS_LOG(EXCEPTION) << "the bprop output should not be empty" << grad_param->cnode->DebugString();
}
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
UpdateNextEdges(fn, cnode, outputs, op_args);
anfnode_to_variable_adjoint_.insert(std::make_pair(cnode, variable_adjoint));
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
// record last_node for brackpropagate
last_node_ = cnode;
last_node_ = grad_param->cnode;
return true;
}
@ -346,15 +349,8 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
}
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
} else {
// Set current knode for cnode
k_node = BuildKNode(grad_param);
const auto it = anfnode_to_variable_adjoint_.find(grad_param->cnode);
if (it == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "Can not find cnode " << grad_param->cnode->DebugString();
}
// Get current cnode all inputs knode
const auto &k_node_list = BuildKNodeListFromPrimalCNode(grad_param->cnode, it->second);
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, k_node_list, grad_param->out, &dout);
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
}
std::vector<CNodePtr> outputs;
@ -369,50 +365,10 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
variable_adjoint->set_k_node(k_node);
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
has_fbprop_ = true;
need_do_manager_replace_ = true;
return true;
}
AnfNodePtr AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param) {
AnfNodePtrList node_list;
MS_EXCEPTION_IF_NULL(grad_param);
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
(void)node_list.emplace_back(BuildKNodeForCNodeInput(grad_param->op_args, grad_param->cnode->input(i), i));
}
auto k_node = tape_->NewCNode(node_list);
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
return k_node;
}
AnfNodePtrList AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(adjoint);
AnfNodePtrList node_list;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
}
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
(void)node_list.emplace_back(input_adjoint_iter->second->k_node());
}
return node_list;
}
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node,
size_t input_index) {
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
}
return input_adjoint_iter->second->k_node();
} else {
return input_node;
}
}
CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args,
const ValuePtr &out, AnfNodePtr *const tape_dout) {
// Wrap tuple_getitem(fprop_app, 1) in a FuncGraph and optimize it;
@ -460,7 +416,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
}
FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
const GradAttr &grad_attr, bool build_formal_param) {
const GradAttr &grad_attr) {
// Set sens node and weights node
SetSensAndWeights(weights, grad_attr.has_sens);
@ -479,41 +435,84 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
return tape_;
}
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out, const AnfNodePtr &dout) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(out);
MS_EXCEPTION_IF_NULL(dout);
if (cnode->size() == 0) {
MS_LOG(EXCEPTION) << "cnode do not have inputs";
}
std::vector<AnfNodePtr> node_lists;
(void)node_lists.emplace_back(cnode->input(0));
for (size_t i = 0; i < op_args.size(); ++i) {
auto v = op_args[i];
auto node = cnode->input(i + 1);
if (node->isa<Parameter>()) {
node_lists.emplace_back(node);
node->set_abstract(v->ToAbstract());
continue;
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout) {
MS_EXCEPTION_IF_NULL(grad_param);
std::vector<AnfNodePtr> node_list;
(void)node_list.emplace_back(grad_param->cnode->input(0));
if (grad_param->grad_by_value) {
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
const auto &v = grad_param->op_args[i];
auto node = grad_param->cnode->input(i + 1);
if (node->isa<Parameter>()) {
node_list.emplace_back(node);
node->set_abstract(v->ToAbstract());
continue;
}
auto v_node = NewValueNode(grad_param->op_args[i]);
v_node->set_abstract(grad_param->op_args[i]->ToAbstract());
node_list.emplace_back(v_node);
}
auto v_node = NewValueNode(op_args[i]);
v_node->set_abstract(op_args[i]->ToAbstract());
node_lists.emplace_back(v_node);
} else {
// Input is a Parameter or cnode, not a value node
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list);
}
auto out_node = NewValueNode(out);
out_node->set_abstract(out->ToAbstract());
node_lists.emplace_back(out_node);
node_lists.emplace_back(dout);
CNodePtr input_node = tape_->NewCNode(node_lists);
input_node->set_abstract(out->ToAbstract()->Broaden());
auto out_node = NewValueNode(grad_param->out);
auto out_abs = grad_param->out->ToAbstract()->Broaden();
out_node->set_abstract(out_abs);
// set out
node_list.emplace_back(out_node);
// set dout
node_list.emplace_back(dout);
auto input_node = tape_->NewCNode(node_list);
input_node->set_abstract(out_abs);
return input_node;
}
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out) {
return k_cell->KPynativeOp(cnode, op_args, out);
void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
std::vector<AnfNodePtr> *const node_list) {
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
MS_LOG(DEBUG) << "Find input knode of node " << cnode->input(i)->DebugString();
if (cnode->input(i)->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
}
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
(void)node_list->emplace_back(input_adjoint_iter->second->k_node());
} else {
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract());
(void)node_list->emplace_back(cnode->input(i));
}
}
}
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode) {
MS_EXCEPTION_IF_NULL(grad_param);
AnfNodePtrList node_list;
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
(void)node_list.emplace_back(BuildKNodeForCNodeInput(grad_param->cnode->input(i)));
}
auto k_node = tape_->NewCNode(node_list);
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
VariableNode->set_k_node(k_node);
}
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
}
return input_adjoint_iter->second->k_node();
} else {
return input_node;
}
}
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param) {
return k_cell->KPynativeOp(grad_param);
}
void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode,
@ -987,7 +986,7 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new
return;
}
auto &old_node_users = users_[old_node];
for (auto pair_node : old_node_users) {
for (const auto &pair_node : old_node_users) {
auto cnode = pair_node.first;
size_t index = pair_node.second;
if (index >= cnode->size()) {
@ -1029,7 +1028,7 @@ void AutoGradCellImpl::ClearDeviceAddress(const ValuePtr &out) {
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
const auto &parameters = tape_->parameters();
auto cell_inputs_size = cell_inputs_.size();
if (has_fbprop_) {
if (need_do_manager_replace_) {
MS_LOG(DEBUG) << "Do parameter replace by manager";
auto mng = MakeManager({tape_}, false);
auto tr = mng->Transact();
@ -1046,7 +1045,7 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo
(void)tr.Replace(weights[i], parameters[weight_offset + i]);
}
tr.Commit();
has_fbprop_ = false;
need_do_manager_replace_ = false;
} else {
for (size_t i = 0; i < cell_inputs_size; ++i) {
Replace(cell_inputs_[i], parameters[i]);
@ -1088,9 +1087,8 @@ AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
}
FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &auto_grad_cell, const AnfNodePtrList &weights,
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
bool build_formal_param) {
return auto_grad_cell->Finish(weights, grad_position, grad_attr, build_formal_param);
const std::vector<size_t> &grad_position, const GradAttr &grad_attr) {
return auto_grad_cell->Finish(weights, grad_position, grad_attr);
}
} // namespace ad
} // namespace mindspore

View File

@ -42,8 +42,9 @@ struct GradAttr {
};
struct GradParam {
GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg = nullptr)
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {}
GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg,
bool grad_by_value)
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)), grad_by_value(grad_by_value) {}
// Primal CNode create by op forward process
const CNodePtr &cnode;
@ -111,7 +112,7 @@ class AutoGradCellImpl {
AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values);
~AutoGradCellImpl() = default;
// Reverse connect bprop of op
bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
bool KPynativeOp(const GradParamPtr &grad_param);
// Reverse connect ms_function or higher order sub bprop funcgraph
bool KPynativeWithFProp(const GradParamPtr &grad_param);
CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args, const ValuePtr &out,
@ -121,7 +122,7 @@ class AutoGradCellImpl {
// Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
// can be grad again.
FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
const GradAttr &grad_attr, bool build_formal_param);
const GradAttr &grad_attr);
private:
// Last cnode of this Cell, may be a primitive op or cell with user defined bprop.
@ -139,14 +140,13 @@ class AutoGradCellImpl {
// Record cnode's input map for tape_
UserType users_;
// Flag for ms_funtcion and high order
bool has_fbprop_{false};
bool need_do_manager_replace_{false};
bool IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const;
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
// construct input as cnode for expander
CNodePtr ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const AnfNodePtr &dout);
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout);
// Back propagate for one node;
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
const ValuePtrList &op_args);
@ -182,9 +182,10 @@ class AutoGradCellImpl {
void ClearDeviceAddress(const ValuePtr &out);
// Fbprop
AnfNodePtr BuildKNode(const GradParamPtr &grad_param);
AnfNodePtrList BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint);
AnfNodePtr BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node, size_t input_index);
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
std::vector<AnfNodePtr> *const node_list);
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);
};
using AutoGradCellImplPtr = std::shared_ptr<AutoGradCellImpl>;
@ -209,15 +210,13 @@ AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
// else:
// each cnode in primal funcgraph is replaced by value node
FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &k_cell, const AnfNodePtrList &weights,
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
bool build_formal_param = false);
const std::vector<size_t> &grad_position, const GradAttr &grad_attr);
// Grad for each operation.
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
// op_args: the arguments list of each input parameters.
// out: the op result.
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out);
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param);
// adjoint bprop form ms_function and high grad
void GradPynativeFBprop(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,

View File

@ -33,14 +33,17 @@
namespace mindspore {
/* namespace to support opt */
namespace opt {
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
static const std::map<std::string, std::vector<std::string>> op2attrs = {
{prim::kPrimBroadcastTo->name(), {kAttrShape}},
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
namespace {
const std::map<std::string, std::vector<std::string>> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}},
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
}
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto todos = TopoSort(graph->get_return());
auto mng = MakeManager({graph}, false);
for (const auto &node : todos) {
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
continue;
@ -66,7 +69,8 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
AnfNodePtrList inputs = {NewValueNode(new_prim)};
auto cnode = dyn_cast_ptr<CNode>(node);
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
cnode->set_inputs(inputs);
auto new_cnode = graph->NewCNodeInOrder(inputs);
(void)mng->Replace(node, new_cnode);
}
return true;
}

View File

@ -24,6 +24,7 @@ namespace opt {
* Try Expand cnode for front end graph.
*/
AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node);
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H

View File

@ -33,9 +33,11 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
// almost all ops are defined in two main paths
const std::string ops_module = OP_PATH;
const std::string inner_ops_module = INNER_OP_PATH;
const std::string grad_ops_module = GRAD_OP_PATH;
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
py::module grad_mod = py::module::import(common::SafeCStr(grad_ops_module));
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
@ -44,9 +46,12 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
if (py::hasattr(mod, common::SafeCStr(op_name))) {
return ops_module;
}
if (py::hasattr(grad_mod, common::SafeCStr(op_name))) {
return grad_ops_module;
}
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module
<< " don't have op:" << op_name;
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << grad_ops_module << " and "
<< functional_op_module << " don't have op:" << op_name;
}
return functional_op_module;
}

View File

@ -126,6 +126,7 @@ constexpr char REDUCE_OP_ALL[] = "prod";
constexpr char REDUCE_OP_PROD[] = "prod";
constexpr char OP_PATH[] = "mindspore.ops.operations";
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
constexpr char GRAD_OP_PATH[] = "mindspore.ops.operations._grad_ops";
constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional";
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
constexpr char GET_OP_FUNCTION[] = "_get_python_op";

View File

@ -200,6 +200,16 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
irpass.depend_value_elim_,
});
OptPassGroupMap map({{"ad_final_opt", bg_final_opt}});
if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
opt::OptPassConfig environ_eliminate = opt::OptPassConfig({
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
});
(void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate));
}
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", resource, map);
MS_EXCEPTION_IF_NULL(resource);
auto func_graph = resource->func_graph();

View File

@ -26,6 +26,7 @@
#include "backend/common/optimizer/helper.h"
#include "include/common/utils/convert_utils_py.h"
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/expander.h"
#include "pipeline/jit/pass.h"
namespace mindspore {
@ -135,7 +136,8 @@ ValuePtr ConvertOutputValueToTensor(const ValuePtr &v) {
int64_t input = v->cast<Int64ImmPtr>()->value();
return std::make_shared<tensor::Tensor>(input, kInt64);
} else {
MS_LOG(EXCEPTION) << "Output is " << v->ToString() << ", abstract " << v->ToAbstract()->Broaden();
MS_LOG(DEBUG) << "Output is " << v->ToString() << ", abstract " << v->ToAbstract()->Broaden();
return v;
}
}
@ -358,8 +360,8 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
fg->debug_info()->set_name("pynative_forward_graph");
auto resource = std::make_shared<pipeline::Resource>();
const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(input_args_info->grad_order);
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->cell_id, already_run_cell_id,
resource, fg);
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->is_high_order_top_cell, input_args_info->grad_order,
input_args_info->cell_id, already_run_cell_id, resource, fg);
top_cell_->set_forward_already_run(true);
top_cell_->set_input_args_id(input_args_info->input_args_id);
PushHighOrderGraphStack(top_cell_);
@ -385,8 +387,7 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
// Set last output abstract and will be used for sens
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
auto sens_v = ConvertOutputValueToTensor(v);
auto cloned_value = ShallowCopyTensorValue(sens_v);
auto cloned_value = ShallowCopyTensorValue(v);
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
}
@ -417,9 +418,11 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
const auto &cell_id = input_args_info->cell_id;
MS_LOG(DEBUG) << "EndGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
<< ", input args info ptr " << input_args_info.get();
const auto &out_value = input_args_info->out_value;
MS_EXCEPTION_IF_NULL(out_value);
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(out_value);
bool is_top_cell_end = (cell_id == top_cell()->cell_id());
if (is_top_cell_end) {
input_args_info->out_value = ConvertOutputValueToTensor(input_args_info->out_value);
}
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(input_args_info->out_value);
DoGradForCustomBprop(input_args_info, out_id);
// Update bprop grad stack
if (grad_is_running_ && !bprop_grad_stack_.empty()) {
@ -432,16 +435,15 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
}
}
// Just only dump the last forward graph
bool is_top_cell_end = cell_id == top_cell()->cell_id();
if (is_top_cell_end && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
curr_g()->set_output(GetInput(out_value, out_id));
curr_g()->set_output(GetInput(input_args_info->out_value, out_id));
PyNativeAlgo::Common::DumpGraphIR("fg.ir", curr_g());
}
// Reset grad flag and update output node of the outermost cell
if (input_args_info->is_grad_topest_cell && is_top_cell_end) {
MS_LOG(DEBUG) << "Cur top last cell " << cell_id;
(void)PopHighOrderGraphStack();
SetForwardLastNodeInfo(out_value, out_id);
SetForwardLastNodeInfo(input_args_info->out_value, out_id);
top_cell()->ClearCellHookOp();
cell_order_ = 0;
// set_grad_flag(false);
@ -450,7 +452,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
if (is_top_cell_end) {
// In high grad cases, the output of the internal graph may be a tuple, and node needs to be created in the getobj
if (!input_args_info->is_grad_topest_cell) {
SetForwardLastNodeInfo(out_value, out_id);
SetForwardLastNodeInfo(input_args_info->out_value, out_id);
}
top_cell()->CheckSubCellHookChanged();
top_input_args_info_ = input_args_info;
@ -555,7 +557,7 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
if (top_input_args_info_->input_arg_value_vec.size() == args.size()) {
top_input_args_info_->input_arg_value_vec.pop_back();
}
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_v));
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_tensor));
top_input_args_info_->has_sens = true;
}
@ -678,12 +680,14 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
MS_EXCEPTION_IF_NULL(bprop_graph);
std::vector<ValuePtr> tensor_args;
size_t input_size = has_sens ? input_args.size() - 1 : input_args.size();
// Sens may be a value tuple not a single tensor
// Sens may be a value tuple not a single tensor; bprop gradph have only one ses params, so tuple sens can not be
// flatten
for (size_t i = 0; i < input_size; ++i) {
if (PyNativeAlgo::Common::IsTensor(input_args[i])) {
(void)tensor_args.emplace_back(input_args[i]);
}
}
// No flatten
if (has_sens) {
(void)tensor_args.emplace_back(input_args[input_size]);
}
@ -721,15 +725,9 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const vector<AnfNodePtr> &w_args,
const vector<size_t> &p_args) {
MS_EXCEPTION_IF_NULL(top_input_args_info_);
bool build_formal_param = false;
if (!top_input_args_info_->has_custom_bprop && !top_input_args_info_->is_grad_topest_cell && IsNestedGrad()) {
build_formal_param = true;
need_renormalize_ = true;
}
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr, build_formal_param);
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr);
MS_EXCEPTION_IF_NULL(bprop_graph);
MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
@ -738,7 +736,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop_graph->debug_info()->set_name(ss.str());
UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens);
if (top_cell()->ms_function_flag()) {
if (top_cell()->need_do_final_opt()) {
bprop_graph = BpropGraphFinalOpt(bprop_graph);
}
if (top_input_args_info_->is_grad_topest_cell) {
@ -830,14 +828,15 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
MS_LOG(DEBUG) << "Bprop nested";
}
MS_EXCEPTION_IF_NULL(first_grad_fg);
PyNativeAlgo::Common::DumpGraphIR("first_grad_fg.ir", first_grad_fg);
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
ValuePtrList weights_args;
DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
r->manager()->AddFuncGraph(first_grad_fg);
if (!opt::ConvertPrimToPrimPy(first_grad_fg)) {
MS_LOG(EXCEPTION) << "Convert PrimitiveC to PrimitivePy failed";
}
auto r = std::make_shared<pipeline::Resource>();
set_eliminate_forward(false);
(void)first_grad_fg->transforms().erase(kGrad);
// Do high order
@ -861,10 +860,12 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
std::vector<ValuePtr> out_v{out_value};
out_value = std::make_shared<ValueTuple>(out_v);
}
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg);
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg,
!top_cell()->is_high_order_top_cell());
if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
}
top_cell()->set_need_do_final_opt(true);
need_renormalize_ = true;
}
@ -1175,12 +1176,9 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
std::back_inserter(cloned_op_args),
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(cloned_out, &tensors);
for (auto tensor : tensors) {
tensor->set_is_forward_output(true);
}
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) {
auto grad_param =
std::make_shared<ad::GradParam>(cnode, cloned_op_args, cloned_out, nullptr, !top_cell()->is_high_order_top_cell());
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), grad_param)) {
MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name;
}
}

View File

@ -240,13 +240,13 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
// Connect grad graph of ms_function to context.
auto auto_grad_cell_ptr = top_cell->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
auto grad_param =
std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph);
auto grad_param = std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value,
grad_graph, !top_cell->is_high_order_top_cell());
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
<< ms_function_cnode->DebugString();
}
top_cell->set_ms_function_flag(true);
top_cell->set_need_do_final_opt(true);
return ms_function_cnode;
}

View File

@ -55,9 +55,10 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
class TopCellInfo {
public:
~TopCellInfo() = default;
TopCellInfo(size_t grad_order, std::string cellid, std::string already_run_cell_id, pipeline::ResourcePtr r,
FuncGraphPtr fg)
: grad_order_(grad_order),
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string cellid, std::string already_run_cell_id,
pipeline::ResourcePtr r, FuncGraphPtr fg)
: is_high_order_top_cell_(is_high_order_top_cell),
grad_order_(grad_order),
cell_id_(std::move(cellid)),
already_run_cell_id_(std::move(already_run_cell_id)),
resource_(std::move(r)),
@ -71,10 +72,11 @@ class TopCellInfo {
inline const CellIdWithBackwardHookOp &cell_backward_hook_op() const { return cell_backward_hook_op_; }
void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op);
inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
inline bool ms_function_flag() const { return ms_function_flag_; }
inline void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
inline bool forward_already_run() const { return forward_already_run_; }
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
inline bool need_do_final_opt() const { return need_do_final_opt_; }
inline pipeline::ResourcePtr resource() const { return resource_; }
inline FuncGraphPtr fg() const {
MS_EXCEPTION_IF_NULL(fg_);
@ -108,9 +110,10 @@ class TopCellInfo {
const std::vector<int64_t> &index) const;
bool hook_changed_{false};
bool ms_function_flag_{false};
bool is_init_kpynative_{false};
bool forward_already_run_{false};
bool is_high_order_top_cell_{false};
bool need_do_final_opt_{false};
size_t grad_order_{0};
std::string cell_id_;
std::string already_run_cell_id_;

View File

@ -37,7 +37,7 @@ def run_watch_dde_network(file_name, log_file_name):
os.remove(log_file_name)
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_watch_dde_error_log():

View File

@ -22,7 +22,7 @@ import numpy as np
import pytest
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_switch_simplify_avoid_dead_node():

View File

@ -148,7 +148,7 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_iteration_function_pynative():