forked from mindspore-Ecosystem/mindspore
!7583 [ME] change some format code.
Merge pull request !7583 from chenzhongming/zomi_master
This commit is contained in:
commit
8defa06a5d
|
@ -109,7 +109,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
|
||||||
raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.')
|
raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.')
|
||||||
if type_mismatch or not rel_fn(arg_value, value):
|
if type_mismatch or not rel_fn(arg_value, value):
|
||||||
rel_str = Rel.get_strs(rel).format(value)
|
rel_str = Rel.get_strs(rel).format(value)
|
||||||
raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, '
|
raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, '
|
||||||
f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
|
f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
|
||||||
|
|
||||||
return arg_value
|
return arg_value
|
||||||
|
@ -130,7 +130,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
|
||||||
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
||||||
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
|
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')
|
||||||
return arg_value
|
return arg_value
|
||||||
raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`')
|
raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
|
||||||
|
|
||||||
|
|
||||||
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
|
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
|
||||||
|
@ -146,7 +146,8 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
|
||||||
arg_name = f'`{arg_name}`' if arg_name else ''
|
arg_name = f'`{arg_name}`' if arg_name else ''
|
||||||
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
|
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
|
||||||
if type_mismatch:
|
if type_mismatch:
|
||||||
raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.')
|
raise TypeError("{} {} must be `{}`, but got `{}`.".format(
|
||||||
|
arg_name, prim_name, value_type.__name__, type(arg_value).__name__))
|
||||||
if not rel_fn(arg_value, lower_limit, upper_limit):
|
if not rel_fn(arg_value, lower_limit, upper_limit):
|
||||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||||
raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format(
|
raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format(
|
||||||
|
|
|
@ -135,7 +135,7 @@ class _MindSporeFunction:
|
||||||
_exec_init_graph(self.obj, init_phase)
|
_exec_init_graph(self.obj, init_phase)
|
||||||
|
|
||||||
def compile(self, arguments_dict, method_name):
|
def compile(self, arguments_dict, method_name):
|
||||||
"""Returns pipline for the given args."""
|
"""Returns pipeline for the given args."""
|
||||||
args_list = tuple(arguments_dict.values())
|
args_list = tuple(arguments_dict.values())
|
||||||
arg_names = tuple(arguments_dict.keys())
|
arg_names = tuple(arguments_dict.keys())
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ from ..ops.functional import cast
|
||||||
from ..parallel._tensor import _load_tensor_by_layout
|
from ..parallel._tensor import _load_tensor_by_layout
|
||||||
from ..common.tensor import Tensor
|
from ..common.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
class Cell(Cell_):
|
class Cell(Cell_):
|
||||||
"""
|
"""
|
||||||
Base class for all neural networks.
|
Base class for all neural networks.
|
||||||
|
@ -579,7 +580,7 @@ class Cell(Cell_):
|
||||||
|
|
||||||
def cast_param(self, param):
|
def cast_param(self, param):
|
||||||
"""
|
"""
|
||||||
Cast parameter according to auto mix precison level in pynative mode.
|
Cast parameter according to auto mix precision level in pynative mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
param (Parameter): The parameter to cast.
|
param (Parameter): The parameter to cast.
|
||||||
|
@ -594,15 +595,13 @@ class Cell(Cell_):
|
||||||
param.set_cast_dtype()
|
param.set_cast_dtype()
|
||||||
return param
|
return param
|
||||||
|
|
||||||
def insert_child_to_cell(self, child_name, child):
|
def insert_child_to_cell(self, child_name, child_cell):
|
||||||
"""
|
"""
|
||||||
Adds a child cell to the current cell.
|
Adds a child cell to the current cell with a given name.
|
||||||
|
|
||||||
Inserts a subcell with a given name to the current cell.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
child_name (str): Name of the child cell.
|
child_name (str): Name of the child cell.
|
||||||
child (Cell): The child cell to be inserted.
|
child_cell (Cell): The child cell to be inserted.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
KeyError: Child Cell's name is incorrect or duplicated with the other child name.
|
||||||
|
@ -612,15 +611,13 @@ class Cell(Cell_):
|
||||||
raise KeyError("Child cell name is incorrect.")
|
raise KeyError("Child cell name is incorrect.")
|
||||||
if hasattr(self, child_name) and child_name not in self._cells:
|
if hasattr(self, child_name) and child_name not in self._cells:
|
||||||
raise KeyError("Duplicate child name '{}'.".format(child_name))
|
raise KeyError("Duplicate child name '{}'.".format(child_name))
|
||||||
if not isinstance(child, Cell) and child is not None:
|
if not isinstance(child_cell, Cell) and child_cell is not None:
|
||||||
raise TypeError("Child cell type is incorrect.")
|
raise TypeError("Child cell type is incorrect.")
|
||||||
self._cells[child_name] = child
|
self._cells[child_name] = child_cell
|
||||||
|
|
||||||
def construct(self, *inputs, **kwargs):
|
def construct(self, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Defines the computation to be performed.
|
Defines the computation to be performed. This method must be overridden by all subclasses.
|
||||||
|
|
||||||
This method must be overridden by all subclasses.
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The inputs of the top cell only allow Tensor.
|
The inputs of the top cell only allow Tensor.
|
||||||
|
|
|
@ -478,7 +478,7 @@ class Model:
|
||||||
len_element = len(next_element)
|
len_element = len(next_element)
|
||||||
next_element = _transfer_tensor_to_tuple(next_element)
|
next_element = _transfer_tensor_to_tuple(next_element)
|
||||||
if self._loss_fn and len_element != 2:
|
if self._loss_fn and len_element != 2:
|
||||||
raise ValueError("when loss_fn is not None, train_dataset should"
|
raise ValueError("when loss_fn is not None, train_dataset should "
|
||||||
"return two elements, but got {}".format(len_element))
|
"return two elements, but got {}".format(len_element))
|
||||||
cb_params.cur_step_num += 1
|
cb_params.cur_step_num += 1
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""
|
"""
|
||||||
export quantization aware training network to infer `AIR` backend.
|
export network to infer `AIR` backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -27,14 +27,17 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net,
|
||||||
from src.config import mnist_cfg as cfg
|
from src.config import mnist_cfg as cfg
|
||||||
from src.lenet import LeNet5
|
from src.lenet import LeNet5
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
|
choices=['Ascend', 'GPU'],
|
||||||
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default="",
|
||||||
|
help='if mode is test, must provide path where the trained ckpt file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
|
||||||
choices=['Ascend', 'GPU'],
|
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
|
||||||
parser.add_argument('--ckpt_path', type=str, default="",
|
|
||||||
help='if mode is test, must provide path where the trained ckpt file')
|
|
||||||
args = parser.parse_args()
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
|
||||||
# define fusion network
|
# define fusion network
|
||||||
|
|
|
@ -30,23 +30,21 @@ from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||||
|
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||||
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
|
parser.add_argument('--data_path', type=str, default="./Data",
|
||||||
|
help='path where the dataset is saved')
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||||
|
path where the trained ckpt file')
|
||||||
|
args = parser.parse_args()
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
|
||||||
parser.add_argument('--data_path', type=str, default="./Data",
|
|
||||||
help='path where the dataset is saved')
|
|
||||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
|
||||||
path where the trained ckpt file')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size)
|
||||||
cfg.batch_size)
|
|
||||||
if ds_train.get_dataset_size() == 0:
|
if ds_train.get_dataset_size() == 0:
|
||||||
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue