From c498e175b29863660f35b7ca42aa0cb48860a34b Mon Sep 17 00:00:00 2001 From: chenzomi Date: Wed, 21 Oct 2020 19:47:39 +0800 Subject: [PATCH] [ME] format code --- mindspore/_checkparam.py | 7 ++++--- mindspore/common/api.py | 2 +- mindspore/nn/cell.py | 19 ++++++++----------- mindspore/train/model.py | 2 +- model_zoo/official/cv/lenet/export.py | 19 +++++++++++-------- model_zoo/official/cv/lenet/train.py | 24 +++++++++++------------- 6 files changed, 36 insertions(+), 37 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index fefaa2db6d3..bd9a615ea8b 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -109,7 +109,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.') if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) - raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, ' + raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, ' f'but got `{arg_value}` with type `{type(arg_value).__name__}`.') return arg_value @@ -130,7 +130,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') return arg_value - raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') + raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`') def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): @@ -146,7 +146,8 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg arg_name = f'`{arg_name}`' if arg_name else '' type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) if type_mismatch: - raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.') + raise TypeError("{} {} must be `{}`, but got `{}`.".format( + arg_name, prim_name, value_type.__name__, type(arg_value).__name__)) if not rel_fn(arg_value, lower_limit, upper_limit): rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format( diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 7f368faea93..66439936142 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -135,7 +135,7 @@ class _MindSporeFunction: _exec_init_graph(self.obj, init_phase) def compile(self, arguments_dict, method_name): - """Returns pipline for the given args.""" + """Returns pipeline for the given args.""" args_list = tuple(arguments_dict.values()) arg_names = tuple(arguments_dict.keys()) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index b6191f8833c..e309d0f37de 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -32,6 +32,7 @@ from ..ops.functional import cast from ..parallel._tensor import _load_tensor_by_layout from ..common.tensor import Tensor + class Cell(Cell_): """ Base class for all neural networks. @@ -579,7 +580,7 @@ class Cell(Cell_): def cast_param(self, param): """ - Cast parameter according to auto mix precison level in pynative mode. + Cast parameter according to auto mix precision level in pynative mode. Args: param (Parameter): The parameter to cast. @@ -594,15 +595,13 @@ class Cell(Cell_): param.set_cast_dtype() return param - def insert_child_to_cell(self, child_name, child): + def insert_child_to_cell(self, child_name, child_cell): """ - Adds a child cell to the current cell. - - Inserts a subcell with a given name to the current cell. + Adds a child cell to the current cell with a given name. Args: child_name (str): Name of the child cell. - child (Cell): The child cell to be inserted. + child_cell (Cell): The child cell to be inserted. Raises: KeyError: Child Cell's name is incorrect or duplicated with the other child name. @@ -612,15 +611,13 @@ class Cell(Cell_): raise KeyError("Child cell name is incorrect.") if hasattr(self, child_name) and child_name not in self._cells: raise KeyError("Duplicate child name '{}'.".format(child_name)) - if not isinstance(child, Cell) and child is not None: + if not isinstance(child_cell, Cell) and child_cell is not None: raise TypeError("Child cell type is incorrect.") - self._cells[child_name] = child + self._cells[child_name] = child_cell def construct(self, *inputs, **kwargs): """ - Defines the computation to be performed. - - This method must be overridden by all subclasses. + Defines the computation to be performed. This method must be overridden by all subclasses. Note: The inputs of the top cell only allow Tensor. diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 0ef0a766a52..104b696664b 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -477,7 +477,7 @@ class Model: len_element = len(next_element) next_element = _transfer_tensor_to_tuple(next_element) if self._loss_fn and len_element != 2: - raise ValueError("when loss_fn is not None, train_dataset should" + raise ValueError("when loss_fn is not None, train_dataset should " "return two elements, but got {}".format(len_element)) cb_params.cur_step_num += 1 diff --git a/model_zoo/official/cv/lenet/export.py b/model_zoo/official/cv/lenet/export.py index 9cdec74ee67..c3861ac2f32 100644 --- a/model_zoo/official/cv/lenet/export.py +++ b/model_zoo/official/cv/lenet/export.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ -export quantization aware training network to infer `AIR` backend. +export network to infer `AIR` backend. """ import argparse @@ -27,14 +27,17 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net, from src.config import mnist_cfg as cfg from src.lenet import LeNet5 + +parser = argparse.ArgumentParser(description='MindSpore MNIST Example') +parser.add_argument('--device_target', type=str, default="Ascend", + choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--ckpt_path', type=str, default="", + help='if mode is test, must provide path where the trained ckpt file') +args = parser.parse_args() + + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='MindSpore MNIST Example') - parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU'], - help='device where the code will be implemented (default: Ascend)') - parser.add_argument('--ckpt_path', type=str, default="", - help='if mode is test, must provide path where the trained ckpt file') - args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) # define fusion network diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py index ea96a6da029..980b5e26b95 100644 --- a/model_zoo/official/cv/lenet/train.py +++ b/model_zoo/official/cv/lenet/train.py @@ -30,23 +30,21 @@ from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.common import set_seed + +parser = argparse.ArgumentParser(description='MindSpore Lenet Example') +parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--data_path', type=str, default="./Data", + help='path where the dataset is saved') +parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ + path where the trained ckpt file') +args = parser.parse_args() set_seed(1) + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='MindSpore Lenet Example') - parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], - help='device where the code will be implemented (default: Ascend)') - parser.add_argument('--data_path', type=str, default="./Data", - help='path where the dataset is saved') - parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ - path where the trained ckpt file') - - args = parser.parse_args() - - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_train = create_dataset(os.path.join(args.data_path, "train"), - cfg.batch_size) + ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size) if ds_train.get_dataset_size() == 0: raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")