forked from mindspore-Ecosystem/mindspore
!48828 Fix some bug For PyNative
Merge pull request !48828 from zjun/fix_1302_bug
This commit is contained in:
commit
a8a1d47af2
|
@ -43,10 +43,12 @@ enum class SpecialType { kZerosLikeType = 0, kOnesLikeType = 1 };
|
|||
const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
|
||||
{SpecialType::kOnesLikeType, prim::kPrimOnesLike}};
|
||||
const size_t kContainerRatio = 2;
|
||||
const std::vector<PrimitivePtr> kGradBlackList{
|
||||
const PrimitiveSet kGradBlackList{
|
||||
prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient, prim::kPrimUpdateState,
|
||||
prim::kPrimNPUAllocFloatStatus, prim::kPrimNPUGetFloatStatus, prim::kPrimNPUClearFloatStatus};
|
||||
|
||||
const PrimitiveSet kMonadPrim = {prim::kPrimLoad, prim::kPrimDepend, prim::kPrimUpdateState};
|
||||
|
||||
mindspore::HashMap<std::string, FuncGraphPtr> pass_grad_graph_;
|
||||
|
||||
void ClearDeviceAddress(const ValuePtr &value) {
|
||||
|
@ -58,14 +60,7 @@ void ClearDeviceAddress(const ValuePtr &value) {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsPrimNeedGrad(const PrimitivePtr &prim) {
|
||||
for (const auto &no_need_grad_prim : kGradBlackList) {
|
||||
if (IsPrimitiveEquals(prim, no_need_grad_prim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
inline bool IsPrimNeedGrad(const PrimitivePtr &prim) { return kGradBlackList.find(prim) == kGradBlackList.end(); }
|
||||
|
||||
ValueNodePtr CreateValueNodeByClonedValue(const ValuePtr &v) {
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
|
@ -159,8 +154,10 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
|
|||
std::transform(dic_v.begin(), dic_v.end(), std::back_inserter(v_list),
|
||||
[](const std::pair<ValuePtr, ValuePtr> &elem) { return elem.second; });
|
||||
return BuildSpecialLikeValue(tape, std::make_shared<ValueTuple>(v_list), type);
|
||||
} else if (value->isa<None>()) {
|
||||
return BuildSpecialLikeValue(tape, MakeValue(0), type);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For value" << value->ToString() << ", the type is not tensor or sequence";
|
||||
MS_EXCEPTION(TypeError) << "For value " << value->ToString() << ", the type is not tensor or sequence";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,6 +238,28 @@ bool IsValidTensorInput(const abstract::AbstractBasePtr &abs) {
|
|||
MS_EXCEPTION_IF_NULL(abs);
|
||||
return abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractSparseTensor>();
|
||||
}
|
||||
|
||||
bool IsMonadPrim(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (kMonadPrim.find(prim) != kMonadPrim.end()) {
|
||||
MS_LOG(DEBUG) << "Get monad cnode " << cnode->DebugString();
|
||||
return true;
|
||||
}
|
||||
if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
|
||||
std::vector<AnfNodePtr> inputs{cnode->inputs().begin(), cnode->inputs().end() - 1};
|
||||
cnode->set_inputs(inputs);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
// Ms function graph contain monad op
|
||||
if (grad_param->is_ms_function_graph) {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
cnode->set_input(i, common::AnfAlgo::VisitKernelWithReturnType(cnode->input(i), 0, false,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
|
||||
.first);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
|
||||
|
@ -405,7 +424,7 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
}
|
||||
// Valuenode, cnode
|
||||
auto v_node = NewValueNode(grad_param->op_args[i]);
|
||||
v_node->set_abstract(input_node->abstract());
|
||||
v_node->set_abstract(input_node->abstract()->Clone());
|
||||
(void)args_node_list.emplace_back(v_node);
|
||||
}
|
||||
bprop_cnode = GetBpropGraphCNode(grad_param, args_node_list, &dout);
|
||||
|
@ -459,6 +478,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromExpander(const GradParamPtr &grad_param,
|
|||
(void)bprop_inputs.emplace_back(*tape_dout);
|
||||
(void)bprop_inputs.insert(bprop_inputs.cbegin(), NewValueNode(ad_graph));
|
||||
auto get_bprop = ad_param()->tape_->NewCNode(bprop_inputs);
|
||||
get_bprop->set_abstract(ad_graph->output()->abstract());
|
||||
// tape_dout is set by next op
|
||||
AddUser(*tape_dout, get_bprop, bprop_inputs.size() - 1);
|
||||
return get_bprop;
|
||||
|
@ -500,6 +520,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const GradParamPtr &grad_param, con
|
|||
(void)bprop_builder_inputs.emplace_back(*tape_dout);
|
||||
(void)bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
|
||||
get_bprop = ad_param()->tape_->NewCNode(bprop_builder_inputs);
|
||||
get_bprop->set_abstract(after_opt_fg->output()->abstract());
|
||||
// tape_dout is set by next op
|
||||
AddUser(*tape_dout, get_bprop, bprop_builder_inputs.size() - 1);
|
||||
return get_bprop;
|
||||
|
@ -527,19 +548,25 @@ FuncGraphPtr AutoGradCellImpl::GradFuncGraph(const GradParamPtr &grad_param) {
|
|||
GradGraphByExpander(grad_param);
|
||||
|
||||
// Set dout parameter
|
||||
auto output_node = grad_param->fg->output();
|
||||
auto ad_graph_dout = ad_param()->tape_->add_parameter();
|
||||
ad_graph_dout->set_abstract(output_node->abstract());
|
||||
if (ad_param()->last_node_ != nullptr) {
|
||||
ad_param()->anfnode_to_variable_adjoint_.at(output_node)->fn()->UpdateAccumulativeDout(ad_graph_dout);
|
||||
(void)BackPropagate();
|
||||
if (kMonadPrim.find(GetCNodePrimitive(ad_param()->last_node_)) != kMonadPrim.end()) {
|
||||
ad_param()->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(ad_param()->last_node_, 0, false,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
|
||||
.first;
|
||||
}
|
||||
auto ad_graph_dout = ad_param()->tape_->add_parameter();
|
||||
ad_graph_dout->set_abstract(ad_param()->last_node_->abstract());
|
||||
ad_param()->anfnode_to_variable_adjoint_.at(ad_param()->last_node_)->fn()->UpdateAccumulativeDout(ad_graph_dout);
|
||||
|
||||
std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::transform(
|
||||
grad_param->fg->parameters().begin(), grad_param->fg->parameters().end(), std::back_inserter(outputs),
|
||||
[this](const AnfNodePtr ¶m) { return ad_param()->anfnode_to_variable_adjoint_.at(param)->RealDout(); });
|
||||
(void)BackPropagate();
|
||||
|
||||
AnfNodePtrList outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
abstract::AbstractBasePtrList out_abs_list;
|
||||
for (const auto &node : grad_param->fg->parameters()) {
|
||||
(void)outputs.emplace_back(ad_param()->anfnode_to_variable_adjoint_.at(node)->RealDout());
|
||||
(void)out_abs_list.emplace_back(outputs.back()->abstract());
|
||||
}
|
||||
auto ad_graph_out = ad_param()->tape_->NewCNode(outputs);
|
||||
ad_graph_out->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
|
||||
ad_param()->tape_->set_output(ad_graph_out);
|
||||
auto ad_graph = ad_param()->tape_;
|
||||
pynative::PyNativeAlgo::Common::DumpGraphIR("ad_output_graph.ir", ad_graph);
|
||||
|
@ -569,13 +596,16 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
|
|||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
ad_param()->last_node_ = node;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
|
||||
}
|
||||
ad_param()->last_node_ = node;
|
||||
if (IsMonadPrim(prim, cnode, grad_param)) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get cnode " << cnode->DebugString() << ", " << cnode->fullname_with_scope();
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
|
||||
(void)BuildKNodeForMakeTuple(cnode);
|
||||
|
@ -584,10 +614,7 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
|
|||
(void)BuildKNodeForTupleGetItem(cnode);
|
||||
continue;
|
||||
}
|
||||
if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
|
||||
std::vector<AnfNodePtr> inputs{cnode->inputs().begin(), cnode->inputs().end() - 1};
|
||||
cnode->set_inputs(inputs);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> cnode_inputs{std::make_shared<ValueNode>(prim)};
|
||||
auto op_args = GetInputArgs(grad_param, cnode, &cnode_inputs);
|
||||
auto k_node = ad_param()->tape_->NewCNode(cnode_inputs);
|
||||
|
@ -614,7 +641,14 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
|
|||
std::vector<CNodePtr> outputs;
|
||||
auto ret = BpropExpander(&outputs, &ad_param()->users_).Run(input_node);
|
||||
if (!ret || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Expander has no bprop of this node: " << input_node->DebugString();
|
||||
MS_LOG(DEBUG) << "Expander has no bprop of this node: " << input_node->DebugString();
|
||||
BuildCustomBpropCNode(input_node, prim, &outputs);
|
||||
}
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "Build fake bprop for this node: " << input_node->DebugString();
|
||||
BuildFakeBpropCNode(input_node, &outputs);
|
||||
variable_adjoint->set_is_fake_bprop(true);
|
||||
variable_adjoint->set_fake_prim_name(prim->name());
|
||||
}
|
||||
// Create current op node din edge
|
||||
UpdateNextEdges(variable_adjoint, cnode, outputs, op_args);
|
||||
|
@ -656,7 +690,7 @@ ValuePtrList AutoGradCellImpl::GetInputArgs(const GradParamPtr &grad_param, cons
|
|||
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
||||
ValuePtrList op_args;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto input_node = cnode->input(i);
|
||||
const auto &input_node = cnode->input(i);
|
||||
const auto it = ad_param()->anfnode_to_variable_adjoint_.find(input_node);
|
||||
if (it != ad_param()->anfnode_to_variable_adjoint_.end()) {
|
||||
(void)cnode_inputs->emplace_back(it->second->k_node());
|
||||
|
@ -711,7 +745,8 @@ CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_par
|
|||
continue;
|
||||
}
|
||||
auto v_node = NewValueNode(grad_param->op_args[i]);
|
||||
v_node->set_abstract(node->abstract());
|
||||
// Node abstract obj may free, so v node abstract will be not correct
|
||||
v_node->set_abstract(node->abstract()->Clone());
|
||||
(void)node_list.emplace_back(v_node);
|
||||
}
|
||||
// Set out
|
||||
|
@ -1034,9 +1069,7 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
|||
}
|
||||
}
|
||||
auto new_din = ad_param()->tape_->NewCNode(inputs);
|
||||
auto out_value_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(out_value->ToAbstract());
|
||||
MS_EXCEPTION_IF_NULL(out_value_abs);
|
||||
new_din->set_abstract(out_value_abs);
|
||||
new_din->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(out_value->ToAbstract()));
|
||||
if (index != -1) {
|
||||
AddUser(fn->fake_dout(), new_din, index);
|
||||
}
|
||||
|
@ -1092,7 +1125,7 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, const Primitive
|
|||
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim,
|
||||
std::vector<CNodePtr> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_LOG(DEBUG) << "Build custom bprop: " << prim->name();
|
||||
MS_LOG(DEBUG) << "Try build custom bprop: " << prim->name();
|
||||
auto prim_py = prim->cast<PrimitivePyPtr>();
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
|
|
|
@ -397,6 +397,18 @@ ForwardExecutorPtr GradExecutor::forward() const {
|
|||
return forward_executor;
|
||||
}
|
||||
|
||||
void GradExecutor::Init() {
|
||||
if (init_) {
|
||||
return;
|
||||
}
|
||||
#ifdef _MSC_VER
|
||||
static WinBpropRegister reg;
|
||||
reg.DoNothing();
|
||||
MS_LOG(DEBUG) << "Do windows bprop expander register";
|
||||
#endif
|
||||
init_ = true;
|
||||
}
|
||||
|
||||
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
|
||||
if (high_order_stack_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
|
||||
|
@ -1069,9 +1081,6 @@ void GradExecutor::CheckParamShapeAndType(const ParameterPtr ¶m_node, const
|
|||
<< param_node->DebugString();
|
||||
}
|
||||
}
|
||||
if (param_node->debug_info()->name() == "sens" && ir_shape != input_shape) {
|
||||
need_renormalize_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args, const FuncGraphPtr &bprop_graph) {
|
||||
|
@ -1668,10 +1677,6 @@ bool GradExecutor::FreeUselessTensors(const CNodePtr &cnode, const ValuePtrList
|
|||
void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode,
|
||||
const ValuePtr &op_out) const {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
#ifdef _MSC_VER
|
||||
static WinBpropRegister reg;
|
||||
reg.DoNothing();
|
||||
#endif
|
||||
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
|
||||
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
|
||||
return;
|
||||
|
|
|
@ -56,6 +56,7 @@ class GradExecutor {
|
|||
ms_function_(std::make_shared<MsFunction>()),
|
||||
async_executor_(std::make_shared<AsyncQueue>()) {}
|
||||
|
||||
void Init();
|
||||
std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
|
||||
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
|
||||
};
|
||||
|
@ -220,6 +221,7 @@ class GradExecutor {
|
|||
bool IsGraphDynamic(const AnfNodePtr &anf_node, size_t node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const;
|
||||
|
||||
bool init_{false};
|
||||
bool grad_flag_{false};
|
||||
bool grad_is_running_{false};
|
||||
bool need_renormalize_{false};
|
||||
|
|
|
@ -49,7 +49,15 @@ void SplitString(const std::string &str, std::vector<std::string> *id_vec) {
|
|||
end = sub_str.find_first_of(colon_delim, begin);
|
||||
paren_pos = sub_str.find_first_of(angle_bracket_left_delim, begin);
|
||||
if (paren_pos < end) {
|
||||
end = sub_str.find_last_of(angle_bracket_right_delim) + 1;
|
||||
const auto &s = sub_str.substr(begin, end - begin);
|
||||
auto num = std::count(s.begin(), s.end(), angle_bracket_left_delim);
|
||||
end = begin;
|
||||
while (num--) {
|
||||
end = sub_str.find_first_of(angle_bracket_right_delim, end) + 1;
|
||||
}
|
||||
if (end >= sub_str.size()) {
|
||||
end = std::string::npos;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -173,6 +173,7 @@ void PyNativeExecutor::Init() {
|
|||
forward_executor_ = std::make_shared<ForwardExecutor>();
|
||||
forward_executor_->Init();
|
||||
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
|
||||
grad_executor_->Init();
|
||||
forward_executor_->set_grad_executor(grad_executor_);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue