!32622 fix_white_list

Merge pull request !32622 from 冯一航/fix_white_list_master2
This commit is contained in:
i-robot 2022-04-09 02:34:34 +00:00 committed by Gitee
commit 92655adb3b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 62 additions and 44 deletions

View File

@ -680,8 +680,9 @@ class ParameterTuple(tuple):
names = set()
for x in data:
if not isinstance(x, Parameter):
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
f"But got a {type(iterable)}, {iterable}")
raise TypeError(f"For ParameterTuple initialization, "
f"ParameterTuple input should be 'Parameter' collection, "
f"but got a {type(iterable)}. ")
if id(x) not in ids:
if x.name in names:
raise ValueError("The value {} , its name '{}' already exists. "

View File

@ -156,7 +156,7 @@ class _Context:
def __getattribute__(self, attr):
value = object.__getattribute__(self, attr)
if attr == "_context_handle" and value is None:
raise ValueError("Context handle is none in context!!!")
raise ValueError("Get {} failed, please check whether 'env_config_path' is correct.".format(attr))
return value
def get_param(self, param):
@ -181,8 +181,10 @@ class _Context:
self.set_backend_policy("vm")
parallel_mode = _get_auto_parallel_context("parallel_mode")
if parallel_mode not in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE):
raise ValueError(f"Pynative Only support STAND_ALONE and DATA_PARALLEL for ParallelMode,"
f"but got {parallel_mode.upper()}.")
raise ValueError(f"Got {parallel_mode}, when the user enabled SEMI_AUTO_PARALELL or AUTO_PARALLEL, "
f"pynative mode dose not support, you should set "
f"context.set_auto_parallel_context(parallel_mode='data_parallel') "
f"or context.set_auto_parallel_context(parallel_mode='stand_alone').")
self._context_switches.push(True, None)
elif mode == GRAPH_MODE:
if self.enable_debug_runtime:
@ -282,15 +284,18 @@ class _Context:
def set_mempool_block_size(self, mempool_block_size):
"""Set the block size of memory pool."""
if _get_mode() == GRAPH_MODE:
logger.warning("Graph mode doesn't supportto set parameter 'mempool_block_size' of context currently")
logger.warning("Graph mode doesn't support to set parameter 'mempool_block_size' of context currently, "
"you can use context.set_context to set pynative mode.")
return
if not Validator.check_str_by_regular(mempool_block_size, _re_pattern):
raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be in "
"correct format! Such as \"10GB\"")
"correct format! Such as \"10GB\", "
"but got {}".format(mempool_block_size))
mempool_block_size_value = float(mempool_block_size[:-2])
if mempool_block_size_value < 1.0:
raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be "
"greater or equal to \"1GB\"")
"greater or equal to \"1GB\", "
"but got {}GB".format(float(mempool_block_size[:-2])))
self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value)
def set_print_file_path(self, file_path):
@ -630,9 +635,8 @@ def _check_target_specific_cfgs(device, arg_key):
if device in supported_devices:
return True
logger.warning(f"For 'context.set_context', "
f"the argument 'device_target' only supports devices in {supported_devices}, "
f"but got '{arg_key}', current device is '{device}'"
f", ignore it.")
f"the argument 'device_target' only supports devices in '{supported_devices}', "
f"but got '{device}', ignore it.")
return False

View File

@ -276,12 +276,12 @@ def _get_env_config():
def _check_directory_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
"""Check whether directory is legitimate."""
if not isinstance(target, str):
raise ValueError("The directory {} must be string, please check it".format(target))
raise ValueError("The directory {} must be string, but got {}, please check it".format(target, type(target)))
if reg is None:
reg = r"^[\/0-9a-zA-Z\_\-\.\:\\]+$"
if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flag'{}'".format(
target, prim_name, reg, flag))

View File

@ -435,7 +435,8 @@ class Cell(Cell_):
def _check_construct_args(self, *inputs, **kwargs):
"""Check the args needed by the function construct"""
if kwargs:
raise ValueError(f"Expect no kwargs here. maybe you pass wrong arguments, rgs: {inputs}, kwargs: {kwargs}")
raise ValueError(f"For 'Cell', expect no kwargs here, "
"maybe you pass wrong arguments, args: {inputs}, kwargs: {kwargs}")
positional_args = 0
default_args = 0
for value in inspect.signature(self.construct).parameters.values():
@ -569,7 +570,7 @@ class Cell(Cell_):
def __call__(self, *args, **kwargs):
if self.__class__.construct is Cell.construct:
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
f"will call the super class(Cell) 'construct'.")
f"it will call the super class(Cell) 'construct'.")
if kwargs:
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
bound_arguments.apply_defaults()
@ -2110,7 +2111,9 @@ class Cell(Cell_):
for key, _ in kwargs.items():
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
raise ValueError("Recompute keyword %s is not recognized!" % key)
raise ValueError("For 'recompute', keyword '%s' is not recognized! "
"the key kwargs must be 'mp_comm_recompute', "
"'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
def infer_param_pipeline_stage(self):
"""
@ -2232,7 +2235,7 @@ class GraphCell(Cell):
params_init = {} if params_init is None else params_init
if not isinstance(params_init, dict):
raise TypeError(f"The 'params_init' must be a dict, but got {type(params_init)}.")
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
for name, value in params_init.items():
if not isinstance(name, str) or not isinstance(value, Tensor):
raise TypeError("For 'GraphCell', the key of the 'params_init' must be str, "

View File

@ -435,7 +435,8 @@ class Optimizer(Cell):
self.dynamic_weight_decay = True
weight_decay = _WrappedWeightDecay(weight_decay, self.loss_scale)
else:
raise TypeError("For 'Optimizer', the argument 'Weight_decay' should be int, float or Cell.")
raise TypeError("For 'Optimizer', the argument 'Weight_decay' should be int, "
"float or Cell.but got {}".format(type(weight_decay)))
return weight_decay
def _preprocess_single_lr(self, learning_rate):
@ -455,9 +456,10 @@ class Optimizer(Cell):
raise ValueError(f"For 'Optimizer', if 'learning_rate' is Tensor type, then the dimension of it should "
f"be 0 or 1, but got {learning_rate.ndim}.")
if learning_rate.ndim == 1 and learning_rate.size < 2:
logger.warning("For 'Optimizer', if use `Tensor` type dynamic learning rate, "
logger.warning("For 'Optimizer', if use 'Tensor' type dynamic learning rate, "
"please make sure that the number "
"of elements in the tensor is greater than 1.")
"of elements in the tensor is greater than 1, "
"but got {}.".format(learning_rate.size))
return learning_rate
if isinstance(learning_rate, LearningRateSchedule):
return learning_rate
@ -570,7 +572,8 @@ class Optimizer(Cell):
for key in group_param.keys():
if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'):
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups, "
f"the key should in ['params', 'lr', 'weight_decay', 'grad_centralization']")
for param in group_param['params']:
validator.check_value_type("parameter", param, [Parameter], self.cls_name)

View File

@ -455,7 +455,8 @@ class _VirtualDatasetCell(Cell):
@constexpr
def _check_shape_value_on_axis_divided_by_target_value(input_shape, dim, param_name, cls_name, target_value):
if input_shape[dim] % target_value != 0:
raise ValueError(f"{cls_name} {param_name} at {dim} shape should be divided by {target_value},"
raise ValueError(f"For MicroBatchInterleaved initialization, "
f"{cls_name} {param_name} at {dim} shape should be divided by {target_value},"
f"but got {input_shape[dim]}")
return True

View File

@ -328,14 +328,15 @@ class CheckpointConfig:
if isinstance(element, dict):
dict_num += 1
if dict_num > 1:
raise TypeError(f"For 'CheckpointConfig', the element of 'append_info' must has only one dict.")
raise TypeError(f"For 'CheckpointConfig', the element of 'append_info' must has only one dict, "
"but got {dict_num}")
for key, value in element.items():
if isinstance(key, str) and isinstance(value, (int, float, bool)):
handle_append_info[key] = value
else:
raise TypeError(f"For 'CheckpointConfig', the type of dict in 'append_info' "
f"must be key: string, value: int or float, "
f"but got key: {type(key)}, value: {type(value)}")
raise TypeError(f"For 'CheckpointConfig', the key type of the dict 'append_info' "
f"must be string, the value type must be int or float or bool, "
f"but got key type {type(key)}, value type {type(value)}")
return handle_append_info
@ -374,7 +375,7 @@ class ModelCheckpoint(Callback):
if not isinstance(prefix, str) or prefix.find('/') >= 0:
raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
"for checkpoint file name is invalid, 'prefix' must be "
"for checkpoint file name is invalid, it must be "
"string and does not contain '/', but got {}.".format(prefix))
self._prefix = prefix
self._exception_prefix = prefix
@ -391,8 +392,8 @@ class ModelCheckpoint(Callback):
self._config = CheckpointConfig()
else:
if not isinstance(config, CheckpointConfig):
raise TypeError("For 'ModelCheckpoint', the argument 'config' should be "
"'CheckpointConfig' type, "
raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be "
"'CheckpointConfig', "
"but got {}.".format(type(config)))
self._config = config

View File

@ -143,5 +143,6 @@ class DatasetGraph:
elif value is None:
message.mapStr[key] = "None"
else:
logger.warning("The parameter %r is not recorded, because its type is not supported in event package. "
"Its type is %r.", key, type(value).__name__)
logger.warning("The parameter %r is not recorded, because its type is not supported in event package, "
"Its type should be in ['str', 'bool', 'int', 'float', '(list, tuple)', 'dict', 'None'], "
"but got type is %r.", key, type(value).__name__)

View File

@ -879,8 +879,9 @@ class Model:
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
raise ValueError("when use Model.build to initialize model, the value of parameter `epoch` in Model.build "
"should be equal to value in Model.train, but got {} and {} separately."
raise ValueError("when use Model.build to initialize model, the value of parameter 'epoch' in Model.build "
"should be equal to value in Model.train, but got the value of epoch in build {} and "
"the value of epoch in train {} separately."
.format(train_dataset._warmup_epoch, epoch))
if dataset_sink_mode and _is_ps_mode() and not _cache_enable():
@ -1063,7 +1064,7 @@ class Model:
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("The model argument 'metrics' can not be None or empty, "
raise ValueError("For Model.eval, the model argument 'metrics' can not be None or empty, "
"you should set the argument 'metrics' for model.")
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")

View File

@ -642,8 +642,10 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
logger.info("Loading parameters into net is finished.")
if param_not_load:
logger.warning("{} parameters in the 'net' are not loaded, because they are not in the "
"'parameter_dict'.".format(len(param_not_load)))
logger.warning("For 'load_param_into_net', "
"{} parameters in the 'net' are not loaded, because they are not in the "
"'parameter_dict', please check whether the network structure is consistent "
"when training and loading checkpoint.".format(len(param_not_load)))
for param_name in param_not_load:
logger.warning("{} is not loaded.".format(param_name))
return param_not_load
@ -1059,8 +1061,8 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
param_data = net_dict[param_name].data.asnumpy().tobytes()
param_proto.raw_data = param_data
else:
logger.critical("The parameter %s in the graph should also be defined in the network.", param_name)
raise ValueError("The parameter {} in the graph should also be defined in the "
logger.critical("The parameter '%s' in the graph should also be defined in the network.", param_name)
raise ValueError("The parameter '{}' in the graph should also be defined in the "
"network.".format(param_name))
if not file_name.endswith('.mindir'):
file_name += ".mindir"
@ -1087,7 +1089,8 @@ def _save_together(net_dict, model):
if name in net_dict.keys():
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
else:
raise ValueError("The parameter {} in the graph should also be defined in the network."
raise ValueError("The parameter '{}' in the graph should also be defined in the network."
.format(param_proto.name))
if data_total > TOTAL_SAVE:
return False
@ -1370,10 +1373,10 @@ def restore_group_info_list(group_info_file_name):
f"but got {type(group_info_file_name)}.")
if not os.path.isfile(group_info_file_name):
raise ValueError(f"No such group information file: {group_info_file_name}.")
raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
if os.path.getsize(group_info_file_name) == 0:
raise ValueError("The group information file should not be empty.")
raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
parallel_group_map = ParallelGroupMap()
@ -1383,7 +1386,7 @@ def restore_group_info_list(group_info_file_name):
restore_list = parallel_group_map.ckpt_restore_rank_list
if not restore_list:
raise ValueError("The group information file has no restore rank list.")
raise ValueError("For 'restore_group_info_list', the group information file has no restore rank list.")
restore_rank_list = [rank for rank in restore_list.dim]
return restore_rank_list
@ -1710,7 +1713,7 @@ def _check_predict_strategy(predict_strategy):
if not flag:
raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
f"the key of it must be string, and the value of it must be list or tuple that "
f"the first four elements are dev_matrix (list[int]), tensor_map (list[int]), "
f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (int, which value is 0)."
f"Please check whether 'predict_strategy' is correct.")

View File

@ -70,7 +70,7 @@ def test_init_graph_cell_parameters_with_wrong_type():
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "The 'params_init' must be a dict, but got" in str(err.value)
assert "For 'GraphCell', the argument 'params_init' must be a dict, but got" in str(err.value)
remove_generated_file(mindir_name)