code clean python 220916
This commit is contained in:
parent
f0ddbe1e01
commit
e958a8628c
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""parallel serialization"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
|
@ -116,8 +118,7 @@ def _restore_group_info_list(group_info_file_name):
|
|||
if not restore_list:
|
||||
raise ValueError("For 'restore_group_info_list', the group information file has no restore rank list.")
|
||||
|
||||
restore_rank_list = [rank for rank in restore_list.dim]
|
||||
return restore_rank_list
|
||||
return [rank for rank in restore_list.dim]
|
||||
|
||||
|
||||
def _get_device_num_from_strategy(strategy_file=None):
|
||||
|
@ -144,7 +145,7 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_file=None
|
|||
dst_strategy_list = _convert_to_list(dst_strategy)
|
||||
result_list = set()
|
||||
handled_layout = []
|
||||
for param_name, src_strategy in src_strategy_list.items():
|
||||
for param_name, _ in src_strategy_list.items():
|
||||
if dst_strategy_file is not None and param_name not in dst_strategy_list:
|
||||
continue
|
||||
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size = _extract_layout_item(
|
||||
|
@ -247,23 +248,11 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|||
raise ValueError("The checkpoint of rank {} is missing.".format(rank_id % device_num))
|
||||
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
||||
device_list, rank_id)
|
||||
for param_rank, _ in param_rank_map.items():
|
||||
if from_opt_shard_size > 0:
|
||||
from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
||||
from_slice_tensor_shape = ()
|
||||
for i, item in enumerate(from_full_tensor_shape):
|
||||
from_slice_tensor_shape += (item // from_tensor_strategy[i],)
|
||||
param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
|
||||
if to_opt_shard_size > 0:
|
||||
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
|
||||
to_slice_tensor_shape = ()
|
||||
for i, item in enumerate(origin_tensor_shape):
|
||||
if i == 0 and to_opt_shard_size > 0:
|
||||
to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
|
||||
continue
|
||||
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
|
||||
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
|
||||
|
||||
|
||||
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
||||
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
||||
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
||||
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
|
||||
_apply_tensor_transform_operators(transform_operator_stack, param_total_dict[param_name], device_num)
|
||||
transform_tensor = ms.Tensor(param_total_dict[param_name][rank_id % device_num])
|
||||
|
@ -320,3 +309,31 @@ def _make_dir(path, arg_name):
|
|||
finally:
|
||||
pass
|
||||
return real_path
|
||||
|
||||
|
||||
def _insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple):
|
||||
"""insert opt_shard op reshape"""
|
||||
from_opt_shard_size = from_info_tuple[0]
|
||||
from_dev_matrix = from_info_tuple[1]
|
||||
from_tensor_map = from_info_tuple[2]
|
||||
from_full_tensor_shape = from_info_tuple[3]
|
||||
to_opt_shard_size = to_info_tuple[0]
|
||||
to_dev_matrix_origin = to_info_tuple[1]
|
||||
to_tensor_map_origin = to_info_tuple[2]
|
||||
origin_tensor_shape = to_info_tuple[3]
|
||||
for param_rank, _ in param_rank_map.items():
|
||||
if from_opt_shard_size > 0:
|
||||
from_tensor_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
||||
from_slice_tensor_shape = ()
|
||||
for i, item in enumerate(from_full_tensor_shape):
|
||||
from_slice_tensor_shape += (item // from_tensor_strategy[i],)
|
||||
param_rank_map.get(param_rank).insert(0, ('Reshape', list(from_slice_tensor_shape)))
|
||||
if to_opt_shard_size > 0:
|
||||
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
|
||||
to_slice_tensor_shape = ()
|
||||
for i, item in enumerate(origin_tensor_shape):
|
||||
if i == 0 and to_opt_shard_size > 0:
|
||||
to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
|
||||
continue
|
||||
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
|
||||
param_rank_map.get(param_rank).append(('Reshape', list(to_slice_tensor_shape)))
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transform distributed checkpoint"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import glob
|
||||
import copy
|
||||
|
@ -203,7 +205,7 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|||
param_total_dict_copy = copy.deepcopy(param_total_dict)
|
||||
transform_param_list = _transform_parallel_checkpoint(transform_rank, param_total_dict_copy,
|
||||
param_attr_dict, src_strategy_file, dst_strategy_file)
|
||||
save_checkpoint_file = ckpt_prefix + str(transform_rank) + ".ckpt"
|
||||
save_checkpoint_file = os.path.join(ckpt_prefix, str(transform_rank), ".ckpt")
|
||||
save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank))
|
||||
if not os.path.exists(save_checkpoint_file_dir):
|
||||
_make_dir(save_checkpoint_file_dir, "path")
|
||||
|
|
Loading…
Reference in New Issue