forked from mindspore-Ecosystem/mindspore
me master 0325
This commit is contained in:
parent
95e5475f58
commit
c688f0a490
|
@ -376,14 +376,14 @@ class Validator:
|
|||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
|
||||
raise ValueError(f'For \'{prim_name}\' the argument `{arg_name}` must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_isinstance(arg_name, arg_value, classes):
|
||||
"""Check arg isinstance of classes"""
|
||||
if not isinstance(arg_value, classes):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
||||
raise ValueError(f'The argument `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
|
@ -501,7 +501,7 @@ class Validator:
|
|||
def check_valid_input(arg_name, arg_value, prim_name):
|
||||
"""Checks valid value."""
|
||||
if arg_value is None:
|
||||
raise ValueError(f"For \'{prim_name}\', the '{arg_name}' can not be None, but got {arg_value}.")
|
||||
raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}' can not be None, but got {arg_value}.")
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
|
@ -627,7 +627,8 @@ class Validator:
|
|||
axis = axis if isinstance(axis, Iterable) else (axis,)
|
||||
exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
|
||||
if list(shape) != exp_shape:
|
||||
raise ValueError(f"For '{prim_name}', the '{arg_name1}'.shape reduce on 'axis': {axis_origin} should "
|
||||
raise ValueError(f"For '{prim_name}', "
|
||||
f"the argument '{arg_name1}'.shape reduce on 'axis': {axis_origin} should "
|
||||
f"be equal to '{arg_name2}'.shape: {shape}, but got {ori_shape}.")
|
||||
|
||||
@staticmethod
|
||||
|
@ -657,7 +658,8 @@ class Validator:
|
|||
perm = tuple(perm)
|
||||
else:
|
||||
if not isinstance(perm, tuple):
|
||||
raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}")
|
||||
raise TypeError(f"The argument `axes` should be a tuple/list, "
|
||||
f"or series of int, but got {type(axes[0])}")
|
||||
return perm
|
||||
|
||||
# if multiple arguments provided, it must be `ndim` number of ints
|
||||
|
@ -679,7 +681,8 @@ class Validator:
|
|||
else:
|
||||
if not isinstance(new_shape, tuple):
|
||||
raise TypeError(
|
||||
f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}")
|
||||
f"The argument `shape` should be an int, or tuple/list, "
|
||||
f"or series of int, but got {type(shp[0])}")
|
||||
return new_shape
|
||||
|
||||
return shp
|
||||
|
@ -706,7 +709,7 @@ class Validator:
|
|||
Validator.check_axis_in_range(axis, ndim)
|
||||
axes = tuple(map(lambda x: x % ndim, axes))
|
||||
return axes
|
||||
raise TypeError(f"The axes should be integer, list or tuple for check, but got {type(axes)}.")
|
||||
raise TypeError(f"The argument 'axes' should be integer, list or tuple for check, but got {type(axes)}.")
|
||||
|
||||
@staticmethod
|
||||
def prepare_shape_for_squeeze(shape, axes):
|
||||
|
@ -1078,7 +1081,7 @@ def args_unreset_check(*unreset_args, **unreset_kwargs):
|
|||
argument_dict = argument_dict["kwargs"]
|
||||
for name, value in argument_dict.items():
|
||||
if name in _set_record.keys():
|
||||
raise TypeError('Argument {} is non-renewable parameter {}.'.format(name, bound_unreset[name]))
|
||||
raise TypeError('The argument {} is non-renewable parameter {}.'.format(name, bound_unreset[name]))
|
||||
if name in bound_unreset:
|
||||
_set_record[name] = value
|
||||
return func(*args, **kwargs)
|
||||
|
|
|
@ -185,8 +185,8 @@ def pytype_to_dtype(obj):
|
|||
if isinstance(obj, typing.Type):
|
||||
return obj
|
||||
if not isinstance(obj, type):
|
||||
raise TypeError("The argument 'obj' must be a python type object, such as int, float, str, etc."
|
||||
"But got type {}.".format(type(obj)))
|
||||
raise TypeError("For 'pytype_to_dtype', the argument 'obj' must be a python type object,"
|
||||
"such as int, float, str, etc. But got type {}.".format(type(obj)))
|
||||
elif obj in _simple_types:
|
||||
return _simple_types[obj]
|
||||
raise NotImplementedError(f"The python type {obj} cannot be converted to MindSpore type.")
|
||||
|
|
|
@ -189,12 +189,14 @@ def _calculate_gain(nonlinearity, param=None):
|
|||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("'negative_slope' {} is not a valid number. When 'nonlinearity' has been set to "
|
||||
raise ValueError("For 'HeUniform', 'negative_slope' {} is not a valid number."
|
||||
"When 'nonlinearity' has been set to "
|
||||
"'leaky_relu', 'negative_slope' should be int or float type, but got "
|
||||
"{}.".format(param, type(param)))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("The argument 'nonlinearity' should be one of ['sigmoid', 'tanh', 'relu' or 'leaky_relu'], "
|
||||
raise ValueError("For 'HeUniform', the argument 'nonlinearity' should be one of "
|
||||
"['sigmoid', 'tanh', 'relu' or 'leaky_relu'], "
|
||||
"but got {}.".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
@ -469,8 +471,8 @@ class Dirac(Initializer):
|
|||
shapes = arr.shape
|
||||
if shapes[0] % self.groups != 0:
|
||||
raise ValueError("For Dirac initializer, the first dimension of"
|
||||
"the initialized tensor must be divisible by group, "
|
||||
"but got {}/{}.".format(shapes[0], self.groups))
|
||||
"the initialized tensor must be divisible by groups, "
|
||||
"but got first dimension{}, groups{}.".format(shapes[0], self.groups))
|
||||
|
||||
out_channel_per_group = shapes[0] // self.groups
|
||||
min_dim = min(out_channel_per_group, shapes[1])
|
||||
|
@ -564,15 +566,16 @@ class VarianceScaling(Initializer):
|
|||
def __init__(self, scale=1.0, mode='fan_in', distribution='truncated_normal'):
|
||||
super(VarianceScaling, self).__init__(scale=scale, mode=mode, distribution=distribution)
|
||||
if scale <= 0.:
|
||||
raise ValueError("For VarianceScaling initializer, scale must be greater than 0, but got {}.".format(scale))
|
||||
raise ValueError("For VarianceScaling initializer, "
|
||||
"the argument 'scale' must be greater than 0, but got {}.".format(scale))
|
||||
|
||||
if mode not in ['fan_in', 'fan_out', 'fan_avg']:
|
||||
raise ValueError('For VarianceScaling initializer, mode must be fan_in, '
|
||||
'fan_out or fan_avg, but got {}.'.format(mode))
|
||||
raise ValueError("For VarianceScaling initializer, the argument 'mode' must be fan_in, "
|
||||
"fan_out or fan_avg, but got {}.".format(mode))
|
||||
|
||||
if distribution not in ['uniform', 'truncated_normal', 'untruncated_normal']:
|
||||
raise ValueError('For VarianceScaling initializer, distribution must be uniform, '
|
||||
'truncated_norm or untruncated_norm, but got {}.'.format(distribution))
|
||||
raise ValueError("For VarianceScaling initializer, the argument 'distribution' must be uniform, "
|
||||
"truncated_norm or untruncated_norm, but got {}.".format(distribution))
|
||||
|
||||
self.scale = scale
|
||||
self.mode = mode
|
||||
|
@ -720,14 +723,15 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
>>> tensor4 = initializer(0, [1, 2, 3], mindspore.float32)
|
||||
"""
|
||||
if not isinstance(init, (Tensor, numbers.Number, str, Initializer)):
|
||||
raise TypeError("The type of the 'init' argument should be 'Tensor', 'number', 'string' "
|
||||
raise TypeError("For 'initializer', the type of the 'init' argument should be 'Tensor', 'number', 'string' "
|
||||
"or 'initializer', but got {}.".format(type(init)))
|
||||
|
||||
if isinstance(init, Tensor):
|
||||
init_shape = init.shape
|
||||
shape = shape if isinstance(shape, (tuple, list)) else [shape]
|
||||
if shape is not None and init_shape != tuple(shape):
|
||||
raise ValueError("The shape of the 'init' argument should be same as the argument 'shape', but got the "
|
||||
raise ValueError("For 'initializer', the shape of the 'init' argument should be same as "
|
||||
"the argument 'shape', but got the "
|
||||
"'init' shape {} and the 'shape' {}.".format(list(init.shape), shape))
|
||||
return init
|
||||
|
||||
|
@ -738,7 +742,8 @@ def initializer(init, shape=None, dtype=mstype.float32):
|
|||
|
||||
for value in shape if shape is not None else ():
|
||||
if not isinstance(value, int) or value <= 0:
|
||||
raise ValueError(f"The argument 'shape' is invalid, the value of 'shape' must be positive integer, "
|
||||
raise ValueError(f"For 'initializer', the argument 'shape' is invalid, the value of 'shape' "
|
||||
f"must be positive integer, "
|
||||
f"but got {shape}")
|
||||
|
||||
if isinstance(init, str):
|
||||
|
|
|
@ -251,7 +251,8 @@ class _Context:
|
|||
|
||||
def set_variable_memory_max_size(self, variable_memory_max_size):
|
||||
"""set values of variable_memory_max_size and graph_memory_max_size"""
|
||||
logger.warning("The parameter 'variable_memory_max_size' is deprecated, and will be removed in a future "
|
||||
logger.warning("For 'context.set_context', the parameter 'variable_memory_max_size' is deprecated, "
|
||||
"and will be removed in a future "
|
||||
"version. Please use parameter 'max_device_memory' instead.")
|
||||
if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern):
|
||||
raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should be in correct"
|
||||
|
@ -281,13 +282,15 @@ class _Context:
|
|||
def set_mempool_block_size(self, mempool_block_size):
|
||||
"""Set the block size of memory pool."""
|
||||
if _get_mode() == GRAPH_MODE:
|
||||
logger.warning("Graph mode not support mempool_block_size context currently")
|
||||
logger.warning("Graph mode doesn't supportto set parameter 'mempool_block_size' of context currently")
|
||||
return
|
||||
if not Validator.check_str_by_regular(mempool_block_size, _re_pattern):
|
||||
raise ValueError("Context param mempool_block_size should be in correct format! Such as \"10GB\"")
|
||||
raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be in "
|
||||
"correct format! Such as \"10GB\"")
|
||||
mempool_block_size_value = float(mempool_block_size[:-2])
|
||||
if mempool_block_size_value < 1.0:
|
||||
raise ValueError("Context param mempool_block_size should be greater or equal to \"1GB\"")
|
||||
raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be "
|
||||
"greater or equal to \"1GB\"")
|
||||
self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value)
|
||||
|
||||
def set_print_file_path(self, file_path):
|
||||
|
@ -626,8 +629,10 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
supported_devices = device_cfgs[arg_key]
|
||||
if device in supported_devices:
|
||||
return True
|
||||
logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
|
||||
", ignore it.")
|
||||
logger.warning(f"For 'context.set_context', "
|
||||
f"the argument 'device_target' only supports devices in {supported_devices}, "
|
||||
f"but got '{arg_key}', current device is '{device}'"
|
||||
f", ignore it.")
|
||||
return False
|
||||
|
||||
|
||||
|
@ -894,7 +899,7 @@ def set_context(**kwargs):
|
|||
for key, value in kwargs.items():
|
||||
if key in ('enable_profiling', 'profiling_options', 'enable_auto_mixed_precision',
|
||||
'enable_dump', 'save_dump_path'):
|
||||
logger.warning(f" '{key}' parameters will be deprecated."
|
||||
logger.warning(f"For 'context.set_context', '{key}' parameters will be deprecated."
|
||||
"For details, please see the interface parameter API comments")
|
||||
continue
|
||||
if not _check_target_specific_cfgs(device, key):
|
||||
|
|
|
@ -276,7 +276,7 @@ def _get_env_config():
|
|||
def _check_directory_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
||||
"""Check whether directory is legitimate."""
|
||||
if not isinstance(target, str):
|
||||
raise ValueError("Args directory {} must be string, please check it".format(target))
|
||||
raise ValueError("The directory {} must be string, please check it".format(target))
|
||||
if reg is None:
|
||||
reg = r"^[\/0-9a-zA-Z\_\-\.\:\\]+$"
|
||||
if re.match(reg, target, flag) is None:
|
||||
|
|
|
@ -434,7 +434,7 @@ class Cell(Cell_):
|
|||
def _check_construct_args(self, *inputs, **kwargs):
|
||||
"""Check the args needed by the function construct"""
|
||||
if kwargs:
|
||||
raise ValueError(f"Expect no kwargs here. Did you pass wrong arguments? args: {inputs}, kwargs: {kwargs}")
|
||||
raise ValueError(f"Expect no kwargs here. maybe you pass wrong arguments, rgs: {inputs}, kwargs: {kwargs}")
|
||||
positional_args = 0
|
||||
default_args = 0
|
||||
for value in inspect.signature(self.construct).parameters.values():
|
||||
|
@ -665,7 +665,7 @@ class Cell(Cell_):
|
|||
continue
|
||||
exist_objs.add(item)
|
||||
if item.name == PARAMETER_NAME_DEFAULT:
|
||||
logger.warning("The parameter definition is deprecated.\n"
|
||||
logger.warning("For 'Cell', the parameter definition is deprecated.\n"
|
||||
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
|
||||
item.name = item.name + "$" + str(self._id)
|
||||
self._id += 1
|
||||
|
@ -2206,7 +2206,8 @@ class GraphCell(Cell):
|
|||
raise TypeError(f"The 'params_init' must be a dict, but got {type(params_init)}.")
|
||||
for name, value in params_init.items():
|
||||
if not isinstance(name, str) or not isinstance(value, Tensor):
|
||||
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
|
||||
raise TypeError("For 'GraphCell', the key of the 'params_init' must be str, "
|
||||
"and the value must be Tensor or Parameter, "
|
||||
f"but got the key type: {type(name)}, and the value type: {type(value)}")
|
||||
|
||||
params_dict = update_func_graph_hyper_params(self.graph, params_init)
|
||||
|
|
|
@ -411,7 +411,7 @@ class Optimizer(Cell):
|
|||
self.dynamic_weight_decay = True
|
||||
weight_decay = _WrappedWeightDecay(weight_decay, self.loss_scale)
|
||||
else:
|
||||
raise TypeError("Weight decay should be int, float or Cell.")
|
||||
raise TypeError("For 'Optimizer', the argument 'Weight_decay' should be int, float or Cell.")
|
||||
return weight_decay
|
||||
|
||||
def _preprocess_single_lr(self, learning_rate):
|
||||
|
@ -431,12 +431,13 @@ class Optimizer(Cell):
|
|||
raise ValueError(f"For 'Optimizer', if 'learning_rate' is Tensor type, then the dimension of it should "
|
||||
f"be 0 or 1, but got {learning_rate.ndim}.")
|
||||
if learning_rate.ndim == 1 and learning_rate.size < 2:
|
||||
logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number "
|
||||
logger.warning("For 'Optimizer', if use `Tensor` type dynamic learning rate, "
|
||||
"please make sure that the number "
|
||||
"of elements in the tensor is greater than 1.")
|
||||
return learning_rate
|
||||
if isinstance(learning_rate, LearningRateSchedule):
|
||||
return learning_rate
|
||||
raise TypeError("For 'Optimizer', 'learning_rate' should be int, float, Tensor, Iterable or "
|
||||
raise TypeError("For 'Optimizer', the argument 'learning_rate' should be int, float, Tensor, Iterable or "
|
||||
"LearningRateSchedule, but got {}.".format(type(learning_rate)))
|
||||
|
||||
def _build_single_lr(self, learning_rate, name):
|
||||
|
|
|
@ -167,7 +167,8 @@ class CheckpointConfig:
|
|||
|
||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||
raise ValueError("The input arguments 'save_checkpoint_steps', 'save_checkpoint_seconds', "
|
||||
raise ValueError("For 'CheckpointConfig', the input arguments 'save_checkpoint_steps', "
|
||||
"'save_checkpoint_seconds', "
|
||||
"'keep_checkpoint_max' and 'keep_checkpoint_per_n_minutes' can't be all None or 0.")
|
||||
Validator.check_bool(exception_save)
|
||||
self.exception_save = exception_save
|
||||
|
@ -308,7 +309,8 @@ class CheckpointConfig:
|
|||
if append_info is None or append_info == []:
|
||||
return None
|
||||
if not isinstance(append_info, list):
|
||||
raise TypeError(f"The type of 'append_info' must be list, but got {str(type(append_info))}.")
|
||||
raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' must be list,"
|
||||
f"but got {str(type(append_info))}.")
|
||||
handle_append_info = {}
|
||||
if "epoch_num" in append_info:
|
||||
handle_append_info["epoch_num"] = 0
|
||||
|
@ -317,20 +319,22 @@ class CheckpointConfig:
|
|||
dict_num = 0
|
||||
for element in append_info:
|
||||
if not isinstance(element, str) and not isinstance(element, dict):
|
||||
raise TypeError(f"The type of 'append_info' element must be str or dict, "
|
||||
raise TypeError(f"For 'CheckpointConfig', the type of 'append_info' element must be str or dict,"
|
||||
f"but got {str(type(element))}.")
|
||||
if isinstance(element, str) and element not in _info_list:
|
||||
raise ValueError(f"The value of element in the argument 'append_info' must be in {_info_list}, "
|
||||
raise ValueError(f"For 'CheckpointConfig', the value of element in the argument 'append_info' "
|
||||
f"must be in {_info_list}, "
|
||||
f"but got {element}.")
|
||||
if isinstance(element, dict):
|
||||
dict_num += 1
|
||||
if dict_num > 1:
|
||||
raise TypeError(f"The element of 'append_info' must has only one dict.")
|
||||
raise TypeError(f"For 'CheckpointConfig', the element of 'append_info' must has only one dict.")
|
||||
for key, value in element.items():
|
||||
if isinstance(key, str) and isinstance(value, (int, float, bool)):
|
||||
handle_append_info[key] = value
|
||||
else:
|
||||
raise TypeError(f"The type of dict in 'append_info' must be key: string, value: int or float, "
|
||||
raise TypeError(f"For 'CheckpointConfig', the type of dict in 'append_info' "
|
||||
f"must be key: string, value: int or float, "
|
||||
f"but got key: {type(key)}, value: {type(value)}")
|
||||
|
||||
return handle_append_info
|
||||
|
@ -369,7 +373,8 @@ class ModelCheckpoint(Callback):
|
|||
self._last_triggered_step = 0
|
||||
|
||||
if not isinstance(prefix, str) or prefix.find('/') >= 0:
|
||||
raise ValueError("The argument 'prefix' for checkpoint file name is invalid, 'prefix' must be "
|
||||
raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
|
||||
"for checkpoint file name is invalid, 'prefix' must be "
|
||||
"string and does not contain '/', but got {}.".format(prefix))
|
||||
self._prefix = prefix
|
||||
self._exception_prefix = prefix
|
||||
|
@ -386,7 +391,8 @@ class ModelCheckpoint(Callback):
|
|||
self._config = CheckpointConfig()
|
||||
else:
|
||||
if not isinstance(config, CheckpointConfig):
|
||||
raise TypeError("The argument 'config' should be 'CheckpointConfig' type, "
|
||||
raise TypeError("For 'ModelCheckpoint', the argument 'config' should be "
|
||||
"'CheckpointConfig' type, "
|
||||
"but got {}.".format(type(config)))
|
||||
self._config = config
|
||||
|
||||
|
|
|
@ -879,8 +879,8 @@ class Model:
|
|||
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
||||
|
||||
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
|
||||
raise ValueError("Use Model.build to initialize model, but the value of parameter `epoch` in Model.build "
|
||||
"is not equal to value in Model.train, got {} and {} separately."
|
||||
raise ValueError("when use Model.build to initialize model, the value of parameter `epoch` in Model.build "
|
||||
"should be equal to value in Model.train, but got {} and {} separately."
|
||||
.format(train_dataset._warmup_epoch, epoch))
|
||||
|
||||
if dataset_sink_mode and _is_ps_mode():
|
||||
|
@ -893,7 +893,8 @@ class Model:
|
|||
if sink_size == -1:
|
||||
sink_size = dataset_size
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The argument 'sink_size' must be -1 or positive, but got {}.".format(sink_size))
|
||||
raise ValueError("For 'Model.train', The argument 'sink_size' must be -1 or positive, "
|
||||
"but got {}.".format(sink_size))
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
|
@ -1143,7 +1144,8 @@ class Model:
|
|||
if sink_size == -1:
|
||||
sink_size = dataset_size
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
raise ValueError("For 'infer_train_layout', the argument 'sink_size' must be -1 or positive, "
|
||||
"but got sink_size {}.".format(sink_size))
|
||||
|
||||
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
|
|
|
@ -276,11 +276,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
raise TypeError("For 'save_checkpoint', the argument 'save_obj' should be nn.Cell or list, "
|
||||
"but got {}.".format(type(save_obj)))
|
||||
if not isinstance(ckpt_file_name, str):
|
||||
raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be "
|
||||
raise TypeError("For 'save_checkpoint', the argument {} for checkpoint file name is invalid,"
|
||||
"'ckpt_file_name' must be "
|
||||
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
|
||||
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
||||
if os.path.isdir(ckpt_file_name):
|
||||
raise IsADirectoryError("The argument `ckpt_file_name`: {} is a directory, "
|
||||
raise IsADirectoryError("For 'save_checkpoint', the argument `ckpt_file_name`: {} is a directory, "
|
||||
"it should be a file name.".format(ckpt_file_name))
|
||||
if not ckpt_file_name.endswith('.ckpt'):
|
||||
ckpt_file_name += ".ckpt"
|
||||
|
@ -501,7 +502,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
continue
|
||||
data = element.tensor.tensor_content
|
||||
data_type = element.tensor.tensor_type
|
||||
np_type = tensor_to_np_type[data_type]
|
||||
np_type = tensor_to_np_type.get(data_type)
|
||||
ms_type = tensor_to_ms_type[data_type]
|
||||
element_data = np.frombuffer(data, np_type)
|
||||
param_data_list.append(element_data)
|
||||
|
@ -529,7 +530,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
|
||||
except BaseException as e:
|
||||
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
||||
raise ValueError(e.__str__() + "\nFailed to load the checkpoint file {}.".format(ckpt_file_name))
|
||||
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
|
||||
"failed to load the checkpoint file {}.".format(ckpt_file_name))
|
||||
|
||||
if not parameter_dict:
|
||||
raise ValueError(f"The loaded parameter dict is empty after filtering, please check whether "
|
||||
|
@ -564,7 +566,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
|||
if isinstance(filter_prefix, str):
|
||||
filter_prefix = (filter_prefix,)
|
||||
if not filter_prefix:
|
||||
raise ValueError("For 'load_checkpoint', the 'filter_prefix' can't be empty when "
|
||||
raise ValueError("For 'load_checkpoint', the argument 'filter_prefix' can't be empty when "
|
||||
"'filter_prefix' is list or tuple.")
|
||||
for index, prefix in enumerate(filter_prefix):
|
||||
if not isinstance(prefix, str):
|
||||
|
@ -865,7 +867,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|||
check_input_data(*inputs, data_class=Tensor)
|
||||
|
||||
if file_format == 'GEIR':
|
||||
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
|
||||
logger.warning(f"For 'export', format 'GEIR' is deprecated, "
|
||||
f"it would be removed in future release, use 'AIR' instead.")
|
||||
file_format = 'AIR'
|
||||
|
||||
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
||||
|
@ -1363,13 +1366,14 @@ def restore_group_info_list(group_info_file_name):
|
|||
>>> restore_list = restore_group_info_list("./group_info.pb")
|
||||
"""
|
||||
if not isinstance(group_info_file_name, str):
|
||||
raise TypeError(f"The group_info_file_name should be str, but got {type(group_info_file_name)}.")
|
||||
raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
|
||||
f"but got {type(group_info_file_name)}.")
|
||||
|
||||
if not os.path.isfile(group_info_file_name):
|
||||
raise ValueError(f"No such group info file: {group_info_file_name}.")
|
||||
raise ValueError(f"No such group information file: {group_info_file_name}.")
|
||||
|
||||
if os.path.getsize(group_info_file_name) == 0:
|
||||
raise ValueError("The group info file should not be empty.")
|
||||
raise ValueError("The group information file should not be empty.")
|
||||
|
||||
parallel_group_map = ParallelGroupMap()
|
||||
|
||||
|
@ -1379,7 +1383,7 @@ def restore_group_info_list(group_info_file_name):
|
|||
|
||||
restore_list = parallel_group_map.ckpt_restore_rank_list
|
||||
if not restore_list:
|
||||
raise ValueError("The group info file has no restore rank list.")
|
||||
raise ValueError("The group information file has no restore rank list.")
|
||||
|
||||
restore_rank_list = [rank for rank in restore_list.dim]
|
||||
return restore_rank_list
|
||||
|
@ -1405,7 +1409,7 @@ def build_searched_strategy(strategy_filename):
|
|||
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
|
||||
"""
|
||||
if not isinstance(strategy_filename, str):
|
||||
raise TypeError(f"For 'build_searched_strategy', the 'strategy_filename' should be string, "
|
||||
raise TypeError(f"For 'build_searched_strategy', the argument 'strategy_filename' should be string, "
|
||||
f"but got {type(strategy_filename)}.")
|
||||
|
||||
if not os.path.isfile(strategy_filename):
|
||||
|
@ -1474,14 +1478,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|||
Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
|
||||
"""
|
||||
if not isinstance(sliced_parameters, list):
|
||||
raise TypeError(f"For 'merge_sliced_parameter', the 'sliced_parameters' should be list, "
|
||||
raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
|
||||
f"but got {type(sliced_parameters)}.")
|
||||
|
||||
if not sliced_parameters:
|
||||
raise ValueError("For 'merge_sliced_parameter', the 'sliced_parameters' should not be empty.")
|
||||
raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
|
||||
|
||||
if strategy and not isinstance(strategy, dict):
|
||||
raise TypeError(f"For 'merge_sliced_parameter', the 'strategy' should be dict, but got {type(strategy)}.")
|
||||
raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
|
||||
f"but got {type(strategy)}.")
|
||||
|
||||
try:
|
||||
parameter_name = sliced_parameters[0].name
|
||||
|
@ -1576,7 +1581,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
train_dev_count *= dim
|
||||
if train_dev_count != ckpt_file_len:
|
||||
raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
|
||||
f"equal to the device count of training process. But the length of 'checkpoint_filenames'"
|
||||
f"equal to the device count of training process. "
|
||||
f"But got the length of 'checkpoint_filenames'"
|
||||
f" is {ckpt_file_len} and the device count is {train_dev_count}.")
|
||||
rank_list = _infer_rank_list(train_strategy, predict_strategy)
|
||||
|
||||
|
|
|
@ -90,7 +90,8 @@ def test_init_graph_cell_parameters_with_wrong_value_type():
|
|||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
|
||||
assert "For 'GraphCell', the key of the 'params_init' must be str, " \
|
||||
"and the value must be Tensor or Parameter" in str(err.value)
|
||||
remove_generated_file(mindir_name)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue