From e958a8628cf20028dee3e9012a72a2220baf1dfb Mon Sep 17 00:00:00 2001 From: yao_yf Date: Fri, 16 Sep 2022 11:36:58 +0800 Subject: [PATCH] code clean python 220916 --- .../parallel/_parallel_serialization.py | 55 ++++++++++++------- .../parallel/checkpoint_transform.py | 4 +- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/mindspore/python/mindspore/parallel/_parallel_serialization.py b/mindspore/python/mindspore/parallel/_parallel_serialization.py index 39af0cc2893..cdca875643d 100644 --- a/mindspore/python/mindspore/parallel/_parallel_serialization.py +++ b/mindspore/python/mindspore/parallel/_parallel_serialization.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """parallel serialization""" +from __future__ import absolute_import + import os import numpy as np import mindspore as ms @@ -116,8 +118,7 @@ def _restore_group_info_list(group_info_file_name): if not restore_list: raise ValueError("For 'restore_group_info_list', the group information file has no restore rank list.") - restore_rank_list = [rank for rank in restore_list.dim] - return restore_rank_list + return [rank for rank in restore_list.dim] def _get_device_num_from_strategy(strategy_file=None): @@ -144,7 +145,7 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_file=None dst_strategy_list = _convert_to_list(dst_strategy) result_list = set() handled_layout = [] - for param_name, src_strategy in src_strategy_list.items(): + for param_name, _ in src_strategy_list.items(): if dst_strategy_file is not None and param_name not in dst_strategy_list: continue from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item( @@ -247,23 +248,11 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s raise ValueError("The checkpoint of rank {} is missing.".format(rank_id % device_num)) param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout, device_list, rank_id) - for param_rank, _ in param_rank_map.items(): - if from_opt_shard_size > 0: - from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map) - from_slice_tensor_shape = () - for i, item in enumerate(from_full_tensor_shape): - from_slice_tensor_shape += (item // from_tensor_strategy[i],) - param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape))) - if to_opt_shard_size > 0: - to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin) - to_slice_tensor_shape = () - for i, item in enumerate(origin_tensor_shape): - if i == 0 and to_opt_shard_size > 0: - to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),) - continue - to_slice_tensor_shape += (item // to_tensor_strategy[i],) - param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape))) + + from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape) + to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape) + _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple) transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id) _apply_tensor_transform_operators(transform_operator_stack, param_total_dict[param_name], device_num) transform_tensor = ms.Tensor(param_total_dict[param_name][rank_id % device_num]) @@ -320,3 +309,31 @@ def _make_dir(path, arg_name): finally: pass return real_path + + +def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple): + """insert opt_shard op reshape""" + from_opt_shard_size = from_info_tuple[0] + from_dev_matrix = from_info_tuple[1] + from_tensor_map = from_info_tuple[2] + from_full_tensor_shape = from_info_tuple[3] + to_opt_shard_size = to_info_tuple[0] + to_dev_matrix_origin = to_info_tuple[1] + to_tensor_map_origin = to_info_tuple[2] + origin_tensor_shape = to_info_tuple[3] + for param_rank, _ in param_rank_map.items(): + if from_opt_shard_size > 0: + from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map) + from_slice_tensor_shape = () + for i, item in enumerate(from_full_tensor_shape): + from_slice_tensor_shape += (item // from_tensor_strategy[i],) + param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape))) + if to_opt_shard_size > 0: + to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin) + to_slice_tensor_shape = () + for i, item in enumerate(origin_tensor_shape): + if i == 0 and to_opt_shard_size > 0: + to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),) + continue + to_slice_tensor_shape += (item // to_tensor_strategy[i],) + param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape))) diff --git a/mindspore/python/mindspore/parallel/checkpoint_transform.py b/mindspore/python/mindspore/parallel/checkpoint_transform.py index 288f7f13934..a6eaf2bd145 100644 --- a/mindspore/python/mindspore/parallel/checkpoint_transform.py +++ b/mindspore/python/mindspore/parallel/checkpoint_transform.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """Transform distributed checkpoint""" +from __future__ import absolute_import + import os import glob import copy @@ -203,7 +205,7 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, param_total_dict_copy = copy.deepcopy(param_total_dict) transform_param_list = _transform_parallel_checkpoint(transform_rank, param_total_dict_copy, param_attr_dict, src_strategy_file, dst_strategy_file) - save_checkpoint_file = ckpt_prefix + str(transform_rank) + ".ckpt" + save_checkpoint_file = os.path.join(ckpt_prefix, str(transform_rank), ".ckpt") save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank)) if not os.path.exists(save_checkpoint_file_dir): _make_dir(save_checkpoint_file_dir, "path")