!29323 Modify Error Info For 1.6

Merge pull request !29323 from liuyang/ms_1_6
This commit is contained in:
i-robot 2022-01-20 06:50:49 +00:00 committed by Gitee
commit 7d4a6f8654
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 8 additions and 9 deletions

View File

@ -1541,11 +1541,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
for dim in train_strategy[list(train_strategy.keys())[0]][0]:
train_dev_count *= dim
if train_dev_count != ckpt_file_len:
raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
f"the key of it must be string, and the value of it must be list or tuple that "
f"the first four elements are dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (int, which value is 0)."
f"Please check whether 'predict_strategy' is correct.")
raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
f"equal to the device count of training process. But the length of 'checkpoint_filenames'"
f" is {ckpt_file_len} and the device count is {train_dev_count}.")
rank_list = _infer_rank_list(train_strategy, predict_strategy)
param_total_dict = defaultdict(dict)
@ -1671,10 +1669,11 @@ def _check_predict_strategy(predict_strategy):
flag = False
if not flag:
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list or a tuple that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")
raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
f"the key of it must be string, and the value of it must be list or tuple that "
f"the first four elements are dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (int, which value is 0)."
f"Please check whether 'predict_strategy' is correct.")
def _check_checkpoint_file(checkpoint_filenames):