This commit is contained in:
tanghuikang 2021-03-13 17:08:34 +08:00 committed by chujinjin
parent d40d2467ca
commit 9ccbb87918
3 changed files with 54 additions and 3 deletions

View File

@ -753,7 +753,7 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i
}
// Save cnode info and build grad graph
if (grad()->need_construct_graph()) {
if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
grad()->DoOpGrad(op_exec_info, cnode, out_real);
}
@ -1778,6 +1778,10 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
auto cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
// When the cell has custom bprop, in_custom_bprop_cell is lager than 0
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
custom_bprop_cell_count_ += 1;
}
if (cell_stack_.empty() && top_cell_ != nullptr) {
// non-first step
if (!top_cell()->IsSubCell(cell_id) && already_run_top_cell_.find(cell_id) != already_run_top_cell_.end()) {
@ -1910,6 +1914,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
MS_LOG(DEBUG) << "Brop cell no need construct graph";
return;
}
DoGradForCustomBprop(cell, out, args);
if ((cell_stack_.size() > 1 && !IsNestedGrad()) || (IsNestedGrad() && cell_stack_.size() != cell_nums())) {
PopCellStack();
MS_LOG(DEBUG) << "Sub cell no need construct graph";
@ -1927,6 +1932,46 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
}
}
void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) {
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
return;
}
custom_bprop_cell_count_ -= 1;
if (custom_bprop_cell_count_ != 0) {
return;
}
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name(), py::object());
fake_prim->set_hook(bprop_func);
const auto &cell_id = GetCellId(cell, args);
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
py::object code_obj = py::getattr(bprop_func, "__code__");
// Three parameters self, out and dout need to be excluded
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
if (inputs_num > args.size()) {
MS_LOG(EXCEPTION) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
<< args.size() << "]";
}
py::list cell_inputs;
for (size_t i = 0; i < inputs_num; i += 1) {
cell_inputs.append(args[i]);
}
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
op_exec_info->op_name = fake_prim->name();
op_exec_info->py_primitive = fake_prim;
op_exec_info->op_inputs = cell_inputs;
abstract::AbstractBasePtrList args_spec_list;
std::vector<int64_t> op_masks;
auto cnode = forward()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
DoOpGrad(op_exec_info, cnode, out);
const std::string out_obj_id = GetId(out);
SaveOutputNodeMap(out_obj_id, out, cnode);
}
void GradExecutor::UpdateBpropCellGraph(const py::object &cell, const std::string &cell_id) {
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
return;

View File

@ -175,6 +175,7 @@ class GradExecutor {
bool grad_flag() const { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
bool in_grad_process() const { return in_grad_process_; }
bool in_cell_with_custom_bprop_() const {return custom_bprop_cell_count_ > 0;}
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
std::string GetCellId(const py::object &obj, const py::args &args);
std::stack<std::string> &cell_stack() { return cell_stack_; }
@ -249,6 +250,7 @@ class GradExecutor {
const std::vector<int64_t> &index) {
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
}
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
private:
size_t grad_order_{0};
@ -256,6 +258,7 @@ class GradExecutor {
bool grad_flag_{false};
bool in_bprop_process_{false};
bool in_grad_process_{false};
int custom_bprop_cell_count_{0};
bool grad_is_running_{false};
FuncGraphPtr curr_g_{nullptr};
@ -287,6 +290,8 @@ class ForwardExecutor {
void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
void ClearRes();
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
private:
GradExecutorPtr grad() const;
@ -296,8 +301,6 @@ class ForwardExecutor {
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
PynativeStatusCode *status);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,

View File

@ -288,6 +288,9 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
auto primitive_py = primitive->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(primitive_py);
this->set_hook(primitive_py->hook());
if (primitive_py->HasAttr(kBpropAttrName)) {
this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
}
}
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {