forked from mindspore-Ecosystem/mindspore
!7583 [ME] change some format code.
Merge pull request !7583 from chenzhongming/zomi_master
This commit is contained in:
commit
8defa06a5d
|
@ -109,7 +109,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
|
|||
raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.')
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, '
|
||||
raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, '
|
||||
f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
|
||||
|
||||
return arg_value
|
||||
|
@ -130,7 +130,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
|
|||
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
||||
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
|
||||
return arg_value
|
||||
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
|
||||
raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
|
||||
|
||||
|
||||
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
|
||||
|
@ -146,7 +146,8 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
|
|||
arg_name = f'`{arg_name}`' if arg_name else ''
|
||||
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
|
||||
if type_mismatch:
|
||||
raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.')
|
||||
raise TypeError("{} {} must be `{}`, but got `{}`.".format(
|
||||
arg_name, prim_name, value_type.__name__, type(arg_value).__name__))
|
||||
if not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format(
|
||||
|
|
|
@ -135,7 +135,7 @@ class _MindSporeFunction:
|
|||
_exec_init_graph(self.obj, init_phase)
|
||||
|
||||
def compile(self, arguments_dict, method_name):
|
||||
"""Returns pipline for the given args."""
|
||||
"""Returns pipeline for the given args."""
|
||||
args_list = tuple(arguments_dict.values())
|
||||
arg_names = tuple(arguments_dict.keys())
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from ..ops.functional import cast
|
|||
from ..parallel._tensor import _load_tensor_by_layout
|
||||
from ..common.tensor import Tensor
|
||||
|
||||
|
||||
class Cell(Cell_):
|
||||
"""
|
||||
Base class for all neural networks.
|
||||
|
@ -579,7 +580,7 @@ class Cell(Cell_):
|
|||
|
||||
def cast_param(self, param):
|
||||
"""
|
||||
Cast parameter according to auto mix precison level in pynative mode.
|
||||
Cast parameter according to auto mix precision level in pynative mode.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter to cast.
|
||||
|
@ -594,15 +595,13 @@ class Cell(Cell_):
|
|||
param.set_cast_dtype()
|
||||
return param
|
||||
|
||||
def insert_child_to_cell(self, child_name, child):
|
||||
def insert_child_to_cell(self, child_name, child_cell):
|
||||
"""
|
||||
Adds a child cell to the current cell.
|
||||
|
||||
Inserts a subcell with a given name to the current cell.
|
||||
Adds a child cell to the current cell with a given name.
|
||||
|
||||
Args:
|
||||
child_name (str): Name of the child cell.
|
||||
child (Cell): The child cell to be inserted.
|
||||
child_cell (Cell): The child cell to be inserted.
|
||||
|
||||
Raises:
|
||||
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
||||
|
@ -612,15 +611,13 @@ class Cell(Cell_):
|
|||
raise KeyError("Child cell name is incorrect.")
|
||||
if hasattr(self, child_name) and child_name not in self._cells:
|
||||
raise KeyError("Duplicate child name '{}'.".format(child_name))
|
||||
if not isinstance(child, Cell) and child is not None:
|
||||
if not isinstance(child_cell, Cell) and child_cell is not None:
|
||||
raise TypeError("Child cell type is incorrect.")
|
||||
self._cells[child_name] = child
|
||||
self._cells[child_name] = child_cell
|
||||
|
||||
def construct(self, *inputs, **kwargs):
|
||||
"""
|
||||
Defines the computation to be performed.
|
||||
|
||||
This method must be overridden by all subclasses.
|
||||
Defines the computation to be performed. This method must be overridden by all subclasses.
|
||||
|
||||
Note:
|
||||
The inputs of the top cell only allow Tensor.
|
||||
|
|
|
@ -478,7 +478,7 @@ class Model:
|
|||
len_element = len(next_element)
|
||||
next_element = _transfer_tensor_to_tuple(next_element)
|
||||
if self._loss_fn and len_element != 2:
|
||||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
raise ValueError("when loss_fn is not None, train_dataset should "
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
export quantization aware training network to infer `AIR` backend.
|
||||
export network to infer `AIR` backend.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -27,14 +27,17 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net,
|
|||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet import LeNet5
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="",
|
||||
help='if mode is test, must provide path where the trained ckpt file')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--ckpt_path', type=str, default="",
|
||||
help='if mode is test, must provide path where the trained ckpt file')
|
||||
args = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
# define fusion network
|
||||
|
|
|
@ -30,23 +30,21 @@ from mindspore.train import Model
|
|||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.common import set_seed
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
args = parser.parse_args()
|
||||
set_seed(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
||||
cfg.batch_size)
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size)
|
||||
if ds_train.get_dataset_size() == 0:
|
||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||
|
||||
|
|
Loading…
Reference in New Issue