forked from mindspore-Ecosystem/mindspore
[ME]delete ParamValidator and change all to Validator
This commit is contained in:
parent
d5ae6fdd84
commit
6c9b6d491d
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Check parameters."""
|
||||
|
||||
import re
|
||||
import inspect
|
||||
import math
|
||||
|
@ -20,10 +21,9 @@ from enum import Enum
|
|||
from functools import reduce, wraps
|
||||
from itertools import repeat
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from .common import dtype as mstype
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
# Named string regular expression
|
||||
|
@ -103,18 +103,17 @@ class Validator:
|
|||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
|
||||
"""
|
||||
Method for judging relation between two int values or list/tuple made up of ints.
|
||||
|
||||
This method is not suitable for judging relation between floats, since it does not consider float error.
|
||||
"""
|
||||
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
||||
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
|
||||
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel, prim_name):
|
||||
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
|
||||
"""Integer value judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
|
@ -135,6 +134,20 @@ class Validator:
|
|||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_isinstance(arg_name, arg_value, classes):
|
||||
"""Check arg isinstance of classes"""
|
||||
if not isinstance(arg_value, classes):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_bool(arg_name, arg_value):
|
||||
"""Check arg isinstance of bool"""
|
||||
if not isinstance(arg_value, bool):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
||||
"""Method for checking whether an int value is in some range."""
|
||||
|
@ -208,6 +221,27 @@ class Validator:
|
|||
"""Checks valid value."""
|
||||
if arg_value is None:
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_type(arg_name, arg_value, valid_types):
|
||||
"""Type checking."""
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
||||
# `check_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return arg_value
|
||||
raise_error_msg()
|
||||
|
||||
@staticmethod
|
||||
def check_type_same(args, valid_values, prim_name):
|
||||
|
@ -239,7 +273,6 @@ class Validator:
|
|||
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
|
||||
"""
|
||||
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
|
||||
|
||||
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
|
||||
"""
|
||||
|
||||
|
@ -335,63 +368,6 @@ class Validator:
|
|||
f'{tuple(exp_shape)}, but got {shape}.')
|
||||
|
||||
|
||||
class ParamValidator:
|
||||
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
||||
|
||||
@staticmethod
|
||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
|
||||
"""This method is only used for check int values, since when compare float values,
|
||||
we need consider float error."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
||||
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel):
|
||||
"""Integer value judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_isinstance(arg_name, arg_value, classes):
|
||||
"""Check arg isinstance of classes"""
|
||||
if not isinstance(arg_value, classes):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_bool(arg_name, arg_value):
|
||||
"""Check arg isinstance of bool"""
|
||||
if not isinstance(arg_value, bool):
|
||||
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_type(arg_name, arg_value, valid_types):
|
||||
"""Type checking."""
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
|
||||
if isinstance(arg_value, type(mstype.tensor)):
|
||||
arg_value = arg_value.element_type()
|
||||
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
|
||||
# `check_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return arg_value
|
||||
raise_error_msg()
|
||||
|
||||
|
||||
def check_int(input_param):
|
||||
"""Int type judgment."""
|
||||
if isinstance(input_param, int) and not isinstance(input_param, bool):
|
||||
|
@ -638,7 +614,6 @@ def args_type_check(*type_args, **type_kwargs):
|
|||
if value is not None and not isinstance(value, bound_types[name]):
|
||||
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return type_check
|
||||
|
|
|
@ -21,7 +21,7 @@ from ...ops import operations as P
|
|||
from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ...ops.composite import multitype_ops as C
|
||||
from ...ops.operations import _grad_ops as G
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator
|
||||
from ..cell import Cell, GraphKernel
|
||||
|
||||
|
||||
|
@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel):
|
|||
use_locking=False,
|
||||
gradient_scale=1.0):
|
||||
super(ApplyMomentum, self).__init__()
|
||||
self.gradient_scale = validator.check_type('gradient_scale', gradient_scale, [float])
|
||||
self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float])
|
||||
self.fake_output_assign_1 = InplaceAssign()
|
||||
self.fake_output_assign_1.add_prim_attr("fake_output", True)
|
||||
self.fake_output_assign_2 = InplaceAssign()
|
||||
|
@ -334,7 +334,7 @@ class ReduceMean(GraphKernel):
|
|||
|
||||
def __init__(self, keep_dims=True):
|
||||
super(ReduceMean, self).__init__()
|
||||
self.keep_dims = validator.check_type('keep_dims', keep_dims, [bool])
|
||||
self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool])
|
||||
self.sum = P.ReduceSum(self.keep_dims)
|
||||
|
||||
def construct(self, x, axis):
|
||||
|
@ -431,8 +431,8 @@ class LayerNormForward(GraphKernel):
|
|||
""" Forward function of the LayerNorm operator. """
|
||||
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
|
||||
super(LayerNormForward, self).__init__()
|
||||
self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int])
|
||||
self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int])
|
||||
self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int])
|
||||
self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int])
|
||||
self.mul = P.Mul()
|
||||
self.sum_keep_dims = P.ReduceSum(keep_dims=True)
|
||||
self.sub = P.Sub()
|
||||
|
@ -686,7 +686,7 @@ class LogSoftmax(GraphKernel):
|
|||
|
||||
def __init__(self, axis=-1):
|
||||
super(LogSoftmax, self).__init__()
|
||||
self.axis = validator.check_type('axis', axis, [int])
|
||||
self.axis = Validator.check_type('axis', axis, [int])
|
||||
self.max_keep_dims = P.ReduceMax(keep_dims=True)
|
||||
self.sub = P.Sub()
|
||||
self.exp = P.Exp()
|
||||
|
@ -952,13 +952,13 @@ class Softmax(GraphKernel):
|
|||
|
||||
def __init__(self, axis):
|
||||
super(Softmax, self).__init__()
|
||||
validator.check_type("axis", axis, [int, tuple])
|
||||
Validator.check_type("axis", axis, [int, tuple])
|
||||
if isinstance(axis, int):
|
||||
self.axis = (axis,)
|
||||
else:
|
||||
self.axis = axis
|
||||
for item in self.axis:
|
||||
validator.check_type("item of axis", item, [int])
|
||||
Validator.check_type("item of axis", item, [int])
|
||||
self.max = P.ReduceMax(keep_dims=True)
|
||||
self.sub = P.Sub()
|
||||
self.exp = P.Exp()
|
||||
|
|
|
@ -21,8 +21,7 @@ from mindspore.ops.primitive import constexpr
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import ParamValidator as validator, Rel
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive, Validator
|
||||
from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive
|
||||
from mindspore._extends import cell_attr_register
|
||||
from ..cell import Cell
|
||||
|
||||
|
@ -240,8 +239,8 @@ class Conv2d(_Conv):
|
|||
"""Initialize depthwise conv2d op"""
|
||||
if context.get_context("device_target") == "Ascend" and self.group > 1:
|
||||
self.dilation = self._dilation
|
||||
validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
|
||||
validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
|
||||
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
|
||||
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
|
||||
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=self.kernel_size,
|
||||
pad_mode=self.pad_mode,
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, ParamValidator as validator
|
||||
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator
|
||||
import mindspore.context as context
|
||||
from .normalization import BatchNorm2d, BatchNorm1d
|
||||
from .activation import get_activation, ReLU, LeakyReLU
|
||||
|
@ -133,7 +133,7 @@ class Conv2dBnAct(Cell):
|
|||
has_bias=has_bias,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init)
|
||||
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||
self.has_bn = Validator.check_bool("has_bn", has_bn)
|
||||
self.has_act = activation is not None
|
||||
self.after_fake = after_fake
|
||||
if has_bn:
|
||||
|
@ -201,7 +201,7 @@ class DenseBnAct(Cell):
|
|||
weight_init,
|
||||
bias_init,
|
||||
has_bias)
|
||||
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||
self.has_bn = Validator.check_bool("has_bn", has_bn)
|
||||
self.has_act = activation is not None
|
||||
self.after_fake = after_fake
|
||||
if has_bn:
|
||||
|
@ -320,10 +320,10 @@ class FakeQuantWithMinMax(Cell):
|
|||
quant_delay=0):
|
||||
"""Initialize FakeQuantWithMinMax layer"""
|
||||
super(FakeQuantWithMinMax, self).__init__()
|
||||
validator.check_type("min_init", min_init, [int, float])
|
||||
validator.check_type("max_init", max_init, [int, float])
|
||||
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||
validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
|
||||
Validator.check_type("min_init", min_init, [int, float])
|
||||
Validator.check_type("max_init", max_init, [int, float])
|
||||
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
|
||||
self.min_init = min_init
|
||||
self.max_init = max_init
|
||||
self.num_bits = num_bits
|
||||
|
@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell):
|
|||
|
||||
# initialize convolution op and Parameter
|
||||
if context.get_context('device_target') == "Ascend" and group > 1:
|
||||
validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||
validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||
Validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||
Validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=self.kernel_size,
|
||||
pad_mode=pad_mode,
|
||||
|
@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
self.bias = None
|
||||
# initialize convolution op and Parameter
|
||||
if context.get_context('device_target') == "Ascend" and group > 1:
|
||||
validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||
validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||
Validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||
Validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=self.kernel_size,
|
||||
pad_mode=pad_mode,
|
||||
|
|
|
@ -22,7 +22,7 @@ import mindspore.context as context
|
|||
|
||||
from ... import log as logger
|
||||
from ... import nn, ops
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import Tensor
|
||||
from ...common import dtype as mstype
|
||||
|
@ -89,19 +89,19 @@ class ConvertToQuantNetwork:
|
|||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
||||
self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
|
||||
self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
|
||||
self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"])
|
||||
self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
|
||||
self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
|
||||
self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
|
||||
self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0])
|
||||
self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1])
|
||||
self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0])
|
||||
self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1])
|
||||
self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0])
|
||||
self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1])
|
||||
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
||||
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
|
||||
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
|
||||
self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"])
|
||||
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
|
||||
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
|
||||
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
|
||||
self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0])
|
||||
self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1])
|
||||
self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0])
|
||||
self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1])
|
||||
self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0])
|
||||
self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1])
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
|
||||
|
@ -316,7 +316,7 @@ class ExportToQuantInferNetwork:
|
|||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
||||
network = validator.check_isinstance('network', network, (nn.Cell,))
|
||||
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
||||
self.input_scale = 1 / std_dev
|
||||
self.input_zero_point = round(mean)
|
||||
self.data_type = mstype.int8
|
||||
|
@ -510,8 +510,8 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
|
|||
supported_device = ["Ascend", "GPU"]
|
||||
supported_formats = ['AIR', 'MINDIR']
|
||||
|
||||
mean = validator.check_type("mean", mean, (int, float))
|
||||
std_dev = validator.check_type("std_dev", std_dev, (int, float))
|
||||
mean = Validator.check_type("mean", mean, (int, float))
|
||||
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
|
||||
|
||||
if context.get_context('device_target') not in supported_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
|
Loading…
Reference in New Issue