Fix bprop dynamic

Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
zjun 2020-12-19 14:43:18 +08:00
parent 1f9ff76371
commit 942b6928ab
2 changed files with 33 additions and 55 deletions

View File

@ -1467,13 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad)
});
}
void PynativeExecutor::ClearResidualRes() {
void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
if (top_cell_list_.empty() && !graph_stack_.empty()) {
graph_id_ = 0;
graph_info_map_.clear();
cell_sw_map_.clear();
cell_graph_list_.clear();
top_cell_list_.clear();
std::stack<FuncGraphPtr>().swap(graph_stack_);
}
if (dynamic_cell_) {
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
}
}
@ -1486,8 +1489,8 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
if (grad_order_ == 0 || grad_order_ == 1) {
return top_cell_list_.back().df_builder;
}
if (top_cell_list_.size() < grad_order_) {
MS_LOG(EXCEPTION) << "Get wrong grad order";
if (top_cell_list_.size() < 2) {
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
}
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
// Grad order greater than 2
@ -1517,8 +1520,8 @@ ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) {
if (grad_order_ == 0 || grad_order_ == 1) {
return top_cell_list_.back().resource;
}
if (top_cell_list_.size() < grad_order_) {
MS_LOG(EXCEPTION) << "Get wrong grad order";
if (top_cell_list_.size() < 2) {
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
}
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
// Grad order greater than 2
@ -1718,7 +1721,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
// init resource for constructing forward graph and grad graph
auto g = std::make_shared<FuncGraph>();
curr_g_ = g;
ClearResidualRes();
ClearResidualRes(cell_id);
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
MakeNewTopGraph(cell_id, args, g);
}
@ -2030,10 +2033,6 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
// Set all params(input+weights)
SetGradGraphParams(df_builder, resource, size);
// Clone df_builder and resource at first time
if (CloneDfbuiler(cell_id, df_builder, resource)) {
df_builder = GetDfbuilder(cell_id);
}
// Get params(weights) require derivative
auto w_args = GetWeightsArgs(weights, df_builder);
// Get the parameters items and add the value to args_spec
@ -2239,26 +2238,6 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
return args_spec;
}
bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder,
const ResourcePtr &resource) {
bool is_cloned = false;
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
if (it == top_cell_list_.end()) {
MS_LOG(EXCEPTION) << "Get top cell failed";
}
if (it->bg == nullptr) {
auto cloned_df_newfg = BasicClone(resource->func_graph());
it->bg = cloned_df_newfg;
MS_LOG(DEBUG) << "Cloned df newfg";
is_cloned = false;
} else {
resource->set_func_graph(it->bg);
MS_LOG(DEBUG) << "Used cloned df newfg";
}
return is_cloned;
}
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
FuncGraphPtr top_g = nullptr;
@ -2433,28 +2412,6 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
}
template <typename T>
void MapClear(T *map, const std::string &cell_id) {
for (auto it = map->begin(); it != map->end();) {
if (it->first.find(cell_id) != std::string::npos) {
it = map->erase(it);
} else {
it++;
}
}
}
template <typename T>
void VectorClear(T *vec, const std::string &cell_id) {
for (auto it = vec->begin(); it != vec->end();) {
if (it->cell_id.find(cell_id) != std::string::npos) {
it = vec->erase(it);
} else {
it++;
}
}
}
void PynativeExecutor::Clear(const std::string &cell_id) {
if (cell_id.empty()) {
Clean();

View File

@ -128,6 +128,28 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
private:
PynativeExecutor() = default;
template <typename T>
void MapClear(T *map, const std::string &cell_id) {
for (auto it = map->begin(); it != map->end();) {
if (it->first.find(cell_id) != std::string::npos) {
it = map->erase(it);
} else {
it++;
}
}
}
template <typename T>
void VectorClear(T *vec, const std::string &cell_id) {
for (auto it = vec->begin(); it != vec->end();) {
if (it->cell_id.find(cell_id) != std::string::npos) {
it = vec->erase(it);
} else {
it++;
}
}
}
// Check cell struct
bool IsDynamicCell(const py::object &cell);
std::string GetCellInfo(const py::object &cell);
@ -187,7 +209,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
void ClearResidualRes();
void ClearResidualRes(const std::string &cell_id);
void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
@ -203,7 +225,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::string GetCellId(const py::object &obj, const py::args &args);
std::pair<bool, bool> CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
bool CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder, const ResourcePtr &resource);
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
size_t arg_size, const std::string &cell_id);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);