forked from mindspore-Ecosystem/mindspore
Fix bprop dynamic
Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
parent
1f9ff76371
commit
942b6928ab
|
@ -1467,13 +1467,16 @@ bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad)
|
|||
});
|
||||
}
|
||||
|
||||
void PynativeExecutor::ClearResidualRes() {
|
||||
void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
|
||||
if (top_cell_list_.empty() && !graph_stack_.empty()) {
|
||||
graph_id_ = 0;
|
||||
graph_info_map_.clear();
|
||||
cell_sw_map_.clear();
|
||||
cell_graph_list_.clear();
|
||||
top_cell_list_.clear();
|
||||
std::stack<FuncGraphPtr>().swap(graph_stack_);
|
||||
}
|
||||
if (dynamic_cell_) {
|
||||
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1486,8 +1489,8 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
|
|||
if (grad_order_ == 0 || grad_order_ == 1) {
|
||||
return top_cell_list_.back().df_builder;
|
||||
}
|
||||
if (top_cell_list_.size() < grad_order_) {
|
||||
MS_LOG(EXCEPTION) << "Get wrong grad order";
|
||||
if (top_cell_list_.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
|
||||
// Grad order greater than 2
|
||||
|
@ -1517,8 +1520,8 @@ ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) {
|
|||
if (grad_order_ == 0 || grad_order_ == 1) {
|
||||
return top_cell_list_.back().resource;
|
||||
}
|
||||
if (top_cell_list_.size() < grad_order_) {
|
||||
MS_LOG(EXCEPTION) << "Get wrong grad order";
|
||||
if (top_cell_list_.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Top cell list size must greater than 2";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
|
||||
// Grad order greater than 2
|
||||
|
@ -1718,7 +1721,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
// init resource for constructing forward graph and grad graph
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
curr_g_ = g;
|
||||
ClearResidualRes();
|
||||
ClearResidualRes(cell_id);
|
||||
if (graph_stack_.empty() && !IsBpropGraph(cell_id)) {
|
||||
MakeNewTopGraph(cell_id, args, g);
|
||||
}
|
||||
|
@ -2030,10 +2033,6 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
|||
|
||||
// Set all params(input+weights)
|
||||
SetGradGraphParams(df_builder, resource, size);
|
||||
// Clone df_builder and resource at first time
|
||||
if (CloneDfbuiler(cell_id, df_builder, resource)) {
|
||||
df_builder = GetDfbuilder(cell_id);
|
||||
}
|
||||
// Get params(weights) require derivative
|
||||
auto w_args = GetWeightsArgs(weights, df_builder);
|
||||
// Get the parameters items and add the value to args_spec
|
||||
|
@ -2239,26 +2238,6 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
|
|||
return args_spec;
|
||||
}
|
||||
|
||||
bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder,
|
||||
const ResourcePtr &resource) {
|
||||
bool is_cloned = false;
|
||||
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
|
||||
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
|
||||
if (it == top_cell_list_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Get top cell failed";
|
||||
}
|
||||
if (it->bg == nullptr) {
|
||||
auto cloned_df_newfg = BasicClone(resource->func_graph());
|
||||
it->bg = cloned_df_newfg;
|
||||
MS_LOG(DEBUG) << "Cloned df newfg";
|
||||
is_cloned = false;
|
||||
} else {
|
||||
resource->set_func_graph(it->bg);
|
||||
MS_LOG(DEBUG) << "Used cloned df newfg";
|
||||
}
|
||||
return is_cloned;
|
||||
}
|
||||
|
||||
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
|
||||
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
|
||||
FuncGraphPtr top_g = nullptr;
|
||||
|
@ -2433,28 +2412,6 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
|
|||
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MapClear(T *map, const std::string &cell_id) {
|
||||
for (auto it = map->begin(); it != map->end();) {
|
||||
if (it->first.find(cell_id) != std::string::npos) {
|
||||
it = map->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VectorClear(T *vec, const std::string &cell_id) {
|
||||
for (auto it = vec->begin(); it != vec->end();) {
|
||||
if (it->cell_id.find(cell_id) != std::string::npos) {
|
||||
it = vec->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PynativeExecutor::Clear(const std::string &cell_id) {
|
||||
if (cell_id.empty()) {
|
||||
Clean();
|
||||
|
|
|
@ -128,6 +128,28 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
private:
|
||||
PynativeExecutor() = default;
|
||||
|
||||
template <typename T>
|
||||
void MapClear(T *map, const std::string &cell_id) {
|
||||
for (auto it = map->begin(); it != map->end();) {
|
||||
if (it->first.find(cell_id) != std::string::npos) {
|
||||
it = map->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VectorClear(T *vec, const std::string &cell_id) {
|
||||
for (auto it = vec->begin(); it != vec->end();) {
|
||||
if (it->cell_id.find(cell_id) != std::string::npos) {
|
||||
it = vec->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check cell struct
|
||||
bool IsDynamicCell(const py::object &cell);
|
||||
std::string GetCellInfo(const py::object &cell);
|
||||
|
@ -187,7 +209,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
|
||||
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
||||
bool need_cloned = false, bool is_grad = false);
|
||||
void ClearResidualRes();
|
||||
void ClearResidualRes(const std::string &cell_id);
|
||||
void NewGraphInner(const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
|
||||
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
|
||||
|
@ -203,7 +225,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
std::pair<bool, bool> CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
|
||||
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
|
||||
bool CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder, const ResourcePtr &resource);
|
||||
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size, const std::string &cell_id);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
|
|
Loading…
Reference in New Issue