Add primitive name to param error message for math_ops.py

This commit is contained in:
fary86 2020-04-02 07:50:13 +08:00
parent cc0ba93d17
commit 69ed72f10d
11 changed files with 1003 additions and 129 deletions

View File

@ -15,6 +15,7 @@
"""Check parameters.""" """Check parameters."""
import re import re
from enum import Enum from enum import Enum
from functools import reduce
from itertools import repeat from itertools import repeat
from collections import Iterable from collections import Iterable
@ -93,8 +94,131 @@ rel_strs = {
} }
class Validator:
"""validator for checking input parameters"""
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
"""
Method for judging relation between two int values or list/tuple made up of ints.
This method is not suitable for judging relation between floats, since it does not consider float error.
"""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For {prim_name} the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
f' but got {arg_value}.')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether an int value is in some range."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, prim_name):
"""Check whether some type is sublcass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""check whether the element types of input tensors are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
elem_type = arg_val.element_type()
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)
@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
def _check_argument_type(arg):
arg_key, arg_val = arg
if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type()
if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.')
return arg
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
excp_flag = False
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
arg1_type = arg1_type.element_type()
arg2_type = arg2_type.element_type()
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
pass
else:
excp_flag = True
if excp_flag or arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
reduce(_check_types_same, map(_check_argument_type, args.items()))
@staticmethod
def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Check whether a values is instance of some types."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
f'{"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
class ParamValidator: class ParamValidator:
"""Parameter validator.""" """Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod @staticmethod
def equal(arg_name, arg_value, cond_str, cond): def equal(arg_name, arg_value, cond_str, cond):

View File

@ -16,13 +16,14 @@
"""broadcast""" """broadcast"""
def _get_broadcast_shape(x_shape, y_shape): def _get_broadcast_shape(x_shape, y_shape, prim_name):
""" """
Doing broadcast between tensor x and tensor y. Doing broadcast between tensor x and tensor y.
Args: Args:
x_shape (list): The shape of tensor x. x_shape (list): The shape of tensor x.
y_shape (list): The shape of tensor y. y_shape (list): The shape of tensor y.
prim_name (str): Primitive name.
Returns: Returns:
List, the shape that broadcast between tensor x and tensor y. List, the shape that broadcast between tensor x and tensor y.
@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape):
elif x_shape[i] == y_shape[i]: elif x_shape[i] == y_shape[i]:
broadcast_shape_back.append(x_shape[i]) broadcast_shape_back.append(x_shape[i])
else: else:
raise ValueError("The x_shape {} and y_shape {} can not broadcast.".format(x_shape, y_shape)) raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
prim_name, x_shape, y_shape))
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
broadcast_shape = broadcast_shape_front + broadcast_shape_back broadcast_shape = broadcast_shape_front + broadcast_shape_back

View File

@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from ..operations.math_ops import _check_infer_attr_reduce, _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims):
validator.check_type('keep_dims', keep_dims, [bool])
validator.check_type('axis', axis, [int, tuple])
if isinstance(axis, tuple):
for index, value in enumerate(axis):
validator.check_type('axis[%d]' % index, value, [int])
class ExpandDims(PrimitiveWithInfer): class ExpandDims(PrimitiveWithInfer):
""" """
@ -1091,7 +1098,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
@ -1137,7 +1144,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):

View File

@ -19,7 +19,7 @@ import numpy as np
from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype from ..._c_expression import signature_dtype as sig_dtype
from ..._checkparam import ParamValidator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
@ -27,16 +27,16 @@ from .._utils import _get_broadcast_shape
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register
def _infer_shape_reduce(x, axis, keep_dims): def _infer_shape_reduce(x, axis, keep_dims, prim_name):
"""Common infer for reduce operator""" """Common infer for reduce operator"""
def reduce_one_axis(one_axis): def reduce_one_axis(one_axis):
validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT) validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name)
if one_axis < 0: if one_axis < 0:
one_axis += dim one_axis += dim
axis_reduce.add(one_axis) axis_reduce.add(one_axis)
validator.check_type('axis', axis, [int, tuple, list]) validator.check_value_type('axis', axis, [int, tuple, list], prim_name)
dim = len(x) dim = len(x)
axis_reduce = set() axis_reduce = set()
@ -48,7 +48,7 @@ def _infer_shape_reduce(x, axis, keep_dims):
return [1] * dim return [1] * dim
return [] return []
for index, one_axis in enumerate(axis): for index, one_axis in enumerate(axis):
validator.check_type('axis[%d]' % index, one_axis, [int]) validator.check_value_type('axis[%d]' % index, one_axis, [int], prim_name)
reduce_one_axis(one_axis) reduce_one_axis(one_axis)
out_shape = [] out_shape = []
@ -61,14 +61,6 @@ def _infer_shape_reduce(x, axis, keep_dims):
return out_shape return out_shape
def _check_infer_attr_reduce(axis, keep_dims):
validator.check_type('keep_dims', keep_dims, [bool])
validator.check_type('axis', axis, [int, tuple])
if isinstance(axis, tuple):
for index, value in enumerate(axis):
validator.check_type('axis[%d]' % index, value, [int])
class _BinaryOp(PrimitiveWithInfer): class _BinaryOp(PrimitiveWithInfer):
""" """
Define binary operators. Define binary operators.
@ -82,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape): def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape) return _get_broadcast_shape(x_shape, y_shape, self.prim_name())
class _MathBinaryOp(_BinaryOp): class _MathBinaryOp(_BinaryOp):
@ -91,15 +83,13 @@ class _MathBinaryOp(_BinaryOp):
""" """
@staticmethod @staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type): def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
args_type = {"x": x_dtype, "y": y_dtype} args_type = {"x": x_dtype, "y": y_dtype}
validator.check_args_tensor(args_type) validator.check_tensor_type_same(args_type, valid_dtype, prim_name)
args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype}
validator.check_type_same(args_dtype, valid_dtype)
return x_dtype return x_dtype
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype) return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name())
class TensorAdd(_MathBinaryOp): class TensorAdd(_MathBinaryOp):
@ -166,7 +156,7 @@ class AssignAdd(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"value": value}
validator.check_type_same(args, mstype.number_type) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
return value return value
@ -207,7 +197,7 @@ class AssignSub(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"value": value}
validator.check_type_same(args, mstype.number_type) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
return value return value
@ -228,15 +218,16 @@ class _Reduce(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, keep_dims=False): def __init__(self, keep_dims=False):
"""init Reduce""" """init Reduce"""
validator.check_type('keep_dims', keep_dims, [bool]) validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name())
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
axis_v = axis['value'] axis_v = axis['value']
input_shp = input_x['shape'] input_shp = input_x['shape']
validator.check_subclass('input_x', input_x['dtype'], mstype.tensor) args = {'input_x': input_x['dtype']}
validator.check_typename('input_x', input_x['dtype'], valid_dtype) validator.check_tensor_type_same(args, valid_dtype, self.prim_name())
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims)
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name())
return {'shape': input_shp, return {'shape': input_shp,
'dtype': input_x['dtype'], 'dtype': input_x['dtype'],
'value': None} 'value': None}
@ -471,16 +462,17 @@ class CumProd(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, exclusive=False, reverse=False): def __init__(self, exclusive=False, reverse=False):
self.exclusive = validator.check_type("exclusive", exclusive, [bool]) cls_name = self.prim_name()
self.reverse = validator.check_type("reverse", reverse, [bool]) self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
def infer_shape(self, x_shape, axis_shape): def infer_shape(self, x_shape, axis_shape):
return x_shape return x_shape
def infer_dtype(self, x_type, axis_type): def infer_dtype(self, x_type, axis_type):
validator.check_subclass('x_type', x_type, mstype.tensor) cls_name = self.prim_name()
validator.check_typename('x_type', x_type, mstype.number_type) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
validator.check_subclass("axis_type", axis_type, mstype.int_) validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type return x_type
@ -514,8 +506,9 @@ class MatMul(PrimitiveWithInfer):
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True self.__setattr_flag__ = True
validator.check_type("transpose_a", transpose_a, [bool]) cls_name = self.prim_name()
validator.check_type("transpose_b", transpose_b, [bool]) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
def check_shape_size(self, x, y): def check_shape_size(self, x, y):
if len(x) != 2 or len(y) != 2: if len(x) != 2 or len(y) != 2:
@ -524,11 +517,11 @@ class MatMul(PrimitiveWithInfer):
def infer_shape(self, x, y): def infer_shape(self, x, y):
self.check_shape_size(x, y) self.check_shape_size(x, y)
cls_name = self.__class__.__name__ cls_name = self.prim_name()
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for i in range(len(x) - 2): for i in range(len(x) - 2):
if x[i] != y[i]: if x[i] != y[i]:
raise ValueError(f'{cls_name} shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}') raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')
# validate whether last two dims satifing matrix multiply # validate whether last two dims satifing matrix multiply
x_last = x[-2:] x_last = x[-2:]
@ -537,8 +530,8 @@ class MatMul(PrimitiveWithInfer):
x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0] x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0]
y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1] y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1]
if x_col != y_row: if x_col != y_row:
raise ValueError(f'{cls_name} evaluator shapes of inputs can not do this operator, got {x_col} and {y_row}' raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
+ f' for {cls_name}, with x shape {x}(transpose_a={self.transpose_a})' + f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})'
+ f', y shape {y}(transpose_b={self.transpose_b}).') + f', y shape {y}(transpose_b={self.transpose_b}).')
# set attribute # set attribute
self.add_prim_attr('transpose_x1', self.transpose_a) self.add_prim_attr('transpose_x1', self.transpose_a)
@ -548,10 +541,8 @@ class MatMul(PrimitiveWithInfer):
return ret_dims return ret_dims
def infer_dtype(self, x, y): def infer_dtype(self, x, y):
validator.check_subclass("x", x, mstype.tensor) args = {"x": x, "y": y}
validator.check_subclass("y", y, mstype.tensor) validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name())
args = {"x dtype": x, "y dtype": y}
validator.check_type_same(args, mstype.float_type + mstype.int_type)
return x return x
@ -595,12 +586,13 @@ class BatchMatMul(MatMul):
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True self.__setattr_flag__ = True
validator.check_type("transpose_a", transpose_a, [bool]) cls_name = self.prim_name()
validator.check_type("transpose_b", transpose_b, [bool]) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
def check_shape_size(self, x, y): def check_shape_size(self, x, y):
if len(x) != len(y) or len(x) < 3: if len(x) != len(y) or len(x) < 3:
raise ValueError('BatchMatMul input x, y should be the same dimension size and should be ' raise ValueError('For \'BatchMatMul\' input x, y should be the same dimension size and should be '
'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}')
@ -632,18 +624,17 @@ class CumSum(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, exclusive=False, reverse=False): def __init__(self, exclusive=False, reverse=False):
"""init cumsum""" """init cumsum"""
self.exclusive = validator.check_type('exclusive', exclusive, [bool]) cls_name = self.prim_name()
self.add_prim_attr("exclusive", self.exclusive) validator.check_value_type('exclusive', exclusive, [bool], cls_name)
self.reverse = validator.check_type('reverse', reverse, [bool]) validator.check_value_type('reverse', reverse, [bool], cls_name)
self.add_prim_attr("reverse", self.reverse)
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
def __infer__(self, x, axis): def __infer__(self, x, axis):
cls_name = self.prim_name()
x_shp = x['shape'] x_shp = x['shape']
validator.check_type('axis', axis['value'], [int]) validator.check_value_type('axis', axis['value'], [int], cls_name)
validator.check_subclass('x', x['dtype'], mstype.tensor) valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_typename('x', x['dtype'], [mstype.uint8, mstype.int8, validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
mstype.int32, mstype.float16, mstype.float32])
return {'shape': x_shp, return {'shape': x_shp,
'dtype': x['dtype'], 'dtype': x['dtype'],
'value': None} 'value': None}
@ -684,21 +675,22 @@ class AddN(PrimitiveWithInfer):
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def infer_shape(self, inputs): def infer_shape(self, inputs):
validator.check_integer("inputs", len(inputs), 1, Rel.GE) cls_name = self.prim_name()
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
self.add_prim_attr('n', len(inputs)) self.add_prim_attr('n', len(inputs))
shp0 = inputs[0] shp0 = inputs[0]
for i, shp in enumerate(inputs): for i, shp in enumerate(inputs):
validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0) validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0, Rel.EQ, cls_name)
return shp0 return shp0
def infer_dtype(self, inputs): def infer_dtype(self, inputs):
validator.check_type("inputs", inputs, [tuple, list]) cls_name = self.prim_name()
validator.check_integer("inputs", len(inputs), 1, Rel.GE) validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
args = {} args = {}
for i, dtype in enumerate(inputs): for i, dtype in enumerate(inputs):
validator.check_subclass(f"inputs[{i}]", dtype, mstype.tensor)
args[f"inputs[{i}]"] = dtype args[f"inputs[{i}]"] = dtype
validator.check_type_same(args, mstype.number_type + (mstype.bool_,)) validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
return inputs[0] return inputs[0]
@ -722,8 +714,7 @@ class Neg(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_subclass("input_x", input_x, mstype.tensor) validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name())
validator.check_typename("input_x", input_x, mstype.number_type)
return input_x return input_x
@ -806,8 +797,7 @@ class Square(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
validator.check_typename("x_dtype", x_type, mstype.number_type)
return x_type return x_type
@ -836,8 +826,7 @@ class Rsqrt(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
validator.check_typename("x_dtype", x_type, mstype.number_type)
return x_type return x_type
@ -866,8 +855,7 @@ class Sqrt(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
validator.check_typename("x_dtype", x_type, mstype.number_type)
return x_type return x_type
@ -897,7 +885,7 @@ class Reciprocal(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor) validator.check_subclass("x", x, mstype.tensor, self.prim_name())
return x return x
@ -935,8 +923,7 @@ class Pow(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x, power): def infer_dtype(self, x, power):
validator.check_subclass("x", x, mstype.tensor) validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name())
validator.check_typename("power", power, mstype.number_type)
return x return x
@ -966,7 +953,7 @@ class Exp(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor) validator.check_subclass("x", x_type, mstype.tensor, self.prim_name())
return x_type return x_type
@ -995,7 +982,7 @@ class Log(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor) validator.check_subclass("x", x, mstype.tensor, self.prim_name())
return x return x
@ -1178,8 +1165,7 @@ class Floor(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor) validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name())
validator.check_typename("x_dtype", x_dtype, mstype.float_type)
return x_dtype return x_dtype
@ -1234,8 +1220,7 @@ class Acosh(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x_dtype", x, mstype.tensor) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_typename('x_dtype', x, mstype.number_type)
return x return x
@ -1245,15 +1230,13 @@ class _LogicBinaryOp(_BinaryOp):
""" """
@staticmethod @staticmethod
def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type): def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None):
args_type = {"x": x_dtype, "y": y_dtype} args_dtype = {"x": x_dtype, "y": y_dtype}
validator.check_args_tensor(args_type) validator.check_tensor_type_same(args_dtype, valid_type, prim_name)
args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype}
validator.check_type_same(args_dtype, valid_type)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name())
class Equal(_LogicBinaryOp): class Equal(_LogicBinaryOp):
@ -1289,7 +1272,7 @@ class Equal(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,)) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
class EqualCount(PrimitiveWithInfer): class EqualCount(PrimitiveWithInfer):
@ -1318,11 +1301,13 @@ class EqualCount(PrimitiveWithInfer):
"""init EqualCount""" """init EqualCount"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, w_shape): def infer_shape(self, x_shape, y_shape):
output_shape = (1,) output_shape = (1,)
return output_shape return output_shape
def infer_dtype(self, x_dtype, w_dtype): def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name())
return x_dtype return x_dtype
@ -1359,7 +1344,7 @@ class NotEqual(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,)) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
class Greater(_LogicBinaryOp): class Greater(_LogicBinaryOp):
@ -1495,8 +1480,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor) validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name())
validator.check_typename("x_dtype", x_dtype, [mstype.bool_])
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
@ -1526,7 +1510,7 @@ class LogicalAnd(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,)) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
class LogicalOr(_LogicBinaryOp): class LogicalOr(_LogicBinaryOp):
@ -1555,7 +1539,7 @@ class LogicalOr(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,)) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
class NPUAllocFloatStatus(PrimitiveWithInfer): class NPUAllocFloatStatus(PrimitiveWithInfer):
@ -1616,13 +1600,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True) self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ) cls_name = self.prim_name()
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ) validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
args = {"x_dtype": x_dtype} validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
validator.check_type_same(args, [mstype.float32])
return mstype.float32 return mstype.float32
@ -1658,13 +1642,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True) self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ) cls_name = self.prim_name()
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ) validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
args = {"x_dtype": x_dtype} validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
validator.check_type_same(args, [mstype.float32])
return mstype.float32 return mstype.float32
@ -1692,8 +1676,7 @@ class Cos(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x_dtype", x, mstype.tensor) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_typename('x_dtype', x, mstype.number_type)
return x return x
@ -1721,8 +1704,7 @@ class ACos(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x_dtype", x, mstype.tensor) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_typename('x_dtype', x, mstype.number_type)
return x return x
@ -1750,8 +1732,7 @@ class Sin(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x_dtype", x, mstype.tensor) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_typename('x_dtype', x, mstype.number_type)
return x return x
@ -1796,19 +1777,19 @@ class NMSWithMask(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, iou_threshold=0.5): def __init__(self, iou_threshold=0.5):
"""Init NMSWithMask""" """Init NMSWithMask"""
validator.check_type("iou_threshold", iou_threshold, [float]) validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name())
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
def infer_shape(self, bboxes_shape): def infer_shape(self, bboxes_shape):
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ) cls_name = self.prim_name()
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT) validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ) validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
num = bboxes_shape[0] num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))
def infer_dtype(self, bboxes_dtype): def infer_dtype(self, bboxes_dtype):
validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor) validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name())
validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32])
return (bboxes_dtype, mstype.int32, mstype.bool_) return (bboxes_dtype, mstype.int32, mstype.bool_)
@ -1837,8 +1818,7 @@ class Abs(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x_dtype", x_type, mstype.tensor) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
validator.check_typename('x_dtype', x_type, mstype.number_type)
return x_type return x_type
def infer_value(self, x): def infer_value(self, x):
@ -1880,8 +1860,7 @@ class Sign(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass('x', x_dtype, mstype.tensor) validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name())
validator.check_typename('x_dtype', x_dtype, mstype.number_type)
return x_dtype return x_dtype
@ -1910,8 +1889,7 @@ class Round(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x_dtype", x_type, mstype.tensor) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
validator.check_typename('x_dtype', x_type, mstype.number_type)
return x_type return x_type

View File

@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive):
Primitive.__init__(self, name) Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_shape) self.set_prim_type(prim_type.py_infer_shape)
def prim_name(self):
return self.__class__.__name__
def _clone(self): def _clone(self):
""" """
Deeply clones the primitive object. Deeply clones the primitive object.

View File

@ -23,20 +23,25 @@ from ...utils import keyword
class CheckExceptionsEC(IExectorComponent): class CheckExceptionsEC(IExectorComponent):
""" """
Check if the function raises the expected Exception. Check if the function raises the expected Exception and the error message contains specified keywords if not None.
Examples: Examples:
{ {
'block': f, 'block': f,
'exception': Exception 'exception': Exception,
'error_keywords': ['TensorAdd', 'shape']
} }
""" """
def run_function(self, function, inputs, verification_set): def run_function(self, function, inputs, verification_set):
f = function[keyword.block] f = function[keyword.block]
args = inputs[keyword.desc_inputs] args = inputs[keyword.desc_inputs]
e = function.get(keyword.exception, Exception) e = function.get(keyword.exception, Exception)
error_kws = function.get(keyword.error_keywords, None)
try: try:
with pytest.raises(e): with pytest.raises(e) as exec_info:
f(*args) f(*args)
except: except:
raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}") raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}")
if error_kws and any(keyword not in str(exec_info.value) for keyword in error_kws):
raise ValueError('Error message `{}` does not contain all keywords `{}`'.format(
str(exec_info.value), error_kws))

View File

@ -87,8 +87,9 @@ def get_function_config(function):
init_param_with = function.get(keyword.init_param_with, None) init_param_with = function.get(keyword.init_param_with, None)
split_outputs = function.get(keyword.split_outputs, True) split_outputs = function.get(keyword.split_outputs, True)
exception = function.get(keyword.exception, Exception) exception = function.get(keyword.exception, Exception)
error_keywords = function.get(keyword.error_keywords, None)
return delta, max_error, input_selector, output_selector, sampling_times, \ return delta, max_error, input_selector, output_selector, sampling_times, \
reduce_output, init_param_with, split_outputs, exception reduce_output, init_param_with, split_outputs, exception, error_keywords
def get_grad_checking_options(function, inputs): def get_grad_checking_options(function, inputs):
""" """
@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs):
""" """
f = function[keyword.block] f = function[keyword.block]
args = inputs[keyword.desc_inputs] args = inputs[keyword.desc_inputs]
delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _ = \ delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _, _ = \
get_function_config(function) get_function_config(function)
return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output

View File

@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
block = block_config block = block_config
delta, max_error, input_selector, output_selector, \ delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config({}) sampling_times, reduce_output, init_param_with, split_outputs, exception, error_keywords = get_function_config({})
if isinstance(block_config, tuple) and isinstance(block_config[-1], dict): if isinstance(block_config, tuple) and isinstance(block_config[-1], dict):
block = block_config[0] block = block_config[0]
delta, max_error, input_selector, output_selector, \ delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config(block_config[-1]) sampling_times, reduce_output, init_param_with, \
split_outputs, exception, error_keywords = get_function_config(block_config[-1])
if block: if block:
func_list.append({ func_list.append({
@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
keyword.const_first: const_first, keyword.const_first: const_first,
keyword.add_fake_input: add_fake_input, keyword.add_fake_input: add_fake_input,
keyword.split_outputs: split_outputs, keyword.split_outputs: split_outputs,
keyword.exception: exception keyword.exception: exception,
keyword.error_keywords: error_keywords
}) })
if desc_inputs or desc_const: if desc_inputs or desc_const:

View File

@ -73,5 +73,6 @@ keyword.const_first = "const_first"
keyword.add_fake_input = "add_fake_input" keyword.add_fake_input = "add_fake_input"
keyword.fake_input_type = "fake_input_type" keyword.fake_input_type = "fake_input_type"
keyword.exception = "exception" keyword.exception = "exception"
keyword.error_keywords = "error_keywords"
sys.modules[__name__] = keyword sys.modules[__name__] = keyword

View File

@ -234,7 +234,7 @@ raise_set = [
'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}), 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
('ReduceSum_Error', { ('ReduceSum_Error', {
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': ValueError}), 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
] ]

View File

@ -0,0 +1,751 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ops """
import functools
import numpy as np
from mindspore import ops
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.composite as C
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from ..ut_filter import non_graph_engine
from mindspore.common.api import _executor
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward\
import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
from ....mindspore_test_framework.pipeline.gradient.compile_gradient\
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class AssignAddNet(nn.Cell):
def __init__(self,):
super(AssignAddNet, self).__init__()
self.op = P.AssignAdd()
self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_add1")
def construct(self, x):
self.op(self.inputdata, x)
return self.inputdata
class AssignSubNet(nn.Cell):
def __init__(self,):
super(AssignSubNet, self).__init__()
self.op = P.AssignSub()
self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_sub1")
def construct(self, x):
self.op(self.inputdata, x)
return self.inputdata
class ReduceNet(nn.Cell):
def __init__(self, op_class, keep_dims, axis):
super(ReduceNet, self).__init__()
self.axis = axis
self.op = op_class(keep_dims=keep_dims)
def construct(self, x):
return self.op(x, self.axis)
class CumProdNet(nn.Cell):
def __init__(self):
super(CumProdNet, self).__init__()
self.op = P.CumProd()
def construct(self, x, axis):
return self.op(x, axis)
class CumSumNet(nn.Cell):
def __init__(self, axis):
super(CumSumNet, self).__init__()
self.axis = axis
self.op = P.CumSum()
def construct(self, x):
return self.op(x, self.axis)
raise_set = [
# one input is scalar, and another is Tensor(float32)
('TensorAdd0', {
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('TensorAdd1', {
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('TensorAdd2', {
'block': (P.TensorAdd(), {'exception': ValueError, 'error_keywords': ['TensorAdd']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# check input Tensor(bool_)
('AssignAdd', {
'block': (AssignAddNet(), {'exception': TypeError, 'error_keywords': ['AssignAdd']}),
'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)],
'skip': ['backward']}),
# check input Tensor(bool_)
('AssignSub', {
'block': (AssignSubNet(), {'exception': TypeError, 'error_keywords': ['AssignSub']}),
'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)],
'skip': ['backward']}),
# type of axis is float, not int
('ReduceMean1', {
'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5.0),
{'exception': TypeError, 'error_keywords': ['ReduceMean']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# axis is out of range
('ReduceMean2', {
'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5),
{'exception': ValueError, 'error_keywords': ['ReduceMean']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# type of axis is float, not int
('ReduceSum1', {
'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5.0),
{'exception': TypeError, 'error_keywords': ['ReduceSum']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# axis is out of range
('ReduceSum2', {
'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5),
{'exception': ValueError, 'error_keywords': ['ReduceSum']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# type of axis is float, not int
('ReduceAll1', {
'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5.0),
{'exception': TypeError, 'error_keywords': ['ReduceAll']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))],
'skip': ['backward']}),
# axis is out of range
('ReduceAll2', {
'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5),
{'exception': ValueError, 'error_keywords': ['ReduceAll']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))],
'skip': ['backward']}),
# type of axis is float, not int
('ReduceMax1', {
'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5.0),
{'exception': TypeError, 'error_keywords': ['ReduceMax']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# axis is out of range
('ReduceMax2', {
'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5),
{'exception': ValueError, 'error_keywords': ['ReduceMax']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# type of axis is float, not int
('ReduceMin1', {
'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5.0),
{'exception': TypeError, 'error_keywords': ['ReduceMin']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# axis is out of range
('ReduceMin2', {
'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5),
{'exception': ValueError, 'error_keywords': ['ReduceMin']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# type of axis is float, not int
('ReduceProd1', {
'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5.0),
{'exception': TypeError, 'error_keywords': ['ReduceProd']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# axis is out of range
('ReduceProd2', {
'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5),
{'exception': ValueError, 'error_keywords': ['ReduceProd']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
'skip': ['backward']}),
# type of x is Tensor(bool)
('CumProd1', {
'block': (CumProdNet(),
{'exception': TypeError, 'error_keywords': ['CumProd']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool)), 1],
'skip': ['backward']}),
# type of axis in float, not int
('CumProd2', {
'block': (CumProdNet(),
{'exception': TypeError, 'error_keywords': ['CumProd']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32)), 5.0],
'skip': ['backward']}),
# type of x and y are Tensor(uint32)
('MatMul1', {
'block': (P.MatMul(),
{'exception': TypeError, 'error_keywords': ['MatMul']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.uint32))],
'skip': ['backward']}),
# type of x and y not match
('MatMul2', {
'block': (P.MatMul(),
{'exception': TypeError, 'error_keywords': ['MatMul']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.int32))],
'skip': ['backward']}),
# shape of x and y not match
('MatMul3', {
'block': (P.MatMul(),
{'exception': ValueError, 'error_keywords': ['MatMul']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([2, 3]).astype(np.float32))],
'skip': ['backward']}),
# dims of x and y are less than 3
('BatchMatMul1', {
'block': (P.BatchMatMul(),
{'exception': ValueError, 'error_keywords': ['BatchMatMul']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32))],
'skip': ['backward']}),
# type of x is Tensor(bool)
('CumSum1', {
'block': (CumSumNet(axis=1),
{'exception': TypeError, 'error_keywords': ['CumSum']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))],
'skip': ['backward']}),
# type of axis in float, not int
('CumSum2', {
'block': (CumSumNet(axis=1.0),
{'exception': TypeError, 'error_keywords': ['CumSum']}),
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))],
'skip': ['backward']}),
# intput is not tuple or list
('AddN1', {
'block': (P.AddN(),
{'exception': TypeError, 'error_keywords': ['AddN']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32))],
'skip': ['backward']}),
# type not match
('AddN2', {
'block': (P.AddN(),
{'exception': TypeError, 'error_keywords': ['AddN']}),
'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.int32)))],
'skip': ['backward']}),
# shape not match
('AddN3', {
'block': (P.AddN(),
{'exception': ValueError, 'error_keywords': ['AddN']}),
'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32)))],
'skip': ['backward']}),
# input is Tensor(bool)
('Neg1', {
'block': (P.Neg(),
{'exception': TypeError, 'error_keywords': ['Neg']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Sub0', {
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Sub1', {
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('Sub2', {
'block': (P.Sub(), {'exception': ValueError, 'error_keywords': ['Sub']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Mul0', {
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Mul1', {
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('Mul2', {
'block': (P.Mul(), {'exception': ValueError, 'error_keywords': ['Mul']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input is Tensor(bool)
('Square1', {
'block': (P.Square(),
{'exception': TypeError, 'error_keywords': ['Square']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
'skip': ['backward']}),
# input is Tensor(bool)
('Rsqrt1', {
'block': (P.Rsqrt(),
{'exception': TypeError, 'error_keywords': ['Rsqrt']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
'skip': ['backward']}),
# input is Tensor(bool)
('Sqrt1', {
'block': (P.Sqrt(),
{'exception': TypeError, 'error_keywords': ['Sqrt']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
'skip': ['backward']}),
# input is not Tensor
('Reciprocal1', {
'block': (P.Reciprocal(),
{'exception': TypeError, 'error_keywords': ['Reciprocal']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input x is Tensor(bool)
('Pow1', {
'block': (P.Pow(),
{'exception': TypeError, 'error_keywords': ['Pow']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0],
'skip': ['backward']}),
# input is not Tensor
('Exp1', {
'block': (P.Exp(),
{'exception': TypeError, 'error_keywords': ['Exp']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is not Tensor
('Log1', {
'block': (P.Log(),
{'exception': TypeError, 'error_keywords': ['Log']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Minimum0', {
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Minimum1', {
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('Minimum2', {
'block': (P.Minimum(), {'exception': ValueError, 'error_keywords': ['Minimum']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Maximum0', {
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Maximum1', {
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('Maximum2', {
'block': (P.Maximum(), {'exception': ValueError, 'error_keywords': ['Maximum']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('RealDiv0', {
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('RealDiv1', {
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('RealDiv2', {
'block': (P.RealDiv(), {'exception': ValueError, 'error_keywords': ['RealDiv']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Div0', {
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Div1', {
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('Div2', {
'block': (P.Div(), {'exception': ValueError, 'error_keywords': ['Div']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('FloorDiv0', {
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('FloorDiv1', {
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('FloorDiv2', {
'block': (P.FloorDiv(), {'exception': ValueError, 'error_keywords': ['FloorDiv']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input x is Tensor(int32), not Tensor(float)
('Floor1', {
'block': (P.Floor(),
{'exception': TypeError, 'error_keywords': ['Floor']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('FloorMod0', {
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('FloorMod1', {
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('FFloorMod2', {
'block': (P.FloorMod(), {'exception': ValueError, 'error_keywords': ['FloorMod']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input x is Tensor(int32), not Tensor(float)
('Acosh1', {
'block': (P.Acosh(),
{'exception': TypeError, 'error_keywords': ['Acosh']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('Equal0', {
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('Equal1', {
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape of x and y not match
('Equal2', {
'block': (P.Equal(), {'exception': ValueError, 'error_keywords': ['Equal']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('EqualCount0', {
'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('EqualCount1', {
'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape of x and y not match
# input is not tensor
('NotEqual0', {
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('NotEqual1', {
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape of x and y not match
('NotEqual2', {
'block': (P.NotEqual(), {'exception': ValueError, 'error_keywords': ['NotEqual']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('Greater0', {
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('Greater1', {
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape of x and y not match
('Greater2', {
'block': (P.Greater(), {'exception': ValueError, 'error_keywords': ['Greater']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('GreaterEqual0', {
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('GreaterEqual1', {
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape of x and y not match
('GreaterEqual2', {
'block': (P.GreaterEqual(), {'exception': ValueError, 'error_keywords': ['GreaterEqual']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('Less0', {
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('Less1', {
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape of x and y not match
('Less2', {
'block': (P.Less(), {'exception': ValueError, 'error_keywords': ['Less']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('LessEqual0', {
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('LessEqual1', {
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape of x and y not match
('LessEqual2', {
'block': (P.LessEqual(), {'exception': ValueError, 'error_keywords': ['LessEqual']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input x is not Tensor(bool)
('LogicalNot1', {
'block': (P.LogicalNot(),
{'exception': TypeError, 'error_keywords': ['LogicalNot']}),
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))],
'skip': ['backward']}),
# type of x and y not match
('LogicalAnd1', {
'block': (P.LogicalAnd(), {'exception': TypeError, 'error_keywords': ['LogicalAnd']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# shape of x and y not match
('LogicalAnd2', {
'block': (P.LogicalAnd(), {'exception': ValueError, 'error_keywords': ['LogicalAnd']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))],
'skip': ['backward']}),
# type of x and y not match
('LogicalOr1', {
'block': (P.LogicalOr(), {'exception': TypeError, 'error_keywords': ['LogicalOr']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# shape of x and y not match
('LogicalOr2', {
'block': (P.LogicalOr(), {'exception': ValueError, 'error_keywords': ['LogicalOr']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('NPUGetFloatStatus0', {
'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(int32), not Tensor(float32)
('NPUGetFloatStatus1', {
'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
# dims is not 1
('NPUGetFloatStatus2', {
'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape[0] is not 8
('NPUGetFloatStatus3', {
'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}),
'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('NPUClearFloatStatus0', {
'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(int32), not Tensor(float32)
('NPUClearFloatStatus1', {
'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
# dims is not 1
('NPUClearFloatStatus2', {
'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# shape[0] is not 8
('NPUClearFloatStatus3', {
'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}),
'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('Cos0', {
'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(bool)
('Cos1', {
'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('ACos0', {
'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(bool)
('ACos1', {
'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('Sin0', {
'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(bool)
('Sin1', {
'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('NMSWithMask0', {
'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is not Tensor(float16) or Tensor(float32)
('NMSWithMask1', {
'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
# dims is not 2
('NMSWithMask2', {
'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}),
'desc_inputs': [Tensor(np.ones([3, 4, 2]).astype(np.float32))],
'skip': ['backward']}),
# shape[1] is not 5
('NMSWithMask3', {
'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}),
'desc_inputs': [Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('Abs0', {
'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(bool)
('Abs1', {
'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('Sign0', {
'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(bool)
('Sign1', {
'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('Round0', {
'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input is Tensor(bool)
('Round1', {
'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Atan20', {
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Atan21', {
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, their shapes do not match
('Atan22', {
'block': (P.Atan2(), {'exception': ValueError, 'error_keywords': ['Atan2']}),
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
]
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
def test_check_exception():
return raise_set