[ME]delete ParamValidator and change all to Validator

This commit is contained in:
chenzomi 2020-10-09 10:52:50 +08:00
parent d5ae6fdd84
commit 6c9b6d491d
5 changed files with 78 additions and 104 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Check parameters."""
import re
import inspect
import math
@ -20,10 +21,9 @@ from enum import Enum
from functools import reduce, wraps
from itertools import repeat
from collections.abc import Iterable
import numpy as np
from mindspore import log as logger
from .common import dtype as mstype
from mindspore.common import dtype as mstype
# Named string regular expression
@ -103,18 +103,17 @@ class Validator:
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
"""
Method for judging relation between two int values or list/tuple made up of ints.
This method is not suitable for judging relation between floats, since it does not consider float error.
"""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name):
def check_integer(arg_name, arg_value, value, rel, prim_name=None):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
@ -135,6 +134,20 @@ class Validator:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether an int value is in some range."""
@ -208,6 +221,27 @@ class Validator:
"""Checks valid value."""
if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_type_same(args, valid_values, prim_name):
@ -239,7 +273,6 @@ class Validator:
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
"""
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
"""
@ -335,63 +368,6 @@ class Validator:
f'{tuple(exp_shape)}, but got {shape}.')
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
def check_int(input_param):
"""Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool):
@ -638,7 +614,6 @@ def args_type_check(*type_args, **type_kwargs):
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)
return wrapper
return type_check

View File

@ -21,7 +21,7 @@ from ...ops import operations as P
from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
from ...ops.composite import multitype_ops as C
from ...ops.operations import _grad_ops as G
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator
from ..cell import Cell, GraphKernel
@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel):
use_locking=False,
gradient_scale=1.0):
super(ApplyMomentum, self).__init__()
self.gradient_scale = validator.check_type('gradient_scale', gradient_scale, [float])
self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float])
self.fake_output_assign_1 = InplaceAssign()
self.fake_output_assign_1.add_prim_attr("fake_output", True)
self.fake_output_assign_2 = InplaceAssign()
@ -334,7 +334,7 @@ class ReduceMean(GraphKernel):
def __init__(self, keep_dims=True):
super(ReduceMean, self).__init__()
self.keep_dims = validator.check_type('keep_dims', keep_dims, [bool])
self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool])
self.sum = P.ReduceSum(self.keep_dims)
def construct(self, x, axis):
@ -431,8 +431,8 @@ class LayerNormForward(GraphKernel):
""" Forward function of the LayerNorm operator. """
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
super(LayerNormForward, self).__init__()
self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int])
self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int])
self.mul = P.Mul()
self.sum_keep_dims = P.ReduceSum(keep_dims=True)
self.sub = P.Sub()
@ -686,7 +686,7 @@ class LogSoftmax(GraphKernel):
def __init__(self, axis=-1):
super(LogSoftmax, self).__init__()
self.axis = validator.check_type('axis', axis, [int])
self.axis = Validator.check_type('axis', axis, [int])
self.max_keep_dims = P.ReduceMax(keep_dims=True)
self.sub = P.Sub()
self.exp = P.Exp()
@ -952,13 +952,13 @@ class Softmax(GraphKernel):
def __init__(self, axis):
super(Softmax, self).__init__()
validator.check_type("axis", axis, [int, tuple])
Validator.check_type("axis", axis, [int, tuple])
if isinstance(axis, int):
self.axis = (axis,)
else:
self.axis = axis
for item in self.axis:
validator.check_type("item of axis", item, [int])
Validator.check_type("item of axis", item, [int])
self.max = P.ReduceMax(keep_dims=True)
self.sub = P.Sub()
self.exp = P.Exp()

View File

@ -21,8 +21,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator, Rel
from mindspore._checkparam import check_bool, twice, check_int_positive, Validator
from mindspore._checkparam import Validator, Rel, check_bool, twice, check_int_positive
from mindspore._extends import cell_attr_register
from ..cell import Cell
@ -240,8 +239,8 @@ class Conv2d(_Conv):
"""Initialize depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation
validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,

View File

@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, ParamValidator as validator
from mindspore._checkparam import Rel, check_int_positive, check_bool, twice, Validator
import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU
@ -133,7 +133,7 @@ class Conv2dBnAct(Cell):
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_bn = Validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None
self.after_fake = after_fake
if has_bn:
@ -201,7 +201,7 @@ class DenseBnAct(Cell):
weight_init,
bias_init,
has_bias)
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_bn = Validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None
self.after_fake = after_fake
if has_bn:
@ -320,10 +320,10 @@ class FakeQuantWithMinMax(Cell):
quant_delay=0):
"""Initialize FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
validator.check_type("min_init", min_init, [int, float])
validator.check_type("max_init", max_init, [int, float])
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
@ -489,8 +489,8 @@ class Conv2dBnFoldQuant(Cell):
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
@ -674,8 +674,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.bias = None
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
Validator.check_integer('group', group, in_channels, Rel.EQ)
Validator.check_integer('group', group, out_channels, Rel.EQ)
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,

View File

@ -22,7 +22,7 @@ import mindspore.context as context
from ... import log as logger
from ... import nn, ops
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator
from ..._checkparam import Rel
from ...common import Tensor
from ...common import dtype as mstype
@ -89,19 +89,19 @@ class ConvertToQuantNetwork:
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, **kwargs):
self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"])
self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0])
self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1])
self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0])
self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1])
self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0])
self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1])
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = Validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = Validator.check_bool("bn fold", kwargs["bn_fold"])
self.freeze_bn = Validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = Validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = Validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = Validator.check_bool("per channel", kwargs["per_channel"][0])
self.act_channel = Validator.check_bool("per channel", kwargs["per_channel"][-1])
self.weight_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][0])
self.act_symmetric = Validator.check_bool("symmetric", kwargs["symmetric"][-1])
self.weight_range = Validator.check_bool("narrow range", kwargs["narrow_range"][0])
self.act_range = Validator.check_bool("narrow range", kwargs["narrow_range"][-1])
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
@ -316,7 +316,7 @@ class ExportToQuantInferNetwork:
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = validator.check_isinstance('network', network, (nn.Cell,))
network = Validator.check_isinstance('network', network, (nn.Cell,))
self.input_scale = 1 / std_dev
self.input_zero_point = round(mean)
self.data_type = mstype.int8
@ -510,8 +510,8 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR']
mean = validator.check_type("mean", mean, (int, float))
std_dev = validator.check_type("std_dev", std_dev, (int, float))
mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))