forked from mindspore-Ecosystem/mindspore
bporp
This commit is contained in:
parent
d40d2467ca
commit
9ccbb87918
|
@ -753,7 +753,7 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save cnode info and build grad graph
|
// Save cnode info and build grad graph
|
||||||
if (grad()->need_construct_graph()) {
|
if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) {
|
||||||
grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
|
grad()->SaveOutputNodeMap(obj_id, out_real, cnode);
|
||||||
grad()->DoOpGrad(op_exec_info, cnode, out_real);
|
grad()->DoOpGrad(op_exec_info, cnode, out_real);
|
||||||
}
|
}
|
||||||
|
@ -1778,6 +1778,10 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
|
||||||
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
|
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
|
||||||
auto cell_id = GetCellId(cell, args);
|
auto cell_id = GetCellId(cell, args);
|
||||||
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
|
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
|
||||||
|
// When the cell has custom bprop, in_custom_bprop_cell is lager than 0
|
||||||
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||||
|
custom_bprop_cell_count_ += 1;
|
||||||
|
}
|
||||||
if (cell_stack_.empty() && top_cell_ != nullptr) {
|
if (cell_stack_.empty() && top_cell_ != nullptr) {
|
||||||
// non-first step
|
// non-first step
|
||||||
if (!top_cell()->IsSubCell(cell_id) && already_run_top_cell_.find(cell_id) != already_run_top_cell_.end()) {
|
if (!top_cell()->IsSubCell(cell_id) && already_run_top_cell_.find(cell_id) != already_run_top_cell_.end()) {
|
||||||
|
@ -1910,6 +1914,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
||||||
MS_LOG(DEBUG) << "Brop cell no need construct graph";
|
MS_LOG(DEBUG) << "Brop cell no need construct graph";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
DoGradForCustomBprop(cell, out, args);
|
||||||
if ((cell_stack_.size() > 1 && !IsNestedGrad()) || (IsNestedGrad() && cell_stack_.size() != cell_nums())) {
|
if ((cell_stack_.size() > 1 && !IsNestedGrad()) || (IsNestedGrad() && cell_stack_.size() != cell_nums())) {
|
||||||
PopCellStack();
|
PopCellStack();
|
||||||
MS_LOG(DEBUG) << "Sub cell no need construct graph";
|
MS_LOG(DEBUG) << "Sub cell no need construct graph";
|
||||||
|
@ -1927,6 +1932,46 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) {
|
||||||
|
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
custom_bprop_cell_count_ -= 1;
|
||||||
|
if (custom_bprop_cell_count_ != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
|
||||||
|
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name(), py::object());
|
||||||
|
fake_prim->set_hook(bprop_func);
|
||||||
|
const auto &cell_id = GetCellId(cell, args);
|
||||||
|
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id));
|
||||||
|
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
|
||||||
|
|
||||||
|
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||||
|
// Three parameters self, out and dout need to be excluded
|
||||||
|
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||||
|
if (inputs_num > args.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs["
|
||||||
|
<< args.size() << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
py::list cell_inputs;
|
||||||
|
for (size_t i = 0; i < inputs_num; i += 1) {
|
||||||
|
cell_inputs.append(args[i]);
|
||||||
|
}
|
||||||
|
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
|
||||||
|
op_exec_info->op_name = fake_prim->name();
|
||||||
|
op_exec_info->py_primitive = fake_prim;
|
||||||
|
op_exec_info->op_inputs = cell_inputs;
|
||||||
|
|
||||||
|
abstract::AbstractBasePtrList args_spec_list;
|
||||||
|
std::vector<int64_t> op_masks;
|
||||||
|
auto cnode = forward()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
|
||||||
|
DoOpGrad(op_exec_info, cnode, out);
|
||||||
|
const std::string out_obj_id = GetId(out);
|
||||||
|
SaveOutputNodeMap(out_obj_id, out, cnode);
|
||||||
|
}
|
||||||
|
|
||||||
void GradExecutor::UpdateBpropCellGraph(const py::object &cell, const std::string &cell_id) {
|
void GradExecutor::UpdateBpropCellGraph(const py::object &cell, const std::string &cell_id) {
|
||||||
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -175,6 +175,7 @@ class GradExecutor {
|
||||||
bool grad_flag() const { return grad_flag_; }
|
bool grad_flag() const { return grad_flag_; }
|
||||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||||
bool in_grad_process() const { return in_grad_process_; }
|
bool in_grad_process() const { return in_grad_process_; }
|
||||||
|
bool in_cell_with_custom_bprop_() const {return custom_bprop_cell_count_ > 0;}
|
||||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||||
std::stack<std::string> &cell_stack() { return cell_stack_; }
|
std::stack<std::string> &cell_stack() { return cell_stack_; }
|
||||||
|
@ -249,6 +250,7 @@ class GradExecutor {
|
||||||
const std::vector<int64_t> &index) {
|
const std::vector<int64_t> &index) {
|
||||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
|
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
|
||||||
}
|
}
|
||||||
|
void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t grad_order_{0};
|
size_t grad_order_{0};
|
||||||
|
@ -256,6 +258,7 @@ class GradExecutor {
|
||||||
bool grad_flag_{false};
|
bool grad_flag_{false};
|
||||||
bool in_bprop_process_{false};
|
bool in_bprop_process_{false};
|
||||||
bool in_grad_process_{false};
|
bool in_grad_process_{false};
|
||||||
|
int custom_bprop_cell_count_{0};
|
||||||
bool grad_is_running_{false};
|
bool grad_is_running_{false};
|
||||||
|
|
||||||
FuncGraphPtr curr_g_{nullptr};
|
FuncGraphPtr curr_g_{nullptr};
|
||||||
|
@ -287,6 +290,8 @@ class ForwardExecutor {
|
||||||
void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
|
void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
|
||||||
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
|
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
|
||||||
void ClearRes();
|
void ClearRes();
|
||||||
|
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||||
|
abstract::AbstractBasePtrList *args_spec_list);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GradExecutorPtr grad() const;
|
GradExecutorPtr grad() const;
|
||||||
|
@ -296,8 +301,6 @@ class ForwardExecutor {
|
||||||
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||||
PynativeStatusCode *status);
|
PynativeStatusCode *status);
|
||||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
|
||||||
abstract::AbstractBasePtrList *args_spec_list);
|
|
||||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||||
abstract::AbstractBasePtrList *args_spec_list);
|
abstract::AbstractBasePtrList *args_spec_list);
|
||||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||||
|
|
|
@ -288,6 +288,9 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
|
||||||
auto primitive_py = primitive->cast<PrimitivePyPtr>();
|
auto primitive_py = primitive->cast<PrimitivePyPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(primitive_py);
|
MS_EXCEPTION_IF_NULL(primitive_py);
|
||||||
this->set_hook(primitive_py->hook());
|
this->set_hook(primitive_py->hook());
|
||||||
|
if (primitive_py->HasAttr(kBpropAttrName)) {
|
||||||
|
this->AddAttr(kBpropAttrName, primitive_py->GetAttr(kBpropAttrName));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
|
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
|
||||||
|
|
Loading…
Reference in New Issue