!46233 Fix bug for PyNative high grad

Merge pull request !46233 from zjun/fix_ms_function
This commit is contained in:
i-robot 2022-12-05 03:21:24 +00:00 committed by Gitee
commit 2c538003fa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 387 additions and 345 deletions

View File

@ -33,6 +33,7 @@
#include "utils/profile.h"
#include "include/common/utils/primitive_utils.h"
#include "pipeline/jit/pass.h"
namespace mindspore {
namespace ad {
namespace {
@ -234,6 +235,7 @@ bool IsZerosLikeNode(const AnfNodePtr &node) {
}
FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_before_opt.ir", bprop_func_graph);
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
resource->set_func_graph(bprop_func_graph);
auto manager = resource->manager();
@ -243,6 +245,34 @@ FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
return after_opt_bg;
}
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) {
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
return false;
}
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
return false;
}
constexpr int kEmptyTupeSize = 1;
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
return false;
}
return true;
}
AnfNodePtr GenerateEmptyTupleValue() {
std::vector<ValuePtr> value_list;
auto inputs_value = std::make_shared<ValueTuple>(value_list);
auto weights_value = std::make_shared<ValueTuple>(value_list);
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
auto tuple_value_node = NewValueNode(tuple_value);
tuple_value_node->set_abstract(tuple_value->ToAbstract());
return tuple_value_node;
}
} // namespace
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
@ -302,16 +332,15 @@ void FunctionNode::ReplaceEdges() {
AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
: tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
tape_->debug_info()->set_name("grad_top");
MS_LOG(DEBUG) << "Start AutoGradCellImpl: "
<< "cell_inputs size: " << cell_inputs.size();
MS_LOG(DEBUG) << "Start AutoGradCellImpl, cell_inputs size: " << cell_inputs.size();
for (size_t i = 0; i < cell_inputs.size(); ++i) {
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
auto parameter = tape_->add_parameter();
parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden());
auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]);
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
auto input_adjoint = std::make_shared<VariableNode>(func_node, input_param_values[i]);
anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, input_param_values[i]);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
}
}
@ -332,13 +361,14 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
ClearDeviceAddress(cloned_value);
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, cloned_value);
if (!grad_param->grad_by_value) {
BuildKNode(grad_param, variable_adjoint);
need_do_manager_replace_ = true;
}
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout);
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint);
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
// Gradient outputs
std::vector<CNodePtr> outputs;
#ifndef ENABLE_TEST
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
@ -364,7 +394,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
variable_adjoint->set_is_fake_bprop(true);
variable_adjoint->set_fake_prim_name(prim->name());
}
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
// record last_node for brackpropagate
last_node_ = grad_param->cnode;
return true;
@ -372,9 +402,9 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
MS_EXCEPTION_IF_NULL(grad_param);
AnfNodePtrList args_node_list;
CNodePtr bprop_cnode = nullptr;
AnfNodePtr k_node = nullptr;
AnfNodePtr dout = nullptr;
if (grad_param->grad_by_value) {
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
@ -395,7 +425,11 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
}
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, grad_param->out);
if (!grad_param->grad_by_value) {
BuildKNode(grad_param, variable_adjoint);
}
std::vector<CNodePtr> outputs;
for (size_t i = 1; i < grad_param->cnode->size(); ++i) {
// bprop_app[0] env
@ -403,11 +437,8 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
din->set_abstract(grad_param->op_args[i - 1]->ToAbstract()->Broaden());
(void)outputs.emplace_back(din);
}
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableNode>(fn, grad_param->out);
variable_adjoint->set_k_node(k_node);
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
need_do_manager_replace_ = true;
return true;
}
@ -430,7 +461,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
auto get_bprop =
bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))});
// Get graph after optimize
// Get bprop from fprop_fg, it is 2th output of fprop_fg
AnfNodePtrList node_list{get_bprop};
auto dout = bprop_builder->add_parameter();
MS_EXCEPTION_IF_NULL(out);
@ -438,12 +469,15 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
(void)node_list.emplace_back(dout);
auto call_bprop = bprop_builder->NewCNode(node_list);
bprop_builder->set_output(call_bprop);
// Call pass for optimize graph, such as inline
auto after_opt_fg = OptimizeBpropBuilder(bprop_builder);
// Call by tape_
MS_EXCEPTION_IF_NULL(tape_dout);
*tape_dout = BuildZerosLikeNode(tape_, out);
(void)bprop_builder_inputs.emplace_back(*tape_dout);
bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
(void)bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
get_bprop = tape_->NewCNode(bprop_builder_inputs);
// tape_dout is set by next op
AddUser(*tape_dout, get_bprop, bprop_builder_inputs.size() - 1);
@ -467,45 +501,46 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
if (!last_node_->isa<ValueNode>() && !last_node_->isa<Parameter>()) {
(void)BackPropagate();
}
// Return the gradient;
if (grad_attr.get_by_position && grad_position.empty()) {
MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
}
SetOutput(weights, grad_position, grad_attr);
// Replace Parameter of primal funcgraph with parameter of tape_;
// Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, grad_attr.has_sens);
pynative::PyNativeAlgo::Common::DumpGraphIR("before_final_opt.ir", tape_);
return tape_;
}
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout) {
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
const VariableAdjointPtr &variable_adjoint) {
MS_EXCEPTION_IF_NULL(grad_param);
std::vector<AnfNodePtr> node_list;
(void)node_list.emplace_back(grad_param->cnode->input(0));
auto out_abs = grad_param->out->ToAbstract()->Broaden();
if (grad_param->grad_by_value) {
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
const auto &v = grad_param->op_args[i];
auto node = grad_param->cnode->input(i + 1);
if (node->isa<Parameter>()) {
node_list.emplace_back(node);
node->set_abstract(v->ToAbstract());
(void)node_list.emplace_back(node);
node->set_abstract(v->ToAbstract()->Broaden());
continue;
}
auto v_node = NewValueNode(grad_param->op_args[i]);
v_node->set_abstract(grad_param->op_args[i]->ToAbstract());
node_list.emplace_back(v_node);
v_node->set_abstract(grad_param->op_args[i]->ToAbstract()->Broaden());
(void)node_list.emplace_back(v_node);
}
// Set out
auto out_node = NewValueNode(grad_param->out);
out_node->set_abstract(out_abs);
(void)node_list.emplace_back(out_node);
} else {
// Input is a Parameter or cnode, not a value node
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list);
// Set out
MS_EXCEPTION_IF_NULL(variable_adjoint);
(void)node_list.emplace_back(variable_adjoint->k_node());
}
auto out_node = NewValueNode(grad_param->out);
auto out_abs = grad_param->out->ToAbstract()->Broaden();
out_node->set_abstract(out_abs);
// set out
node_list.emplace_back(out_node);
// set dout
node_list.emplace_back(dout);
// Set dout
(void)node_list.emplace_back(dout);
auto input_node = tape_->NewCNode(node_list);
input_node->set_abstract(out_abs);
return input_node;
@ -515,7 +550,7 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
std::vector<AnfNodePtr> *const node_list) {
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
MS_LOG(DEBUG) << "Find input knode of node " << cnode->input(i)->DebugString();
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
if (cnode->input(i)->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
@ -524,13 +559,13 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
(void)node_list->emplace_back(input_adjoint_iter->second->k_node());
} else {
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract());
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract()->Broaden());
(void)node_list->emplace_back(cnode->input(i));
}
}
}
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode) {
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint) {
MS_EXCEPTION_IF_NULL(grad_param);
AnfNodePtrList node_list;
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
@ -538,7 +573,8 @@ void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const Variable
}
auto k_node = tape_->NewCNode(node_list);
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
VariableNode->set_k_node(k_node);
variable_adjoint->set_k_node(k_node);
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
}
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
@ -565,8 +601,9 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
}
for (size_t i = 0; i < op_args.size(); ++i) {
auto node = cnode->input(i + 1);
auto din = dins[i];
const auto &node = cnode->input(i + 1);
const auto &din = dins[i];
MS_LOG(DEBUG) << "Node " << node->DebugString() << ", din " << din->DebugString();
UpdateNextEdges(fn, node, din, op_args[i]);
}
}
@ -617,29 +654,27 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const AnfNodeP
AddParameterNode(param, tensor);
UpdateNextEdges(fn, node, din, op_arg);
} else {
MS_LOG(DEBUG) << "It is not a cnode: " << node->DebugString();
MS_LOG(DEBUG) << "It is not a cnode or parameter: " << node->DebugString();
return;
}
}
void AutoGradCellImpl::BuildForwardLastNode() {
MS_EXCEPTION_IF_NULL(last_node_);
if (last_node_->isa<ValueNode>() ||
anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
return;
}
if (anfnode_to_variable_adjoint_.find(last_node_) == anfnode_to_variable_adjoint_.end()) {
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
// If last_node is a maketuple or tuplegetitem, need update next edges,
// if last_node is parameter, not need to update next edges.
if (last_node_->isa<CNode>()) {
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
}
auto input_adjoint = std::make_shared<VariableNode>(fn, sens_value_);
anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
} else {
MS_LOG(EXCEPTION) << "Unprocessed node: " << last_node_->DebugString();
MS_LOG(DEBUG) << "Process last node info " << last_node_->DebugString();
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
// If last_node is a maketuple or tuplegetitem, need update next edges,
// if last_node is parameter, not need to update next edges.
if (last_node_->isa<CNode>()) {
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
}
auto input_adjoint = std::make_shared<VariableAdjoint>(fn, sens_value_);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
}
void AutoGradCellImpl::AddParameterNode(const AnfNodePtr &parameter, const ValuePtr &tensor) {
@ -647,9 +682,9 @@ void AutoGradCellImpl::AddParameterNode(const AnfNodePtr &parameter, const Value
MS_EXCEPTION_IF_NULL(tensor);
auto zeros_like_dout = BuildZerosLikeNode(tape_, tensor);
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
auto input_adjoint = std::make_shared<VariableNode>(func_node, tensor);
anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
weights_.push_back(parameter);
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, tensor);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
(void)weights_used_in_graph_.emplace_back(parameter);
}
AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
@ -657,8 +692,8 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
MS_EXCEPTION_IF_NULL(out_value);
MS_EXCEPTION_IF_NULL(sub_value);
MS_EXCEPTION_IF_NULL(din);
std::string out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
std::string sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
const auto &out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
const auto &sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
if (out_value_id == sub_value_id) {
return din;
} else if (out_value->isa<tensor::Tensor>()) {
@ -672,13 +707,13 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
}
auto value_seq = out_value->cast<ValueSequencePtr>();
int index = -1;
for (auto value : value_seq->value()) {
for (const auto &value : value_seq->value()) {
auto real_din = GetRealDin(fn, value, sub_value, din);
(void)inputs.emplace_back(real_din);
// if exist din == fake_dout, we record it in user vector
if (din == fn->fake_dout() && real_din == din) {
index = inputs.size() - 1;
index = static_cast<int>(inputs.size()) - 1;
}
}
auto new_din = tape_->NewCNode(inputs);
@ -704,7 +739,7 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
prim_py->AddBpropCutPrim(bprop_cut);
if (prim->HasAttr("cell_id")) {
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
if (cell_id != "") {
if (!cell_id.empty()) {
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
(void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
}
@ -723,12 +758,11 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
if (i < args_size) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(SizeToLong(i - 1))});
din->set_abstract(cnode->input(i)->abstract()->Broaden());
outputs->emplace_back(din);
(void)outputs->emplace_back(din);
(void)abs.emplace_back(din->abstract());
}
}
output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
return;
}
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
@ -760,11 +794,7 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
}
void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
MS_EXCEPTION_IF_NULL(last_node_);
MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
BuildForwardLastNode();
// Add sens parameter
ParameterPtr sens_param = nullptr;
if (has_sens_arg) {
@ -774,8 +804,9 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
}
// update dout for dout
MS_EXCEPTION_IF_NULL(last_node_);
if (anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
auto variable = anfnode_to_variable_adjoint_.at(last_node_);
const auto &variable = anfnode_to_variable_adjoint_.at(last_node_);
if (has_sens_arg && sens_param != nullptr) {
variable->fn()->UpdateAccumulativeDout(sens_param);
} else {
@ -787,19 +818,19 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
need_grad_weights_.clear();
for (const auto &weight : weights) {
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
auto p = tape_->add_parameter();
(void)need_grad_weights_.emplace(weight);
auto input_w = weight->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(input_w);
// Use name to match weight parameter in high order
auto default_param = input_w->default_param();
p->set_name(input_w->name());
p->set_default_param(default_param);
p->set_abstract(default_param->ToAbstract()->Broaden());
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
(void)need_grad_weights_.emplace(t->id());
auto p = tape_->add_parameter();
auto param = weight->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
p->set_name(param->name());
p->set_default_param(t);
p->set_abstract(t->ToAbstract()->Broaden());
}
}
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
for (auto iter = anfnode_to_variable_adjoint_.rbegin(); iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
if (!iter->first->isa<CNode>()) {
continue;
@ -815,25 +846,30 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
void AutoGradCellImpl::BackPropagate() {
const auto &last_node_reverse_iter = GetLastNodeReverseIter();
bool has_primc = false;
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
auto variable = iter->second;
const auto &variable = iter->second;
if (!variable->is_need_propagate()) {
MS_LOG(DEBUG) << "No need grad";
continue;
}
if (variable->is_need_propagate() && variable->is_fake_bprop()) {
if (variable->is_fake_bprop()) {
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
}
auto fn = variable->fn();
if (!has_primc && iter->first->isa<CNode>() && GetCNodePrimitive(iter->first) != nullptr) {
has_primc = true;
}
const auto &fn = variable->fn();
// replace real dout to fake dout
Replace(fn->fake_dout(), fn->RealDout());
// replace edges which exist fake dout
fn->ReplaceEdges();
auto &next_edges = fn->next_edges();
const auto &next_edges = fn->next_edges();
for (const auto &next_edge : next_edges) {
auto node = next_edge.first;
auto din = next_edge.second;
const auto &node = next_edge.first;
const auto &din = next_edge.second;
if (anfnode_to_variable_adjoint_.find(node) == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "current node not find corresponding node";
}
@ -842,6 +878,7 @@ void AutoGradCellImpl::BackPropagate() {
last_variable->set_is_need_propagate(true);
}
}
tape_->set_flag(kPrimCPrimPyMixed, has_primc && need_do_manager_replace_);
}
AnfNodePtr AutoGradCellImpl::GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index) {
@ -925,34 +962,6 @@ AnfNodePtr AutoGradCellImpl::GetWeightGrad(bool grad_weights, const AnfNodePtrLi
}
}
bool AutoGradCellImpl::IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const {
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
return false;
}
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
return false;
}
constexpr int kEmptyTupeSize = 1;
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
return false;
}
return true;
}
AnfNodePtr AutoGradCellImpl::GenerateEmptyTupleValue() {
std::vector<ValuePtr> value_list;
auto inputs_value = std::make_shared<ValueTuple>(value_list);
auto weights_value = std::make_shared<ValueTuple>(value_list);
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
auto tuple_value_node = NewValueNode(tuple_value);
tuple_value_node->set_abstract(tuple_value->ToAbstract());
return tuple_value_node;
}
void AutoGradCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
const GradAttr &grad_attr) {
auto inputs_grad_ret = GetInputGrad(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
@ -1017,8 +1026,8 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new
}
void AutoGradCellImpl::ElimateTupleGetItem() {
for (auto iter = users_.begin(); iter != users_.end(); iter++) {
auto old_node = iter->first;
for (auto &user : users_) {
auto old_node = user.first;
if (!old_node->isa<CNode>()) {
continue;
}
@ -1039,6 +1048,7 @@ void AutoGradCellImpl::ElimateTupleGetItem() {
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
const auto &parameters = tape_->parameters();
auto cell_inputs_size = cell_inputs_.size();
pynative::PyNativeAlgo::Common::DumpGraphIR("replace_param.ir", tape_);
if (need_do_manager_replace_) {
MS_LOG(DEBUG) << "Do parameter replace by manager";
auto mng = MakeManager({tape_}, false);
@ -1070,13 +1080,12 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo
}
}
for (auto &weight : weights_) {
if (need_grad_weights_.find(weight) == need_grad_weights_.end()) {
auto parameter = weight->cast<ParameterPtr>();
const auto &input_value = parameter->default_param();
MS_EXCEPTION_IF_NULL(input_value);
auto value_node = NewValueNode(input_value);
value_node->set_abstract(input_value->ToAbstract()->Broaden());
for (auto &weight : weights_used_in_graph_) {
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
if (need_grad_weights_.find(t->id()) == need_grad_weights_.end()) {
MS_LOG(DEBUG) << "Convert " << weight->DebugString() << " to value node";
auto value_node = NewValueNode(t);
value_node->set_abstract(t->ToAbstract()->Broaden());
Replace(weight, value_node);
}
}

View File

@ -86,9 +86,9 @@ class FunctionNode {
};
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
class VariableNode {
class VariableAdjoint {
public:
VariableNode(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
VariableAdjoint(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
ValuePtr out_value() const { return out_value_; }
FunctionNodePtr fn() const { return fn_; }
@ -114,7 +114,7 @@ class VariableNode {
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
AnfNodePtr k_node_{nullptr};
};
using VariableNodePtr = std::shared_ptr<VariableNode>;
using VariableAdjointPtr = std::shared_ptr<VariableAdjoint>;
class AutoGradCellImpl {
public:
@ -143,10 +143,10 @@ class AutoGradCellImpl {
// Top cell inputs
AnfNodePtrList cell_inputs_;
// These weights need to calculate gradient.
mindspore::HashSet<AnfNodePtr> need_grad_weights_;
mindspore::HashSet<std::string> need_grad_weights_;
// Bprop dins of each variable or middle out
OrderedMap<AnfNodePtr, VariableNodePtr> anfnode_to_variable_adjoint_;
AnfNodePtrList weights_;
OrderedMap<AnfNodePtr, VariableAdjointPtr> anfnode_to_variable_adjoint_;
AnfNodePtrList weights_used_in_graph_;
// Record cnode's input map for tape_
UserType users_;
// Flag for ms_funtcion and high order
@ -156,7 +156,8 @@ class AutoGradCellImpl {
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
// construct input as cnode for expander
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout);
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
const VariableAdjointPtr &variable_adjoint);
// Back propagate for one node;
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
const ValuePtrList &op_args);
@ -176,7 +177,7 @@ class AutoGradCellImpl {
// Set sens and weights parameter nodes by user input info
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
// get last reverse iterator
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator GetLastNodeReverseIter();
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
void BackPropagate();
// Set return node according to grad flag
@ -184,14 +185,12 @@ class AutoGradCellImpl {
AnfNodePtr GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index);
AnfNodePtr GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
AnfNodePtr GetWeightGrad(bool grad_weights, const AnfNodePtrList &weights, bool weight_param_is_tuple);
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const;
AnfNodePtr GenerateEmptyTupleValue();
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void ElimateTupleGetItem();
// Fbprop
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);
void BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint);
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
std::vector<AnfNodePtr> *const node_list);
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);

View File

@ -29,40 +29,32 @@ using mindspore::tensor::Tensor;
namespace mindspore {
namespace parallel {
std::string GetOpPythonPath(const OperatorName &op_name) {
// almost all ops are defined in two main paths
const std::string ops_module = OP_PATH;
const std::string inner_ops_module = INNER_OP_PATH;
const std::string grad_ops_module = GRAD_OP_PATH;
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
py::module grad_mod = py::module::import(common::SafeCStr(grad_ops_module));
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
const char *GetOpPythonPath(const char *op_name) {
static py::module inner_mod = py::module::import(INNER_OP_PATH);
if (py::hasattr(inner_mod, op_name)) {
return INNER_OP_PATH;
}
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
return inner_ops_module;
static py::module mod = py::module::import(OP_PATH);
if (py::hasattr(mod, op_name)) {
return OP_PATH;
}
if (py::hasattr(mod, common::SafeCStr(op_name))) {
return ops_module;
static py::module grad_mod = py::module::import(GRAD_OP_PATH);
if (py::hasattr(grad_mod, op_name)) {
return GRAD_OP_PATH;
}
if (py::hasattr(grad_mod, common::SafeCStr(op_name))) {
return grad_ops_module;
static py::module functional_mod = py::module::import(FUNCTIONAL_OP_PATH);
if (!py::hasattr(functional_mod, op_name)) {
MS_LOG(EXCEPTION) << OP_PATH << " and " << INNER_OP_PATH << " and " << GRAD_OP_PATH << " and " << FUNCTIONAL_OP_PATH
<< " don't have op:" << op_name;
}
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << grad_ops_module << " and "
<< functional_op_module << " don't have op:" << op_name;
}
return functional_op_module;
return FUNCTIONAL_OP_PATH;
}
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
std::string op_path = GetOpPythonPath(op_name);
py::module mod = py::module::import(common::SafeCStr(op_path));
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
return nullptr;
}
const auto op_path = GetOpPythonPath(op_name.c_str());
std::vector<py::object> arg_list;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
[](const Attr &attr) { return ValueToPyData(attr.second); });

View File

@ -31,7 +31,7 @@ namespace mindspore {
namespace parallel {
const char USING_HASH_NAME[] = "USING_HASH_NAME";
// Get the operator's path where the operator has be defined
std::string GetOpPythonPath(const OperatorName &op_name);
const char *GetOpPythonPath(const char *op_name);
// Init python operator Instance
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);

View File

@ -911,6 +911,7 @@ constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph";
constexpr auto kFlagNeedRenormalize = "need_renormalize";
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process";
constexpr auto kPrimCPrimPyMixed = "primc_primpy_mixed";
// TODO(dsj): for ms_function running in graph_mode. should be delete later
constexpr auto kAttrMSFunction = "ms_function_graph";

View File

@ -88,12 +88,16 @@ struct InputArgsInfo {
bool has_custom_bprop;
size_t input_size;
std::string obj_id;
bool has_sens{false};
bool use_dynamic_shape_process = false;
PrimitivePyPtr custom_bprp_prim{nullptr};
ValuePtr out_value{nullptr};
std::string cell_id;
std::string already_run_cell_id;
std::string input_args_id;
// Cell unique id, cell_id + cell_order;
std::string obj_order_id;
size_t custom_bprop_cell_count = 0;
size_t grad_order = 0;
std::vector<std::string> input_arg_id_vec;

View File

@ -20,6 +20,7 @@
#include "pipeline/pynative/pynative_utils.h"
#include "pipeline/jit/pipeline.h"
#include "ir/cell.h"
#include "ir/func_graph_cloner.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/debug/trace.h"
#include "backend/common/optimizer/helper.h"
@ -71,8 +72,8 @@ std::string GetFnInfoByPyObj(const py::object &obj) {
return (module_name + "_" + fn_name + "_" + filename + "_" + code_lineno);
}
InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell,
bool is_high_order_top_cell) {
InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, size_t obj_order,
bool is_grad_top_cell, bool is_high_order_top_cell) {
bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
std::string obj_id;
if (!py::isinstance<Cell>(obj) && (is_grad_top_cell || is_high_order_top_cell)) {
@ -97,7 +98,9 @@ InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::arg
pipeline::CheckArgsValid(obj, args);
}
input_args_info->cell_id = GetCellId(obj, args, input_args_info);
MS_LOG(DEBUG) << "cell_id is " << obj_id << ", is grad top cell " << (is_grad_top_cell || is_high_order_top_cell);
input_args_info->obj_order_id = input_args_info->cell_id + '_' + std::to_string(obj_order);
MS_LOG(DEBUG) << "Cell_id is " << input_args_info->cell_id << ", is grad top cell "
<< (is_grad_top_cell || is_high_order_top_cell);
return input_args_info;
}
@ -194,8 +197,8 @@ bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) {
return p1->key() == p2->key();
}
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const std::vector<AnfNodePtr> &new_anf_inputs,
const TopCellInfoPtr &top_cell) {
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, size_t node_index,
const std::vector<AnfNodePtr> &new_anf_inputs, const TopCellInfoPtr &top_cell) {
MS_EXCEPTION_IF_NULL(old_node_info);
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
old_node_info->input_param_infos.size();
@ -237,9 +240,9 @@ bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const s
// Compare cnode edge.
MS_EXCEPTION_IF_NULL(top_cell);
if (old_op_index != top_cell->get_op_index_by_cnode_hash(new_anf_input->hash())) {
if (old_op_index != top_cell->get_op_index_by_cnode_hash(new_anf_input->hash(), node_index)) {
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index: " << old_op_index
<< " new op_index: " << top_cell->get_op_index_by_cnode_hash(new_anf_input->hash());
<< " new op_index: " << node_index;
return true;
}
} else {
@ -267,23 +270,11 @@ bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const s
return false;
}
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
bool is_ms_function_node, const std::string &graph_phase,
const TopCellInfoPtr &top_cell) {
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, size_t node_index,
const CNodePtr &new_cnode, const TopCellInfoPtr &top_cell) {
MS_EXCEPTION_IF_NULL(old_node_info);
// 1.Detect ms_function phase
if (is_ms_function_node) {
if (!old_node_info->is_graph_node || graph_phase != old_node_info->graph_phase) {
MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node
<< " new is_graph_node: " << is_ms_function_node << " old graph_phase "
<< old_node_info->graph_phase << " new graph_phase: " << graph_phase;
return true;
}
return false;
}
// 2.Detect cnode prim
MS_EXCEPTION_IF_NULL(new_cnode);
auto new_prim = GetCNodePrimitive(new_cnode);
if (!common::IsEqual(new_prim, old_node_info->prim)) {
MS_LOG(DEBUG) << "Graph is dynamic, old prim: "
@ -299,10 +290,18 @@ bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info
}
// 4.Detect inputs
return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs(), top_cell);
return IsCnodeInputsDynamic(old_node_info, node_index, new_cnode->inputs(), top_cell);
}
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph) {
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph, bool need_renormalize) {
MS_LOG(DEBUG) << "Do bporp graph final opt";
MS_EXCEPTION_IF_NULL(bprop_graph);
if (need_renormalize && bprop_graph->has_flag(kPrimCPrimPyMixed)) {
MS_LOG(DEBUG) << "Convert PrimitiveC to PrimitivePy";
if (!opt::ConvertPrimToPrimPy(bprop_graph)) {
MS_LOG(EXCEPTION) << "Convert PrimitiveC to PrimitivePy failed";
}
}
auto resource = std::make_shared<pipeline::Resource>();
resource->set_func_graph(bprop_graph);
auto manager = resource->manager();
@ -349,6 +348,24 @@ void SetGraphInputArgs(const std::vector<ValuePtr> &input_vec, const pipeline::R
}
}
void SetSensValue(const prim::GradOperationPtr &grad, const InputArgsInfoPtr &input_args_info, const py::args &args) {
MS_EXCEPTION_IF_NULL(grad);
if (!grad->sens_param()) {
return;
}
MS_LOG(DEBUG) << "Get sens param";
size_t forward_args_size = args.size() - 1;
auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
const auto &sens_tensor = ConvertOutputValueToTensor(sens_v);
// Sens have already exist, which may be need update
MS_EXCEPTION_IF_NULL(input_args_info);
if (input_args_info->input_arg_value_vec.size() == args.size()) {
input_args_info->input_arg_value_vec.pop_back();
}
(void)input_args_info->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_tensor));
input_args_info->has_sens = true;
}
AbstractBasePtr GetGradGraphOutputAbstract(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(fg->output());
@ -362,13 +379,6 @@ ForwardExecutorPtr GradExecutor::forward() const {
return forward_executor;
}
std::string GradExecutor::GetCurCellOrder() const {
if (cur_cell_id_.empty()) {
MS_LOG(EXCEPTION) << "The cur_cell_id_ is empty!";
}
return cur_cell_id_ + "_" + std::to_string(cell_order_);
}
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
if (high_order_stack_.empty()) {
MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
@ -383,7 +393,6 @@ TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
void GradExecutor::PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info) {
input_args_info_stack_.push(input_args_info);
// ++cell_order_;
}
void GradExecutor::PopInputArgsInfoStack() {
@ -468,39 +477,45 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
}
}
void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const {
top_cell()->set_need_compile_graph(need_compile_graph);
top_cell()->set_forward_already_run(forward_already_run);
}
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
const auto input_args_info = GetInputArgsInfo(obj, args);
PushInputArgsInfoStack(input_args_info);
if (input_args_info->has_custom_bprop) {
custom_bprop_cell_count_ += 1;
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
}
if (grad_order_ == 0) {
IncreaseGradOrder();
}
input_args_info->grad_order = grad_order_;
// May be can async here
NewGraphImpl(input_args_info);
}
InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args) {
auto input_args_info =
ParsePyArgsToInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
ParsePyArgsToInputArgsInfo(obj, args, obj_order_++, input_args_info_stack_.empty(), is_high_order_top_cell());
if (input_args_info->has_custom_bprop) {
custom_bprop_cell_count_ += 1;
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
}
// CheckAlready run first, grad_order_ will increase 1(highorder scenario)
// If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
if (input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell) {
if (grad_order_ == 0) {
IncreaseGradOrder();
}
// Both set grad: NetA.set_grad(); NetB.set_grad();
// Run forward: NetA(); NetB();then grad order is 2
// Grad(NetA()); Grad(NetB()). NetA grad order now is 2. Forward run again. grad_order is disordered, so need reset.
if (input_args_info->is_grad_topest_cell && input_args_info->grad_order > 1) {
input_args_info->grad_order--;
}
input_args_info->already_run_cell_id = GetAlreadyRunCellId(input_args_info->obj_id);
}
input_args_info->grad_order = grad_order_;
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
// top_input_args_info_ indicate current running cell info
top_input_args_info_ = input_args_info;
return input_args_info;
}
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
++cell_order_;
const auto &cell_id = input_args_info->cell_id;
cur_cell_id_ = cell_id;
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
<< ", input args info ptr " << input_args_info.get();
// Make top graph and init resource
@ -514,34 +529,21 @@ void GradExecutor::AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info) {
}
void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
// CheckAlready run first, grad_order_ will increase 1(highorder scenario)
// If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
if (input_args_info->grad_order == 0) {
input_args_info->grad_order++;
}
// Both set grad: NetA.set_grad(); NetB.set_grad();
// Run forward: NetA(); NetB();
// Grad(NetA()); Grad(NetB()). grad_order_ is disordered, so need reset.
if (input_args_info->is_grad_topest_cell && input_args_info->grad_order > 1) {
input_args_info->grad_order--;
}
auto fg = std::make_shared<FuncGraph>();
fg->debug_info()->set_name("pynative_forward_graph");
auto resource = std::make_shared<pipeline::Resource>();
const auto &already_run_cell_id = GetAlreadyRunCellId(input_args_info->obj_id);
top_cell_ =
std::make_shared<TopCellInfo>(input_args_info->is_high_order_top_cell, input_args_info->grad_order,
input_args_info->obj_id, input_args_info->cell_id, already_run_cell_id, resource, fg);
MS_EXCEPTION_IF_NULL(input_args_info);
const auto &obj_id_with_grad_order = input_args_info->obj_id + "_" + std::to_string(input_args_info->grad_order);
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->is_high_order_top_cell, input_args_info->grad_order,
obj_id_with_grad_order, input_args_info->cell_id,
input_args_info->already_run_cell_id, resource, fg);
top_cell_->set_forward_already_run(true);
top_cell_->set_input_args_id(input_args_info->input_args_id);
PushHighOrderGraphStack(top_cell_);
(void)top_cell_list_.emplace_back(top_cell_);
const auto &cell_id = input_args_info->obj_id.append("_").append(std::to_string(grad_order_));
is_cell_id_in_dynamic_detect_nodes_map_ =
(cell_id_with_dynamic_detect_nodes_.find(cell_id) != cell_id_with_dynamic_detect_nodes_.end());
(cell_id_with_dynamic_detect_nodes_.find(obj_id_with_grad_order) != cell_id_with_dynamic_detect_nodes_.end());
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
}
@ -578,25 +580,34 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
}
const auto input_args_info = input_args_info_stack_.top();
MS_EXCEPTION_IF_NULL(input_args_info);
if (input_args_info->has_custom_bprop) {
GetCustomBpropPrim(obj, args, out, input_args_info);
}
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
UpdateInputArgsInfo(input_args_info, obj, out, args);
PopInputArgsInfoStack();
if (input_args_info->is_grad_topest_cell) {
set_grad_flag(false);
}
// May be can async here
EndGraphImpl(input_args_info);
}
void GradExecutor::UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj,
const py::object &out, const py::args &args) {
MS_EXCEPTION_IF_NULL(input_args_info);
if (input_args_info->has_custom_bprop) {
GetCustomBpropPrim(obj, args, out, input_args_info);
}
// Used at master thread, change its at master thread
if (input_args_info->is_grad_topest_cell) {
grad_flag_ = false;
obj_order_ = 0;
}
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
// If use_dynamic_shape_process_ update in run op process, here can instantly sensed
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
}
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
const auto &cell_id = input_args_info->cell_id;
MS_LOG(DEBUG) << "EndGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
MS_LOG(DEBUG) << "EndGraphInner start " << input_args_info->input_size << ", cell_id " << input_args_info->cell_id
<< ", input args info ptr " << input_args_info.get();
bool is_top_cell_end = (cell_id == top_cell()->cell_id());
bool is_top_cell_end = (input_args_info->cell_id == top_cell()->cell_id());
if (is_top_cell_end) {
input_args_info->out_value = ConvertOutputValueToTensor(input_args_info->out_value);
}
@ -619,14 +630,12 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
}
// Reset grad flag and update output node of the outermost cell
if (input_args_info->is_grad_topest_cell && is_top_cell_end) {
MS_LOG(DEBUG) << "Cur top last cell " << cell_id;
MS_LOG(DEBUG) << "Cur top last cell " << input_args_info->cell_id;
(void)PopHighOrderGraphStack();
SetForwardLastNodeInfo(input_args_info->out_value, out_id);
top_cell()->ClearCellHookOp();
cell_order_ = 0;
// set_grad_flag(false);
}
// Checkout whether need to compile graph when each top cell has ran finished
// Checkout whether need to compile graph when each top cell has run finished
if (is_top_cell_end) {
// In high grad cases, the output of the internal graph may be a tuple, and node needs to be created in the getobj
if (!input_args_info->is_grad_topest_cell) {
@ -731,7 +740,6 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info
if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph";
already_run_top_cell_[already_top_cell_id] = new_top_cell;
pre_top_cell_ = top_cell();
return;
}
@ -742,30 +750,27 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info
if (input_args_info->use_dynamic_shape_process) {
// Function need compile every time.
MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again";
EraseTopCellFromTopCellList(pre_top_cell);
{
py::gil_scoped_acquire acquire;
pre_top_cell->Clear();
EraseTopCellFromTopCellList(pre_top_cell);
}
already_run_top_cell_[already_top_cell_id] = new_top_cell;
pre_top_cell_ = nullptr;
} else {
MS_LOG(DEBUG) << "no need to compile graph again";
MS_LOG(DEBUG) << "No need to compile graph again";
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
// In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
// the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
if (!input_args_info->is_grad_topest_cell) {
pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), new_top_cell->graph_info_map().at(new_top_cell->fg()));
}
pre_top_cell_ = pre_top_cell;
pre_top_cell->set_forward_already_run(true);
}
}
void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
MS_EXCEPTION_IF_NULL(top_cell);
auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
const auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
if (iter == top_cell_list_.end()) {
MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
<< " from top cell list";
@ -783,27 +788,14 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
MS_EXCEPTION_IF_NULL(top_input_args_info_);
MS_LOG(DEBUG) << "GradNetInner start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
<< ", input args info ptr " << top_input_args_info_.get();
MS_EXCEPTION_IF_NULL(grad);
if (grad->sens_param()) {
MS_LOG(DEBUG) << "Get sens param";
size_t forward_args_size = args.size() - 1;
auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
const auto &sens_tensor = ConvertOutputValueToTensor(sens_v);
// Sens have already exist, which may be need update
if (top_input_args_info_->input_arg_value_vec.size() == args.size()) {
top_input_args_info_->input_arg_value_vec.pop_back();
}
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_tensor));
top_input_args_info_->has_sens = true;
}
// For async, top can not be change when run SetForwardLastNodeInfo
if (pre_top_cell_ != nullptr) {
set_top_cell(pre_top_cell_);
}
SetSensValue(grad, top_input_args_info_, args);
// For async, top can not be change when run SetForwardLastNodeInfo; Change top cell after sync
set_top_cell(already_run_top_cell_.at(top_cell()->already_run_cell_id()));
if (!top_cell()->need_compile_graph()) {
MS_LOG(DEBUG) << "No need compile graph";
top_cell_list_.pop_back();
UpdateTopCellInfo(false, false);
top_cell()->UpdateTopCellInfo(false, false, false);
return;
}
MS_LOG(DEBUG) << "Need compile graph";
@ -829,12 +821,6 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
const std::vector<size_t> &p_args) {
// Get bprop graph of top cell
auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
MS_EXCEPTION_IF_NULL(bprop_graph);
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_;
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
MS_EXCEPTION_IF_NULL(top_input_args_info_);
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
auto resource = top_cell()->resource();
MS_EXCEPTION_IF_NULL(resource);
resource->set_func_graph(bprop_graph);
@ -848,7 +834,7 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
(void)TaskEmitAction(resource);
MS_LOG(DEBUG) << "Start execute action";
(void)ExecuteAction(resource);
UpdateTopCellInfo(false, false);
top_cell()->UpdateTopCellInfo(false, false, true);
resource->Clean();
}
@ -870,7 +856,7 @@ std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights,
(void)w_args.emplace_back(fn(weights_tuple[i]));
}
} else {
MS_LOG(DEBUG) << "No parameter tuple get, add weights params by input weight";
MS_LOG(DEBUG) << "No parameter tuple get, try get weights params by input weight";
if (py::isinstance<py::tuple>(weights) || py::isinstance<py::list>(weights)) {
auto weights_tuple = py::cast<py::tuple>(weights);
for (size_t i = 0; i < weights_tuple.size(); ++i) {
@ -894,6 +880,10 @@ std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_pos
const auto &tuple = grad_position.cast<py::tuple>();
(void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args),
[](const py::handle &elem) { return elem.cast<int64_t>(); });
// Return the gradient;
if (pos_args.empty()) {
MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
}
return pos_args;
}
MS_LOG(EXCEPTION) << "Grad position only support tuple when grad_by_position is set True.";
@ -982,30 +972,33 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
const vector<size_t> &p_args) {
MS_EXCEPTION_IF_NULL(top_input_args_info_);
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr);
MS_EXCEPTION_IF_NULL(bprop_graph);
MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
std::ostringstream ss;
ss << "grad{" << top_input_args_info_->input_arg_value_vec.size() << "}";
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
MS_EXCEPTION_IF_NULL(bprop_graph);
bprop_graph->debug_info()->set_name(ss.str());
UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens);
if (top_cell()->need_do_final_opt()) {
bprop_graph = BpropGraphFinalOpt(bprop_graph);
}
if (top_input_args_info_->is_grad_topest_cell) {
need_renormalize_ = false;
bprop_graph = BpropGraphFinalOpt(bprop_graph, need_renormalize_);
MS_EXCEPTION_IF_NULL(bprop_graph);
}
need_renormalize_ = false;
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_;
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
return bprop_graph;
}
void GradExecutor::SetGradOrder(const std::string &cell_id) {
void GradExecutor::SetGradOrder(const std::string &obj_id) {
// top_cell_ == nullptr means call by grad first
// Args of CheckAlreadyRun may be have sens arg, so cell_id is include top cell id,
// If cell_id.find(top_cell_->cell_id()) == std::string::npos, means current cell is not top cell, may be high order
if (top_cell_ == nullptr || cell_id.find(top_cell_->c_cell_id()) == std::string::npos) {
// top_cell_->obj_id_with_grad_order() include obj_id and grad_order
// If top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos, means current cell is not top cell, grad
// high order come in
if (top_cell_ == nullptr || top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos) {
IncreaseGradOrder();
}
if (!grad_is_running_) {
@ -1016,10 +1009,10 @@ void GradExecutor::SetGradOrder(const std::string &cell_id) {
py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
const py::object &grad_hash_id, const py::args &args) {
auto cell_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
// Check current cell grad order and erase it if in current top cell list
SetGradOrder(cell_id);
SetGradOrder(obj_id);
// Include weight param size and required grad flag
std::string grad_hash_id_str;
if (!py::isinstance<py::none>(grad_hash_id)) {
@ -1036,7 +1029,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
// check whether need to run forward process
bool forward_run = false;
if (input_args_info_stack_.empty() && top_cell_ != nullptr) {
const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
const auto &check_already_run_cell_id = GetAlreadyRunCellId(obj_id);
auto find_top_cell = GetTopCell(check_already_run_cell_id);
if (find_top_cell != nullptr) {
MS_LOG(DEBUG) << "Find already run top cell";
@ -1048,7 +1041,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
}
}
}
MS_LOG(DEBUG) << "Graph have already ran " << forward_run << " top cell id " << cell_id;
MS_LOG(DEBUG) << "Graph have already ran " << forward_run << " top cell id " << obj_id;
return BaseRefToPyData(forward_run);
}
@ -1087,12 +1080,20 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
ClearGradRes();
return;
}
FuncGraphPtr first_grad_fg = cur_run_bprop_graph;
// High grad hit cache
if (!top_cell()->vm_compile()) {
SwitchTopCell();
return;
}
auto first_grad_fg = cur_run_bprop_graph;
if (has_custom_bprop) {
first_grad_fg = curr_g();
MS_LOG(DEBUG) << "Bprop nested";
}
MS_EXCEPTION_IF_NULL(first_grad_fg);
// Because ConvertPrimToPrimPy will change first_grad_fg, when hit bprop graph cache
// resource->func_graph() will be changed, abstract may be nullptr.
first_grad_fg = BasicClone(first_grad_fg);
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
ValuePtrList weights_args;
DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
@ -1162,18 +1163,18 @@ void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const s
// Replace weights param
MS_EXCEPTION_IF_NULL(weights_args);
mindspore::HashSet<std::string> graph_weights_set;
// Weight in graph
mindspore::HashSet<std::string> inner_graph_used_weights_set;
// Weight in inner graph
const auto &fir_graph_parameters = first_grad_fg->parameters();
for (const auto &param : fir_graph_parameters) {
auto weight_tensor = PyNativeAlgo::Common::GetTensorFromParam(param);
if (weight_tensor != nullptr) {
(void)graph_weights_set.emplace(weight_tensor->id());
(void)inner_graph_used_weights_set.emplace(weight_tensor->id());
}
}
for (const auto &weight : inner_graph_info->weight_params) {
// If weight used in graph, but not need get grad by gradnet, so will not process in outer graph
if (graph_weights_set.find(weight.first) == graph_weights_set.end()) {
// If weight used in graph, but not need get grad by gradnet, it will be a valuenode, no need replace
if (inner_graph_used_weights_set.find(weight.first) == inner_graph_used_weights_set.end()) {
continue;
}
const auto it = outer_graph_info->weight_params.find(weight.first);
@ -1351,8 +1352,6 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
auto cnode = curr_g()->NewCNode(inputs);
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
CheckGraphDynamic(cnode, top_cell()->op_index());
top_cell()->IncreaseOpIndex();
return cnode;
}
@ -1505,7 +1504,6 @@ void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_
void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const {
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
const auto fn = [auto_grad_cell_ptr, output_node, cloned_value]() {
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
};
@ -1722,7 +1720,7 @@ CNodePtr GradExecutor::ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_
}
const auto &cnode = curr_g()->NewCNodeInOrder(inputs);
if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
top_cell()->RecordCellBackwardHookOp(GetCurCellOrder(), cnode);
top_cell()->RecordCellBackwardHookOp(top_input_args_info()->obj_order_id, cnode);
}
MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
@ -1745,7 +1743,7 @@ void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const {
graph_executor->SetJitConfig(jit_config_dict);
}
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx,
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t node_idx,
bool is_ms_function_node,
const std::string &graph_phase) const {
MS_EXCEPTION_IF_NULL(cnode);
@ -1760,7 +1758,7 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c
node_info->input_values[i] = GetValueNode(input_node);
} else if (input_node->isa<CNode>()) {
const auto &node_abs = input_node->abstract();
auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash());
auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash(), node_idx);
node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs);
} else {
if (!input_node->isa<Parameter>()) {
@ -1778,11 +1776,12 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c
node_info->graph_phase = graph_phase;
}
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
(void)cell_id_with_dynamic_detect_nodes_[cell_id].emplace_back(node_info);
(void)cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()].emplace_back(node_info);
MS_LOG(DEBUG) << "Save node " << cnode->DebugString() << " firstly, node_idx: " << node_idx
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase;
}
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const {
MS_EXCEPTION_IF_NULL(cnode);
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
@ -1791,15 +1790,28 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx,
return false;
}
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[cell_id];
MS_LOG(DEBUG) << "Check node " << cnode->DebugString() << " node_idx: " << node_idx
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase;
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()];
if (node_idx >= dynamic_nodes.size()) {
MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << " cur node_idx is: " << node_idx
MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << ", cur node_idx is: " << node_idx
<< ", graph is dynamic.";
return true;
}
if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase, top_cell())) {
// 1.Detect ms_function phase
const DynamicDetectNodeInfoPtr &old_node_info = dynamic_nodes[node_idx];
if (is_ms_function_node) {
if (!old_node_info->is_graph_node || graph_phase != old_node_info->graph_phase) {
MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node
<< " new is_graph_node: " << is_ms_function_node << " old graph_phase "
<< old_node_info->graph_phase << " new graph_phase: " << graph_phase;
return true;
}
return false;
}
if (IsDynamicDetectNodeInfoChange(old_node_info, node_idx, cnode, top_cell())) {
MS_LOG(DEBUG) << "Graph is dynamic, node_idx: " << node_idx
<< " is different, cnode: " << cnode->fullname_with_scope();
return true;
@ -1808,7 +1820,7 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx,
return false;
}
void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const {
if (use_dynamic_shape_process_) {
return;
@ -1816,9 +1828,7 @@ void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_i
use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase);
if (use_dynamic_shape_process_) {
MS_LOG(DEBUG) << "Cnode: " << cnode->fullname_with_scope() << ", node_idx: " << node_idx
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase
<< ", use_dynamic_shape_process_: " << use_dynamic_shape_process_;
MS_LOG(DEBUG) << "Set use_dynamic_shape_process_: " << use_dynamic_shape_process_;
cell_id_with_dynamic_detect_nodes_.clear();
}
}

View File

@ -82,6 +82,10 @@ class GradExecutor {
inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
use_dynamic_shape_process_ = use_dynamic_shape_process;
}
inline InputArgsInfoPtr top_input_args_info() const {
MS_EXCEPTION_IF_NULL(top_input_args_info_);
return top_input_args_info_;
}
inline bool need_renormalize() const { return need_renormalize_; }
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
@ -110,15 +114,14 @@ class GradExecutor {
const std::vector<tensor::TensorPtr> &pre_tensors) const;
void ClearRes();
void WorkerJoin() { async_executor_->WorkerJoin(); }
void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false,
void CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node = false,
const std::string &graph_phase = "") const;
private:
ForwardExecutorPtr forward() const;
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
std::string GetCurCellOrder() const;
void SetGradOrder(const std::string &cell_id);
void SetGradOrder(const std::string &obj_id);
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
const CNodePtr &cnode) const;
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
@ -154,7 +157,6 @@ class GradExecutor {
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const;
// Manage resource when run grad process.
bool IsBpropGraph(const std::string &cell_id) const;
@ -163,6 +165,8 @@ class GradExecutor {
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
void UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj, const py::object &out,
const py::args &args);
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
void AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info);
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
@ -188,9 +192,9 @@ class GradExecutor {
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const;
bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
bool IsGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const;
bool grad_flag_{false};
@ -200,16 +204,11 @@ class GradExecutor {
mutable bool use_dynamic_shape_process_{false};
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
int custom_bprop_cell_count_{0};
// Used in sub thread
size_t cell_order_{0};
std::string cur_cell_id_;
size_t obj_order_{0};
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
size_t grad_order_{0};
std::string grad_operation_;
TopCellInfoPtr top_cell_{nullptr};
TopCellInfoPtr pre_top_cell_{nullptr};
InputArgsInfoPtr top_input_args_info_{nullptr};
// Records every cell info for share, regardless of whether need construct grad graph
std::stack<InputArgsInfoPtr> input_args_info_stack_;

View File

@ -230,11 +230,11 @@ void MsFunction::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const G
} else {
top_cell->fg()->add_parameter(param);
param->debug_info()->set_name(param->name());
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
}
(void)new_params.emplace_back(param);
(void)input_nodes->emplace_back(param);
(void)op_run_info->input_value.emplace_back(tensor_value);
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
<< tensor_value->ToString() << ". Its name is: " << param->name();
}

View File

@ -90,6 +90,12 @@ void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
++op_index_;
}
void TopCellInfo::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile) {
need_compile_graph_ = need_compile_graph;
forward_already_run_ = forward_already_run;
vm_compile_ = vm_compile;
}
void TopCellInfo::ClearDeviceMemory() const {
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
auto ms_context = MsContext::GetInstance();
@ -132,6 +138,7 @@ void TopCellInfo::Clear() {
is_init_kpynative_ = false;
need_compile_graph_ = false;
forward_already_run_ = false;
vm_compile_ = false;
op_index_ = 0;
resource_ = nullptr;
fg_ = nullptr;

View File

@ -58,11 +58,11 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
class TopCellInfo {
public:
~TopCellInfo() = default;
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string c_cell_id, std::string cellid,
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string obj_id_with_grad_order, std::string cellid,
std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg)
: is_high_order_top_cell_(is_high_order_top_cell),
grad_order_(grad_order),
c_cell_id_(std::move(c_cell_id)),
obj_id_with_grad_order_(std::move(obj_id_with_grad_order)),
cell_id_(std::move(cellid)),
already_run_cell_id_(std::move(already_run_cell_id)),
resource_(std::move(r)),
@ -81,6 +81,7 @@ class TopCellInfo {
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
inline bool need_compile_graph() const { return need_compile_graph_; }
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
inline bool vm_compile() const { return vm_compile_; }
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
inline bool need_do_final_opt() const { return need_do_final_opt_; }
@ -91,7 +92,7 @@ class TopCellInfo {
}
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
inline const std::string &cell_id() const { return cell_id_; }
inline const std::string &c_cell_id() const { return c_cell_id_; }
inline const std::string &obj_id_with_grad_order() const { return obj_id_with_grad_order_; }
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
inline const std::string &input_args_id() const { return input_args_id_; }
@ -124,10 +125,11 @@ class TopCellInfo {
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
cnode_hash_with_op_index_[node_hash] = op_index;
}
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) const {
inline size_t get_op_index_by_cnode_hash(const size_t node_hash, const size_t node_idx) const {
const auto iter = cnode_hash_with_op_index_.find(node_hash);
if (iter == cnode_hash_with_op_index_.end()) {
MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
MS_LOG(DEBUG) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
return node_idx;
}
return iter->second;
}
@ -136,6 +138,7 @@ class TopCellInfo {
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr &param, bool is_weight = false) const;
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
bool need_save_sub_id = true) const;
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile);
void ClearDeviceMemory() const;
void Clear();
@ -150,11 +153,13 @@ class TopCellInfo {
bool is_init_kpynative_{false};
bool forward_already_run_{false};
bool need_compile_graph_{false};
size_t op_index_{0};
bool vm_compile_{false};
bool is_run_cell_{false};
bool is_high_order_top_cell_{false};
bool need_do_final_opt_{false};
size_t grad_order_{0};
std::string c_cell_id_;
size_t op_index_{0};
std::string obj_id_with_grad_order_;
std::string cell_id_;
std::string already_run_cell_id_;
std::string input_args_id_;

View File

@ -29,6 +29,18 @@
namespace mindspore {
namespace pynative {
namespace PyNativeAlgo {
namespace {
void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
// Clone a new prim
MS_EXCEPTION_IF_NULL(op_run_info);
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
}
}
} // namespace
std::string Common::GetIdByValue(const ValuePtr &v) {
MS_EXCEPTION_IF_NULL(v);
if (v->isa<tensor::Tensor>()) {
@ -517,12 +529,8 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
bool need_convert_input_to_attr = NeedConvertConstInputToAttr(op_run_info, device_target, &input_to_attr);
MS_LOG(DEBUG) << "Need convert input to addr " << need_convert_input_to_attr;
if (need_convert_input_to_attr) {
// Clone a new prim
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
}
// Prim may be changed attr
ClonePrim(op_run_info);
}
const auto &op_prim = op_run_info->op_prim;
@ -538,10 +546,17 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
// Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
ConvertValueToTensor(op_run_info, input_object, index, op_prim);
// -1 indicates input_object is not a dynInput
if (op_prim->HasAttr(kAttrDynInputSizes) && !input_object->isa<ValueSequence>()) {
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
(void)dyn_v.emplace_back(-1);
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
if (op_prim->HasAttr(kAttrDynInputSizes)) {
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
// Like addn, prim define in python, but number of inputs change, so the value of kAttrDynInputSizes
// changed too. In async, do opgrad may be not complete.
ClonePrim(op_run_info);
}
if (!input_object->isa<ValueSequence>()) {
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
(void)dyn_v.emplace_back(-1);
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
}
}
}
op_prim->EndRecordAddAttr();

View File

@ -229,5 +229,5 @@ def bprop_scalar_not(x, out, dout):
@bprops.register("TensorMove")
def bprop_tensor_move(x, out, dout):
"""Backpropagator for primitive `mutable`."""
"""Backpropagator for primitive `TensorMove`."""
return (dout,)

View File

@ -995,6 +995,7 @@ class RandomShuffle(Primitive):
def __init__(self, seed=0, seed2=0):
"""Initialize RandomShuffle"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
self.add_prim_attr("side_effect_hidden", True)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

View File

@ -213,9 +213,9 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
/// Description:
/// Expectation: the python path is right
TEST_F(TestStepParallel, GetPythonPath1) {
OperatorName operator_name = "AllReduce";
const char *operator_name = "AllReduce";
const std::string expect = "mindspore.ops.operations";
auto temp = parallel::GetOpPythonPath(operator_name);
std::string temp = parallel::GetOpPythonPath(operator_name);
ASSERT_EQ(temp, expect);
}
@ -223,9 +223,9 @@ TEST_F(TestStepParallel, GetPythonPath1) {
/// Description:
/// Expectation: the python path is right
TEST_F(TestStepParallel, GetPythonPath2) {
OperatorName operator_name = "Add";
const char *operator_name = "Add";
const std::string expect = "mindspore.ops.operations";
auto temp = parallel::GetOpPythonPath(operator_name);
std::string temp = parallel::GetOpPythonPath(operator_name);
ASSERT_EQ(temp, expect);
}