!17596 Reset op info in op_exec_Info when it be cached

From: @joylvliang
Reviewed-by: @ginfung,@chujinjin,@ginfung
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-06-04 16:06:37 +08:00 committed by Gitee
commit 4db1ad903b
2 changed files with 4 additions and 5 deletions

View File

@ -1293,6 +1293,7 @@ void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
input_args_info += "d";
}
// Record op name and index
op_exec_info->op_info.clear();
const auto &curr_op_num = top_cell()->op_num();
op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
top_cell()->all_op_info() += "_" + op_exec_info->op_info;
@ -1955,11 +1956,11 @@ void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const
if (!pre_top_cell->is_dynamic()) {
MS_LOG(DEBUG) << "Top cell " << cell_id << " is not dynamic, no need to run NewGraphInner again";
ResetTopCellInfo(pre_top_cell, args);
PushHighOrderGraphStack(pre_top_cell);
set_top_cell(pre_top_cell);
cached_top_cell_forward_running_ = true;
return;
}
} else if (top_cell()->IsSubCell(cell_id) || cached_top_cell_forward_running_) {
} else if (top_cell()->IsSubCell(cell_id) || GetHighOrderStackSize() >= 1) {
// Sub cell (may be a temporary cell) forward run in cache process.
MS_LOG(DEBUG) << "No need to run NewGraphInner again";
return;
@ -2092,8 +2093,8 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
if (cell_stack_.empty()) {
MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again";
if (top_cell()->is_topest() && cell_id == top_cell()->cell_id()) {
PopHighOrderGraphStack();
set_grad_flag(false);
cached_top_cell_forward_running_ = false;
}
return;
}
@ -2664,7 +2665,6 @@ void GradExecutor::ClearRes() {
grad_flag_ = false;
need_renormalize_ = false;
grad_is_running_ = false;
cached_top_cell_forward_running_ = false;
top_cell_ = nullptr;
curr_g_ = nullptr;
bprop_cell_list_.clear();

View File

@ -256,7 +256,6 @@ class GradExecutor {
bool grad_flag_{false};
bool need_renormalize_{false};
bool grad_is_running_{false};
bool cached_top_cell_forward_running_{false};
int custom_bprop_cell_count_{0};
size_t grad_order_{0};