forked from mindspore-Ecosystem/mindspore
modify dis_load_ckpt for master
This commit is contained in:
parent
d7554bbbd3
commit
422614c558
|
@ -22,6 +22,7 @@ import shutil
|
|||
import time
|
||||
import copy
|
||||
from threading import Thread, Lock
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -1138,19 +1139,18 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|||
return merged_parameter
|
||||
|
||||
|
||||
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, dec_key=None, dec_mode='AES-GCM'):
|
||||
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
|
||||
train_strategy_filename=None, dec_key=None, dec_mode='AES-GCM'):
|
||||
"""
|
||||
Load checkpoint into net for distributed predication.
|
||||
|
||||
Args:
|
||||
network (Cell): Network for distributed predication.
|
||||
checkpoint_filenames (list(str)): The name of Checkpoint files
|
||||
in order of rank id.
|
||||
predict_strategy (Optional(dict)): Strategy of predication process, whose key
|
||||
is parameter name, and value is a list or a tuple that the first four
|
||||
elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
|
||||
it means that the predication process just uses single device.
|
||||
Default: None.
|
||||
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
|
||||
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
|
||||
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
|
||||
it means that the predication process just uses single device. Default: None.
|
||||
train_strategy_filename (str): Train strategy file. Default: None.
|
||||
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
||||
is not required. Default: None.
|
||||
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
||||
|
@ -1161,35 +1161,34 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
ValueError: Failed to load checkpoint into net.
|
||||
"""
|
||||
network = Validator.check_isinstance("network", network, nn.Cell)
|
||||
|
||||
for index, filename in enumerate(checkpoint_filenames):
|
||||
if not isinstance(filename, str) or not os.path.exists(filename) \
|
||||
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
|
||||
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")
|
||||
|
||||
if not _check_predict_strategy(predict_strategy):
|
||||
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
|
||||
f"and the value is a list or a tuple that the first four elements are "
|
||||
f"dev_matrix (list[int]), tensor_map (list[int]), "
|
||||
f"param_split_shape (list[int]) and field_size (zero).")
|
||||
_check_checkpoint_file(checkpoint_filenames)
|
||||
_check_predict_strategy(predict_strategy)
|
||||
|
||||
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
||||
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
||||
|
||||
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
|
||||
if train_strategy_filename is None:
|
||||
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
|
||||
_train_strategy = build_searched_strategy(train_strategy_filename)
|
||||
train_strategy = _convert_to_list(_train_strategy)
|
||||
|
||||
train_dev_count = 1
|
||||
ckpt_file_len = len(checkpoint_filenames)
|
||||
for dim in train_strategy[list(train_strategy.keys())[0]][0]:
|
||||
train_dev_count *= dim
|
||||
if train_dev_count != len(checkpoint_filenames):
|
||||
if train_dev_count != ckpt_file_len:
|
||||
raise ValueError(
|
||||
f"The length of checkpoint_filenames should be equal to the device count of training process. "
|
||||
f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}.")
|
||||
f"The length is {ckpt_file_len} but the device count is {train_dev_count}.")
|
||||
|
||||
rank_list = _infer_rank_list(train_strategy, predict_strategy)
|
||||
|
||||
param_total_dict = defaultdict(dict)
|
||||
for file_index, file_name in enumerate(checkpoint_filenames):
|
||||
ckpt_dict = load_checkpoint(file_name, dec_key, dec_mode)
|
||||
for param_name, param in ckpt_dict.items():
|
||||
param_total_dict[param_name][file_index] = param
|
||||
|
||||
param_dict = {}
|
||||
for _, param in network.parameters_and_names():
|
||||
sliced_params = []
|
||||
|
@ -1197,8 +1196,31 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
continue
|
||||
param_rank = rank_list[param.name][0]
|
||||
skip_merge_split = rank_list[param.name][1]
|
||||
shard_stride = train_strategy[param.name][4]
|
||||
if train_strategy[param.name][5]:
|
||||
shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5]
|
||||
else:
|
||||
shard_size = 0
|
||||
for rank in param_rank:
|
||||
sliced_param = load_checkpoint(checkpoint_filenames[rank], dec_key=dec_key, dec_mode=dec_mode)[param.name]
|
||||
param_total_list = list(range(0, ckpt_file_len))
|
||||
if shard_size > 0:
|
||||
shard_total_list = [param_total_list[i:i + shard_size] for i in
|
||||
range(0, ckpt_file_len, shard_size)]
|
||||
param_total_list = shard_total_list[rank // shard_size]
|
||||
if shard_stride > 0:
|
||||
param_stride = []
|
||||
# merge pre parameter
|
||||
param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
|
||||
param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
|
||||
param_index = list(set(param_index))
|
||||
param_index.sort()
|
||||
for rank_num in param_index:
|
||||
param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
|
||||
|
||||
sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
|
||||
else:
|
||||
sliced_param = param_total_dict[param.name][rank]
|
||||
|
||||
sliced_params.append(sliced_param)
|
||||
if skip_merge_split:
|
||||
split_param = sliced_params[0]
|
||||
|
@ -1222,19 +1244,33 @@ def _check_predict_strategy(predict_strategy):
|
|||
return True
|
||||
|
||||
if predict_strategy is None:
|
||||
return True
|
||||
return
|
||||
|
||||
flag = True
|
||||
predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
|
||||
for key in predict_strategy.keys():
|
||||
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
|
||||
or len(predict_strategy[key]) < 4:
|
||||
return False
|
||||
flag = False
|
||||
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
|
||||
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
|
||||
not (_check_int_list(param_split_shape) or not param_split_shape) or \
|
||||
not (isinstance(field_size, int) and field_size == 0):
|
||||
return False
|
||||
return True
|
||||
flag = False
|
||||
|
||||
if not flag:
|
||||
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
|
||||
f"and the value is a list or a tuple that the first four elements are "
|
||||
f"dev_matrix (list[int]), tensor_map (list[int]), "
|
||||
f"param_split_shape (list[int]) and field_size (zero).")
|
||||
|
||||
|
||||
def _check_checkpoint_file(checkpoint_filenames):
|
||||
"""Check checkpoint file name."""
|
||||
for index, filename in enumerate(checkpoint_filenames):
|
||||
if not isinstance(filename, str) or not os.path.exists(filename) \
|
||||
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
|
||||
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")
|
||||
|
||||
|
||||
def _convert_to_list(strategy):
|
||||
|
|
Loading…
Reference in New Issue