!37089 init all the opt states in ms_funciton in case it is executed in graph mode

Merge pull request !37089 from wangjun/bug_fix_opt_ms_master
This commit is contained in:
i-robot 2022-07-05 06:09:10 +00:00 committed by Gitee
commit 7f34cdd055
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 8 additions and 0 deletions

View File

@ -250,6 +250,14 @@ class _MindsporeFunctionExecutor:
layout = obj.parameter_layout_dict[param.name]
new_tensor = _load_tensor_by_layout(state.data, layout)
state.set_data(new_tensor, True)
# set data for all optimizer states in case it is executed in graph mode
prefix_list = ["moments", "accum", "moment1", "moment2", "lamb_m", "lamb_v", "mean_grad",
"mean_square", "prev"]
for opt_param in opt_params:
prefix = opt_param.name[:opt_param.name.find(".")]
if opt_param.has_init and (prefix in prefix_list or opt_param.name == "global_step"):
opt_param.init_data()
_pynative_executor.get_top_cell().parameter_layout_dict = obj.parameter_layout_dict
def compile(self, args_list, method_name):