!46609 Fix function grad resut is not right

Merge pull request !46609 from zjun/fix_bug4_alpha
This commit is contained in:
i-robot 2022-12-09 03:15:35 +00:00 committed by Gitee
commit 260f492b02
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 14 additions and 2 deletions

View File

@ -740,9 +740,12 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info
auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id); auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
MS_EXCEPTION_IF_NULL(pre_top_cell); MS_EXCEPTION_IF_NULL(pre_top_cell);
if (input_args_info->use_dynamic_shape_process) { // In high order situations, the internal top cell has changed, but outer top cell remains unchanged. Then outer
// bprop graph need compile again
if (input_args_info->use_dynamic_shape_process || new_top_cell->force_top_cell_compile()) {
// Function need compile every time. // Function need compile every time.
MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again"; input_args_info->use_dynamic_shape_process ? MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again"
: MS_LOG(DEBUG) << "Force outer graph compile graph";
{ {
py::gil_scoped_acquire acquire; py::gil_scoped_acquire acquire;
EraseTopCellFromTopCellList(pre_top_cell); EraseTopCellFromTopCellList(pre_top_cell);
@ -1219,6 +1222,10 @@ void GradExecutor::SwitchTopCell() {
// Get outer top cell // Get outer top cell
auto outer_top_cell = PopHighOrderGraphStack(); auto outer_top_cell = PopHighOrderGraphStack();
MS_EXCEPTION_IF_NULL(outer_top_cell); MS_EXCEPTION_IF_NULL(outer_top_cell);
// If inner graph compile graph, outer must be compile
if (top_cell()->vm_compile()) {
outer_top_cell->set_force_top_cell_compile(true);
}
set_top_cell(outer_top_cell); set_top_cell(outer_top_cell);
} }

View File

@ -82,6 +82,10 @@ class TopCellInfo {
inline bool need_compile_graph() const { return need_compile_graph_; } inline bool need_compile_graph() const { return need_compile_graph_; }
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; } inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
inline bool vm_compile() const { return vm_compile_; } inline bool vm_compile() const { return vm_compile_; }
inline void set_force_top_cell_compile(bool force_top_cell_compile) {
force_top_cell_compile_ = force_top_cell_compile;
}
inline bool force_top_cell_compile() const { return force_top_cell_compile_; }
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; } inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; } inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
inline bool need_do_final_opt() const { return need_do_final_opt_; } inline bool need_do_final_opt() const { return need_do_final_opt_; }
@ -155,6 +159,7 @@ class TopCellInfo {
bool forward_already_run_{false}; bool forward_already_run_{false};
bool need_compile_graph_{false}; bool need_compile_graph_{false};
bool vm_compile_{false}; bool vm_compile_{false};
bool force_top_cell_compile_{false};
bool is_high_order_top_cell_{false}; bool is_high_order_top_cell_{false};
bool need_do_final_opt_{false}; bool need_do_final_opt_{false};
size_t op_index_{0}; size_t op_index_{0};