forked from mindspore-Ecosystem/mindspore
!31110 Remove hook changed check in grad operation
Merge pull request !31110 from JoyLvliang/remove_hook_changed_check_in_grad_operation
This commit is contained in:
commit
4576a21f59
|
@ -950,6 +950,19 @@ bool TopCellInfo::IsSubCell(const std::string &cell_id) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void TopCellInfo::CheckSubCellHookChanged() {
|
||||
if (!hook_changed_) {
|
||||
for (const auto &sub_cell : sub_cell_list_) {
|
||||
const auto sub_cell_id = sub_cell.substr(0, sub_cell.find('_'));
|
||||
if (sub_cell_hook_changed_.find(sub_cell_id) != sub_cell_hook_changed_.end()) {
|
||||
hook_changed_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
sub_cell_hook_changed_.clear();
|
||||
}
|
||||
|
||||
void TopCellInfo::ClearDeviceMemory() {
|
||||
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -1738,6 +1751,9 @@ void GradExecutor::SetHookChanged(const py::object &cell) {
|
|||
}
|
||||
}
|
||||
}
|
||||
if (need_construct_graph() && top_cell_ != nullptr) {
|
||||
top_cell_->set_sub_cell_hook_changed(cell_id);
|
||||
}
|
||||
}
|
||||
|
||||
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
|
||||
|
@ -2591,6 +2607,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
|||
if (!cell_stack_.empty()) {
|
||||
(void)GetObjNode(out, GetId(out));
|
||||
}
|
||||
top_cell()->CheckSubCellHookChanged();
|
||||
CheckNeedCompileGraph();
|
||||
}
|
||||
}
|
||||
|
@ -2965,7 +2982,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
|
|||
auto find_top_cell = GetTopCell(check_already_run_cell_id);
|
||||
if (find_top_cell != nullptr) {
|
||||
MS_LOG(DEBUG) << "Find already run top cell";
|
||||
forward_run = find_top_cell->forward_already_run() && !find_top_cell->hook_changed();
|
||||
forward_run = find_top_cell->forward_already_run();
|
||||
auto curr_top_cell = top_cell();
|
||||
set_top_cell(find_top_cell);
|
||||
bool input_args_changed =
|
||||
|
|
|
@ -80,6 +80,7 @@ class TopCellInfo {
|
|||
void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
|
||||
bool hook_changed() const { return hook_changed_; }
|
||||
void set_hook_changed(bool hook_changed) { hook_changed_ = hook_changed; }
|
||||
void set_sub_cell_hook_changed(const std::string &sub_cell) { sub_cell_hook_changed_.emplace(sub_cell); }
|
||||
bool vm_compiled() const { return vm_compiled_; }
|
||||
void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; }
|
||||
bool ms_function_flag() const { return ms_function_flag_; }
|
||||
|
@ -104,6 +105,7 @@ class TopCellInfo {
|
|||
mindspore::HashSet<std::string> &sub_cell_list() { return sub_cell_list_; }
|
||||
std::set<std::string> &forward_op_output_id() { return forward_op_output_id_; }
|
||||
bool IsSubCell(const std::string &cell_id) const;
|
||||
void CheckSubCellHookChanged();
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
|
||||
OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
|
||||
TensorIdWithTensorObject &tensor_id_with_tensor_object() { return tensor_id_with_tensor_object_; }
|
||||
|
@ -143,6 +145,9 @@ class TopCellInfo {
|
|||
std::string grad_operation_;
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||
mindspore::HashSet<std::string> sub_cell_list_;
|
||||
// Record `register hook` or `remove hook` function has been called by sub cell
|
||||
// The record range between the begin and end of top cell.
|
||||
mindspore::HashSet<std::string> sub_cell_hook_changed_;
|
||||
// Record forward output tensor id
|
||||
std::set<std::string> forward_op_output_id_;
|
||||
OpInfoWithTensorId op_info_with_tensor_id_;
|
||||
|
|
Loading…
Reference in New Issue