code clean python 220916

This commit is contained in:
yao_yf 2022-09-16 11:36:58 +08:00
parent f0ddbe1e01
commit e958a8628c
2 changed files with 39 additions and 20 deletions

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""parallel serialization"""
from __future__ import absolute_import
import os
import numpy as np
import mindspore as ms
@ -116,8 +118,7 @@ def _restore_group_info_list(group_info_file_name):
if not restore_list:
raise ValueError("For 'restore_group_info_list', the group information file has no restore rank list.")
restore_rank_list = [rank for rank in restore_list.dim]
return restore_rank_list
return [rank for rank in restore_list.dim]
def _get_device_num_from_strategy(strategy_file=None):
@ -144,7 +145,7 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_file=None
dst_strategy_list = _convert_to_list(dst_strategy)
result_list = set()
handled_layout = []
for param_name, src_strategy in src_strategy_list.items():
for param_name, _ in src_strategy_list.items():
if dst_strategy_file is not None and param_name not in dst_strategy_list:
continue
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
@ -247,23 +248,11 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
raise ValueError("The checkpoint of rank {} is missing.".format(rank_id % device_num))
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
device_list, rank_id)
for param_rank, _ in param_rank_map.items():
if from_opt_shard_size > 0:
from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
from_slice_tensor_shape = ()
for i, item in enumerate(from_full_tensor_shape):
from_slice_tensor_shape += (item // from_tensor_strategy[i],)
param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
if to_opt_shard_size > 0:
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
to_slice_tensor_shape = ()
for i, item in enumerate(origin_tensor_shape):
if i == 0 and to_opt_shard_size > 0:
to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
continue
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
_apply_tensor_transform_operators(transform_operator_stack, param_total_dict[param_name], device_num)
transform_tensor = ms.Tensor(param_total_dict[param_name][rank_id % device_num])
@ -320,3 +309,31 @@ def _make_dir(path, arg_name):
finally:
pass
return real_path
def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
"""insert opt_shard op reshape"""
from_opt_shard_size = from_info_tuple[0]
from_dev_matrix = from_info_tuple[1]
from_tensor_map = from_info_tuple[2]
from_full_tensor_shape = from_info_tuple[3]
to_opt_shard_size = to_info_tuple[0]
to_dev_matrix_origin = to_info_tuple[1]
to_tensor_map_origin = to_info_tuple[2]
origin_tensor_shape = to_info_tuple[3]
for param_rank, _ in param_rank_map.items():
if from_opt_shard_size > 0:
from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
from_slice_tensor_shape = ()
for i, item in enumerate(from_full_tensor_shape):
from_slice_tensor_shape += (item // from_tensor_strategy[i],)
param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
if to_opt_shard_size > 0:
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
to_slice_tensor_shape = ()
for i, item in enumerate(origin_tensor_shape):
if i == 0 and to_opt_shard_size > 0:
to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
continue
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Transform distributed checkpoint"""
from __future__ import absolute_import
import os
import glob
import copy
@ -203,7 +205,7 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
param_total_dict_copy = copy.deepcopy(param_total_dict)
transform_param_list = _transform_parallel_checkpoint(transform_rank, param_total_dict_copy,
param_attr_dict, src_strategy_file, dst_strategy_file)
save_checkpoint_file = ckpt_prefix + str(transform_rank) + ".ckpt"
save_checkpoint_file = os.path.join(ckpt_prefix, str(transform_rank), ".ckpt")
save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank))
if not os.path.exists(save_checkpoint_file_dir):
_make_dir(save_checkpoint_file_dir, "path")