forked from mindspore-Ecosystem/mindspore
bporp
This commit is contained in:
parent
d40d2467ca
commit
9ccbb87918
|
@ -753,7 +753,7 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i
|
|||
}
|
||||
|
||||
// Save cnode info and build grad graph
|
||||
if (grad()->need_construct_graph()) {
|
||||
if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
|
||||
grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
|
||||
grad()->DoOpGrad(op_exec_info, cnode, out_real);
|
||||
}
|
||||
|
@ -1778,6 +1778,10 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
|
|||
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
|
||||
auto cell_id = GetCellId(cell, args);
|
||||
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
|
||||
// When the cell has custom bprop, in_custom_bprop_cell is lager than 0
|
||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
custom_bprop_cell_count_ += 1;
|
||||
}
|
||||
if (cell_stack_.empty() && top_cell_ != nullptr) {
|
||||
// non-first step
|
||||
if (!top_cell()->IsSubCell(cell_id) && already_run_top_cell_.find(cell_id) != already_run_top_cell_.end()) {
|
||||
|
@ -1910,6 +1914,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
|||
MS_LOG(DEBUG) << "Brop cell no need construct graph";
|
||||
return;
|
||||
}
|
||||
DoGradForCustomBprop(cell, out, args);
|
||||
if ((cell_stack_.size() > 1 && !IsNestedGrad()) || (IsNestedGrad() && cell_stack_.size() != cell_nums())) {
|
||||
PopCellStack();
|
||||
MS_LOG(DEBUG) << "Sub cell no need construct graph";
|
||||
|
@ -1927,6 +1932,46 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
|||
}
|
||||
}
|
||||
|
||||
void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) {
|
||||
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
return;
|
||||
}
|
||||
custom_bprop_cell_count_ -= 1;
|
||||
if (custom_bprop_cell_count_ != 0) {
|
||||
return;
|
||||
}
|
||||
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
|
||||
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name(), py::object());
|
||||
fake_prim->set_hook(bprop_func);
|
||||
const auto &cell_id = GetCellId(cell, args);
|
||||
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
|
||||
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
|
||||
|
||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||
// Three parameters self, out and dout need to be excluded
|
||||
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||
if (inputs_num > args.size()) {
|
||||
MS_LOG(EXCEPTION) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
|
||||
<< args.size() << "]";
|
||||
}
|
||||
|
||||
py::list cell_inputs;
|
||||
for (size_t i = 0; i < inputs_num; i += 1) {
|
||||
cell_inputs.append(args[i]);
|
||||
}
|
||||
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
|
||||
op_exec_info->op_name = fake_prim->name();
|
||||
op_exec_info->py_primitive = fake_prim;
|
||||
op_exec_info->op_inputs = cell_inputs;
|
||||
|
||||
abstract::AbstractBasePtrList args_spec_list;
|
||||
std::vector<int64_t> op_masks;
|
||||
auto cnode = forward()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
|
||||
DoOpGrad(op_exec_info, cnode, out);
|
||||
const std::string out_obj_id = GetId(out);
|
||||
SaveOutputNodeMap(out_obj_id, out, cnode);
|
||||
}
|
||||
|
||||
void GradExecutor::UpdateBpropCellGraph(const py::object &cell, const std::string &cell_id) {
|
||||
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
return;
|
||||
|
|
|
@ -175,6 +175,7 @@ class GradExecutor {
|
|||
bool grad_flag() const { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
bool in_grad_process() const { return in_grad_process_; }
|
||||
bool in_cell_with_custom_bprop_() const {return custom_bprop_cell_count_ > 0;}
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
std::stack<std::string> &cell_stack() { return cell_stack_; }
|
||||
|
@ -249,6 +250,7 @@ class GradExecutor {
|
|||
const std::vector<int64_t> &index) {
|
||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
|
||||
|
||||
private:
|
||||
size_t grad_order_{0};
|
||||
|
@ -256,6 +258,7 @@ class GradExecutor {
|
|||
bool grad_flag_{false};
|
||||
bool in_bprop_process_{false};
|
||||
bool in_grad_process_{false};
|
||||
int custom_bprop_cell_count_{0};
|
||||
bool grad_is_running_{false};
|
||||
|
||||
FuncGraphPtr curr_g_{nullptr};
|
||||
|
@ -287,6 +290,8 @@ class ForwardExecutor {
|
|||
void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
|
||||
void ClearRes();
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
|
||||
private:
|
||||
GradExecutorPtr grad() const;
|
||||
|
@ -296,8 +301,6 @@ class ForwardExecutor {
|
|||
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||
PynativeStatusCode *status);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
|
|
|
@ -288,6 +288,9 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
|
|||
auto primitive_py = primitive->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(primitive_py);
|
||||
this->set_hook(primitive_py->hook());
|
||||
if (primitive_py->HasAttr(kBpropAttrName)) {
|
||||
this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
|
||||
|
|
Loading…
Reference in New Issue