me master 0325

This commit is contained in:
dingpeifei 2022-03-25 18:35:51 +08:00
parent 95e5475f58
commit c688f0a490
11 changed files with 94 additions and 64 deletions

View File

@ -376,14 +376,14 @@ class Validator:
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\' the argument `{arg_name}` must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
raise ValueError(f'The argument `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
@ -501,7 +501,7 @@ class Validator:
def check_valid_input(arg_name, arg_value, prim_name):
"""Checks valid value."""
if arg_value is None:
raise ValueError(f"For \'{prim_name}\', the '{arg_name}' can not be None, but got {arg_value}.")
raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}' can not be None, but got {arg_value}.")
return arg_value
@staticmethod
@ -627,7 +627,8 @@ class Validator:
axis = axis if isinstance(axis, Iterable) else (axis,)
exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
if list(shape) != exp_shape:
raise ValueError(f"For '{prim_name}', the '{arg_name1}'.shape reduce on 'axis': {axis_origin} should "
raise ValueError(f"For '{prim_name}', "
f"the argument '{arg_name1}'.shape reduce on 'axis': {axis_origin} should "
f"be equal to '{arg_name2}'.shape: {shape}, but got {ori_shape}.")
@staticmethod
@ -657,7 +658,8 @@ class Validator:
perm = tuple(perm)
else:
if not isinstance(perm, tuple):
raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}")
raise TypeError(f"The argument `axes` should be a tuple/list, "
f"or series of int, but got {type(axes[0])}")
return perm
# if multiple arguments provided, it must be `ndim` number of ints
@ -679,7 +681,8 @@ class Validator:
else:
if not isinstance(new_shape, tuple):
raise TypeError(
f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}")
f"The argument `shape` should be an int, or tuple/list, "
f"or series of int, but got {type(shp[0])}")
return new_shape
return shp
@ -706,7 +709,7 @@ class Validator:
Validator.check_axis_in_range(axis, ndim)
axes = tuple(map(lambda x: x % ndim, axes))
return axes
raise TypeError(f"The axes should be integer, list or tuple for check, but got {type(axes)}.")
raise TypeError(f"The argument 'axes' should be integer, list or tuple for check, but got {type(axes)}.")
@staticmethod
def prepare_shape_for_squeeze(shape, axes):
@ -1078,7 +1081,7 @@ def args_unreset_check(*unreset_args, **unreset_kwargs):
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in _set_record.keys():
raise TypeError('Argument {} is non-renewable parameter {}.'.format(name, bound_unreset[name]))
raise TypeError('The argument {} is non-renewable parameter {}.'.format(name, bound_unreset[name]))
if name in bound_unreset:
_set_record[name] = value
return func(*args, **kwargs)

View File

@ -185,8 +185,8 @@ def pytype_to_dtype(obj):
if isinstance(obj, typing.Type):
return obj
if not isinstance(obj, type):
raise TypeError("The argument 'obj' must be a python type object, such as int, float, str, etc."
"But got type {}.".format(type(obj)))
raise TypeError("For 'pytype_to_dtype', the argument 'obj' must be a python type object,"
"such as int, float, str, etc. But got type {}.".format(type(obj)))
elif obj in _simple_types:
return _simple_types[obj]
raise NotImplementedError(f"The python type {obj} cannot be converted to MindSpore type.")

View File

@ -189,12 +189,14 @@ def _calculate_gain(nonlinearity, param=None):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("'negative_slope' {} is not a valid number. When 'nonlinearity' has been set to "
raise ValueError("For 'HeUniform', 'negative_slope' {} is not a valid number."
"When 'nonlinearity' has been set to "
"'leaky_relu', 'negative_slope' should be int or float type, but got "
"{}.".format(param, type(param)))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("The argument 'nonlinearity' should be one of ['sigmoid', 'tanh', 'relu' or 'leaky_relu'], "
raise ValueError("For 'HeUniform', the argument 'nonlinearity' should be one of "
"['sigmoid', 'tanh', 'relu' or 'leaky_relu'], "
"but got {}.".format(nonlinearity))
return res
@ -469,8 +471,8 @@ class Dirac(Initializer):
shapes = arr.shape
if shapes[0] % self.groups != 0:
raise ValueError("For Dirac initializer, the first dimension of"
"the initialized tensor must be divisible by group, "
"but got {}/{}.".format(shapes[0], self.groups))
"the initialized tensor must be divisible by groups, "
"but got first dimension{}, groups{}.".format(shapes[0], self.groups))
out_channel_per_group = shapes[0] // self.groups
min_dim = min(out_channel_per_group, shapes[1])
@ -564,15 +566,16 @@ class VarianceScaling(Initializer):
def __init__(self, scale=1.0, mode='fan_in', distribution='truncated_normal'):
super(VarianceScaling, self).__init__(scale=scale, mode=mode, distribution=distribution)
if scale <= 0.:
raise ValueError("For VarianceScaling initializer, scale must be greater than 0, but got {}.".format(scale))
raise ValueError("For VarianceScaling initializer, "
"the argument 'scale' must be greater than 0, but got {}.".format(scale))
if mode not in ['fan_in', 'fan_out', 'fan_avg']:
raise ValueError('For VarianceScaling initializer, mode must be fan_in, '
'fan_out or fan_avg, but got {}.'.format(mode))
raise ValueError("For VarianceScaling initializer, the argument 'mode' must be fan_in, "
"fan_out or fan_avg, but got {}.".format(mode))
if distribution not in ['uniform', 'truncated_normal', 'untruncated_normal']:
raise ValueError('For VarianceScaling initializer, distribution must be uniform, '
'truncated_norm or untruncated_norm, but got {}.'.format(distribution))
raise ValueError("For VarianceScaling initializer, the argument 'distribution' must be uniform, "
"truncated_norm or untruncated_norm, but got {}.".format(distribution))
self.scale = scale
self.mode = mode
@ -720,14 +723,15 @@ def initializer(init, shape=None, dtype=mstype.float32):
>>> tensor4 = initializer(0, [1, 2, 3], mindspore.float32)
"""
if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
raise TypeError("The type of the 'init' argument should be 'Tensor', 'number', 'string' "
raise TypeError("For 'initializer', the type of the 'init' argument should be 'Tensor', 'number', 'string' "
"or 'initializer', but got {}.".format(type(init)))
if isinstance(init, Tensor):
init_shape = init.shape
shape = shape if isinstance(shape, (tuple, list)) else [shape]
if shape is not None and init_shape != tuple(shape):
raise ValueError("The shape of the 'init' argument should be same as the argument 'shape', but got the "
raise ValueError("For 'initializer', the shape of the 'init' argument should be same as "
"the argument 'shape', but got the "
"'init' shape {} and the 'shape' {}.".format(list(init.shape), shape))
return init
@ -738,7 +742,8 @@ def initializer(init, shape=None, dtype=mstype.float32):
for value in shape if shape is not None else ():
if not isinstance(value, int) or value <= 0:
raise ValueError(f"The argument 'shape' is invalid, the value of 'shape' must be positive integer, "
raise ValueError(f"For 'initializer', the argument 'shape' is invalid, the value of 'shape' "
f"must be positive integer, "
f"but got {shape}")
if isinstance(init, str):

View File

@ -251,7 +251,8 @@ class _Context:
def set_variable_memory_max_size(self, variable_memory_max_size):
"""set values of variable_memory_max_size and graph_memory_max_size"""
logger.warning("The parameter 'variable_memory_max_size' is deprecated, and will be removed in a future "
logger.warning("For 'context.set_context', the parameter 'variable_memory_max_size' is deprecated, "
"and will be removed in a future "
"version. Please use parameter 'max_device_memory' instead.")
if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern):
raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should be in correct"
@ -281,13 +282,15 @@ class _Context:
def set_mempool_block_size(self, mempool_block_size):
"""Set the block size of memory pool."""
if _get_mode() == GRAPH_MODE:
logger.warning("Graph mode not support mempool_block_size context currently")
logger.warning("Graph mode doesn't supportto set parameter 'mempool_block_size' of context currently")
return
if not Validator.check_str_by_regular(mempool_block_size, _re_pattern):
raise ValueError("Context param mempool_block_size should be in correct format! Such as \"10GB\"")
raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be in "
"correct format! Such as \"10GB\"")
mempool_block_size_value = float(mempool_block_size[:-2])
if mempool_block_size_value < 1.0:
raise ValueError("Context param mempool_block_size should be greater or equal to \"1GB\"")
raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be "
"greater or equal to \"1GB\"")
self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value)
def set_print_file_path(self, file_path):
@ -626,8 +629,10 @@ def _check_target_specific_cfgs(device, arg_key):
supported_devices = device_cfgs[arg_key]
if device in supported_devices:
return True
logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
", ignore it.")
logger.warning(f"For 'context.set_context', "
f"the argument 'device_target' only supports devices in {supported_devices}, "
f"but got '{arg_key}', current device is '{device}'"
f", ignore it.")
return False
@ -894,7 +899,7 @@ def set_context(**kwargs):
for key, value in kwargs.items():
if key in ('enable_profiling', 'profiling_options', 'enable_auto_mixed_precision',
'enable_dump', 'save_dump_path'):
logger.warning(f" '{key}' parameters will be deprecated."
logger.warning(f"For 'context.set_context', '{key}' parameters will be deprecated."
"For details, please see the interface parameter API comments")
continue
if not _check_target_specific_cfgs(device, key):

View File

@ -276,7 +276,7 @@ def _get_env_config():
def _check_directory_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
"""Check whether directory is legitimate."""
if not isinstance(target, str):
raise ValueError("Args directory {} must be string, please check it".format(target))
raise ValueError("The directory {} must be string, please check it".format(target))
if reg is None:
reg = r"^[\/0-9a-zA-Z\_\-\.\:\\]+$"
if re.match(reg, target, flag) is None:

View File

@ -434,7 +434,7 @@ class Cell(Cell_):
def _check_construct_args(self, *inputs, **kwargs):
"""Check the args needed by the function construct"""
if kwargs:
raise ValueError(f"Expect no kwargs here. Did you pass wrong arguments? args: {inputs}, kwargs: {kwargs}")
raise ValueError(f"Expect no kwargs here. maybe you pass wrong arguments, rgs: {inputs}, kwargs: {kwargs}")
positional_args = 0
default_args = 0
for value in inspect.signature(self.construct).parameters.values():
@ -665,7 +665,7 @@ class Cell(Cell_):
continue
exist_objs.add(item)
if item.name == PARAMETER_NAME_DEFAULT:
logger.warning("The parameter definition is deprecated.\n"
logger.warning("For 'Cell', the parameter definition is deprecated.\n"
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
item.name = item.name + "$" + str(self._id)
self._id += 1
@ -2206,7 +2206,8 @@ class GraphCell(Cell):
raise TypeError(f"The 'params_init' must be a dict, but got {type(params_init)}.")
for name, value in params_init.items():
if not isinstance(name, str) or not isinstance(value, Tensor):
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
raise TypeError("For 'GraphCell', the key of the 'params_init' must be str, "
"and the value must be Tensor or Parameter, "
f"but got the key type: {type(name)}, and the value type: {type(value)}")
params_dict = update_func_graph_hyper_params(self.graph, params_init)

View File

@ -411,7 +411,7 @@ class Optimizer(Cell):
self.dynamic_weight_decay = True
weight_decay = _WrappedWeightDecay(weight_decay, self.loss_scale)
else:
raise TypeError("Weight decay should be int, float or Cell.")
raise TypeError("For 'Optimizer', the argument 'Weight_decay' should be int, float or Cell.")
return weight_decay
def _preprocess_single_lr(self, learning_rate):
@ -431,12 +431,13 @@ class Optimizer(Cell):
raise ValueError(f"For 'Optimizer', if 'learning_rate' is Tensor type, then the dimension of it should "
f"be 0 or 1, but got {learning_rate.ndim}.")
if learning_rate.ndim == 1 and learning_rate.size < 2:
logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number "
logger.warning("For 'Optimizer', if use `Tensor` type dynamic learning rate, "
"please make sure that the number "
"of elements in the tensor is greater than 1.")
return learning_rate
if isinstance(learning_rate, LearningRateSchedule):
return learning_rate
raise TypeError("For 'Optimizer', 'learning_rate' should be int, float, Tensor, Iterable or "
raise TypeError("For 'Optimizer', the argument 'learning_rate' should be int, float, Tensor, Iterable or "
"LearningRateSchedule, but got {}.".format(type(learning_rate)))
def _build_single_lr(self, learning_rate, name):

View File

@ -167,7 +167,8 @@ class CheckpointConfig:
if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input arguments 'save_checkpoint_steps', 'save_checkpoint_seconds', "
raise ValueError("For 'CheckpointConfig', the input arguments 'save_checkpoint_steps', "
"'save_checkpoint_seconds', "
"'keep_checkpoint_max' and 'keep_checkpoint_per_n_minutes' can't be all None or 0.")
Validator.check_bool(exception_save)
self.exception_save = exception_save
@ -308,7 +309,8 @@ class CheckpointConfig:
if append_info is None or append_info == []:
return None
if not isinstance(append_info, list):
raise TypeError(f"The type of 'append_info' must be list, but got {str(type(append_info))}.")
raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' must be list,"
f"but got {str(type(append_info))}.")
handle_append_info = {}
if "epoch_num" in append_info:
handle_append_info["epoch_num"] = 0
@ -317,20 +319,22 @@ class CheckpointConfig:
dict_num = 0
for element in append_info:
if not isinstance(element, str) and not isinstance(element, dict):
raise TypeError(f"The type of 'append_info' element must be str or dict, "
raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' element must be str or dict,"
f"but got {str(type(element))}.")
if isinstance(element, str) and element not in _info_list:
raise ValueError(f"The value of element in the argument 'append_info' must be in {_info_list}, "
raise ValueError(f"For 'CheckpointConfig', the value of element in the argument 'append_info' "
f"must be in {_info_list}, "
f"but got {element}.")
if isinstance(element, dict):
dict_num += 1
if dict_num > 1:
raise TypeError(f"The element of 'append_info' must has only one dict.")
raise TypeError(f"For 'CheckpointConfig', the element of 'append_info' must has only one dict.")
for key, value in element.items():
if isinstance(key, str) and isinstance(value, (int, float, bool)):
handle_append_info[key] = value
else:
raise TypeError(f"The type of dict in 'append_info' must be key: string, value: int or float, "
raise TypeError(f"For 'CheckpointConfig', the type of dict in 'append_info' "
f"must be key: string, value: int or float, "
f"but got key: {type(key)}, value: {type(value)}")
return handle_append_info
@ -369,7 +373,8 @@ class ModelCheckpoint(Callback):
self._last_triggered_step = 0
if not isinstance(prefix, str) or prefix.find('/') >= 0:
raise ValueError("The argument 'prefix' for checkpoint file name is invalid, 'prefix' must be "
raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
"for checkpoint file name is invalid, 'prefix' must be "
"string and does not contain '/', but got {}.".format(prefix))
self._prefix = prefix
self._exception_prefix = prefix
@ -386,7 +391,8 @@ class ModelCheckpoint(Callback):
self._config = CheckpointConfig()
else:
if not isinstance(config, CheckpointConfig):
raise TypeError("The argument 'config' should be 'CheckpointConfig' type, "
raise TypeError("For 'ModelCheckpoint', the argument 'config' should be "
"'CheckpointConfig' type, "
"but got {}.".format(type(config)))
self._config = config

View File

@ -879,8 +879,8 @@ class Model:
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
raise ValueError("Use Model.build to initialize model, but the value of parameter `epoch` in Model.build "
"is not equal to value in Model.train, got {} and {} separately."
raise ValueError("when use Model.build to initialize model, the value of parameter `epoch` in Model.build "
"should be equal to value in Model.train, but got {} and {} separately."
.format(train_dataset._warmup_epoch, epoch))
if dataset_sink_mode and _is_ps_mode():
@ -893,7 +893,8 @@ class Model:
if sink_size == -1:
sink_size = dataset_size
if sink_size < -1 or sink_size == 0:
raise ValueError("The argument 'sink_size' must be -1 or positive, but got {}.".format(sink_size))
raise ValueError("For 'Model.train', The argument 'sink_size' must be -1 or positive, "
"but got {}.".format(sink_size))
_device_number_check(self._parallel_mode, self._device_number)
@ -1143,7 +1144,8 @@ class Model:
if sink_size == -1:
sink_size = dataset_size
if sink_size < -1 or sink_size == 0:
raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
raise ValueError("For 'infer_train_layout', the argument 'sink_size' must be -1 or positive, "
"but got sink_size {}.".format(sink_size))
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
"""

View File

@ -276,11 +276,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
raise TypeError("For 'save_checkpoint', the argument 'save_obj' should be nn.Cell or list, "
"but got {}.".format(type(save_obj)))
if not isinstance(ckpt_file_name, str):
raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be "
raise TypeError("For 'save_checkpoint', the argument {} for checkpoint file name is invalid,"
"'ckpt_file_name' must be "
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
ckpt_file_name = os.path.realpath(ckpt_file_name)
if os.path.isdir(ckpt_file_name):
raise IsADirectoryError("The argument `ckpt_file_name`: {} is a directory, "
raise IsADirectoryError("For 'save_checkpoint', the argument `ckpt_file_name`: {} is a directory, "
"it should be a file name.".format(ckpt_file_name))
if not ckpt_file_name.endswith('.ckpt'):
ckpt_file_name += ".ckpt"
@ -501,7 +502,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
np_type = tensor_to_np_type.get(data_type)
ms_type = tensor_to_ms_type[data_type]
element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data)
@ -529,7 +530,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
except BaseException as e:
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
raise ValueError(e.__str__() + "\nFailed to load the checkpoint file {}.".format(ckpt_file_name))
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
"failed to load the checkpoint file {}.".format(ckpt_file_name))
if not parameter_dict:
raise ValueError(f"The loaded parameter dict is empty after filtering, please check whether "
@ -564,7 +566,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
if isinstance(filter_prefix, str):
filter_prefix = (filter_prefix,)
if not filter_prefix:
raise ValueError("For 'load_checkpoint', the 'filter_prefix' can't be empty when "
raise ValueError("For 'load_checkpoint', the argument 'filter_prefix' can't be empty when "
"'filter_prefix' is list or tuple.")
for index, prefix in enumerate(filter_prefix):
if not isinstance(prefix, str):
@ -865,7 +867,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
check_input_data(*inputs, data_class=Tensor)
if file_format == 'GEIR':
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
logger.warning(f"For 'export', format 'GEIR' is deprecated, "
f"it would be removed in future release, use 'AIR' instead.")
file_format = 'AIR'
supported_formats = ['AIR', 'ONNX', 'MINDIR']
@ -1363,13 +1366,14 @@ def restore_group_info_list(group_info_file_name):
>>> restore_list = restore_group_info_list("./group_info.pb")
"""
if not isinstance(group_info_file_name, str):
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")
raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
f"but got {type(group_info_file_name)}.")
if not os.path.isfile(group_info_file_name):
raise ValueError(f"No such group info file: {group_info_file_name}.")
raise ValueError(f"No such group information file: {group_info_file_name}.")
if os.path.getsize(group_info_file_name) == 0:
raise ValueError("The group info file should not be empty.")
raise ValueError("The group information file should not be empty.")
parallel_group_map = ParallelGroupMap()
@ -1379,7 +1383,7 @@ def restore_group_info_list(group_info_file_name):
restore_list = parallel_group_map.ckpt_restore_rank_list
if not restore_list:
raise ValueError("The group info file has no restore rank list.")
raise ValueError("The group information file has no restore rank list.")
restore_rank_list = [rank for rank in restore_list.dim]
return restore_rank_list
@ -1405,7 +1409,7 @@ def build_searched_strategy(strategy_filename):
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
"""
if not isinstance(strategy_filename, str):
raise TypeError(f"For 'build_searched_strategy', the 'strategy_filename' should be string, "
raise TypeError(f"For 'build_searched_strategy', the argument 'strategy_filename' should be string, "
f"but got {type(strategy_filename)}.")
if not os.path.isfile(strategy_filename):
@ -1474,14 +1478,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
"""
if not isinstance(sliced_parameters, list):
raise TypeError(f"For 'merge_sliced_parameter', the 'sliced_parameters' should be list, "
raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
f"but got {type(sliced_parameters)}.")
if not sliced_parameters:
raise ValueError("For 'merge_sliced_parameter', the 'sliced_parameters' should not be empty.")
raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
if strategy and not isinstance(strategy, dict):
raise TypeError(f"For 'merge_sliced_parameter', the 'strategy' should be dict, but got {type(strategy)}.")
raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
f"but got {type(strategy)}.")
try:
parameter_name = sliced_parameters[0].name
@ -1576,7 +1581,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
train_dev_count *= dim
if train_dev_count != ckpt_file_len:
raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
f"equal to the device count of training process. But the length of 'checkpoint_filenames'"
f"equal to the device count of training process. "
f"But got the length of 'checkpoint_filenames'"
f" is {ckpt_file_len} and the device count is {train_dev_count}.")
rank_list = _infer_rank_list(train_strategy, predict_strategy)

View File

@ -90,7 +90,8 @@ def test_init_graph_cell_parameters_with_wrong_value_type():
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
assert "For 'GraphCell', the key of the 'params_init' must be str, " \
"and the value must be Tensor or Parameter" in str(err.value)
remove_generated_file(mindir_name)