forked from mindspore-Ecosystem/mindspore
Add prim name to error message for other operators left
This commit is contained in:
parent
72f42fc37c
commit
6770c66ed9
|
@ -15,8 +15,8 @@
|
|||
|
||||
"""Operators for quantization."""
|
||||
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Rel, check_bool, check_int_positive, check_int
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ...common import dtype as mstype
|
||||
|
||||
|
@ -69,36 +69,31 @@ class FakeQuantWithMinMax(PrimitiveWithInfer):
|
|||
training=True):
|
||||
"""init FakeQuantWithMinMax OP"""
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
"Attr \'ema\' and \'ema_decay\' should set together.")
|
||||
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = check_bool(ema)
|
||||
self.symmetric = check_bool(symmetric)
|
||||
self.narrow_range = check_bool(narrow_range)
|
||||
self.training = check_bool(training)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
|
||||
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
|
||||
self.training = validator.check_value_type('training', training, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['out'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x shape", len(x_shape), 1, Rel.GT)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ)
|
||||
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ)
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"x type", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("min type", min_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename("max type", max_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -109,29 +104,24 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, quant_delay=0):
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'],
|
||||
outputs=['dx'])
|
||||
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
||||
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
||||
validator.check("dout shape", dout_shape, "x shape", x_shape)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ)
|
||||
validator.check_integer("max shape", len(min_shape), 1, Rel.EQ)
|
||||
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name)
|
||||
return dout_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"dout type", dout_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"x type", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("min type", min_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename("max type", max_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
|
||||
return dout_type
|
||||
|
||||
|
||||
|
@ -172,37 +162,30 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
|
|||
training=True):
|
||||
"""init FakeQuantWithMinMaxPerChannel OP"""
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.")
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
"Attr \'ema\' and \'ema_decay\' should set together.")
|
||||
raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
||||
|
||||
self.ema = check_bool(ema)
|
||||
self.symmetric = check_bool(symmetric)
|
||||
self.narrow_range = check_bool(narrow_range)
|
||||
self.training = check_bool(training)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['out'])
|
||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
|
||||
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
|
||||
self.training = validator.check_value_type('training', training, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x shape", len(x_shape), 1, Rel.GT)
|
||||
validator.check_integer(
|
||||
"min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ)
|
||||
validator.check_integer(
|
||||
"max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ)
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||
validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name)
|
||||
validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"x type", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("min type", min_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename("max type", max_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -214,12 +197,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
|
|||
def __init__(self, num_bits=8, quant_delay=0):
|
||||
"""init FakeQuantWithMinMaxPerChannel Fill"""
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError("Attr \'num_bits\' is not support.")
|
||||
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
|
||||
self.quant_delay = check_int(quant_delay)
|
||||
self.num_bits = check_int_positive(num_bits)
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'],
|
||||
outputs=['dx'])
|
||||
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
||||
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
||||
validator.check("dout shape", dout_shape, "x shape", x_shape)
|
||||
|
@ -227,13 +209,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
|
|||
return dout_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
validator.check_typename(
|
||||
"dout", dout_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename("x", x_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"min", min_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"max", max_type, (mstype.float16, mstype.float32))
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"min": min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"max": max_type}, valid_types, self.name)
|
||||
return dout_type
|
||||
|
||||
|
||||
|
@ -269,31 +249,26 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0):
|
||||
"""init batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range(
|
||||
'momentum', momentum, 0, 1, Rel.INC_BOTH)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon)
|
||||
self.is_training = check_bool(is_training)
|
||||
self.freeze_bn = check_int(freeze_bn)
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
|
||||
self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
|
||||
outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
|
||||
|
||||
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
|
||||
validator.check("mean shape", mean_shape,
|
||||
"gamma_shape", variance_shape)
|
||||
validator.check("mean_shape size",
|
||||
mean_shape[0], "input channel", x_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
|
||||
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name)
|
||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||
return mean_shape, mean_shape, mean_shape, mean_shape
|
||||
|
||||
def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
|
||||
validator.check("input type", x_type, "mean type", mean_type)
|
||||
validator.check("input type", x_type, "variance type", variance_type)
|
||||
validator.check_typename("input type", x_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
args = {"x": x_type, "mean": mean_type, "variance": variance_type}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
|
||||
return x_type, x_type, x_type, x_type
|
||||
|
||||
|
||||
|
@ -304,39 +279,31 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0):
|
||||
"""init BatchNormGrad layer"""
|
||||
self.is_training = check_bool(is_training)
|
||||
self.freeze_bn = check_int(freeze_bn)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon)
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
|
||||
outputs=['dx'])
|
||||
|
||||
def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
|
||||
global_step_shape):
|
||||
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
||||
"d_batch_std shape", d_batch_std_shape)
|
||||
"d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
|
||||
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
||||
"batch_mean shape", batch_mean_shape)
|
||||
"batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
|
||||
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
||||
"batch_std shape", batch_std_shape)
|
||||
validator.check(
|
||||
"x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
"batch_std shape", batch_std_shape, Rel.EQ, self.name)
|
||||
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ,
|
||||
self.name)
|
||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
|
||||
global_step_type):
|
||||
validator.check("input type", x_type,
|
||||
"d_batch_mean type", d_batch_mean_type)
|
||||
validator.check("input type", x_type,
|
||||
"d_batch_std type", d_batch_std_type)
|
||||
validator.check("input type", x_type,
|
||||
"batch_mean type", batch_mean_type)
|
||||
validator.check("input type", x_type, "batch_std type", batch_std_type)
|
||||
validator.check_typename("input type", x_type,
|
||||
(mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
|
||||
"batch_mean": batch_mean_type, "batch_std": batch_std_type}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -364,18 +331,14 @@ class CorrectionMul(PrimitiveWithInfer):
|
|||
outputs=['out'])
|
||||
|
||||
def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_std shape", running_std_shape)
|
||||
validator.check(
|
||||
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel])
|
||||
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel],
|
||||
Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, batch_std_type, running_std_type):
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_std type", running_std_type)
|
||||
validator.check("batch_std_type", batch_std_type, "x_type", x_type)
|
||||
validator.check_typename(
|
||||
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
|
||||
args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -390,20 +353,16 @@ class CorrectionMulGrad(PrimitiveWithInfer):
|
|||
outputs=['dx', 'd_gamma'])
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
|
||||
validator.check("dout shape", dout_shape, "x_shape x", x_shape)
|
||||
validator.check(
|
||||
"gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel])
|
||||
validator.check(
|
||||
"running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel])
|
||||
validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
|
||||
validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel],
|
||||
Rel.EQ, self.name)
|
||||
validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel],
|
||||
Rel.EQ, self.name)
|
||||
return x_shape, gamma_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
|
||||
validator.check("x type", x_type, "dout type", dout_type)
|
||||
validator.check("gamma type", gamma_type, "dout type", dout_type)
|
||||
validator.check("running_std type", running_std_type,
|
||||
"dout type", dout_type)
|
||||
validator.check_typename(
|
||||
"dout type", dout_type, (mstype.float16, mstype.float32))
|
||||
args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
return x_type, x_type
|
||||
|
||||
|
||||
|
@ -432,46 +391,29 @@ class BatchNormFold2(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, freeze_bn=0):
|
||||
"""init conv2d fold layer"""
|
||||
self.freeze_bn = check_int(freeze_bn)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
|
||||
'running_std', 'running_mean', 'global_step'],
|
||||
outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
|
||||
running_mean_shape, global_step_shape):
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_std shape", running_std_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"batch_mean shape", batch_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"beta shape", beta_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_mean shape", running_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"batch_mean shape", gamma_shape)
|
||||
validator.check(
|
||||
"batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel],
|
||||
Rel.EQ, self.name)
|
||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
|
||||
running_mean_type, global_step_type):
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_std type", running_std_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"batch_mean type", batch_mean_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"beta type", beta_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"running_mean type", running_mean_type)
|
||||
validator.check("batch_std type", batch_std_type,
|
||||
"gamma type", gamma_type)
|
||||
validator.check("x_type", x_type, "batch_std type", batch_std_type)
|
||||
validator.check_typename(
|
||||
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
|
||||
"beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -491,18 +433,13 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
|
|||
def infer_shape(self, dout_shape, x_shape, gamma_shape,
|
||||
batch_std_shape, batch_mean_shape,
|
||||
running_std_shape, running_mean_shape, global_step_shape):
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"batch_mean shape", batch_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_std shape", running_std_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"running_mean shape", running_mean_shape)
|
||||
validator.check("batch_std shape", batch_std_shape,
|
||||
"gamma shape", gamma_shape)
|
||||
validator.check(
|
||||
"batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel])
|
||||
validator.check_integer("global_step shape",
|
||||
len(global_step_shape), 1, Rel.EQ)
|
||||
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
|
||||
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel],
|
||||
Rel.EQ, self.name)
|
||||
validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name)
|
||||
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, gamma_type,
|
||||
|
@ -518,8 +455,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
|
|||
"running_mean type", running_mean_type)
|
||||
validator.check("batch_std_type", batch_std_type,
|
||||
"dout type", dout_type)
|
||||
validator.check_typename(
|
||||
"batch_std type", batch_std_type, (mstype.float16, mstype.float32))
|
||||
validator.check_typename(
|
||||
"global_step type", global_step_type, (mstype.int32,))
|
||||
args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
|
||||
"running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
|
||||
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
||||
validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name)
|
||||
return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
|
||||
"""comm_ops"""
|
||||
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...communication.management import get_rank, get_group_size, GlobalComm, get_group
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
@ -148,12 +149,10 @@ class AllGather(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
if not isinstance(get_group(group), str):
|
||||
raise TypeError("The group of AllGather should be str.")
|
||||
validator.check_value_type('group', get_group(group), (str,), self.name)
|
||||
self.rank = get_rank(get_group(group))
|
||||
self.rank_size = get_group_size(get_group(group))
|
||||
if self.rank >= self.rank_size:
|
||||
raise ValueError("The rank of AllGather should be less than the rank_size.")
|
||||
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
|
||||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', get_group(group))
|
||||
|
||||
|
@ -163,7 +162,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("AllGather does not support 'Bool' as the dtype of input!")
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
@ -205,10 +204,8 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
if not isinstance(op, type(ReduceOp.SUM)):
|
||||
raise TypeError("The operation of ReduceScatter should be {}.".format(type(ReduceOp.SUM)))
|
||||
if not isinstance(get_group(group), str):
|
||||
raise TypeError("The group of ReduceScatter should be str.")
|
||||
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
||||
validator.check_value_type('group', get_group(group), (str,), self.name)
|
||||
self.op = op
|
||||
self.rank_size = get_group_size(get_group(group))
|
||||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
|
@ -216,13 +213,13 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x_shape):
|
||||
if x_shape[0] % self.rank_size != 0:
|
||||
raise ValueError("The first dimension of x should be divided by rank_size.")
|
||||
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
|
||||
x_shape[0] = int(x_shape[0]/self.rank_size)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!")
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
@ -270,10 +267,8 @@ class Broadcast(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
if not isinstance(root_rank, int):
|
||||
raise TypeError("The root_rank of Broadcast should be int.")
|
||||
if not isinstance(get_group(group), str):
|
||||
raise TypeError("The group of Broadcast should be str.")
|
||||
validator.check_value_type('root_rank', root_rank, (int,), self.name)
|
||||
validator.check_value_type('group', get_group(group), (str,), self.name)
|
||||
self.add_prim_attr('group', get_group(group))
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -281,7 +276,7 @@ class Broadcast(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!")
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -311,8 +306,7 @@ class _AlltoAll(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""init AlltoAll"""
|
||||
if not isinstance(get_group(group), str):
|
||||
raise TypeError("The group of AllGather should be str.")
|
||||
validator.check_value_type('group', get_group(group), (str,), self.name)
|
||||
self.split_count = split_count
|
||||
self.split_dim = split_dim
|
||||
self.concat_dim = concat_dim
|
||||
|
@ -325,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!")
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
@ -420,6 +414,6 @@ class _GetTensorSlice(PrimitiveWithInfer):
|
|||
|
||||
def infer_value(self, x, dev_mat, tensor_map):
|
||||
from mindspore.parallel._tensor import _load_tensor
|
||||
validator.check_type("dev_mat", dev_mat, [tuple])
|
||||
validator.check_type("tensor_map", tensor_map, [tuple])
|
||||
validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
|
||||
validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
|
||||
return _load_tensor(x, dev_mat, tensor_map)
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
"""control_ops"""
|
||||
|
||||
from ...common import dtype as mstype
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
|
@ -123,11 +124,11 @@ class GeSwitch(PrimitiveWithInfer):
|
|||
raise NotImplementedError
|
||||
|
||||
def infer_shape(self, data, pred):
|
||||
validator.check_scalar_shape_input("pred", pred)
|
||||
validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name)
|
||||
return (data, data)
|
||||
|
||||
def infer_dtype(self, data_type, pred_type):
|
||||
validator.check_type("pred", pred_type, [type(mstype.bool_)])
|
||||
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
|
||||
return (data_type, data_type)
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""debug_ops"""
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer
|
||||
|
||||
|
@ -219,5 +219,5 @@ class Print(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, *inputs):
|
||||
for dtype in inputs:
|
||||
validator.check_subclass("input", dtype, (mstype.tensor, mstype.string))
|
||||
validator.check_subclass("input", dtype, (mstype.tensor, mstype.string), self.name)
|
||||
return mstype.int32
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
"""Other operators."""
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._checkparam import ParamValidator as validator, Rel
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
@ -82,22 +82,21 @@ class BoundingBoxEncode(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
|
||||
validator.check_type('means', means, [tuple])
|
||||
validator.check_type('stds', stds, [tuple])
|
||||
validator.check("means len", len(means), '', 4)
|
||||
validator.check("stds len", len(stds), '', 4)
|
||||
validator.check_value_type('means', means, [tuple], self.name)
|
||||
validator.check_value_type('stds', stds, [tuple], self.name)
|
||||
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
|
||||
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
|
||||
|
||||
def infer_shape(self, anchor_box, groundtruth_box):
|
||||
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0])
|
||||
validator.check('anchor_box shape[1]', anchor_box[1], '', 4)
|
||||
validator.check('groundtruth_box shape[1]', groundtruth_box[1], '', 4)
|
||||
validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ,
|
||||
self.name)
|
||||
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
|
||||
validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name)
|
||||
return anchor_box
|
||||
|
||||
def infer_dtype(self, anchor_box, groundtruth_box):
|
||||
args = {"anchor_box": anchor_box,
|
||||
"groundtruth_box": groundtruth_box
|
||||
}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return anchor_box
|
||||
|
||||
|
||||
|
@ -126,26 +125,24 @@ class BoundingBoxDecode(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
|
||||
validator.check_type('means', means, [tuple])
|
||||
validator.check_type('stds', stds, [tuple])
|
||||
validator.check_type('wh_ratio_clip', wh_ratio_clip, [float])
|
||||
validator.check("means", len(means), '', 4)
|
||||
validator.check("stds", len(stds), '', 4)
|
||||
validator.check_value_type('means', means, [tuple], self.name)
|
||||
validator.check_value_type('stds', stds, [tuple], self.name)
|
||||
validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
|
||||
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
|
||||
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
|
||||
if max_shape is not None:
|
||||
validator.check_type('max_shape', max_shape, [tuple])
|
||||
validator.check("max_shape", len(max_shape), '', 2)
|
||||
validator.check_value_type('max_shape', max_shape, [tuple], self.name)
|
||||
validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name)
|
||||
|
||||
def infer_shape(self, anchor_box, deltas):
|
||||
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0])
|
||||
validator.check('anchor_box shape[1]', anchor_box[1], '', 4)
|
||||
validator.check('deltas shape[1]', deltas[1], '', 4)
|
||||
validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name)
|
||||
validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name)
|
||||
validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name)
|
||||
return anchor_box
|
||||
|
||||
def infer_dtype(self, anchor_box, deltas):
|
||||
args = {"anchor_box": anchor_box,
|
||||
"deltas": deltas
|
||||
}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
args = {"anchor_box": anchor_box, "deltas": deltas}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return anchor_box
|
||||
|
||||
|
||||
|
@ -168,10 +165,10 @@ class CheckValid(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, bboxes_shape, metas_shape):
|
||||
validator.check_shape_length("bboxes shape length", len(bboxes_shape), 2, Rel.EQ)
|
||||
validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ)
|
||||
validator.check_shape_length("img_metas shape length", len(metas_shape), 1, Rel.EQ)
|
||||
validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ)
|
||||
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name)
|
||||
validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name)
|
||||
return bboxes_shape[:-1]
|
||||
|
||||
def infer_dtype(self, bboxes_type, metas_type):
|
||||
|
@ -221,18 +218,16 @@ class IOU(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
|
||||
|
||||
def infer_shape(self, anchor_boxes, gt_boxes):
|
||||
validator.check('gt_boxes shape[1]', gt_boxes[1], '', 4)
|
||||
validator.check('anchor_boxes shape[1]', anchor_boxes[1], '', 4)
|
||||
validator.check('anchor_boxes rank', len(anchor_boxes), '', 2)
|
||||
validator.check('gt_boxes rank', len(gt_boxes), '', 2)
|
||||
validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name)
|
||||
validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name)
|
||||
validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name)
|
||||
validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name)
|
||||
iou = [gt_boxes[0], anchor_boxes[0]]
|
||||
return iou
|
||||
|
||||
def infer_dtype(self, anchor_boxes, gt_boxes):
|
||||
validator.check_subclass("anchor_boxes", anchor_boxes, mstype.tensor)
|
||||
validator.check_subclass("gt_boxes", gt_boxes, mstype.tensor)
|
||||
args = {"anchor_boxes": anchor_boxes, "gt_boxes": gt_boxes}
|
||||
validator.check_type_same(args, (mstype.float16,))
|
||||
validator.check_tensor_type_same(args, (mstype.float16,), self.name)
|
||||
return anchor_boxes
|
||||
|
||||
|
||||
|
@ -270,7 +265,7 @@ class MakeRefKey(Primitive):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, tag):
|
||||
validator.check_type('tag', tag, (str,))
|
||||
validator.check_value_type('tag', tag, (str,), self.name)
|
||||
|
||||
def __call__(self):
|
||||
pass
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
"""Operators for random."""
|
||||
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
@ -52,16 +52,15 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, count=256, seed=0, seed2=0):
|
||||
"""Init RandomChoiceWithMask"""
|
||||
validator.check_type("count", count, [int])
|
||||
validator.check_integer("count", count, 0, Rel.GT)
|
||||
validator.check_type('seed', seed, [int])
|
||||
validator.check_type('seed2', seed2, [int])
|
||||
validator.check_value_type("count", count, [int], self.name)
|
||||
validator.check_integer("count", count, 0, Rel.GT, self.name)
|
||||
validator.check_value_type('seed', seed, [int], self.name)
|
||||
validator.check_value_type('seed2', seed2, [int], self.name)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_shape_length("input_x shape", len(x_shape), 1, Rel.GE)
|
||||
validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
return ([self.count, len(x_shape)], [self.count])
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass('x_dtype', x_dtype, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x_dtype, [mstype.bool_])
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
|
||||
return (mstype.int32, mstype.bool_)
|
||||
|
|
Loading…
Reference in New Issue