[ME] change `check_integer` to format `check_positive_int` and `check_integeter`
This commit is contained in:
parent
d4e8e94981
commit
d471d32e87
|
@ -92,6 +92,25 @@ rel_strs = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None):
|
||||||
|
"""
|
||||||
|
Check argument integer.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0
|
||||||
|
"""
|
||||||
|
rel_fn = Rel.get_fns(rel)
|
||||||
|
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||||
|
type_except = TypeError if type_mismatch else ValueError
|
||||||
|
if type_mismatch or not rel_fn(arg_value, value):
|
||||||
|
rel_str = Rel.get_strs(rel).format(value)
|
||||||
|
arg_name = arg_name if arg_name else "parameter"
|
||||||
|
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||||
|
raise type_except(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
|
||||||
|
f' with type `{type(arg_value).__name__}`.')
|
||||||
|
return arg_value
|
||||||
|
|
||||||
|
|
||||||
class Validator:
|
class Validator:
|
||||||
"""validator for checking input parameters"""
|
"""validator for checking input parameters"""
|
||||||
|
|
||||||
|
@ -121,6 +140,49 @@ class Validator:
|
||||||
f' with type `{type(arg_value).__name__}`.')
|
f' with type `{type(arg_value).__name__}`.')
|
||||||
return arg_value
|
return arg_value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_positive_int(arg_value, arg_name=None, prim_name=None):
|
||||||
|
"""
|
||||||
|
Check argument is positive integer, which mean arg_value > 0.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- number = check_positive_int(number)
|
||||||
|
- number = check_positive_int(number, "bias")
|
||||||
|
"""
|
||||||
|
return _check_integer(arg_value, 0, Rel.GT, arg_name, prim_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_negative_int(arg_value, arg_name=None, prim_name=None):
|
||||||
|
"""
|
||||||
|
Check argument is negative integer, which mean arg_value < 0.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- number = check_negative_int(number)
|
||||||
|
- number = check_negative_int(number, "bias")
|
||||||
|
"""
|
||||||
|
return _check_integer(arg_value, 0, Rel.LT, arg_name, prim_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
|
||||||
|
"""
|
||||||
|
Check argument is non-negative integer, which mean arg_value <= 0.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- number = check_non_positive_int(number)
|
||||||
|
- number = check_non_positive_int(number, "bias")
|
||||||
|
"""
|
||||||
|
return _check_integer(arg_value, 0, Rel.LE, arg_name, prim_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
|
||||||
|
"""
|
||||||
|
Check argument is non-negative integer, which mean arg_value >= 0.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- number = check_non_negative_int(number)
|
||||||
|
- number = check_non_negative_int(number, "bias")
|
||||||
|
"""
|
||||||
|
return _check_integer(arg_value, 0, Rel.GE, arg_name, prim_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_number(arg_name, arg_value, value, rel, prim_name):
|
def check_number(arg_name, arg_value, value, rel, prim_name):
|
||||||
|
@ -140,7 +202,13 @@ class Validator:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_bool(arg_value, arg_name=None):
|
def check_bool(arg_value, arg_name=None):
|
||||||
"""Check argument is instance of bool"""
|
"""
|
||||||
|
Check argument is instance of bool.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- has_bias = check_bool(has_bias)
|
||||||
|
- has_bias = check_bool(has_bias, "has_bias")
|
||||||
|
"""
|
||||||
if not isinstance(arg_value, bool):
|
if not isinstance(arg_value, bool):
|
||||||
arg_name = arg_name if arg_name else "Parameter"
|
arg_name = arg_name if arg_name else "Parameter"
|
||||||
raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.')
|
raise TypeError(f'`{arg_name}` should be isinstance of bool, but got `{arg_value}`.')
|
||||||
|
@ -169,7 +237,12 @@ class Validator:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
||||||
"""Checks whether a string is in some value list"""
|
"""
|
||||||
|
Check whether string is in some value list.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- method = check_string(method, ["string1", "string2", "string3"], "method")
|
||||||
|
"""
|
||||||
if isinstance(arg_value, str) and arg_value in valid_values:
|
if isinstance(arg_value, str) and arg_value in valid_values:
|
||||||
return arg_value
|
return arg_value
|
||||||
arg_name = arg_name if arg_name else "Parameter"
|
arg_name = arg_name if arg_name else "Parameter"
|
||||||
|
@ -372,28 +445,6 @@ def check_int(input_param):
|
||||||
raise TypeError("Input type must be int!")
|
raise TypeError("Input type must be int!")
|
||||||
|
|
||||||
|
|
||||||
def check_int_positive(input_param):
|
|
||||||
"""Int type judgment."""
|
|
||||||
if isinstance(input_param, bool):
|
|
||||||
raise TypeError("Input type must be int cannot be bool!")
|
|
||||||
if isinstance(input_param, int):
|
|
||||||
if input_param > 0:
|
|
||||||
return input_param
|
|
||||||
raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param))
|
|
||||||
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
|
|
||||||
|
|
||||||
|
|
||||||
def check_int_non_negative(input_param):
|
|
||||||
"""Non_negative type judgment."""
|
|
||||||
if isinstance(input_param, bool):
|
|
||||||
raise TypeError("Input type must be int cannot be bool!")
|
|
||||||
if isinstance(input_param, int):
|
|
||||||
if input_param >= 0:
|
|
||||||
return input_param
|
|
||||||
raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param))
|
|
||||||
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
|
|
||||||
|
|
||||||
|
|
||||||
def check_int_zero_one(input_param):
|
def check_int_zero_one(input_param):
|
||||||
"""Judge whether it is 0 or 1."""
|
"""Judge whether it is 0 or 1."""
|
||||||
if input_param in (0, 1):
|
if input_param in (0, 1):
|
||||||
|
|
|
@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates):
|
||||||
lr = []
|
lr = []
|
||||||
last_item = 0
|
last_item = 0
|
||||||
for i, item in enumerate(milestone):
|
for i, item in enumerate(milestone):
|
||||||
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None)
|
validator.check_positive_int(item, f'milestone[{i}]')
|
||||||
validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None)
|
validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None)
|
||||||
if item < last_item:
|
if item < last_item:
|
||||||
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
|
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
|
||||||
|
@ -63,9 +63,9 @@ def piecewise_constant_lr(milestone, learning_rates):
|
||||||
|
|
||||||
|
|
||||||
def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
|
def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
|
||||||
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
|
validator.check_positive_int(total_step, 'total_step')
|
||||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(decay_epoch, 'decay_epoch')
|
||||||
validator.check_float_positive('learning_rate', learning_rate, None)
|
validator.check_float_positive('learning_rate', learning_rate, None)
|
||||||
validator.check_float_legal_value('learning_rate', learning_rate, None)
|
validator.check_float_legal_value('learning_rate', learning_rate, None)
|
||||||
validator.check_float_positive('decay_rate', decay_rate, None)
|
validator.check_float_positive('decay_rate', decay_rate, None)
|
||||||
|
@ -236,9 +236,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
|
||||||
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
|
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||||
validator.check_float_positive('max_lr', max_lr, None)
|
validator.check_float_positive('max_lr', max_lr, None)
|
||||||
validator.check_float_legal_value('max_lr', max_lr, None)
|
validator.check_float_legal_value('max_lr', max_lr, None)
|
||||||
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
|
validator.check_positive_int(total_step, 'total_step')
|
||||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(decay_epoch, 'decay_epoch')
|
||||||
if min_lr >= max_lr:
|
if min_lr >= max_lr:
|
||||||
raise ValueError('`max_lr` should be greater than `min_lr`.')
|
raise ValueError('`max_lr` should be greater than `min_lr`.')
|
||||||
|
|
||||||
|
@ -306,9 +306,9 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
|
||||||
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
|
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||||
validator.check_float_positive('power', power, None)
|
validator.check_float_positive('power', power, None)
|
||||||
validator.check_float_legal_value('power', power, None)
|
validator.check_float_legal_value('power', power, None)
|
||||||
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
|
validator.check_positive_int(total_step, 'total_step')
|
||||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||||
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(decay_epoch, 'decay_epoch')
|
||||||
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)
|
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)
|
||||||
|
|
||||||
origin_decay_epoch = decay_epoch
|
origin_decay_epoch = decay_epoch
|
||||||
|
@ -357,9 +357,9 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
|
||||||
if not isinstance(learning_rate, float):
|
if not isinstance(learning_rate, float):
|
||||||
raise TypeError("learning_rate must be float.")
|
raise TypeError("learning_rate must be float.")
|
||||||
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
|
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||||
validator.check_integer('warmup_epoch', warmup_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(warmup_epoch, 'warmup_epoch')
|
||||||
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
|
validator.check_positive_int(total_step, 'total_step')
|
||||||
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
|
validator.check_positive_int(step_per_epoch, 'step_per_epoch')
|
||||||
|
|
||||||
function = lambda x, y: (x, min(x, y))
|
function = lambda x, y: (x, min(x, y))
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore._checkparam import Rel, Validator, check_int_positive
|
from mindspore._checkparam import Rel, Validator
|
||||||
from mindspore.common.api import ms_function
|
from mindspore.common.api import ms_function
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
@ -203,8 +203,8 @@ class Dense(Cell):
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(Dense, self).__init__()
|
super(Dense, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
|
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer, Initializer
|
from mindspore.common.initializer import initializer, Initializer
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import Validator, Rel, twice, check_int_positive
|
from mindspore._checkparam import Validator, Rel, twice
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
|
||||||
|
@ -47,8 +47,8 @@ class _Conv(Cell):
|
||||||
bias_init,
|
bias_init,
|
||||||
transposed=False):
|
transposed=False):
|
||||||
super(_Conv, self).__init__()
|
super(_Conv, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
|
@ -65,7 +65,7 @@ class _Conv(Cell):
|
||||||
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
|
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
|
||||||
|
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
self.group = check_int_positive(group)
|
self.group = Validator.check_positive_int(group)
|
||||||
self.has_bias = has_bias
|
self.has_bias = has_bias
|
||||||
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
|
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
|
||||||
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
|
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
|
||||||
from mindspore.communication.management import get_group_size
|
from mindspore.communication.management import get_group_size
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.parallel._utils import _get_parallel_mode
|
from mindspore.parallel._utils import _get_parallel_mode
|
||||||
from mindspore._checkparam import Rel, Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
|
||||||
__all__ = ['Embedding', 'EmbeddingLookup']
|
__all__ = ['Embedding', 'EmbeddingLookup']
|
||||||
|
@ -170,7 +170,7 @@ class EmbeddingLookup(Cell):
|
||||||
if not isinstance(manual_shapes, tuple):
|
if not isinstance(manual_shapes, tuple):
|
||||||
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
|
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
|
||||||
for dim in manual_shapes:
|
for dim in manual_shapes:
|
||||||
validator.check_integer('manul shape dim', dim, 0, Rel.GT, self.cls_name)
|
validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
|
||||||
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
||||||
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
|
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
|
||||||
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""lstm"""
|
"""lstm"""
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore._checkparam import Rel, Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -103,8 +103,8 @@ class LSTM(Cell):
|
||||||
bidirectional=False):
|
bidirectional=False):
|
||||||
super(LSTM, self).__init__()
|
super(LSTM, self).__init__()
|
||||||
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
|
||||||
validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name)
|
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
|
||||||
validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name)
|
validator.check_positive_int(num_layers, "num_layers", self.cls_name)
|
||||||
|
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
self.transpose = P.Transpose()
|
self.transpose = P.Transpose()
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..._checkparam import Rel, Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul']
|
__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul']
|
||||||
|
@ -156,7 +156,7 @@ class LinSpace(Cell):
|
||||||
validator.check_value_type("start", start, [int, float], self.cls_name)
|
validator.check_value_type("start", start, [int, float], self.cls_name)
|
||||||
validator.check_value_type("stop", stop, [int, float], self.cls_name)
|
validator.check_value_type("stop", stop, [int, float], self.cls_name)
|
||||||
validator.check_value_type("num", num, [int], self.cls_name)
|
validator.check_value_type("num", num, [int], self.cls_name)
|
||||||
validator.check_integer("num", num, 0, Rel.GT, self.cls_name)
|
validator.check_positive_int(num, "num", self.cls_name)
|
||||||
|
|
||||||
self.is_single = bool(num == 1)
|
self.is_single = bool(num == 1)
|
||||||
self.lin_space = inner.LinSpace()
|
self.lin_space = inner.LinSpace()
|
||||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore._checkparam import Validator, check_typename, check_int_positive
|
from mindspore._checkparam import Validator, check_typename
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.communication.management import get_group_size, get_rank
|
from mindspore.communication.management import get_group_size, get_rank
|
||||||
from mindspore.communication import management
|
from mindspore.communication import management
|
||||||
|
@ -64,7 +64,7 @@ class _BatchNorm(Cell):
|
||||||
gamma_init, num_features), name="gamma", requires_grad=affine)
|
gamma_init, num_features), name="gamma", requires_grad=affine)
|
||||||
self.beta = Parameter(initializer(
|
self.beta = Parameter(initializer(
|
||||||
beta_init, num_features), name="beta", requires_grad=affine)
|
beta_init, num_features), name="beta", requires_grad=affine)
|
||||||
self.group = check_int_positive(device_num_each_group)
|
self.group = Validator.check_positive_int(device_num_each_group)
|
||||||
self.is_global = False
|
self.is_global = False
|
||||||
if self.group != 1:
|
if self.group != 1:
|
||||||
self.rank_id = get_rank()
|
self.rank_id = get_rank()
|
||||||
|
@ -464,7 +464,7 @@ class GlobalBatchNorm(_BatchNorm):
|
||||||
use_batch_statistics,
|
use_batch_statistics,
|
||||||
device_num_each_group,
|
device_num_each_group,
|
||||||
input_dims='both')
|
input_dims='both')
|
||||||
self.group = check_int_positive(device_num_each_group)
|
self.group = Validator.check_positive_int(device_num_each_group)
|
||||||
if self.group <= 1:
|
if self.group <= 1:
|
||||||
raise ValueError("the number of group must be greater than 1.")
|
raise ValueError("the number of group must be greater than 1.")
|
||||||
|
|
||||||
|
@ -599,8 +599,8 @@ class GroupNorm(Cell):
|
||||||
|
|
||||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
|
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
|
||||||
super(GroupNorm, self).__init__()
|
super(GroupNorm, self).__init__()
|
||||||
self.num_groups = check_int_positive(num_groups)
|
self.num_groups = Validator.check_positive_int(num_groups)
|
||||||
self.num_channels = check_int_positive(num_channels)
|
self.num_channels = Validator.check_positive_int(num_channels)
|
||||||
if num_channels % num_groups != 0:
|
if num_channels % num_groups != 0:
|
||||||
raise ValueError("num_channels should be divided by num_groups")
|
raise ValueError("num_channels should be divided by num_groups")
|
||||||
self.eps = check_typename('eps', eps, (float,))
|
self.eps = check_typename('eps', eps, (float,))
|
||||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import Validator, Rel, check_int_positive, twice
|
from mindspore._checkparam import Validator, Rel, twice
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from .normalization import BatchNorm2d, BatchNorm1d
|
from .normalization import BatchNorm2d, BatchNorm1d
|
||||||
from .activation import get_activation, ReLU, LeakyReLU
|
from .activation import get_activation, ReLU, LeakyReLU
|
||||||
|
@ -657,8 +657,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
||||||
self.kernel_size = (kernel_size, kernel_size)
|
self.kernel_size = (kernel_size, kernel_size)
|
||||||
else:
|
else:
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = has_bias
|
self.has_bias = has_bias
|
||||||
self.stride = twice(stride)
|
self.stride = twice(stride)
|
||||||
self.dilation = twice(dilation)
|
self.dilation = twice(dilation)
|
||||||
|
@ -785,8 +785,8 @@ class Conv2dQuant(Cell):
|
||||||
self.kernel_size = (kernel_size, kernel_size)
|
self.kernel_size = (kernel_size, kernel_size)
|
||||||
else:
|
else:
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = has_bias
|
self.has_bias = has_bias
|
||||||
self.stride = twice(stride)
|
self.stride = twice(stride)
|
||||||
self.dilation = twice(dilation)
|
self.dilation = twice(dilation)
|
||||||
|
@ -886,8 +886,8 @@ class DenseQuant(Cell):
|
||||||
narrow_range=False,
|
narrow_range=False,
|
||||||
quant_delay=0):
|
quant_delay=0):
|
||||||
super(DenseQuant, self).__init__()
|
super(DenseQuant, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
|
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
|
|
|
@ -44,7 +44,7 @@ class LearningRateSchedule(Cell):
|
||||||
|
|
||||||
|
|
||||||
def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):
|
def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name):
|
||||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, cls_name)
|
validator.check_positive_int(decay_steps, 'decay_steps', cls_name)
|
||||||
validator.check_float_positive('learning_rate', learning_rate, cls_name)
|
validator.check_float_positive('learning_rate', learning_rate, cls_name)
|
||||||
validator.check_float_legal_value('learning_rate', learning_rate, cls_name)
|
validator.check_float_legal_value('learning_rate', learning_rate, cls_name)
|
||||||
validator.check_float_positive('decay_rate', decay_rate, cls_name)
|
validator.check_float_positive('decay_rate', decay_rate, cls_name)
|
||||||
|
@ -257,7 +257,7 @@ class CosineDecayLR(LearningRateSchedule):
|
||||||
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||||
validator.check_float_positive('max_lr', max_lr, self.cls_name)
|
validator.check_float_positive('max_lr', max_lr, self.cls_name)
|
||||||
validator.check_float_legal_value('max_lr', max_lr, self.cls_name)
|
validator.check_float_legal_value('max_lr', max_lr, self.cls_name)
|
||||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name)
|
validator.check_positive_int(decay_steps, "decay_steps", self.cls_name)
|
||||||
if min_lr >= max_lr:
|
if min_lr >= max_lr:
|
||||||
raise ValueError('`max_lr` should be greater than `min_lr`.')
|
raise ValueError('`max_lr` should be greater than `min_lr`.')
|
||||||
self.min_lr = min_lr
|
self.min_lr = min_lr
|
||||||
|
@ -324,7 +324,7 @@ class PolynomialDecayLR(LearningRateSchedule):
|
||||||
raise TypeError("end_learning_rate must be float.")
|
raise TypeError("end_learning_rate must be float.")
|
||||||
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
|
validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT,
|
||||||
self.cls_name)
|
self.cls_name)
|
||||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, self.cls_name)
|
validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name)
|
||||||
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
|
validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name)
|
||||||
validator.check_float_positive('power', power, self.cls_name)
|
validator.check_float_positive('power', power, self.cls_name)
|
||||||
validator.check_float_legal_value('power', power, self.cls_name)
|
validator.check_float_legal_value('power', power, self.cls_name)
|
||||||
|
@ -388,7 +388,7 @@ class WarmUpLR(LearningRateSchedule):
|
||||||
if not isinstance(learning_rate, float):
|
if not isinstance(learning_rate, float):
|
||||||
raise TypeError("learning_rate must be float.")
|
raise TypeError("learning_rate must be float.")
|
||||||
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||||
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GT, self.cls_name)
|
validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name)
|
||||||
self.warmup_steps = warmup_steps
|
self.warmup_steps = warmup_steps
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.min = P.Minimum()
|
self.min = P.Minimum()
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""dense_variational"""
|
"""dense_variational"""
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import check_int_positive, Validator
|
from mindspore._checkparam import Validator
|
||||||
from ...cell import Cell
|
from ...cell import Cell
|
||||||
from ...layer.activation import get_activation
|
from ...layer.activation import get_activation
|
||||||
from .layer_distribution import NormalPrior, NormalPosterior
|
from .layer_distribution import NormalPrior, NormalPosterior
|
||||||
|
@ -39,8 +39,8 @@ class _DenseVariational(Cell):
|
||||||
bias_prior_fn=NormalPrior,
|
bias_prior_fn=NormalPrior,
|
||||||
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||||
super(_DenseVariational, self).__init__()
|
super(_DenseVariational, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
|
|
||||||
if isinstance(weight_prior_fn, Cell):
|
if isinstance(weight_prior_fn, Cell):
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Conditional Variational auto-encoder (CVAE)."""
|
"""Conditional Variational auto-encoder (CVAE)."""
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore._checkparam import check_int_positive
|
from mindspore._checkparam import Validator
|
||||||
from ....cell import Cell
|
from ....cell import Cell
|
||||||
from ....layer.basic import Dense, OneHot
|
from ....layer.basic import Dense, OneHot
|
||||||
|
|
||||||
|
@ -57,11 +57,11 @@ class ConditionalVAE(Cell):
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
|
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
|
||||||
raise TypeError('The encoder and decoder should be Cell type.')
|
raise TypeError('The encoder and decoder should be Cell type.')
|
||||||
self.hidden_size = check_int_positive(hidden_size)
|
self.hidden_size = Validator.check_positive_int(hidden_size)
|
||||||
self.latent_size = check_int_positive(latent_size)
|
self.latent_size = Validator.check_positive_int(latent_size)
|
||||||
if hidden_size < latent_size:
|
if hidden_size < latent_size:
|
||||||
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
|
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
|
||||||
self.num_classes = check_int_positive(num_classes)
|
self.num_classes = Validator.check_positive_int(num_classes)
|
||||||
self.normal = C.normal
|
self.normal = C.normal
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
|
@ -108,7 +108,7 @@ class ConditionalVAE(Cell):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the generated samples.
|
Tensor, the generated samples.
|
||||||
"""
|
"""
|
||||||
generate_nums = check_int_positive(generate_nums)
|
generate_nums = Validator.check_positive_int(generate_nums)
|
||||||
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
|
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
|
||||||
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
|
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
|
||||||
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
|
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Variational auto-encoder (VAE)"""
|
"""Variational auto-encoder (VAE)"""
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore._checkparam import check_int_positive
|
from mindspore._checkparam import Validator
|
||||||
from ....cell import Cell
|
from ....cell import Cell
|
||||||
from ....layer.basic import Dense
|
from ....layer.basic import Dense
|
||||||
|
|
||||||
|
@ -52,8 +52,8 @@ class VAE(Cell):
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
|
if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
|
||||||
raise TypeError('The encoder and decoder should be Cell type.')
|
raise TypeError('The encoder and decoder should be Cell type.')
|
||||||
self.hidden_size = check_int_positive(hidden_size)
|
self.hidden_size = Validator.check_positive_int(hidden_size)
|
||||||
self.latent_size = check_int_positive(latent_size)
|
self.latent_size = Validator.check_positive_int(latent_size)
|
||||||
if hidden_size < latent_size:
|
if hidden_size < latent_size:
|
||||||
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
|
raise ValueError('The latent_size should be less than or equal to the hidden_size.')
|
||||||
self.normal = C.normal
|
self.normal = C.normal
|
||||||
|
@ -94,7 +94,7 @@ class VAE(Cell):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the generated samples.
|
Tensor, the generated samples.
|
||||||
"""
|
"""
|
||||||
generate_nums = check_int_positive(generate_nums)
|
generate_nums = Validator.check_positive_int(generate_nums)
|
||||||
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
|
if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
|
||||||
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
|
raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
|
||||||
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
|
sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Stochastic Variational Inference(SVI)."""
|
"""Stochastic Variational Inference(SVI)."""
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import check_int_positive
|
from mindspore._checkparam import Validator
|
||||||
from ....cell import Cell
|
from ....cell import Cell
|
||||||
from ....wrap.cell_wrapper import TrainOneStepCell
|
from ....wrap.cell_wrapper import TrainOneStepCell
|
||||||
from .elbo import ELBO
|
from .elbo import ELBO
|
||||||
|
@ -57,7 +57,7 @@ class SVI:
|
||||||
Outputs:
|
Outputs:
|
||||||
Cell, the trained probability network.
|
Cell, the trained probability network.
|
||||||
"""
|
"""
|
||||||
epochs = check_int_positive(epochs)
|
epochs = Validator.check_positive_int(epochs)
|
||||||
train_net = TrainOneStepCell(self.net_with_loss, self.optimizer)
|
train_net = TrainOneStepCell(self.net_with_loss, self.optimizer)
|
||||||
train_net.set_train()
|
train_net.set_train()
|
||||||
for _ in range(1, epochs+1):
|
for _ in range(1, epochs+1):
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore._checkparam import check_int_positive, Validator
|
from mindspore._checkparam import Validator
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
|
@ -81,7 +81,7 @@ class UncertaintyEvaluation:
|
||||||
self.epi_train_dataset = train_dataset
|
self.epi_train_dataset = train_dataset
|
||||||
self.ale_train_dataset = deepcopy(train_dataset)
|
self.ale_train_dataset = deepcopy(train_dataset)
|
||||||
self.task_type = task_type
|
self.task_type = task_type
|
||||||
self.epochs = check_int_positive(epochs)
|
self.epochs = Validator.check_positive_int(epochs)
|
||||||
self.epi_uncer_model_path = epi_uncer_model_path
|
self.epi_uncer_model_path = epi_uncer_model_path
|
||||||
self.ale_uncer_model_path = ale_uncer_model_path
|
self.ale_uncer_model_path = ale_uncer_model_path
|
||||||
self.save_model = Validator.check_bool(save_model)
|
self.save_model = Validator.check_bool(save_model)
|
||||||
|
@ -95,7 +95,7 @@ class UncertaintyEvaluation:
|
||||||
if task_type not in ('regression', 'classification'):
|
if task_type not in ('regression', 'classification'):
|
||||||
raise ValueError('The task should be regression or classification.')
|
raise ValueError('The task should be regression or classification.')
|
||||||
if task_type == 'classification':
|
if task_type == 'classification':
|
||||||
self.num_classes = check_int_positive(num_classes)
|
self.num_classes = Validator.check_positive_int(num_classes)
|
||||||
else:
|
else:
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if save_model:
|
if save_model:
|
||||||
|
|
|
@ -65,9 +65,9 @@ def get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||||
def get_concat_offset(x_shp, x_type, axis, prim_name):
|
def get_concat_offset(x_shp, x_type, axis, prim_name):
|
||||||
"""for concat and concatoffset check args and compute offset"""
|
"""for concat and concatoffset check args and compute offset"""
|
||||||
validator.check_value_type("shape", x_shp, [tuple], prim_name)
|
validator.check_value_type("shape", x_shp, [tuple], prim_name)
|
||||||
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
|
validator.check_positive_int(len(x_shp), "input_x rank", prim_name)
|
||||||
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
|
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
|
||||||
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
|
validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
|
||||||
rank_base = len(x_shp[0])
|
rank_base = len(x_shp[0])
|
||||||
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
|
||||||
if axis < 0:
|
if axis < 0:
|
||||||
|
|
|
@ -12,8 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""constexpr util"""
|
"""constexpr util"""
|
||||||
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -60,30 +60,6 @@ def check_equal(param1, param2, msg="{},{}"):
|
||||||
return param1
|
return param1
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_int_positive(arg_name, arg_value, op_name):
|
|
||||||
"""Int type judgment."""
|
|
||||||
if isinstance(arg_value, bool):
|
|
||||||
raise TypeError("For \'{}\' the `{}` must be int, cannot be bool.".format(op_name, arg_name))
|
|
||||||
if isinstance(arg_value, int):
|
|
||||||
if arg_value > 0:
|
|
||||||
return arg_value
|
|
||||||
raise ValueError("For \'{}\' the `{}` must be positive, but got {}.".format(op_name, arg_name, arg_value))
|
|
||||||
raise TypeError("For \'{}\' the `{}` must be int, cannot be {}.".format(op_name, arg_name, type(arg_value)))
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_int_non_negative(arg_name, arg_value, op_name):
|
|
||||||
"""Int type judgment."""
|
|
||||||
if isinstance(arg_value, bool):
|
|
||||||
raise TypeError("For \'{}\' the `{}` must be int, cannot be bool.".format(op_name, arg_name))
|
|
||||||
if isinstance(arg_value, int):
|
|
||||||
if arg_value >= 0:
|
|
||||||
return arg_value
|
|
||||||
raise ValueError("For \'{}\' the `{}` must be non_negative, but got {}.".format(op_name, arg_name, arg_value))
|
|
||||||
raise TypeError("For \'{}\' the `{}` must be int, cannot be {}.".format(op_name, arg_name, type(arg_value)))
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
||||||
"""Checks the shape and size of the sensor and value."""
|
"""Checks the shape and size of the sensor and value."""
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Operations for random number generators."""
|
"""Operations for random number generators."""
|
||||||
|
|
||||||
|
from mindspore._checkparam import Validator
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from ..primitive import constexpr
|
from ..primitive import constexpr
|
||||||
|
@ -54,7 +54,7 @@ def get_seed(op_seed, kernel_name):
|
||||||
if op_seed is None:
|
if op_seed is None:
|
||||||
temp_seed = _get_op_seed(0, kernel_name)
|
temp_seed = _get_op_seed(0, kernel_name)
|
||||||
else:
|
else:
|
||||||
const_utils.check_int_non_negative("seed", op_seed, kernel_name)
|
Validator.check_non_negative_int(op_seed, "seed", kernel_name)
|
||||||
temp_seed = _get_op_seed(op_seed, kernel_name)
|
temp_seed = _get_op_seed(op_seed, kernel_name)
|
||||||
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
|
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
|
||||||
_update_seeds(op_seed, kernel_name)
|
_update_seeds(op_seed, kernel_name)
|
||||||
|
|
|
@ -915,9 +915,9 @@ class LSTMGradData(PrimitiveWithInfer):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||||
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
|
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
||||||
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
|
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
||||||
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
|
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
||||||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||||
|
@ -964,9 +964,9 @@ class LSTMGradWeight(PrimitiveWithInfer):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||||
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
|
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
||||||
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
|
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
||||||
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
|
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
||||||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||||
|
@ -999,9 +999,9 @@ class LSTMGrad(PrimitiveWithInfer):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||||
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
|
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
||||||
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
|
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
||||||
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
|
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
||||||
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||||
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||||
|
|
|
@ -701,7 +701,7 @@ class Padding(PrimitiveWithInfer):
|
||||||
def __init__(self, pad_dim_size=8):
|
def __init__(self, pad_dim_size=8):
|
||||||
"""Initialize padding"""
|
"""Initialize padding"""
|
||||||
validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name)
|
validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name)
|
||||||
validator.check_integer("pad_dim_size", pad_dim_size, 0, Rel.GT, self.name)
|
validator.check_positive_int(pad_dim_size, "pad_dim_size", self.name)
|
||||||
self.pad_dim_size = pad_dim_size
|
self.pad_dim_size = pad_dim_size
|
||||||
|
|
||||||
def __infer__(self, x):
|
def __infer__(self, x):
|
||||||
|
@ -911,8 +911,8 @@ class Fill(PrimitiveWithInfer):
|
||||||
def __infer__(self, dtype, dims, x):
|
def __infer__(self, dtype, dims, x):
|
||||||
validator.check_value_type("shape", dims['value'], [tuple], self.name)
|
validator.check_value_type("shape", dims['value'], [tuple], self.name)
|
||||||
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
|
validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
|
||||||
for idx, item in enumerate(dims['value']):
|
for i, item in enumerate(dims['value']):
|
||||||
validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name)
|
validator.check_positive_int(item, f'dims[{i}]', self.name)
|
||||||
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
||||||
mstype.uint8, mstype.uint32, mstype.uint64,
|
mstype.uint8, mstype.uint32, mstype.uint64,
|
||||||
mstype.float16, mstype.float32, mstype.float64]
|
mstype.float16, mstype.float32, mstype.float64]
|
||||||
|
@ -1482,20 +1482,20 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
||||||
validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
|
validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
|
||||||
validator.check_value_type("x_shape", x_shp, [list], self.name)
|
validator.check_value_type("x_shape", x_shp, [list], self.name)
|
||||||
x_shp_len = len(x_shp)
|
x_shp_len = len(x_shp)
|
||||||
validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT, self.name)
|
validator.check_positive_int(x_shp_len, "rank of input_x", self.name)
|
||||||
segment_ids_shp = segment_ids['shape']
|
segment_ids_shp = segment_ids['shape']
|
||||||
segment_ids_type = segment_ids['dtype']
|
segment_ids_type = segment_ids['dtype']
|
||||||
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
|
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
|
||||||
validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
|
validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
|
||||||
segment_ids_shp_len = len(segment_ids_shp)
|
segment_ids_shp_len = len(segment_ids_shp)
|
||||||
validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT, self.name)
|
validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
|
||||||
validator.check(f'rank of input_x', len(x_shp),
|
validator.check(f'rank of input_x', len(x_shp),
|
||||||
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
|
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
|
||||||
for i, value in enumerate(segment_ids_shp):
|
for i, value in enumerate(segment_ids_shp):
|
||||||
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
|
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
|
||||||
num_segments_v = num_segments['value']
|
num_segments_v = num_segments['value']
|
||||||
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
||||||
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
|
validator.check_positive_int(num_segments_v, "num_segments", self.name)
|
||||||
shp = [num_segments_v]
|
shp = [num_segments_v]
|
||||||
shp += x_shp[segment_ids_shp_len:]
|
shp += x_shp[segment_ids_shp_len:]
|
||||||
out = {'shape': shp,
|
out = {'shape': shp,
|
||||||
|
@ -1544,7 +1544,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
|
||||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||||
num_segments_v = num_segments['value']
|
num_segments_v = num_segments['value']
|
||||||
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
||||||
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
|
validator.check_positive_int(num_segments_v, "num_segments", self.name)
|
||||||
segment_ids_shape_len = len(segment_ids_shape)
|
segment_ids_shape_len = len(segment_ids_shape)
|
||||||
out_shape = [num_segments_v]
|
out_shape = [num_segments_v]
|
||||||
out_shape += x_shape[segment_ids_shape_len:]
|
out_shape += x_shape[segment_ids_shape_len:]
|
||||||
|
@ -1597,7 +1597,7 @@ class UnsortedSegmentProd(PrimitiveWithInfer):
|
||||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||||
num_segments_v = num_segments['value']
|
num_segments_v = num_segments['value']
|
||||||
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
||||||
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
|
validator.check_positive_int(num_segments_v, "num_segments", self.name)
|
||||||
segment_ids_shape_len = len(segment_ids_shape)
|
segment_ids_shape_len = len(segment_ids_shape)
|
||||||
out_shape = [num_segments_v]
|
out_shape = [num_segments_v]
|
||||||
out_shape += x_shape[segment_ids_shape_len:]
|
out_shape += x_shape[segment_ids_shape_len:]
|
||||||
|
@ -1832,7 +1832,7 @@ class Unpack(PrimitiveWithInfer):
|
||||||
self.axis = self.axis + dim
|
self.axis = self.axis + dim
|
||||||
output_num = x_shape[self.axis]
|
output_num = x_shape[self.axis]
|
||||||
validator.check_value_type("num", output_num, [int], self.name)
|
validator.check_value_type("num", output_num, [int], self.name)
|
||||||
validator.check_integer("output_num", output_num, 0, Rel.GT, self.name)
|
validator.check_positive_int(output_num, "output_num", self.name)
|
||||||
self.add_prim_attr('num', output_num)
|
self.add_prim_attr('num', output_num)
|
||||||
output_valid_check = x_shape[self.axis] - output_num
|
output_valid_check = x_shape[self.axis] - output_num
|
||||||
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
|
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
|
||||||
|
@ -2401,8 +2401,8 @@ class Eye(PrimitiveWithInfer):
|
||||||
"""Initialize Eye"""
|
"""Initialize Eye"""
|
||||||
|
|
||||||
def infer_value(self, n, m, t):
|
def infer_value(self, n, m, t):
|
||||||
validator.check_integer("n", n, 0, Rel.GT, self.name)
|
validator.check_positive_int(n, "n", self.name)
|
||||||
validator.check_integer("m", m, 0, Rel.GT, self.name)
|
validator.check_positive_int(m, "m", self.name)
|
||||||
args = {"dtype": t}
|
args = {"dtype": t}
|
||||||
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
|
validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
|
||||||
np_type = mstype.dtype_to_nptype(t)
|
np_type = mstype.dtype_to_nptype(t)
|
||||||
|
@ -2443,7 +2443,7 @@ class ScatterNd(PrimitiveWithInfer):
|
||||||
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
|
validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
|
||||||
validator.check_value_type("shape", shp, [tuple], self.name)
|
validator.check_value_type("shape", shp, [tuple], self.name)
|
||||||
for i, x in enumerate(shp):
|
for i, x in enumerate(shp):
|
||||||
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
|
validator.check_positive_int(x, f'shape[{i}]', self.name)
|
||||||
|
|
||||||
indices_shape, update_shape = indices["shape"], update["shape"]
|
indices_shape, update_shape = indices["shape"], update["shape"]
|
||||||
if indices_shape[0] != update_shape[0]:
|
if indices_shape[0] != update_shape[0]:
|
||||||
|
@ -3469,7 +3469,7 @@ class BroadcastTo(PrimitiveWithInfer):
|
||||||
validator.check_value_type("shape", shape, (tuple), self.name)
|
validator.check_value_type("shape", shape, (tuple), self.name)
|
||||||
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
|
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
|
||||||
for i in shape:
|
for i in shape:
|
||||||
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
|
validator.check_positive_int(i, "shape element", self.name)
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
|
|
|
@ -160,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name)
|
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||||
x_shape[0] = x_shape[0] * self.rank_size
|
x_shape[0] = x_shape[0] * self.rank_size
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('group', group)
|
self.add_prim_attr('group', group)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name)
|
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||||
x_shape[0] = x_shape[0] * self.group_size
|
x_shape[0] = x_shape[0] * self.group_size
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
|
|
|
@ -1005,8 +1005,8 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
|
|
||||||
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
||||||
self.add_prim_attr('data_format', "NCHW")
|
self.add_prim_attr('data_format', "NCHW")
|
||||||
self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name)
|
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||||
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
|
self.group = validator.check_positive_int(group, 'group', self.name)
|
||||||
self.add_prim_attr('offset_a', 0)
|
self.add_prim_attr('offset_a', 0)
|
||||||
|
|
||||||
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
||||||
|
@ -1142,9 +1142,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
||||||
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
|
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
|
||||||
self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name)
|
self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name)
|
||||||
self.add_prim_attr('data_format', "NCHW")
|
self.add_prim_attr('data_format', "NCHW")
|
||||||
self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT,
|
self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name)
|
||||||
self.name)
|
self.group = validator.check_positive_int(group, "group", self.name)
|
||||||
self.group = validator.check_integer("group", group, 0, Rel.GT, self.name)
|
|
||||||
self.add_prim_attr('offset_a', 0)
|
self.add_prim_attr('offset_a', 0)
|
||||||
|
|
||||||
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
||||||
|
@ -1508,7 +1507,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
||||||
group=1):
|
group=1):
|
||||||
"""Initialize Conv2DBackpropInput"""
|
"""Initialize Conv2DBackpropInput"""
|
||||||
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
|
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
|
||||||
self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name)
|
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
||||||
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||||
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False)
|
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False)
|
||||||
self.add_prim_attr('stride', self.stride)
|
self.add_prim_attr('stride', self.stride)
|
||||||
|
@ -1531,7 +1530,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
||||||
pad_mode = pad_mode.upper()
|
pad_mode = pad_mode.upper()
|
||||||
self.add_prim_attr('pad_mode', pad_mode)
|
self.add_prim_attr('pad_mode', pad_mode)
|
||||||
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
||||||
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
|
self.group = validator.check_positive_int(group, 'group', self.name)
|
||||||
self.add_prim_attr('data_format', "NCHW")
|
self.add_prim_attr('data_format', "NCHW")
|
||||||
if pad_list:
|
if pad_list:
|
||||||
for x in pad_list:
|
for x in pad_list:
|
||||||
|
@ -2062,10 +2061,10 @@ class SGD(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
|
def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
|
||||||
accum_shape, momentum_shape, stat_shape):
|
accum_shape, momentum_shape, stat_shape):
|
||||||
validator.check_integer(f'parameters rank', len(parameters_shape), 0, Rel.GT, self.name)
|
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
|
||||||
validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name)
|
validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name)
|
||||||
validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name)
|
validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name)
|
||||||
validator.check_integer(f'accumulation rank', len(accum_shape), 0, Rel.GT, self.name)
|
validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
|
||||||
validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name)
|
validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name)
|
||||||
validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name)
|
validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name)
|
||||||
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
|
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
|
||||||
|
@ -2748,9 +2747,9 @@ class LSTM(PrimitiveWithInfer):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||||
self.input_size = validator.check_integer("input_size", input_size, 0, Rel.GT, self.name)
|
self.input_size = validator.check_positive_int(input_size, "input_size", self.name)
|
||||||
self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.name)
|
self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name)
|
||||||
self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.name)
|
self.num_layers = validator.check_positive_int(num_layers, "num_layers", self.name)
|
||||||
self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
|
self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
|
||||||
self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
|
self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
|
||||||
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||||
|
|
|
@ -12,11 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Operators for random."""
|
"""Operators for random."""
|
||||||
|
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator, Rel
|
||||||
from ..._checkparam import Rel
|
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||||
from .._utils import get_broadcast_shape
|
from .._utils import get_broadcast_shape
|
||||||
|
@ -46,16 +44,16 @@ class StandardNormal(PrimitiveWithInfer):
|
||||||
def __init__(self, seed=0, seed2=0):
|
def __init__(self, seed=0, seed2=0):
|
||||||
"""Initialize StandardNormal"""
|
"""Initialize StandardNormal"""
|
||||||
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
||||||
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
||||||
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
def __infer__(self, shape):
|
def __infer__(self, shape):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
out = {
|
out = {
|
||||||
'shape': shape_v,
|
'shape': shape_v,
|
||||||
'dtype': mstype.float32,
|
'dtype': mstype.float32,
|
||||||
|
@ -91,16 +89,16 @@ class StandardLaplace(PrimitiveWithInfer):
|
||||||
def __init__(self, seed=0, seed2=0):
|
def __init__(self, seed=0, seed2=0):
|
||||||
"""Initialize StandardLaplace"""
|
"""Initialize StandardLaplace"""
|
||||||
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
||||||
validator.check_value_type('seed', seed, [int], self.name)
|
Validator.check_value_type('seed', seed, [int], self.name)
|
||||||
validator.check_value_type('seed2', seed2, [int], self.name)
|
Validator.check_value_type('seed2', seed2, [int], self.name)
|
||||||
|
|
||||||
def __infer__(self, shape):
|
def __infer__(self, shape):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
out = {
|
out = {
|
||||||
'shape': shape_v,
|
'shape': shape_v,
|
||||||
'dtype': mstype.float32,
|
'dtype': mstype.float32,
|
||||||
|
@ -143,18 +141,18 @@ class Gamma(PrimitiveWithInfer):
|
||||||
def __init__(self, seed=0, seed2=0):
|
def __init__(self, seed=0, seed2=0):
|
||||||
"""Initialize Gamma"""
|
"""Initialize Gamma"""
|
||||||
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
|
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
|
||||||
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
||||||
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
def __infer__(self, shape, alpha, beta):
|
def __infer__(self, shape, alpha, beta):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
|
Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
|
||||||
validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
|
Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
|
||||||
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
|
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
|
||||||
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
||||||
out = {
|
out = {
|
||||||
|
@ -195,17 +193,17 @@ class Poisson(PrimitiveWithInfer):
|
||||||
def __init__(self, seed=0, seed2=0):
|
def __init__(self, seed=0, seed2=0):
|
||||||
"""Initialize Poisson"""
|
"""Initialize Poisson"""
|
||||||
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
|
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
|
||||||
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
||||||
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
def __infer__(self, shape, mean):
|
def __infer__(self, shape, mean):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
||||||
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
|
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
|
||||||
out = {
|
out = {
|
||||||
'shape': broadcast_shape,
|
'shape': broadcast_shape,
|
||||||
|
@ -251,22 +249,22 @@ class UniformInt(PrimitiveWithInfer):
|
||||||
def __init__(self, seed=0, seed2=0):
|
def __init__(self, seed=0, seed2=0):
|
||||||
"""Initialize UniformInt"""
|
"""Initialize UniformInt"""
|
||||||
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
|
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
|
||||||
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
||||||
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
def __infer__(self, shape, minval, maxval):
|
def __infer__(self, shape, minval, maxval):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
|
Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name)
|
||||||
validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
|
Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name)
|
||||||
minval_shape = minval['shape']
|
minval_shape = minval['shape']
|
||||||
maxval_shape = maxval['shape']
|
maxval_shape = maxval['shape']
|
||||||
validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
||||||
validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
Validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
|
||||||
out = {
|
out = {
|
||||||
'shape': shape_v,
|
'shape': shape_v,
|
||||||
'dtype': mstype.int32,
|
'dtype': mstype.int32,
|
||||||
|
@ -298,16 +296,16 @@ class UniformReal(PrimitiveWithInfer):
|
||||||
def __init__(self, seed=0, seed2=0):
|
def __init__(self, seed=0, seed2=0):
|
||||||
"""Initialize UniformReal"""
|
"""Initialize UniformReal"""
|
||||||
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
||||||
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
||||||
validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
def __infer__(self, shape):
|
def __infer__(self, shape):
|
||||||
shape_v = shape["value"]
|
shape_v = shape["value"]
|
||||||
if shape_v is None:
|
if shape_v is None:
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
Validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
for i, shape_i in enumerate(shape_v):
|
for i, shape_i in enumerate(shape_v):
|
||||||
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
|
||||||
out = {
|
out = {
|
||||||
'shape': shape_v,
|
'shape': shape_v,
|
||||||
'dtype': mstype.float32,
|
'dtype': mstype.float32,
|
||||||
|
@ -348,18 +346,18 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, count=256, seed=0, seed2=0):
|
def __init__(self, count=256, seed=0, seed2=0):
|
||||||
"""Initialize RandomChoiceWithMask"""
|
"""Initialize RandomChoiceWithMask"""
|
||||||
validator.check_value_type("count", count, [int], self.name)
|
Validator.check_value_type("count", count, [int], self.name)
|
||||||
validator.check_integer("count", count, 0, Rel.GT, self.name)
|
Validator.check_positive_int(count, "count", self.name)
|
||||||
validator.check_value_type('seed', seed, [int], self.name)
|
Validator.check_value_type('seed', seed, [int], self.name)
|
||||||
validator.check_value_type('seed2', seed2, [int], self.name)
|
Validator.check_value_type('seed2', seed2, [int], self.name)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
|
Validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||||
validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
|
Validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name)
|
||||||
return ([self.count, len(x_shape)], [self.count])
|
return ([self.count, len(x_shape)], [self.count])
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
|
Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
|
||||||
return (mstype.int32, mstype.bool_)
|
return (mstype.int32, mstype.bool_)
|
||||||
|
|
||||||
|
|
||||||
|
@ -399,19 +397,19 @@ class RandomCategorical(PrimitiveWithInfer):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
valid_values = (mstype.int32, mstype.int16, mstype.int64)
|
valid_values = (mstype.int32, mstype.int16, mstype.int64)
|
||||||
validator.check_type_name("dtype", dtype, valid_values, self.name)
|
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||||
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
|
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, logits, num_samples, seed):
|
def __infer__(self, logits, num_samples, seed):
|
||||||
logits_dtype = logits['dtype']
|
logits_dtype = logits['dtype']
|
||||||
valid_types = (mstype.float32, mstype.float16, mstype.float64)
|
valid_types = (mstype.float32, mstype.float16, mstype.float64)
|
||||||
validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
|
Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
|
||||||
num_samples_v = num_samples['value']
|
num_samples_v = num_samples['value']
|
||||||
seed_v = seed['value']
|
seed_v = seed['value']
|
||||||
validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
|
Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
|
||||||
validator.check_value_type('seed', seed_v, (int,), self.name)
|
Validator.check_value_type('seed', seed_v, (int,), self.name)
|
||||||
validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name)
|
Validator.check_positive_int(num_samples_v, "num_samples", self.name)
|
||||||
x_shape = list(logits['shape'])
|
x_shape = list(logits['shape'])
|
||||||
if len(x_shape) != 2:
|
if len(x_shape) != 2:
|
||||||
raise ValueError("RandomCategorical shape should be 2-dimension.")
|
raise ValueError("RandomCategorical shape should be 2-dimension.")
|
||||||
|
@ -450,20 +448,20 @@ class Multinomial(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, seed=0):
|
def __init__(self, seed=0):
|
||||||
"""init"""
|
"""init"""
|
||||||
validator.check_value_type("seed", seed, [int], self.name)
|
Validator.check_value_type("seed", seed, [int], self.name)
|
||||||
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
Validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
||||||
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
|
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, inputs, num_samples):
|
def __infer__(self, inputs, num_samples):
|
||||||
input_shape = inputs["shape"]
|
input_shape = inputs["shape"]
|
||||||
if len(input_shape) != 1 and len(input_shape) != 2:
|
if len(input_shape) != 1 and len(input_shape) != 2:
|
||||||
raise ValueError("input dim must be 1 or 2")
|
raise ValueError("input dim must be 1 or 2")
|
||||||
validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
|
Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name)
|
||||||
num_samples_value = num_samples["value"]
|
num_samples_value = num_samples["value"]
|
||||||
if num_samples_value is None:
|
if num_samples_value is None:
|
||||||
raise ValueError(f"For {self.name}, shape nust be const")
|
raise ValueError(f"For {self.name}, shape nust be const")
|
||||||
validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
|
Validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
|
||||||
validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None)
|
Validator.check_positive_int(num_samples_value, "num_samples")
|
||||||
y_shape = (num_samples_value,)
|
y_shape = (num_samples_value,)
|
||||||
if len(input_shape) == 2:
|
if len(input_shape) == 2:
|
||||||
y_shape = (input_shape[0], num_samples_value)
|
y_shape = (input_shape[0], num_samples_value)
|
||||||
|
|
|
@ -21,7 +21,7 @@ import time
|
||||||
import threading
|
import threading
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._checkparam import Validator, check_int_non_negative
|
from mindspore._checkparam import Validator
|
||||||
from mindspore.train._utils import _make_directory
|
from mindspore.train._utils import _make_directory
|
||||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||||
|
@ -107,13 +107,13 @@ class CheckpointConfig:
|
||||||
async_save=False):
|
async_save=False):
|
||||||
|
|
||||||
if save_checkpoint_steps is not None:
|
if save_checkpoint_steps is not None:
|
||||||
save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
|
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
|
||||||
if save_checkpoint_seconds is not None:
|
if save_checkpoint_seconds is not None:
|
||||||
save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
|
save_checkpoint_seconds = Validator.check_non_negative_int(save_checkpoint_seconds)
|
||||||
if keep_checkpoint_max is not None:
|
if keep_checkpoint_max is not None:
|
||||||
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
keep_checkpoint_max = Validator.check_non_negative_int(keep_checkpoint_max)
|
||||||
if keep_checkpoint_per_n_minutes is not None:
|
if keep_checkpoint_per_n_minutes is not None:
|
||||||
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes)
|
||||||
|
|
||||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||||
|
|
|
@ -13,8 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Loss scale manager abstract class."""
|
"""Loss scale manager abstract class."""
|
||||||
|
|
||||||
from .._checkparam import Validator as validator
|
from .._checkparam import Validator as validator
|
||||||
from .._checkparam import Rel
|
|
||||||
from .. import nn
|
from .. import nn
|
||||||
|
|
||||||
__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]
|
__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]
|
||||||
|
@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager):
|
||||||
if init_loss_scale < 1.0:
|
if init_loss_scale < 1.0:
|
||||||
raise ValueError("Loss scale value should be > 1")
|
raise ValueError("Loss scale value should be > 1")
|
||||||
self.loss_scale = init_loss_scale
|
self.loss_scale = init_loss_scale
|
||||||
validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__)
|
validator.check_positive_int(scale_window, "scale_window", self.__class__.__name__)
|
||||||
self.scale_window = scale_window
|
self.scale_window = scale_window
|
||||||
if scale_factor <= 0:
|
if scale_factor <= 0:
|
||||||
raise ValueError("Scale factor should be > 1")
|
raise ValueError("Scale factor should be > 1")
|
||||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from ..common.tensor import Tensor
|
from ..common.tensor import Tensor
|
||||||
from ..nn.metrics import get_metrics
|
from ..nn.metrics import get_metrics
|
||||||
from .._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
|
from .._checkparam import check_input_data, check_output_data, Validator, check_int
|
||||||
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||||
from .. import context
|
from .. import context
|
||||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
|
@ -339,7 +339,7 @@ class Model:
|
||||||
dataset not sink.
|
dataset not sink.
|
||||||
sink_size (int): Control the amount of data in each sink. Default: -1.
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
||||||
"""
|
"""
|
||||||
epoch = check_int_positive(epoch)
|
epoch = Validator.check_positive_int(epoch)
|
||||||
if self._parameter_broadcast:
|
if self._parameter_broadcast:
|
||||||
self._train_network.set_broadcast_flag()
|
self._train_network.set_broadcast_flag()
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore._checkparam import Validator, twice, check_int_positive
|
from mindspore._checkparam import Validator, twice
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
@ -292,8 +292,8 @@ class Dense_Thor_GPU(Cell):
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(Dense_Thor_GPU, self).__init__()
|
super(Dense_Thor_GPU, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
self.thor = True
|
self.thor = True
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
|
@ -641,8 +641,8 @@ class Dense_Thor(Cell):
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(Dense_Thor, self).__init__()
|
super(Dense_Thor, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
self.thor = True
|
self.thor = True
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
|
@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore._checkparam import check_int_positive, Validator
|
from mindspore._checkparam import Validator
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,8 +72,8 @@ class GNNFeatureTransform(nn.Cell):
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
has_bias=True):
|
has_bias=True):
|
||||||
super(GNNFeatureTransform, self).__init__()
|
super(GNNFeatureTransform, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
|
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
|
@ -259,8 +259,8 @@ class AttentionHead(nn.Cell):
|
||||||
coef_activation=nn.LeakyReLU(),
|
coef_activation=nn.LeakyReLU(),
|
||||||
activation=nn.ELU()):
|
activation=nn.ELU()):
|
||||||
super(AttentionHead, self).__init__()
|
super(AttentionHead, self).__init__()
|
||||||
self.in_channel = check_int_positive(in_channel)
|
self.in_channel = Validator.check_positive_int(in_channel)
|
||||||
self.out_channel = check_int_positive(out_channel)
|
self.out_channel = Validator.check_positive_int(out_channel)
|
||||||
self.in_drop_ratio = in_drop_ratio
|
self.in_drop_ratio = in_drop_ratio
|
||||||
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||||
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||||
|
@ -450,9 +450,9 @@ class GAT(nn.Cell):
|
||||||
super(GAT, self).__init__()
|
super(GAT, self).__init__()
|
||||||
self.features = Tensor(features)
|
self.features = Tensor(features)
|
||||||
self.biases = Tensor(biases)
|
self.biases = Tensor(biases)
|
||||||
self.ftr_dims = check_int_positive(ftr_dims)
|
self.ftr_dims = Validator.check_positive_int(ftr_dims)
|
||||||
self.num_class = check_int_positive(num_class)
|
self.num_class = Validator.check_positive_int(num_class)
|
||||||
self.num_nodes = check_int_positive(num_nodes)
|
self.num_nodes = Validator.check_positive_int(num_nodes)
|
||||||
self.hidden_units = hidden_units
|
self.hidden_units = hidden_units
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.attn_drop = attn_drop
|
self.attn_drop = attn_drop
|
||||||
|
|
|
@ -22,7 +22,7 @@ from mindspore._c_expression import init_exec_dataset
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator, check_int
|
from mindspore._checkparam import check_input_data, check_output_data, Validator, check_int
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.dtype import pytype_to_dtype
|
from mindspore.common.dtype import pytype_to_dtype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -374,7 +374,7 @@ class Model:
|
||||||
dataset not sink.
|
dataset not sink.
|
||||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||||
"""
|
"""
|
||||||
epoch = check_int_positive(epoch)
|
epoch = Validator.check_positive_int(epoch)
|
||||||
self._train_network.set_train()
|
self._train_network.set_train()
|
||||||
|
|
||||||
if self._parameter_broadcast:
|
if self._parameter_broadcast:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""thor_layer"""
|
"""thor_layer"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore._checkparam import Validator, check_int_positive
|
from mindspore._checkparam import Validator
|
||||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -160,8 +160,8 @@ class Dense_Thor(Cell):
|
||||||
activation=None,
|
activation=None,
|
||||||
batch_size=12):
|
batch_size=12):
|
||||||
super(Dense_Thor, self).__init__()
|
super(Dense_Thor, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
self.thor = True
|
self.thor = True
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Aggregator."""
|
"""Aggregator."""
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore._checkparam import check_int_positive, Validator
|
from mindspore._checkparam import Validator
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
|
@ -73,8 +73,8 @@ class GNNFeatureTransform(nn.Cell):
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
has_bias=True):
|
has_bias=True):
|
||||||
super(GNNFeatureTransform, self).__init__()
|
super(GNNFeatureTransform, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
|
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
|
@ -262,8 +262,8 @@ class AttentionHead(nn.Cell):
|
||||||
coef_activation=nn.LeakyReLU(),
|
coef_activation=nn.LeakyReLU(),
|
||||||
activation=nn.ELU()):
|
activation=nn.ELU()):
|
||||||
super(AttentionHead, self).__init__()
|
super(AttentionHead, self).__init__()
|
||||||
self.in_channel = check_int_positive(in_channel)
|
self.in_channel = Validator.check_positive_int(in_channel)
|
||||||
self.out_channel = check_int_positive(out_channel)
|
self.out_channel = Validator.check_positive_int(out_channel)
|
||||||
self.in_drop_ratio = in_drop_ratio
|
self.in_drop_ratio = in_drop_ratio
|
||||||
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||||
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Graph Attention Networks."""
|
"""Graph Attention Networks."""
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore._checkparam import Validator, check_int_positive
|
from mindspore._checkparam import Validator
|
||||||
|
|
||||||
from aggregator import AttentionAggregator
|
from aggregator import AttentionAggregator
|
||||||
|
|
||||||
|
@ -71,9 +71,9 @@ class GAT(nn.Cell):
|
||||||
activation=nn.ELU(),
|
activation=nn.ELU(),
|
||||||
residual=False):
|
residual=False):
|
||||||
super(GAT, self).__init__()
|
super(GAT, self).__init__()
|
||||||
self.ftr_dims = check_int_positive(ftr_dims)
|
self.ftr_dims = Validator.check_positive_int(ftr_dims)
|
||||||
self.num_class = check_int_positive(num_class)
|
self.num_class = Validator.check_positive_int(num_class)
|
||||||
self.num_nodes = check_int_positive(num_nodes)
|
self.num_nodes = Validator.check_positive_int(num_nodes)
|
||||||
self.hidden_units = hidden_units
|
self.hidden_units = hidden_units
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.attn_drop = attn_drop
|
self.attn_drop = attn_drop
|
||||||
|
|
|
@ -19,7 +19,7 @@ from mindspore import context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore._c_expression import init_exec_dataset
|
from mindspore._c_expression import init_exec_dataset
|
||||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, Validator
|
from mindspore._checkparam import check_input_data, check_output_data, Validator
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.dtype import pytype_to_dtype
|
from mindspore.common.dtype import pytype_to_dtype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -377,7 +377,7 @@ class Model:
|
||||||
Configure pynative mode, the training process will be performed with
|
Configure pynative mode, the training process will be performed with
|
||||||
dataset not sink.
|
dataset not sink.
|
||||||
"""
|
"""
|
||||||
epoch = check_int_positive(epoch)
|
epoch = Validator.check_positive_int(epoch)
|
||||||
self._train_network.set_train()
|
self._train_network.set_train()
|
||||||
|
|
||||||
if self._parameter_broadcast:
|
if self._parameter_broadcast:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore._checkparam import Validator, twice, check_int_positive
|
from mindspore._checkparam import Validator, twice
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
@ -337,8 +337,8 @@ class Dense_Thor(Cell):
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(Dense_Thor, self).__init__()
|
super(Dense_Thor, self).__init__()
|
||||||
self.in_channels = check_int_positive(in_channels)
|
self.in_channels = Validator.check_positive_int(in_channels)
|
||||||
self.out_channels = check_int_positive(out_channels)
|
self.out_channels = Validator.check_positive_int(out_channels)
|
||||||
self.has_bias = Validator.check_bool(has_bias)
|
self.has_bias = Validator.check_bool(has_bias)
|
||||||
self.thor = True
|
self.thor = True
|
||||||
if isinstance(weight_init, Tensor):
|
if isinstance(weight_init, Tensor):
|
||||||
|
|
|
@ -15,8 +15,7 @@
|
||||||
""" test checkparameter """
|
""" test checkparameter """
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mindspore._checkparam import check_int, check_int_positive, \
|
from mindspore._checkparam import check_int, check_input_format, Validator, twice
|
||||||
check_input_format, Validator, twice
|
|
||||||
|
|
||||||
kernel_size = 5
|
kernel_size = 5
|
||||||
kernel_size1 = twice(kernel_size)
|
kernel_size1 = twice(kernel_size)
|
||||||
|
@ -29,7 +28,7 @@ def test_check_int_1():
|
||||||
|
|
||||||
def check_int_positive_1():
|
def check_int_positive_1():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
check_int_positive(-1)
|
Validator.check_positive_int(-1)
|
||||||
|
|
||||||
|
|
||||||
def test_NCHW1():
|
def test_NCHW1():
|
||||||
|
|
|
@ -15,8 +15,7 @@
|
||||||
""" test_checkparameter """
|
""" test_checkparameter """
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mindspore._checkparam import check_int, check_int_positive, \
|
from mindspore._checkparam import check_int, Validator, check_input_format, _expand_tuple
|
||||||
Validator, check_input_format, _expand_tuple
|
|
||||||
|
|
||||||
once = _expand_tuple(1)
|
once = _expand_tuple(1)
|
||||||
twice = _expand_tuple(2)
|
twice = _expand_tuple(2)
|
||||||
|
@ -32,7 +31,7 @@ def test_check_int_1():
|
||||||
|
|
||||||
def check_int_positive_1():
|
def check_int_positive_1():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
check_int_positive(-1)
|
Validator.check_positive_int(-1)
|
||||||
|
|
||||||
|
|
||||||
def test_NCHW1():
|
def test_NCHW1():
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
"""VM implementations based on numpy."""
|
"""VM implementations based on numpy."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore._checkparam import Rel
|
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +31,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray, an output array after applying average pooling on input array.
|
numpy.ndarray, an output array after applying average pooling on input array.
|
||||||
"""
|
"""
|
||||||
validator.check_integer("stride", stride, 0, Rel.GT, None)
|
validator.check_positive_int(stride, "stride")
|
||||||
num, channel, height, width = x.shape
|
num, channel, height, width = x.shape
|
||||||
out_h = (height - pool_h) // stride + 1
|
out_h = (height - pool_h) // stride + 1
|
||||||
out_w = (width - pool_w) // stride + 1
|
out_w = (width - pool_w) // stride + 1
|
||||||
|
@ -423,7 +421,7 @@ def matmul(x, w, b=None):
|
||||||
|
|
||||||
def max_pooling(x, pool_h, pool_w, stride):
|
def max_pooling(x, pool_h, pool_w, stride):
|
||||||
"""Max pooling."""
|
"""Max pooling."""
|
||||||
validator.check_integer("stride", stride, 0, Rel.GT, None)
|
validator.check_positive_int(stride, "stride")
|
||||||
num, channel, height, width = x.shape
|
num, channel, height, width = x.shape
|
||||||
out_h = (height - pool_h) // stride + 1
|
out_h = (height - pool_h) // stride + 1
|
||||||
out_w = (width - pool_w) // stride + 1
|
out_w = (width - pool_w) // stride + 1
|
||||||
|
@ -466,7 +464,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
|
||||||
|
|
||||||
def max_pool_with_argmax(x, pool_h, pool_w, stride):
|
def max_pool_with_argmax(x, pool_h, pool_w, stride):
|
||||||
"""Max pooling with argmax."""
|
"""Max pooling with argmax."""
|
||||||
validator.check_integer("stride", stride, 0, Rel.GT, None)
|
validator.check_positive_int(stride, "stride")
|
||||||
num, channel, height, width = x.shape
|
num, channel, height, width = x.shape
|
||||||
out_h = (height - pool_h) // stride + 1
|
out_h = (height - pool_h) // stride + 1
|
||||||
out_w = (width - pool_w) // stride + 1
|
out_w = (width - pool_w) // stride + 1
|
||||||
|
|
Loading…
Reference in New Issue