!21740 Add the cache tag in the case of dynamic images

Merge pull request !21740 from zjun/Add_Dynamic_Graph_Flag
This commit is contained in:
i-robot 2021-08-13 08:32:06 +00:00 committed by Gitee
commit 9490f835c9
2 changed files with 29 additions and 0 deletions

View File

@ -81,6 +81,7 @@ std::mutex PynativeExecutor::instance_lock_;
namespace {
const size_t PTR_LEN = 15;
const size_t ARG_SIZE = 2;
const size_t MAX_TOP_CELL_COUNTS = 20;
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> kVmOperators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
@ -701,6 +702,12 @@ py::object GetDstType(const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(value);
return py::cast(value);
}
void EnableGraphCache(bool flag) {
const auto inst = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
inst->set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, flag);
}
} // namespace
py::object RealRunOp(const py::args &args) {
@ -984,6 +991,8 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
if (shape->IsDynamic()) {
op_exec_info->is_dynamic_shape = true;
// Dynamic shape operator in the current top cell, disable backend cache
EnableGraphCache(false);
}
}
@ -2057,6 +2066,17 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args,
if (grad_order_ == 0) {
++grad_order_;
}
// The number of top cell exceeds MAX_TOP_CELL_COUNTS, delete the last one to keep the maximum length of the list,
// disable backend cache
if (top_cell_list_.size() >= MAX_TOP_CELL_COUNTS) {
EnableGraphCache(false);
const auto last_top_cell = top_cell_list_.back();
top_cell_list_.pop_back();
last_top_cell->Clear();
if (already_run_top_cell_.find(last_top_cell->cell_id()) != already_run_top_cell_.end()) {
(void)already_run_top_cell_.erase(last_top_cell->cell_id());
}
}
// Create top cell
curr_g_ = std::make_shared<FuncGraph>();
auto df_builder = std::make_shared<FuncGraph>();
@ -2531,6 +2551,13 @@ void GradExecutor::CheckNeedCompileGraph() {
MS_LOG(DEBUG) << "New all op info : " << new_all_op_info;
if (pre_all_op_info != new_all_op_info) {
MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
// The top cell switches exceeds MAX_TOP_CELL_COUNTS under the control flow, disable backend cache
if (top_cell_switch_counts_ >= MAX_TOP_CELL_COUNTS) {
EnableGraphCache(false);
} else {
// Increase top cell switches counts
++top_cell_switch_counts_;
}
EraseTopCellFromTopCellList(pre_top_cell);
pre_top_cell->Clear();
already_run_top_cell_[top_cell_id] = new_top_cell;
@ -2782,6 +2809,7 @@ void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
void GradExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear grad res";
grad_order_ = 0;
top_cell_switch_counts_ = 0;
grad_flag_ = false;
need_renormalize_ = false;
grad_is_running_ = false;

View File

@ -261,6 +261,7 @@ class GradExecutor {
bool grad_is_running_{false};
int custom_bprop_cell_count_{0};
size_t grad_order_{0};
size_t top_cell_switch_counts_{0};
// The graph phase is used to obtain backend graph that is complied by ms_function
std::string graph_phase_;