forked from mindspore-Ecosystem/mindspore
!21740 Add the cache tag in the case of dynamic images
Merge pull request !21740 from zjun/Add_Dynamic_Graph_Flag
This commit is contained in:
commit
9490f835c9
|
@ -81,6 +81,7 @@ std::mutex PynativeExecutor::instance_lock_;
|
|||
namespace {
|
||||
const size_t PTR_LEN = 15;
|
||||
const size_t ARG_SIZE = 2;
|
||||
const size_t MAX_TOP_CELL_COUNTS = 20;
|
||||
|
||||
// primitive unable to infer value for constant input in PyNative mode
|
||||
const std::set<std::string> kVmOperators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
|
||||
|
@ -701,6 +702,12 @@ py::object GetDstType(const TypeId &type_id) {
|
|||
MS_EXCEPTION_IF_NULL(value);
|
||||
return py::cast(value);
|
||||
}
|
||||
|
||||
void EnableGraphCache(bool flag) {
|
||||
const auto inst = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(inst);
|
||||
inst->set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, flag);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
py::object RealRunOp(const py::args &args) {
|
||||
|
@ -984,6 +991,8 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|||
|
||||
if (shape->IsDynamic()) {
|
||||
op_exec_info->is_dynamic_shape = true;
|
||||
// Dynamic shape operator in the current top cell, disable backend cache
|
||||
EnableGraphCache(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2057,6 +2066,17 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args,
|
|||
if (grad_order_ == 0) {
|
||||
++grad_order_;
|
||||
}
|
||||
// The number of top cell exceeds MAX_TOP_CELL_COUNTS, delete the last one to keep the maximum length of the list,
|
||||
// disable backend cache
|
||||
if (top_cell_list_.size() >= MAX_TOP_CELL_COUNTS) {
|
||||
EnableGraphCache(false);
|
||||
const auto last_top_cell = top_cell_list_.back();
|
||||
top_cell_list_.pop_back();
|
||||
last_top_cell->Clear();
|
||||
if (already_run_top_cell_.find(last_top_cell->cell_id()) != already_run_top_cell_.end()) {
|
||||
(void)already_run_top_cell_.erase(last_top_cell->cell_id());
|
||||
}
|
||||
}
|
||||
// Create top cell
|
||||
curr_g_ = std::make_shared<FuncGraph>();
|
||||
auto df_builder = std::make_shared<FuncGraph>();
|
||||
|
@ -2531,6 +2551,13 @@ void GradExecutor::CheckNeedCompileGraph() {
|
|||
MS_LOG(DEBUG) << "New all op info : " << new_all_op_info;
|
||||
if (pre_all_op_info != new_all_op_info) {
|
||||
MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
|
||||
// The top cell switches exceeds MAX_TOP_CELL_COUNTS under the control flow, disable backend cache
|
||||
if (top_cell_switch_counts_ >= MAX_TOP_CELL_COUNTS) {
|
||||
EnableGraphCache(false);
|
||||
} else {
|
||||
// Increase top cell switches counts
|
||||
++top_cell_switch_counts_;
|
||||
}
|
||||
EraseTopCellFromTopCellList(pre_top_cell);
|
||||
pre_top_cell->Clear();
|
||||
already_run_top_cell_[top_cell_id] = new_top_cell;
|
||||
|
@ -2782,6 +2809,7 @@ void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
|
|||
void GradExecutor::ClearRes() {
|
||||
MS_LOG(DEBUG) << "Clear grad res";
|
||||
grad_order_ = 0;
|
||||
top_cell_switch_counts_ = 0;
|
||||
grad_flag_ = false;
|
||||
need_renormalize_ = false;
|
||||
grad_is_running_ = false;
|
||||
|
|
|
@ -261,6 +261,7 @@ class GradExecutor {
|
|||
bool grad_is_running_{false};
|
||||
int custom_bprop_cell_count_{0};
|
||||
size_t grad_order_{0};
|
||||
size_t top_cell_switch_counts_{0};
|
||||
|
||||
// The graph phase is used to obtain backend graph that is complied by ms_function
|
||||
std::string graph_phase_;
|
||||
|
|
Loading…
Reference in New Issue