forked from mindspore-Ecosystem/mindspore
Add primitive name to param error message for math_ops.py
This commit is contained in:
parent
cc0ba93d17
commit
69ed72f10d
|
@ -15,6 +15,7 @@
|
|||
"""Check parameters."""
|
||||
import re
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from itertools import repeat
|
||||
from collections import Iterable
|
||||
|
||||
|
@ -93,8 +94,131 @@ rel_strs = {
|
|||
}
|
||||
|
||||
|
||||
class Validator:
|
||||
"""validator for checking input parameters"""
|
||||
|
||||
@staticmethod
|
||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
|
||||
"""
|
||||
Method for judging relation between two int values or list/tuple made up of ints.
|
||||
|
||||
This method is not suitable for judging relation between floats, since it does not consider float error.
|
||||
"""
|
||||
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
||||
msg_prefix = f'For {prim_name} the' if prim_name else "The"
|
||||
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel, prim_name):
|
||||
"""Integer value judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
|
||||
f' but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
||||
"""Method for checking whether an int value is in some range."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int)
|
||||
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
|
||||
f' but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_subclass(arg_name, type_, template_type, prim_name):
|
||||
"""Check whether some type is sublcass of another type"""
|
||||
if not isinstance(template_type, Iterable):
|
||||
template_type = (template_type,)
|
||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
||||
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
||||
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
|
||||
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
||||
|
||||
@staticmethod
|
||||
def check_tensor_type_same(args, valid_values, prim_name):
|
||||
"""check whether the element types of input tensors are the same."""
|
||||
def _check_tensor_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
|
||||
elem_type = arg_val.element_type()
|
||||
if not elem_type in valid_values:
|
||||
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
|
||||
f' but `{arg_key}` is {elem_type}.')
|
||||
return (arg_key, elem_type)
|
||||
|
||||
def _check_types_same(arg1, arg2):
|
||||
arg1_name, arg1_type = arg1
|
||||
arg2_name, arg2_type = arg2
|
||||
if arg1_type != arg2_type:
|
||||
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
|
||||
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
||||
return arg1
|
||||
|
||||
elem_types = map(_check_tensor_type, args.items())
|
||||
reduce(_check_types_same, elem_types)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
|
||||
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
|
||||
def _check_argument_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
if isinstance(arg_val, type(mstype.tensor)):
|
||||
arg_val = arg_val.element_type()
|
||||
if not arg_val in valid_values:
|
||||
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
|
||||
f' but `{arg_key}` is {arg_val}.')
|
||||
return arg
|
||||
|
||||
def _check_types_same(arg1, arg2):
|
||||
arg1_name, arg1_type = arg1
|
||||
arg2_name, arg2_type = arg2
|
||||
excp_flag = False
|
||||
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
|
||||
arg1_type = arg1_type.element_type()
|
||||
arg2_type = arg2_type.element_type()
|
||||
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
|
||||
pass
|
||||
else:
|
||||
excp_flag = True
|
||||
|
||||
if excp_flag or arg1_type != arg2_type:
|
||||
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
||||
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
||||
return arg1
|
||||
reduce(_check_types_same, map(_check_argument_type, args.items()))
|
||||
|
||||
@staticmethod
|
||||
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
||||
"""Check whether a values is instance of some types."""
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
|
||||
f'{"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
|
||||
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
||||
# `check_value_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return arg_value
|
||||
raise_error_msg()
|
||||
|
||||
|
||||
class ParamValidator:
|
||||
"""Parameter validator."""
|
||||
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
||||
|
||||
@staticmethod
|
||||
def equal(arg_name, arg_value, cond_str, cond):
|
||||
|
|
|
@ -16,13 +16,14 @@
|
|||
"""broadcast"""
|
||||
|
||||
|
||||
def _get_broadcast_shape(x_shape, y_shape):
|
||||
def _get_broadcast_shape(x_shape, y_shape, prim_name):
|
||||
"""
|
||||
Doing broadcast between tensor x and tensor y.
|
||||
|
||||
Args:
|
||||
x_shape (list): The shape of tensor x.
|
||||
y_shape (list): The shape of tensor y.
|
||||
prim_name (str): Primitive name.
|
||||
|
||||
Returns:
|
||||
List, the shape that broadcast between tensor x and tensor y.
|
||||
|
@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape):
|
|||
elif x_shape[i] == y_shape[i]:
|
||||
broadcast_shape_back.append(x_shape[i])
|
||||
else:
|
||||
raise ValueError("The x_shape {} and y_shape {} can not broadcast.".format(x_shape, y_shape))
|
||||
raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
|
||||
prim_name, x_shape, y_shape))
|
||||
|
||||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||
broadcast_shape = broadcast_shape_front + broadcast_shape_back
|
||||
|
|
|
@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator
|
|||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
from ..operations.math_ops import _check_infer_attr_reduce, _infer_shape_reduce
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims):
|
||||
validator.check_type('keep_dims', keep_dims, [bool])
|
||||
validator.check_type('axis', axis, [int, tuple])
|
||||
if isinstance(axis, tuple):
|
||||
for index, value in enumerate(axis):
|
||||
validator.check_type('axis[%d]' % index, value, [int])
|
||||
|
||||
|
||||
class ExpandDims(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -1091,7 +1098,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
@ -1137,7 +1144,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
|
|||
axis = self.axis
|
||||
x_rank = len(x_shape)
|
||||
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims)
|
||||
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
|
||||
return ouput_shape, ouput_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
|
@ -27,16 +27,16 @@ from .._utils import _get_broadcast_shape
|
|||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
def _infer_shape_reduce(x, axis, keep_dims):
|
||||
def _infer_shape_reduce(x, axis, keep_dims, prim_name):
|
||||
"""Common infer for reduce operator"""
|
||||
|
||||
def reduce_one_axis(one_axis):
|
||||
validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT)
|
||||
validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name)
|
||||
if one_axis < 0:
|
||||
one_axis += dim
|
||||
axis_reduce.add(one_axis)
|
||||
|
||||
validator.check_type('axis', axis, [int, tuple, list])
|
||||
validator.check_value_type('axis', axis, [int, tuple, list], prim_name)
|
||||
dim = len(x)
|
||||
axis_reduce = set()
|
||||
|
||||
|
@ -48,7 +48,7 @@ def _infer_shape_reduce(x, axis, keep_dims):
|
|||
return [1] * dim
|
||||
return []
|
||||
for index, one_axis in enumerate(axis):
|
||||
validator.check_type('axis[%d]' % index, one_axis, [int])
|
||||
validator.check_value_type('axis[%d]' % index, one_axis, [int], prim_name)
|
||||
reduce_one_axis(one_axis)
|
||||
|
||||
out_shape = []
|
||||
|
@ -61,14 +61,6 @@ def _infer_shape_reduce(x, axis, keep_dims):
|
|||
return out_shape
|
||||
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims):
|
||||
validator.check_type('keep_dims', keep_dims, [bool])
|
||||
validator.check_type('axis', axis, [int, tuple])
|
||||
if isinstance(axis, tuple):
|
||||
for index, value in enumerate(axis):
|
||||
validator.check_type('axis[%d]' % index, value, [int])
|
||||
|
||||
|
||||
class _BinaryOp(PrimitiveWithInfer):
|
||||
"""
|
||||
Define binary operators.
|
||||
|
@ -82,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return _get_broadcast_shape(x_shape, y_shape)
|
||||
return _get_broadcast_shape(x_shape, y_shape, self.prim_name())
|
||||
|
||||
|
||||
class _MathBinaryOp(_BinaryOp):
|
||||
|
@ -91,15 +83,13 @@ class _MathBinaryOp(_BinaryOp):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type):
|
||||
def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None):
|
||||
args_type = {"x": x_dtype, "y": y_dtype}
|
||||
validator.check_args_tensor(args_type)
|
||||
args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype}
|
||||
validator.check_type_same(args_dtype, valid_dtype)
|
||||
validator.check_tensor_type_same(args_type, valid_dtype, prim_name)
|
||||
return x_dtype
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype)
|
||||
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name())
|
||||
|
||||
|
||||
class TensorAdd(_MathBinaryOp):
|
||||
|
@ -166,7 +156,7 @@ class AssignAdd(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"value": value}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
|
||||
return value
|
||||
|
||||
|
||||
|
@ -207,7 +197,7 @@ class AssignSub(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"value": value}
|
||||
validator.check_type_same(args, mstype.number_type)
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
|
||||
return value
|
||||
|
||||
|
||||
|
@ -228,15 +218,16 @@ class _Reduce(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, keep_dims=False):
|
||||
"""init Reduce"""
|
||||
validator.check_type('keep_dims', keep_dims, [bool])
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name())
|
||||
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
|
||||
|
||||
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
|
||||
axis_v = axis['value']
|
||||
input_shp = input_x['shape']
|
||||
validator.check_subclass('input_x', input_x['dtype'], mstype.tensor)
|
||||
validator.check_typename('input_x', input_x['dtype'], valid_dtype)
|
||||
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims)
|
||||
args = {'input_x': input_x['dtype']}
|
||||
validator.check_tensor_type_same(args, valid_dtype, self.prim_name())
|
||||
|
||||
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name())
|
||||
return {'shape': input_shp,
|
||||
'dtype': input_x['dtype'],
|
||||
'value': None}
|
||||
|
@ -471,16 +462,17 @@ class CumProd(PrimitiveWithInfer):
|
|||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, exclusive=False, reverse=False):
|
||||
self.exclusive = validator.check_type("exclusive", exclusive, [bool])
|
||||
self.reverse = validator.check_type("reverse", reverse, [bool])
|
||||
cls_name = self.prim_name()
|
||||
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
|
||||
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
|
||||
|
||||
def infer_shape(self, x_shape, axis_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, axis_type):
|
||||
validator.check_subclass('x_type', x_type, mstype.tensor)
|
||||
validator.check_typename('x_type', x_type, mstype.number_type)
|
||||
validator.check_subclass("axis_type", axis_type, mstype.int_)
|
||||
cls_name = self.prim_name()
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
|
||||
validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -514,8 +506,9 @@ class MatMul(PrimitiveWithInfer):
|
|||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_type("transpose_a", transpose_a, [bool])
|
||||
validator.check_type("transpose_b", transpose_b, [bool])
|
||||
cls_name = self.prim_name()
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||
|
||||
def check_shape_size(self, x, y):
|
||||
if len(x) != 2 or len(y) != 2:
|
||||
|
@ -524,11 +517,11 @@ class MatMul(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x, y):
|
||||
self.check_shape_size(x, y)
|
||||
cls_name = self.__class__.__name__
|
||||
cls_name = self.prim_name()
|
||||
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
||||
for i in range(len(x) - 2):
|
||||
if x[i] != y[i]:
|
||||
raise ValueError(f'{cls_name} shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')
|
||||
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')
|
||||
|
||||
# validate whether last two dims satifing matrix multiply
|
||||
x_last = x[-2:]
|
||||
|
@ -537,8 +530,8 @@ class MatMul(PrimitiveWithInfer):
|
|||
x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0]
|
||||
y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1]
|
||||
if x_col != y_row:
|
||||
raise ValueError(f'{cls_name} evaluator shapes of inputs can not do this operator, got {x_col} and {y_row}'
|
||||
+ f' for {cls_name}, with x shape {x}(transpose_a={self.transpose_a})'
|
||||
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
||||
+ f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})'
|
||||
+ f', y shape {y}(transpose_b={self.transpose_b}).')
|
||||
# set attribute
|
||||
self.add_prim_attr('transpose_x1', self.transpose_a)
|
||||
|
@ -548,10 +541,8 @@ class MatMul(PrimitiveWithInfer):
|
|||
return ret_dims
|
||||
|
||||
def infer_dtype(self, x, y):
|
||||
validator.check_subclass("x", x, mstype.tensor)
|
||||
validator.check_subclass("y", y, mstype.tensor)
|
||||
args = {"x dtype": x, "y dtype": y}
|
||||
validator.check_type_same(args, mstype.float_type + mstype.int_type)
|
||||
args = {"x": x, "y": y}
|
||||
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -595,12 +586,13 @@ class BatchMatMul(MatMul):
|
|||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_type("transpose_a", transpose_a, [bool])
|
||||
validator.check_type("transpose_b", transpose_b, [bool])
|
||||
cls_name = self.prim_name()
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||
|
||||
def check_shape_size(self, x, y):
|
||||
if len(x) != len(y) or len(x) < 3:
|
||||
raise ValueError('BatchMatMul input x, y should be the same dimension size and should be '
|
||||
raise ValueError('For \'BatchMatMul\' input x, y should be the same dimension size and should be '
|
||||
'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}')
|
||||
|
||||
|
||||
|
@ -632,18 +624,17 @@ class CumSum(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, exclusive=False, reverse=False):
|
||||
"""init cumsum"""
|
||||
self.exclusive = validator.check_type('exclusive', exclusive, [bool])
|
||||
self.add_prim_attr("exclusive", self.exclusive)
|
||||
self.reverse = validator.check_type('reverse', reverse, [bool])
|
||||
self.add_prim_attr("reverse", self.reverse)
|
||||
cls_name = self.prim_name()
|
||||
validator.check_value_type('exclusive', exclusive, [bool], cls_name)
|
||||
validator.check_value_type('reverse', reverse, [bool], cls_name)
|
||||
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
|
||||
|
||||
def __infer__(self, x, axis):
|
||||
cls_name = self.prim_name()
|
||||
x_shp = x['shape']
|
||||
validator.check_type('axis', axis['value'], [int])
|
||||
validator.check_subclass('x', x['dtype'], mstype.tensor)
|
||||
validator.check_typename('x', x['dtype'], [mstype.uint8, mstype.int8,
|
||||
mstype.int32, mstype.float16, mstype.float32])
|
||||
validator.check_value_type('axis', axis['value'], [int], cls_name)
|
||||
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
|
||||
return {'shape': x_shp,
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
@ -684,21 +675,22 @@ class AddN(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||
|
||||
def infer_shape(self, inputs):
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE)
|
||||
cls_name = self.prim_name()
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||
self.add_prim_attr('n', len(inputs))
|
||||
shp0 = inputs[0]
|
||||
for i, shp in enumerate(inputs):
|
||||
validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0)
|
||||
validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0, Rel.EQ, cls_name)
|
||||
return shp0
|
||||
|
||||
def infer_dtype(self, inputs):
|
||||
validator.check_type("inputs", inputs, [tuple, list])
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE)
|
||||
cls_name = self.prim_name()
|
||||
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||
args = {}
|
||||
for i, dtype in enumerate(inputs):
|
||||
validator.check_subclass(f"inputs[{i}]", dtype, mstype.tensor)
|
||||
args[f"inputs[{i}]"] = dtype
|
||||
validator.check_type_same(args, mstype.number_type + (mstype.bool_,))
|
||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name)
|
||||
return inputs[0]
|
||||
|
||||
|
||||
|
@ -722,8 +714,7 @@ class Neg(PrimitiveWithInfer):
|
|||
return input_x
|
||||
|
||||
def infer_dtype(self, input_x):
|
||||
validator.check_subclass("input_x", input_x, mstype.tensor)
|
||||
validator.check_typename("input_x", input_x, mstype.number_type)
|
||||
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name())
|
||||
return input_x
|
||||
|
||||
|
||||
|
@ -806,8 +797,7 @@ class Square(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x", x_type, mstype.tensor)
|
||||
validator.check_typename("x_dtype", x_type, mstype.number_type)
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -836,8 +826,7 @@ class Rsqrt(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x", x_type, mstype.tensor)
|
||||
validator.check_typename("x_dtype", x_type, mstype.number_type)
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -866,8 +855,7 @@ class Sqrt(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x", x_type, mstype.tensor)
|
||||
validator.check_typename("x_dtype", x_type, mstype.number_type)
|
||||
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -897,7 +885,7 @@ class Reciprocal(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x", x, mstype.tensor)
|
||||
validator.check_subclass("x", x, mstype.tensor, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -935,8 +923,7 @@ class Pow(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x, power):
|
||||
validator.check_subclass("x", x, mstype.tensor)
|
||||
validator.check_typename("power", power, mstype.number_type)
|
||||
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -966,7 +953,7 @@ class Exp(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x", x_type, mstype.tensor)
|
||||
validator.check_subclass("x", x_type, mstype.tensor, self.prim_name())
|
||||
return x_type
|
||||
|
||||
|
||||
|
@ -995,7 +982,7 @@ class Log(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x", x, mstype.tensor)
|
||||
validator.check_subclass("x", x, mstype.tensor, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1178,8 +1165,7 @@ class Floor(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x", x_dtype, mstype.tensor)
|
||||
validator.check_typename("x_dtype", x_dtype, mstype.float_type)
|
||||
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name())
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -1234,8 +1220,7 @@ class Acosh(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x_dtype", x, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x, mstype.number_type)
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1245,15 +1230,13 @@ class _LogicBinaryOp(_BinaryOp):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type):
|
||||
args_type = {"x": x_dtype, "y": y_dtype}
|
||||
validator.check_args_tensor(args_type)
|
||||
args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype}
|
||||
validator.check_type_same(args_dtype, valid_type)
|
||||
def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None):
|
||||
args_dtype = {"x": x_dtype, "y": y_dtype}
|
||||
validator.check_tensor_type_same(args_dtype, valid_type, prim_name)
|
||||
return mstype.tensor_type(mstype.bool_)
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype)
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name())
|
||||
|
||||
|
||||
class Equal(_LogicBinaryOp):
|
||||
|
@ -1289,7 +1272,7 @@ class Equal(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,))
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
|
||||
|
||||
|
||||
class EqualCount(PrimitiveWithInfer):
|
||||
|
@ -1318,11 +1301,13 @@ class EqualCount(PrimitiveWithInfer):
|
|||
"""init EqualCount"""
|
||||
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
|
||||
|
||||
def infer_shape(self, x_shape, w_shape):
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
output_shape = (1,)
|
||||
return output_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, w_dtype):
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
args = {'x': x_dtype, 'y': y_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name())
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -1359,7 +1344,7 @@ class NotEqual(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,))
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
|
||||
|
||||
|
||||
class Greater(_LogicBinaryOp):
|
||||
|
@ -1495,8 +1480,7 @@ class LogicalNot(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("x", x_dtype, mstype.tensor)
|
||||
validator.check_typename("x_dtype", x_dtype, [mstype.bool_])
|
||||
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name())
|
||||
return mstype.tensor_type(mstype.bool_)
|
||||
|
||||
|
||||
|
@ -1526,7 +1510,7 @@ class LogicalAnd(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,))
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
|
||||
|
||||
|
||||
class LogicalOr(_LogicBinaryOp):
|
||||
|
@ -1555,7 +1539,7 @@ class LogicalOr(_LogicBinaryOp):
|
|||
"""
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,))
|
||||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
|
||||
|
||||
|
||||
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
||||
|
@ -1616,13 +1600,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
|
|||
self.add_prim_attr("_side_effect_flag", True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ)
|
||||
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ)
|
||||
cls_name = self.prim_name()
|
||||
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
|
||||
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
|
||||
return [8]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
args = {"x_dtype": x_dtype}
|
||||
validator.check_type_same(args, [mstype.float32])
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
|
||||
return mstype.float32
|
||||
|
||||
|
||||
|
@ -1658,13 +1642,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
|
|||
self.add_prim_attr("_side_effect_flag", True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ)
|
||||
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ)
|
||||
cls_name = self.prim_name()
|
||||
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
|
||||
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
|
||||
return [8]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
args = {"x_dtype": x_dtype}
|
||||
validator.check_type_same(args, [mstype.float32])
|
||||
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
|
||||
return mstype.float32
|
||||
|
||||
|
||||
|
@ -1692,8 +1676,7 @@ class Cos(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x_dtype", x, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x, mstype.number_type)
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1721,8 +1704,7 @@ class ACos(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x_dtype", x, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x, mstype.number_type)
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1750,8 +1732,7 @@ class Sin(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
def infer_dtype(self, x):
|
||||
validator.check_subclass("x_dtype", x, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x, mstype.number_type)
|
||||
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1796,19 +1777,19 @@ class NMSWithMask(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, iou_threshold=0.5):
|
||||
"""Init NMSWithMask"""
|
||||
validator.check_type("iou_threshold", iou_threshold, [float])
|
||||
validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name())
|
||||
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
|
||||
|
||||
def infer_shape(self, bboxes_shape):
|
||||
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ)
|
||||
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT)
|
||||
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ)
|
||||
cls_name = self.prim_name()
|
||||
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
|
||||
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
|
||||
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
|
||||
num = bboxes_shape[0]
|
||||
return (bboxes_shape, (num,), (num,))
|
||||
|
||||
def infer_dtype(self, bboxes_dtype):
|
||||
validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor)
|
||||
validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32])
|
||||
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name())
|
||||
return (bboxes_dtype, mstype.int32, mstype.bool_)
|
||||
|
||||
|
||||
|
@ -1837,8 +1818,7 @@ class Abs(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x_dtype", x_type, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x_type, mstype.number_type)
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
|
||||
return x_type
|
||||
|
||||
def infer_value(self, x):
|
||||
|
@ -1880,8 +1860,7 @@ class Sign(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass('x', x_dtype, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x_dtype, mstype.number_type)
|
||||
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name())
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -1910,8 +1889,7 @@ class Round(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("x_dtype", x_type, mstype.tensor)
|
||||
validator.check_typename('x_dtype', x_type, mstype.number_type)
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
|
||||
return x_type
|
||||
|
||||
|
||||
|
|
|
@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive):
|
|||
Primitive.__init__(self, name)
|
||||
self.set_prim_type(prim_type.py_infer_shape)
|
||||
|
||||
def prim_name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def _clone(self):
|
||||
"""
|
||||
Deeply clones the primitive object.
|
||||
|
|
|
@ -23,20 +23,25 @@ from ...utils import keyword
|
|||
|
||||
class CheckExceptionsEC(IExectorComponent):
|
||||
"""
|
||||
Check if the function raises the expected Exception.
|
||||
Check if the function raises the expected Exception and the error message contains specified keywords if not None.
|
||||
|
||||
Examples:
|
||||
{
|
||||
'block': f,
|
||||
'exception': Exception
|
||||
'exception': Exception,
|
||||
'error_keywords': ['TensorAdd', 'shape']
|
||||
}
|
||||
"""
|
||||
def run_function(self, function, inputs, verification_set):
|
||||
f = function[keyword.block]
|
||||
args = inputs[keyword.desc_inputs]
|
||||
e = function.get(keyword.exception, Exception)
|
||||
error_kws = function.get(keyword.error_keywords, None)
|
||||
try:
|
||||
with pytest.raises(e):
|
||||
with pytest.raises(e) as exec_info:
|
||||
f(*args)
|
||||
except:
|
||||
raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}")
|
||||
if error_kws and any(keyword not in str(exec_info.value) for keyword in error_kws):
|
||||
raise ValueError('Error message `{}` does not contain all keywords `{}`'.format(
|
||||
str(exec_info.value), error_kws))
|
||||
|
|
|
@ -87,8 +87,9 @@ def get_function_config(function):
|
|||
init_param_with = function.get(keyword.init_param_with, None)
|
||||
split_outputs = function.get(keyword.split_outputs, True)
|
||||
exception = function.get(keyword.exception, Exception)
|
||||
error_keywords = function.get(keyword.error_keywords, None)
|
||||
return delta, max_error, input_selector, output_selector, sampling_times, \
|
||||
reduce_output, init_param_with, split_outputs, exception
|
||||
reduce_output, init_param_with, split_outputs, exception, error_keywords
|
||||
|
||||
def get_grad_checking_options(function, inputs):
|
||||
"""
|
||||
|
@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs):
|
|||
"""
|
||||
f = function[keyword.block]
|
||||
args = inputs[keyword.desc_inputs]
|
||||
delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _ = \
|
||||
delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _, _ = \
|
||||
get_function_config(function)
|
||||
return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output
|
||||
|
|
|
@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
|
|||
|
||||
block = block_config
|
||||
delta, max_error, input_selector, output_selector, \
|
||||
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config({})
|
||||
sampling_times, reduce_output, init_param_with, split_outputs, exception, error_keywords = get_function_config({})
|
||||
if isinstance(block_config, tuple) and isinstance(block_config[-1], dict):
|
||||
block = block_config[0]
|
||||
delta, max_error, input_selector, output_selector, \
|
||||
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config(block_config[-1])
|
||||
sampling_times, reduce_output, init_param_with, \
|
||||
split_outputs, exception, error_keywords = get_function_config(block_config[-1])
|
||||
|
||||
if block:
|
||||
func_list.append({
|
||||
|
@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
|
|||
keyword.const_first: const_first,
|
||||
keyword.add_fake_input: add_fake_input,
|
||||
keyword.split_outputs: split_outputs,
|
||||
keyword.exception: exception
|
||||
keyword.exception: exception,
|
||||
keyword.error_keywords: error_keywords
|
||||
})
|
||||
|
||||
if desc_inputs or desc_const:
|
||||
|
|
|
@ -73,5 +73,6 @@ keyword.const_first = "const_first"
|
|||
keyword.add_fake_input = "add_fake_input"
|
||||
keyword.fake_input_type = "fake_input_type"
|
||||
keyword.exception = "exception"
|
||||
keyword.error_keywords = "error_keywords"
|
||||
|
||||
sys.modules[__name__] = keyword
|
||||
|
|
|
@ -234,7 +234,7 @@ raise_set = [
|
|||
'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
|
||||
('ReduceSum_Error', {
|
||||
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': ValueError}),
|
||||
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
|
||||
]
|
||||
|
||||
|
|
|
@ -0,0 +1,751 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test ops """
|
||||
import functools
|
||||
import numpy as np
|
||||
from mindspore import ops
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
import mindspore.ops.composite as C
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from ..ut_filter import non_graph_engine
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward\
|
||||
import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
|
||||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||
from ....mindspore_test_framework.pipeline.gradient.compile_gradient\
|
||||
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
||||
|
||||
|
||||
class AssignAddNet(nn.Cell):
|
||||
def __init__(self,):
|
||||
super(AssignAddNet, self).__init__()
|
||||
self.op = P.AssignAdd()
|
||||
self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_add1")
|
||||
|
||||
def construct(self, x):
|
||||
self.op(self.inputdata, x)
|
||||
return self.inputdata
|
||||
|
||||
|
||||
class AssignSubNet(nn.Cell):
|
||||
def __init__(self,):
|
||||
super(AssignSubNet, self).__init__()
|
||||
self.op = P.AssignSub()
|
||||
self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_sub1")
|
||||
|
||||
def construct(self, x):
|
||||
self.op(self.inputdata, x)
|
||||
return self.inputdata
|
||||
|
||||
|
||||
class ReduceNet(nn.Cell):
|
||||
def __init__(self, op_class, keep_dims, axis):
|
||||
super(ReduceNet, self).__init__()
|
||||
self.axis = axis
|
||||
self.op = op_class(keep_dims=keep_dims)
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x, self.axis)
|
||||
|
||||
|
||||
class CumProdNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CumProdNet, self).__init__()
|
||||
self.op = P.CumProd()
|
||||
|
||||
def construct(self, x, axis):
|
||||
return self.op(x, axis)
|
||||
|
||||
|
||||
class CumSumNet(nn.Cell):
|
||||
def __init__(self, axis):
|
||||
super(CumSumNet, self).__init__()
|
||||
self.axis = axis
|
||||
self.op = P.CumSum()
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x, self.axis)
|
||||
|
||||
|
||||
raise_set = [
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('TensorAdd0', {
|
||||
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('TensorAdd1', {
|
||||
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('TensorAdd2', {
|
||||
'block': (P.TensorAdd(), {'exception': ValueError, 'error_keywords': ['TensorAdd']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# check input Tensor(bool_)
|
||||
('AssignAdd', {
|
||||
'block': (AssignAddNet(), {'exception': TypeError, 'error_keywords': ['AssignAdd']}),
|
||||
'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# check input Tensor(bool_)
|
||||
('AssignSub', {
|
||||
'block': (AssignSubNet(), {'exception': TypeError, 'error_keywords': ['AssignSub']}),
|
||||
'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of axis is float, not int
|
||||
('ReduceMean1', {
|
||||
'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5.0),
|
||||
{'exception': TypeError, 'error_keywords': ['ReduceMean']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# axis is out of range
|
||||
('ReduceMean2', {
|
||||
'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5),
|
||||
{'exception': ValueError, 'error_keywords': ['ReduceMean']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of axis is float, not int
|
||||
('ReduceSum1', {
|
||||
'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5.0),
|
||||
{'exception': TypeError, 'error_keywords': ['ReduceSum']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# axis is out of range
|
||||
('ReduceSum2', {
|
||||
'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5),
|
||||
{'exception': ValueError, 'error_keywords': ['ReduceSum']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of axis is float, not int
|
||||
('ReduceAll1', {
|
||||
'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5.0),
|
||||
{'exception': TypeError, 'error_keywords': ['ReduceAll']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
# axis is out of range
|
||||
('ReduceAll2', {
|
||||
'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5),
|
||||
{'exception': ValueError, 'error_keywords': ['ReduceAll']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of axis is float, not int
|
||||
('ReduceMax1', {
|
||||
'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5.0),
|
||||
{'exception': TypeError, 'error_keywords': ['ReduceMax']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# axis is out of range
|
||||
('ReduceMax2', {
|
||||
'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5),
|
||||
{'exception': ValueError, 'error_keywords': ['ReduceMax']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of axis is float, not int
|
||||
('ReduceMin1', {
|
||||
'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5.0),
|
||||
{'exception': TypeError, 'error_keywords': ['ReduceMin']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# axis is out of range
|
||||
('ReduceMin2', {
|
||||
'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5),
|
||||
{'exception': ValueError, 'error_keywords': ['ReduceMin']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of axis is float, not int
|
||||
('ReduceProd1', {
|
||||
'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5.0),
|
||||
{'exception': TypeError, 'error_keywords': ['ReduceProd']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# axis is out of range
|
||||
('ReduceProd2', {
|
||||
'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5),
|
||||
{'exception': ValueError, 'error_keywords': ['ReduceProd']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x is Tensor(bool)
|
||||
('CumProd1', {
|
||||
'block': (CumProdNet(),
|
||||
{'exception': TypeError, 'error_keywords': ['CumProd']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool)), 1],
|
||||
'skip': ['backward']}),
|
||||
# type of axis in float, not int
|
||||
('CumProd2', {
|
||||
'block': (CumProdNet(),
|
||||
{'exception': TypeError, 'error_keywords': ['CumProd']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32)), 5.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y are Tensor(uint32)
|
||||
('MatMul1', {
|
||||
'block': (P.MatMul(),
|
||||
{'exception': TypeError, 'error_keywords': ['MatMul']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.uint32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('MatMul2', {
|
||||
'block': (P.MatMul(),
|
||||
{'exception': TypeError, 'error_keywords': ['MatMul']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('MatMul3', {
|
||||
'block': (P.MatMul(),
|
||||
{'exception': ValueError, 'error_keywords': ['MatMul']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([2, 3]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# dims of x and y are less than 3
|
||||
('BatchMatMul1', {
|
||||
'block': (P.BatchMatMul(),
|
||||
{'exception': ValueError, 'error_keywords': ['BatchMatMul']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x is Tensor(bool)
|
||||
('CumSum1', {
|
||||
'block': (CumSumNet(axis=1),
|
||||
{'exception': TypeError, 'error_keywords': ['CumSum']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))],
|
||||
'skip': ['backward']}),
|
||||
# type of axis in float, not int
|
||||
('CumSum2', {
|
||||
'block': (CumSumNet(axis=1.0),
|
||||
{'exception': TypeError, 'error_keywords': ['CumSum']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# intput is not tuple or list
|
||||
('AddN1', {
|
||||
'block': (P.AddN(),
|
||||
{'exception': TypeError, 'error_keywords': ['AddN']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32))],
|
||||
'skip': ['backward']}),
|
||||
# type not match
|
||||
('AddN2', {
|
||||
'block': (P.AddN(),
|
||||
{'exception': TypeError, 'error_keywords': ['AddN']}),
|
||||
'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.int32)))],
|
||||
'skip': ['backward']}),
|
||||
# shape not match
|
||||
('AddN3', {
|
||||
'block': (P.AddN(),
|
||||
{'exception': ValueError, 'error_keywords': ['AddN']}),
|
||||
'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32)))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is Tensor(bool)
|
||||
('Neg1', {
|
||||
'block': (P.Neg(),
|
||||
{'exception': TypeError, 'error_keywords': ['Neg']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Sub0', {
|
||||
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Sub1', {
|
||||
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Sub2', {
|
||||
'block': (P.Sub(), {'exception': ValueError, 'error_keywords': ['Sub']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Mul0', {
|
||||
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Mul1', {
|
||||
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Mul2', {
|
||||
'block': (P.Mul(), {'exception': ValueError, 'error_keywords': ['Mul']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is Tensor(bool)
|
||||
('Square1', {
|
||||
'block': (P.Square(),
|
||||
{'exception': TypeError, 'error_keywords': ['Square']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is Tensor(bool)
|
||||
('Rsqrt1', {
|
||||
'block': (P.Rsqrt(),
|
||||
{'exception': TypeError, 'error_keywords': ['Rsqrt']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is Tensor(bool)
|
||||
('Sqrt1', {
|
||||
'block': (P.Sqrt(),
|
||||
{'exception': TypeError, 'error_keywords': ['Sqrt']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not Tensor
|
||||
('Reciprocal1', {
|
||||
'block': (P.Reciprocal(),
|
||||
{'exception': TypeError, 'error_keywords': ['Reciprocal']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input x is Tensor(bool)
|
||||
('Pow1', {
|
||||
'block': (P.Pow(),
|
||||
{'exception': TypeError, 'error_keywords': ['Pow']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not Tensor
|
||||
('Exp1', {
|
||||
'block': (P.Exp(),
|
||||
{'exception': TypeError, 'error_keywords': ['Exp']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not Tensor
|
||||
('Log1', {
|
||||
'block': (P.Log(),
|
||||
{'exception': TypeError, 'error_keywords': ['Log']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Minimum0', {
|
||||
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Minimum1', {
|
||||
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Minimum2', {
|
||||
'block': (P.Minimum(), {'exception': ValueError, 'error_keywords': ['Minimum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Maximum0', {
|
||||
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Maximum1', {
|
||||
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Maximum2', {
|
||||
'block': (P.Maximum(), {'exception': ValueError, 'error_keywords': ['Maximum']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('RealDiv0', {
|
||||
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('RealDiv1', {
|
||||
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('RealDiv2', {
|
||||
'block': (P.RealDiv(), {'exception': ValueError, 'error_keywords': ['RealDiv']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Div0', {
|
||||
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Div1', {
|
||||
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Div2', {
|
||||
'block': (P.Div(), {'exception': ValueError, 'error_keywords': ['Div']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('FloorDiv0', {
|
||||
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('FloorDiv1', {
|
||||
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('FloorDiv2', {
|
||||
'block': (P.FloorDiv(), {'exception': ValueError, 'error_keywords': ['FloorDiv']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input x is Tensor(int32), not Tensor(float)
|
||||
('Floor1', {
|
||||
'block': (P.Floor(),
|
||||
{'exception': TypeError, 'error_keywords': ['Floor']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('FloorMod0', {
|
||||
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('FloorMod1', {
|
||||
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('FFloorMod2', {
|
||||
'block': (P.FloorMod(), {'exception': ValueError, 'error_keywords': ['FloorMod']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input x is Tensor(int32), not Tensor(float)
|
||||
('Acosh1', {
|
||||
'block': (P.Acosh(),
|
||||
{'exception': TypeError, 'error_keywords': ['Acosh']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Equal0', {
|
||||
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Equal1', {
|
||||
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('Equal2', {
|
||||
'block': (P.Equal(), {'exception': ValueError, 'error_keywords': ['Equal']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('EqualCount0', {
|
||||
'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('EqualCount1', {
|
||||
'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
|
||||
# input is not tensor
|
||||
('NotEqual0', {
|
||||
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('NotEqual1', {
|
||||
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('NotEqual2', {
|
||||
'block': (P.NotEqual(), {'exception': ValueError, 'error_keywords': ['NotEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Greater0', {
|
||||
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Greater1', {
|
||||
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('Greater2', {
|
||||
'block': (P.Greater(), {'exception': ValueError, 'error_keywords': ['Greater']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('GreaterEqual0', {
|
||||
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('GreaterEqual1', {
|
||||
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('GreaterEqual2', {
|
||||
'block': (P.GreaterEqual(), {'exception': ValueError, 'error_keywords': ['GreaterEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Less0', {
|
||||
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('Less1', {
|
||||
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('Less2', {
|
||||
'block': (P.Less(), {'exception': ValueError, 'error_keywords': ['Less']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('LessEqual0', {
|
||||
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# type of x and y not match
|
||||
('LessEqual1', {
|
||||
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('LessEqual2', {
|
||||
'block': (P.LessEqual(), {'exception': ValueError, 'error_keywords': ['LessEqual']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input x is not Tensor(bool)
|
||||
('LogicalNot1', {
|
||||
'block': (P.LogicalNot(),
|
||||
{'exception': TypeError, 'error_keywords': ['LogicalNot']}),
|
||||
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y not match
|
||||
('LogicalAnd1', {
|
||||
'block': (P.LogicalAnd(), {'exception': TypeError, 'error_keywords': ['LogicalAnd']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('LogicalAnd2', {
|
||||
'block': (P.LogicalAnd(), {'exception': ValueError, 'error_keywords': ['LogicalAnd']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# type of x and y not match
|
||||
('LogicalOr1', {
|
||||
'block': (P.LogicalOr(), {'exception': TypeError, 'error_keywords': ['LogicalOr']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
# shape of x and y not match
|
||||
('LogicalOr2', {
|
||||
'block': (P.LogicalOr(), {'exception': ValueError, 'error_keywords': ['LogicalOr']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('NPUGetFloatStatus0', {
|
||||
'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(int32), not Tensor(float32)
|
||||
('NPUGetFloatStatus1', {
|
||||
'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
# dims is not 1
|
||||
('NPUGetFloatStatus2', {
|
||||
'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape[0] is not 8
|
||||
('NPUGetFloatStatus3', {
|
||||
'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}),
|
||||
'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('NPUClearFloatStatus0', {
|
||||
'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(int32), not Tensor(float32)
|
||||
('NPUClearFloatStatus1', {
|
||||
'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
# dims is not 1
|
||||
('NPUClearFloatStatus2', {
|
||||
'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape[0] is not 8
|
||||
('NPUClearFloatStatus3', {
|
||||
'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}),
|
||||
'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Cos0', {
|
||||
'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(bool)
|
||||
('Cos1', {
|
||||
'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('ACos0', {
|
||||
'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(bool)
|
||||
('ACos1', {
|
||||
'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Sin0', {
|
||||
'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(bool)
|
||||
('Sin1', {
|
||||
'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('NMSWithMask0', {
|
||||
'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is not Tensor(float16) or Tensor(float32)
|
||||
('NMSWithMask1', {
|
||||
'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
# dims is not 2
|
||||
('NMSWithMask2', {
|
||||
'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# shape[1] is not 5
|
||||
('NMSWithMask3', {
|
||||
'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Abs0', {
|
||||
'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(bool)
|
||||
('Abs1', {
|
||||
'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Sign0', {
|
||||
'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(bool)
|
||||
('Sign1', {
|
||||
'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# input is not tensor
|
||||
('Round0', {
|
||||
'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}),
|
||||
'desc_inputs': [5.0],
|
||||
'skip': ['backward']}),
|
||||
# input is Tensor(bool)
|
||||
('Round1', {
|
||||
'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
|
||||
'skip': ['backward']}),
|
||||
|
||||
# one input is scalar, and another is Tensor(float32)
|
||||
('Atan20', {
|
||||
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
|
||||
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, but element types are not same
|
||||
('Atan21', {
|
||||
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
# input two tensors, their shapes do not match
|
||||
('Atan22', {
|
||||
'block': (P.Atan2(), {'exception': ValueError, 'error_keywords': ['Atan2']}),
|
||||
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
|
||||
def test_check_exception():
|
||||
return raise_set
|
Loading…
Reference in New Issue