!48828 Fix some bug For PyNative

Merge pull request !48828 from zjun/fix_1302_bug
This commit is contained in:
i-robot 2023-02-16 06:11:41 +00:00 committed by Gitee
commit a8a1d47af2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 90 additions and 41 deletions

View File

@ -43,10 +43,12 @@ enum class SpecialType { kZerosLikeType = 0, kOnesLikeType = 1 };
const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
{SpecialType::kOnesLikeType, prim::kPrimOnesLike}};
const size_t kContainerRatio = 2;
const std::vector<PrimitivePtr> kGradBlackList{
const PrimitiveSet kGradBlackList{
prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient, prim::kPrimUpdateState,
prim::kPrimNPUAllocFloatStatus, prim::kPrimNPUGetFloatStatus, prim::kPrimNPUClearFloatStatus};
const PrimitiveSet kMonadPrim = {prim::kPrimLoad, prim::kPrimDepend, prim::kPrimUpdateState};
mindspore::HashMap<std::string, FuncGraphPtr> pass_grad_graph_;
void ClearDeviceAddress(const ValuePtr &value) {
@ -58,14 +60,7 @@ void ClearDeviceAddress(const ValuePtr &value) {
}
}
bool IsPrimNeedGrad(const PrimitivePtr &prim) {
for (const auto &no_need_grad_prim : kGradBlackList) {
if (IsPrimitiveEquals(prim, no_need_grad_prim)) {
return false;
}
}
return true;
}
inline bool IsPrimNeedGrad(const PrimitivePtr &prim) { return kGradBlackList.find(prim) == kGradBlackList.end(); }
ValueNodePtr CreateValueNodeByClonedValue(const ValuePtr &v) {
MS_EXCEPTION_IF_NULL(v);
@ -159,8 +154,10 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
std::transform(dic_v.begin(), dic_v.end(), std::back_inserter(v_list),
[](const std::pair<ValuePtr, ValuePtr> &elem) { return elem.second; });
return BuildSpecialLikeValue(tape, std::make_shared<ValueTuple>(v_list), type);
} else if (value->isa<None>()) {
return BuildSpecialLikeValue(tape, MakeValue(0), type);
} else {
MS_EXCEPTION(TypeError) << "For value" << value->ToString() << ", the type is not tensor or sequence";
MS_EXCEPTION(TypeError) << "For value " << value->ToString() << ", the type is not tensor or sequence";
}
}
@ -241,6 +238,28 @@ bool IsValidTensorInput(const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(abs);
return abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractSparseTensor>();
}
bool IsMonadPrim(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param) {
MS_EXCEPTION_IF_NULL(prim);
if (kMonadPrim.find(prim) != kMonadPrim.end()) {
MS_LOG(DEBUG) << "Get monad cnode " << cnode->DebugString();
return true;
}
if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
std::vector<AnfNodePtr> inputs{cnode->inputs().begin(), cnode->inputs().end() - 1};
cnode->set_inputs(inputs);
}
MS_EXCEPTION_IF_NULL(grad_param);
// Ms function graph contain monad op
if (grad_param->is_ms_function_graph) {
for (size_t i = 1; i < cnode->size(); ++i) {
cnode->set_input(i, common::AnfAlgo::VisitKernelWithReturnType(cnode->input(i), 0, false,
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
.first);
}
}
return false;
}
} // namespace
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
@ -405,7 +424,7 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
}
// Valuenode, cnode
auto v_node = NewValueNode(grad_param->op_args[i]);
v_node->set_abstract(input_node->abstract());
v_node->set_abstract(input_node->abstract()->Clone());
(void)args_node_list.emplace_back(v_node);
}
bprop_cnode = GetBpropGraphCNode(grad_param, args_node_list, &dout);
@ -459,6 +478,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromExpander(const GradParamPtr &grad_param,
(void)bprop_inputs.emplace_back(*tape_dout);
(void)bprop_inputs.insert(bprop_inputs.cbegin(), NewValueNode(ad_graph));
auto get_bprop = ad_param()->tape_->NewCNode(bprop_inputs);
get_bprop->set_abstract(ad_graph->output()->abstract());
// tape_dout is set by next op
AddUser(*tape_dout, get_bprop, bprop_inputs.size() - 1);
return get_bprop;
@ -500,6 +520,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const GradParamPtr &grad_param, con
(void)bprop_builder_inputs.emplace_back(*tape_dout);
(void)bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
get_bprop = ad_param()->tape_->NewCNode(bprop_builder_inputs);
get_bprop->set_abstract(after_opt_fg->output()->abstract());
// tape_dout is set by next op
AddUser(*tape_dout, get_bprop, bprop_builder_inputs.size() - 1);
return get_bprop;
@ -527,19 +548,25 @@ FuncGraphPtr AutoGradCellImpl::GradFuncGraph(const GradParamPtr &grad_param) {
GradGraphByExpander(grad_param);
// Set dout parameter
auto output_node = grad_param->fg->output();
auto ad_graph_dout = ad_param()->tape_->add_parameter();
ad_graph_dout->set_abstract(output_node->abstract());
if (ad_param()->last_node_ != nullptr) {
ad_param()->anfnode_to_variable_adjoint_.at(output_node)->fn()->UpdateAccumulativeDout(ad_graph_dout);
(void)BackPropagate();
if (kMonadPrim.find(GetCNodePrimitive(ad_param()->last_node_)) != kMonadPrim.end()) {
ad_param()->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(ad_param()->last_node_, 0, false,
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
.first;
}
auto ad_graph_dout = ad_param()->tape_->add_parameter();
ad_graph_dout->set_abstract(ad_param()->last_node_->abstract());
ad_param()->anfnode_to_variable_adjoint_.at(ad_param()->last_node_)->fn()->UpdateAccumulativeDout(ad_graph_dout);
std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)};
std::transform(
grad_param->fg->parameters().begin(), grad_param->fg->parameters().end(), std::back_inserter(outputs),
[this](const AnfNodePtr &param) { return ad_param()->anfnode_to_variable_adjoint_.at(param)->RealDout(); });
(void)BackPropagate();
AnfNodePtrList outputs{NewValueNode(prim::kPrimMakeTuple)};
abstract::AbstractBasePtrList out_abs_list;
for (const auto &node : grad_param->fg->parameters()) {
(void)outputs.emplace_back(ad_param()->anfnode_to_variable_adjoint_.at(node)->RealDout());
(void)out_abs_list.emplace_back(outputs.back()->abstract());
}
auto ad_graph_out = ad_param()->tape_->NewCNode(outputs);
ad_graph_out->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
ad_param()->tape_->set_output(ad_graph_out);
auto ad_graph = ad_param()->tape_;
pynative::PyNativeAlgo::Common::DumpGraphIR("ad_output_graph.ir", ad_graph);
@ -569,13 +596,16 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
if (node == nullptr || !node->isa<CNode>()) {
continue;
}
ad_param()->last_node_ = node;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
}
ad_param()->last_node_ = node;
if (IsMonadPrim(prim, cnode, grad_param)) {
continue;
}
MS_LOG(DEBUG) << "Get cnode " << cnode->DebugString() << ", " << cnode->fullname_with_scope();
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
(void)BuildKNodeForMakeTuple(cnode);
@ -584,10 +614,7 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
(void)BuildKNodeForTupleGetItem(cnode);
continue;
}
if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
std::vector<AnfNodePtr> inputs{cnode->inputs().begin(), cnode->inputs().end() - 1};
cnode->set_inputs(inputs);
}
std::vector<AnfNodePtr> cnode_inputs{std::make_shared<ValueNode>(prim)};
auto op_args = GetInputArgs(grad_param, cnode, &cnode_inputs);
auto k_node = ad_param()->tape_->NewCNode(cnode_inputs);
@ -614,7 +641,14 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
std::vector<CNodePtr> outputs;
auto ret = BpropExpander(&outputs, &ad_param()->users_).Run(input_node);
if (!ret || outputs.empty()) {
MS_LOG(EXCEPTION) << "Expander has no bprop of this node: " << input_node->DebugString();
MS_LOG(DEBUG) << "Expander has no bprop of this node: " << input_node->DebugString();
BuildCustomBpropCNode(input_node, prim, &outputs);
}
if (outputs.empty()) {
MS_LOG(DEBUG) << "Build fake bprop for this node: " << input_node->DebugString();
BuildFakeBpropCNode(input_node, &outputs);
variable_adjoint->set_is_fake_bprop(true);
variable_adjoint->set_fake_prim_name(prim->name());
}
// Create current op node din edge
UpdateNextEdges(variable_adjoint, cnode, outputs, op_args);
@ -656,7 +690,7 @@ ValuePtrList AutoGradCellImpl::GetInputArgs(const GradParamPtr &grad_param, cons
MS_EXCEPTION_IF_NULL(cnode_inputs);
ValuePtrList op_args;
for (size_t i = 1; i < cnode->size(); ++i) {
auto input_node = cnode->input(i);
const auto &input_node = cnode->input(i);
const auto it = ad_param()->anfnode_to_variable_adjoint_.find(input_node);
if (it != ad_param()->anfnode_to_variable_adjoint_.end()) {
(void)cnode_inputs->emplace_back(it->second->k_node());
@ -711,7 +745,8 @@ CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_par
continue;
}
auto v_node = NewValueNode(grad_param->op_args[i]);
v_node->set_abstract(node->abstract());
// Node abstract obj may free, so v node abstract will be not correct
v_node->set_abstract(node->abstract()->Clone());
(void)node_list.emplace_back(v_node);
}
// Set out
@ -1034,9 +1069,7 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
}
}
auto new_din = ad_param()->tape_->NewCNode(inputs);
auto out_value_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(out_value->ToAbstract());
MS_EXCEPTION_IF_NULL(out_value_abs);
new_din->set_abstract(out_value_abs);
new_din->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(out_value->ToAbstract()));
if (index != -1) {
AddUser(fn->fake_dout(), new_din, index);
}
@ -1092,7 +1125,7 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, const Primitive
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim,
std::vector<CNodePtr> *outputs) {
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(DEBUG) << "Build custom bprop: " << prim->name();
MS_LOG(DEBUG) << "Try build custom bprop: " << prim->name();
auto prim_py = prim->cast<PrimitivePyPtr>();
{
py::gil_scoped_acquire gil;

View File

@ -397,6 +397,18 @@ ForwardExecutorPtr GradExecutor::forward() const {
return forward_executor;
}
void GradExecutor::Init() {
if (init_) {
return;
}
#ifdef _MSC_VER
static WinBpropRegister reg;
reg.DoNothing();
MS_LOG(DEBUG) << "Do windows bprop expander register";
#endif
init_ = true;
}
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
if (high_order_stack_.empty()) {
MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
@ -1069,9 +1081,6 @@ void GradExecutor::CheckParamShapeAndType(const ParameterPtr &param_node, const
<< param_node->DebugString();
}
}
if (param_node->debug_info()->name() == "sens" && ir_shape != input_shape) {
need_renormalize_ = true;
}
}
void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args, const FuncGraphPtr &bprop_graph) {
@ -1668,10 +1677,6 @@ bool GradExecutor::FreeUselessTensors(const CNodePtr &cnode, const ValuePtrList
void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode,
const ValuePtr &op_out) const {
MS_EXCEPTION_IF_NULL(op_run_info);
#ifdef _MSC_VER
static WinBpropRegister reg;
reg.DoNothing();
#endif
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
return;

View File

@ -56,6 +56,7 @@ class GradExecutor {
ms_function_(std::make_shared<MsFunction>()),
async_executor_(std::make_shared<AsyncQueue>()) {}
void Init();
std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
};
@ -220,6 +221,7 @@ class GradExecutor {
bool IsGraphDynamic(const AnfNodePtr &anf_node, size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const;
bool init_{false};
bool grad_flag_{false};
bool grad_is_running_{false};
bool need_renormalize_{false};

View File

@ -49,7 +49,15 @@ void SplitString(const std::string &str, std::vector<std::string> *id_vec) {
end = sub_str.find_first_of(colon_delim, begin);
paren_pos = sub_str.find_first_of(angle_bracket_left_delim, begin);
if (paren_pos < end) {
end = sub_str.find_last_of(angle_bracket_right_delim) + 1;
const auto &s = sub_str.substr(begin, end - begin);
auto num = std::count(s.begin(), s.end(), angle_bracket_left_delim);
end = begin;
while (num--) {
end = sub_str.find_first_of(angle_bracket_right_delim, end) + 1;
}
if (end >= sub_str.size()) {
end = std::string::npos;
}
}
}
}

View File

@ -173,6 +173,7 @@ void PyNativeExecutor::Init() {
forward_executor_ = std::make_shared<ForwardExecutor>();
forward_executor_->Init();
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
grad_executor_->Init();
forward_executor_->set_grad_executor(grad_executor_);
}