!7583 [ME] change some format code.

Merge pull request !7583 from chenzhongming/zomi_master
This commit is contained in:
mindspore-ci-bot 2020-10-23 09:28:52 +08:00 committed by Gitee
commit 8defa06a5d
6 changed files with 36 additions and 37 deletions

View File

@ -109,7 +109,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.')
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, '
raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, '
f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
return arg_value
@ -130,7 +130,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
return arg_value
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
@ -146,7 +146,8 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
arg_name = f'`{arg_name}`' if arg_name else ''
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
if type_mismatch:
raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.')
raise TypeError("{} {} must be `{}`, but got `{}`.".format(
arg_name, prim_name, value_type.__name__, type(arg_value).__name__))
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format(

View File

@ -135,7 +135,7 @@ class _MindSporeFunction:
_exec_init_graph(self.obj, init_phase)
def compile(self, arguments_dict, method_name):
"""Returns pipline for the given args."""
"""Returns pipeline for the given args."""
args_list = tuple(arguments_dict.values())
arg_names = tuple(arguments_dict.keys())

View File

@ -32,6 +32,7 @@ from ..ops.functional import cast
from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor
class Cell(Cell_):
"""
Base class for all neural networks.
@ -579,7 +580,7 @@ class Cell(Cell_):
def cast_param(self, param):
"""
Cast parameter according to auto mix precison level in pynative mode.
Cast parameter according to auto mix precision level in pynative mode.
Args:
param (Parameter): The parameter to cast.
@ -594,15 +595,13 @@ class Cell(Cell_):
param.set_cast_dtype()
return param
def insert_child_to_cell(self, child_name, child):
def insert_child_to_cell(self, child_name, child_cell):
"""
Adds a child cell to the current cell.
Inserts a subcell with a given name to the current cell.
Adds a child cell to the current cell with a given name.
Args:
child_name (str): Name of the child cell.
child (Cell): The child cell to be inserted.
child_cell (Cell): The child cell to be inserted.
Raises:
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
@ -612,15 +611,13 @@ class Cell(Cell_):
raise KeyError("Child cell name is incorrect.")
if hasattr(self, child_name) and child_name not in self._cells:
raise KeyError("Duplicate child name '{}'.".format(child_name))
if not isinstance(child, Cell) and child is not None:
if not isinstance(child_cell, Cell) and child_cell is not None:
raise TypeError("Child cell type is incorrect.")
self._cells[child_name] = child
self._cells[child_name] = child_cell
def construct(self, *inputs, **kwargs):
"""
Defines the computation to be performed.
This method must be overridden by all subclasses.
Defines the computation to be performed. This method must be overridden by all subclasses.
Note:
The inputs of the top cell only allow Tensor.

View File

@ -478,7 +478,7 @@ class Model:
len_element = len(next_element)
next_element = _transfer_tensor_to_tuple(next_element)
if self._loss_fn and len_element != 2:
raise ValueError("when loss_fn is not None, train_dataset should"
raise ValueError("when loss_fn is not None, train_dataset should "
"return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
export quantization aware training network to infer `AIR` backend.
export network to infer `AIR` backend.
"""
import argparse
@ -27,14 +27,17 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net,
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# define fusion network

View File

@ -30,23 +30,21 @@ from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.common import set_seed
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
args = parser.parse_args()
set_seed(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")