!21639 modify code check for R1.3

Merge pull request !21639 from lilei/modify_code_check_R1.3
This commit is contained in:
i-robot 2021-08-26 12:17:55 +00:00 committed by Gitee
commit 24a80c1047
5 changed files with 25 additions and 15 deletions

View File

@ -530,7 +530,7 @@ class Parameter(Tensor_):
Initialize the parameter's data.
Args:
layout (Union[None, list(list(int))]): Parameter slice
layout (Union[None, tuple(list(int))]): Parameter slice
layout [dev_mat, tensor_map, slice_shape]. Default: None.
- dev_mat (list(int)): Device matrix.

View File

@ -162,6 +162,8 @@ class _AutoParallelContext:
Args:
loss_repeated_mean (bool): The loss_repeated_mean flag.
"""
if not isinstance(loss_repeated_mean, bool):
raise TypeError(f"The type of loss_repeated_mean must be bool, but got {type(loss_repeated_mean)}.")
self.check_context_handle()
self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
@ -229,7 +231,7 @@ class _AutoParallelContext:
Set strategy checkpoint load path.
Args:
strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint.
"""
self.check_context_handle()
self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
@ -323,8 +325,8 @@ class _AutoParallelContext:
if isinstance(indices, (list)):
for index in indices:
if not isinstance(index, int):
raise TypeError('indices has invalid value')
if not isinstance(index, int) or isinstance(index, bool):
raise TypeError(f"The type of index must be int, but got {type(index)}.")
else:
raise TypeError('indices must be a python list')
@ -372,8 +374,8 @@ class _AutoParallelContext:
self.check_context_handle()
if isinstance(sizes, (list)):
for size in sizes:
if not isinstance(size, int):
raise TypeError('sizes has invalid value')
if not isinstance(size, int) or isinstance(size, bool):
raise TypeError(f"The type of size must be int, but got {type(size)}.")
else:
raise TypeError('sizes must be a python list')
@ -453,6 +455,9 @@ class _AutoParallelContext:
Raises:
ValueError: If parallel mode is not supported.
"""
if not isinstance(communi_parallel_mode, str):
raise TypeError(f"The type of communi_parallel_mode must be str, \
but got {type(communi_parallel_mode)}.")
self.check_context_handle()
ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
if ret is False:
@ -472,8 +477,9 @@ class _AutoParallelContext:
optimizer across devices.
"""
self.check_context_handle()
if not isinstance(optimizer_weight_shard_size, int):
raise TypeError('optimizer_weight_shard_size is invalid type')
if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool):
raise TypeError(f"The type of optimizer_weight_shard_size must be int, \
but got {type(optimizer_weight_shard_size)}.")
if optimizer_weight_shard_size <= 1:
logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
"Please use the integer larger than 1.")

View File

@ -194,7 +194,7 @@ class _CostModelContext:
Set costmodel communication bias.
Args:
bias (float): A parameter used in adjusting communication calculation for practice.
communi_bias (float): A parameter used in adjusting communication calculation for practice.
Raises:
ValueError: If context handle is none.
@ -249,6 +249,8 @@ class _CostModelContext:
Raises:
ValueError: If context handle is none, or phase is not in {0, 1}.
"""
if not isinstance(phase, int) or isinstance(phase, bool):
raise TypeError(f"The type of communi_const must be int, but got {type(phase)}.")
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
if phase not in (0, 1):
@ -276,6 +278,8 @@ class _CostModelContext:
Raises:
ValueError: If context handle is none.
"""
if not isinstance(single_loop, bool):
raise TypeError(f"The type of single_loop must be bool, but got {type(single_loop)}.")
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_dp_algo_single_loop(single_loop)

View File

@ -72,14 +72,14 @@ def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"):
return
finally:
pass
if isinstance(group, (str)):
if isinstance(group, str):
group_len = len(group)
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if isinstance(idx_list, (list)):
if isinstance(idx_list, list):
idx_len = len(idx_list)
if idx_len == 0:
raise ValueError('idx_list length is 0')
@ -87,7 +87,7 @@ def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"):
raise TypeError('idx_list must be a python list')
for idx in idx_list:
if isinstance(idx, (int)):
if isinstance(idx, int):
if idx < 0:
raise ValueError('Idx < 0')
else:
@ -133,13 +133,13 @@ def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"):
finally:
pass
if isinstance(group, (str)):
if isinstance(group, str):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if isinstance(data_size_list, (list)):
if isinstance(data_size_list, list):
len_data_size = len(data_size_list)
if len_data_size == 0:
raise ValueError('data_size_list length is 0')

View File

@ -21,7 +21,7 @@ from mindspore._checkparam import args_type_check
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
class _AlgoParameterConfig():
class _AlgoParameterConfig:
"""
_AlgoParameterConfig is the configuration of setting parameters used in th algorithm.