!47543 modify constexpr

Merge pull request !47543 from huoxinyou/0105_constexpr
This commit is contained in:
i-robot 2023-01-10 02:32:33 +00:00 committed by Gitee
commit 386c33298d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
28 changed files with 1161 additions and 824 deletions

View File

@ -214,6 +214,100 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
return arg_value return arg_value
def check_reshape_shp(shp):
"""Check the shape argument for tensor.reshape"""
if len(shp) == 1:
new_shape = shp[0]
if isinstance(new_shape, int):
return shp
if isinstance(new_shape, list):
new_shape = tuple(new_shape)
return new_shape
return shp
def check_swapaxes_axis(axes, ndim):
"""Check all the axes argument for tensor.swapaxes"""
if isinstance(axes, int):
return axes % ndim
if isinstance(axes, (tuple, list)):
tmp = []
for x in axes:
tmp.append((x + ndim) % ndim)
axes = tuple(tmp)
return axes
return axes
def prepare_shape_for_squeeze(shape, axes):
"""
yield squeezed shape based on the axes
"""
new_shape = []
ndim = len(shape)
if isinstance(axes, int):
axes = [axes]
elif isinstance(axes, (list, tuple)):
axes = set(axes)
for idx, s in enumerate(shape):
if s != 1 or (idx not in axes) and (idx - ndim not in axes):
new_shape.append(s)
return tuple(new_shape)
def check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
return (axis + ndim) % ndim
def check_axis_valid(axes, ndim):
"""
check the validation of axis and return
"""
if axes is None:
axes = tuple(range(ndim))
return axes
if isinstance(axes, (tuple, list)):
tmp = []
for x in axes:
tmp.append((x + ndim) % ndim)
axes = tuple(tmp)
return axes
return (axes % ndim,)
def infer_out_shape(*shapes):
"""
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
"""
shape_out = list()
max_len = ms_max([len(it) for it in shapes])
for i in range(max_len):
items = [it[i-(max_len-len(it))] if i - (max_len - len(it))
>= 0 else 1 for it in shapes]
max_size = 0 if 0 in items else ms_max(items)
shape_out.append(max_size)
return tuple(shape_out)
def check_and_canonicalize_axes(axes, ndim):
"""Check whether the types and values of input axes are valid."""
axes = axes if isinstance(axes, tuple) else (axes,)
new_axes = ()
for ax in axes:
ax = ax if ax >= 0 else ax + ndim
new_axes += (ax,)
return new_axes
def get_log2_size(size):
"""Get log2 size"""
log2_res = F.log2(F.cast(Tensor(size), mstype.float32))
ceil_res = F.ceil(log2_res)
cast_res = F.cast(ceil_res, mstype.int64)
return cast_res
class Validator: class Validator:
"""validator for checking input parameters""" """validator for checking input parameters"""

View File

@ -26,7 +26,8 @@ from mindspore.ops.composite.base import _append, _insert, _pop, _list_clear, _r
_extend, _dict_clear, _haskey, _update, _fromkeys _extend, _dict_clear, _haskey, _update, _fromkeys
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import check_is_number from ..._checkparam import check_is_number, check_reshape_shp, check_swapaxes_axis, prepare_shape_for_squeeze, \
check_axis_in_range, check_axis_valid, infer_out_shape, check_and_canonicalize_axes, get_log2_size
from ...ops import functional as F from ...ops import functional as F
from ...ops import operations as P from ...ops import operations as P
from ...ops.composite import tail, MultitypeFuncGraph, env_get, hyper_add, \ from ...ops.composite import tail, MultitypeFuncGraph, env_get, hyper_add, \
@ -41,7 +42,8 @@ from ...ops.primitive import constexpr
from ...common import dtype as mstype from ...common import dtype as mstype
from ...ops.operations._sequence_ops import ListAppend from ...ops.operations._sequence_ops import ListAppend
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] __all__ = ['MultitypeFuncGraph', 'env_get',
'hyper_add', 'zeros_like', 'ones_like']
shape_ = P.Shape() shape_ = P.Shape()
dtype_ = P.DType() dtype_ = P.DType()
@ -519,7 +521,7 @@ def reshape(x, *shape):
[ 3.6 0.4] [ 3.6 0.4]
[ 0.5 -3.2]] [ 0.5 -3.2]]
""" """
new_shape = check_reshape_shp_const(shape) new_shape = check_reshape_shp(shape)
return F.reshape(x, new_shape) return F.reshape(x, new_shape)
@ -709,7 +711,7 @@ def swapaxes(x, axis1, axis2):
>>> print(output.shape) >>> print(output.shape)
(4,3,2) (4,3,2)
""" """
axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim) axis1, axis2 = check_swapaxes_axis((axis1, axis2), x.ndim)
if axis1 == axis2: if axis1 == axis2:
return x return x
@ -720,10 +722,10 @@ def swapaxes(x, axis1, axis2):
new_perm = None new_perm = None
if axis2 + 1 < x.ndim: if axis2 + 1 < x.ndim:
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:] perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
else: else:
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
return F.transpose(x, new_perm) return F.transpose(x, new_perm)
@ -757,7 +759,7 @@ def squeeze(x, axis=None):
if axis is None: if axis is None:
return F.squeeze(x) return F.squeeze(x)
# yield squeezed shape based on the axes # yield squeezed shape based on the axes
new_shape = prepare_shape_for_squeeze_const(shape, axis) new_shape = prepare_shape_for_squeeze(shape, axis)
return F.reshape(x, new_shape) return F.reshape(x, new_shape)
@ -869,7 +871,7 @@ def argmin(x, axis=None, keepdims=False):
axis = 0 axis = 0
is_axis_none = True is_axis_none = True
else: else:
axis = check_axis_in_range_const(axis, F.rank(x)) axis = check_axis_in_range(axis, F.rank(x))
out = P.Argmin(axis)(x) out = P.Argmin(axis)(x)
if keepdims and not is_axis_none: if keepdims and not is_axis_none:
out = expand_dims(out, axis) out = expand_dims(out, axis)
@ -894,7 +896,7 @@ def median(x, global_median, axis=0, keep_dims=False):
When attr `global_median` is True, the second output Tensor value is meaningless. When attr `global_median` is True, the second output Tensor value is meaningless.
""" """
check_axis_in_range_const(axis, x.ndim) check_axis_in_range(axis, x.ndim)
median_ = Median(global_median, axis, keep_dims) median_ = Median(global_median, axis, keep_dims)
return median_(x) return median_(x)
@ -968,7 +970,7 @@ def cumsum(x, axis=None, dtype=None):
if axis is None: if axis is None:
x = x.ravel() x = x.ravel()
axis = 0 axis = 0
check_axis_in_range_const(axis, x.ndim) check_axis_in_range(axis, x.ndim)
if dtype is not None: if dtype is not None:
dtype = check_astype_dtype_const(dtype) dtype = check_astype_dtype_const(dtype)
if original_dtype != dtype: if original_dtype != dtype:
@ -1350,7 +1352,8 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
""" """
ndim = x.ndim ndim = x.ndim
if ndim < 2: if ndim < 2:
const_utils.raise_value_error('diagonal requires an array of at least two dimensions') const_utils.raise_value_error(
'diagonal requires an array of at least two dimensions')
dtype = x.dtype dtype = x.dtype
axes = check_axis_valid((axis1, axis2), ndim) axes = check_axis_valid((axis1, axis2), ndim)
@ -1602,14 +1605,15 @@ def take(x, indices, axis=None, mode='clip'):
[4 3 6] [4 3 6]
""" """
if mode not in ('raise', 'wrap', 'clip'): if mode not in ('raise', 'wrap', 'clip'):
const_utils.raise_value_error('raise should be one of "raise", "wrap", or "clip"') const_utils.raise_value_error(
'raise should be one of "raise", "wrap", or "clip"')
if axis is None: if axis is None:
a = x.ravel() a = x.ravel()
axis = 0 axis = 0
else: else:
a = x a = x
ndim = a.ndim ndim = a.ndim
axis = check_axis_in_range_const(axis, ndim) axis = check_axis_in_range(axis, ndim)
shape_a = a.shape shape_a = a.shape
shape_indices = indices.shape shape_indices = indices.shape
@ -1690,12 +1694,14 @@ def choose(x, choices, mode='clip'):
# adjusts dtype for F.tensor_mul and F.gather_nd # adjusts dtype for F.tensor_mul and F.gather_nd
a = a.astype(mstype.int32) a = a.astype(mstype.int32)
choices = choices.astype(mstype.int32) choices = choices.astype(mstype.int32)
a = compile_utils.check_indices(choices.shape[0], a, mode, allow_negative_index=False) a = compile_utils.check_indices(
choices.shape[0], a, mode, allow_negative_index=False)
grids = [] grids = []
ndim = len(a.shape) ndim = len(a.shape)
for i in range(ndim): for i in range(ndim):
dim_grid = const_utils.make_tensor(F.make_range(a.shape[i]), mstype.int32) dim_grid = const_utils.make_tensor(
F.make_range(a.shape[i]), mstype.int32)
dim_shape = expanded_shape(ndim, a.shape[i], i) dim_shape = expanded_shape(ndim, a.shape[i], i)
dim_grid = P.BroadcastTo(a.shape)(dim_grid.reshape(dim_shape)) dim_grid = P.BroadcastTo(a.shape)(dim_grid.reshape(dim_shape))
grids.append(dim_grid) grids.append(dim_grid)
@ -1740,7 +1746,8 @@ def searchsorted(x, v, side='left', sorter=None):
shape = v.shape shape = v.shape
if sorter is not None: if sorter is not None:
if sorter.ndim != 1 or sorter.size != a.size: if sorter.ndim != 1 or sorter.size != a.size:
const_utils.raise_value_error('sorter must be 1-D array with the same size as `a`') const_utils.raise_value_error(
'sorter must be 1-D array with the same size as `a`')
sorter = const_utils.make_tensor(sorter) sorter = const_utils.make_tensor(sorter)
sorter = sorter.reshape(sorter.shape + (1,)) sorter = sorter.reshape(sorter.shape + (1,))
a = F.gather_nd(a, sorter) a = F.gather_nd(a, sorter)
@ -1748,12 +1755,14 @@ def searchsorted(x, v, side='left', sorter=None):
i = F.fill(mstype.int32, shape, 0) i = F.fill(mstype.int32, shape, 0)
j = F.fill(mstype.int32, shape, a.size) j = F.fill(mstype.int32, shape, a.size)
sort_range = F.make_range(get_log2_size(F.shape_mul(a.shape) + 1)) loop_num = get_log2_size(F.shape_mul(a.shape) + 1)
for _ in sort_range: index = Tensor([0])
while index < loop_num:
mid = (i - F.neg_tensor(j)) // 2 mid = (i - F.neg_tensor(j)) // 2
mask = less_op(v, F.gather_nd(a, mid.reshape(mid.shape + (1,)))) mask = less_op(v, F.gather_nd(a, mid.reshape(mid.shape + (1,))))
i = F.select(mask, i, mid) i = F.select(mask, i, mid)
j = F.select(mask, mid, j) j = F.select(mask, mid, j)
index += 1
return j return j
@ -1788,7 +1797,8 @@ def fill(x, value):
""" """
if value is None: if value is None:
if x.dtype not in (mstype.float16, mstype.float32, mstype.float64): if x.dtype not in (mstype.float16, mstype.float32, mstype.float64):
const_utils.raise_type_error("If None is used as value, the original Tensor's dtype must be float.") const_utils.raise_type_error(
"If None is used as value, the original Tensor's dtype must be float.")
value = nan_tensor value = nan_tensor
return F.tile(value, x.shape).astype(x.dtype) return F.tile(value, x.shape).astype(x.dtype)
if not isinstance(value, (int, float, bool)): if not isinstance(value, (int, float, bool)):
@ -1838,7 +1848,6 @@ def ptp(x, axis=None, keepdims=False):
if axis is None: if axis is None:
axis = () axis = ()
else: else:
check_axis_type(axis, True, True, False)
axis = check_axis_valid(axis, x.ndim) axis = check_axis_valid(axis, x.ndim)
return x.max(axis, keepdims) - x.min(axis, keepdims) return x.max(axis, keepdims) - x.min(axis, keepdims)
@ -2102,7 +2111,7 @@ def repeat(x, repeats, axis=None):
axis = 0 axis = 0
if not isinstance(axis, int): if not isinstance(axis, int):
const_utils.raise_type_error('axes should be integers') const_utils.raise_type_error('axes should be integers')
check_axis_in_range_const(axis, x.ndim) check_axis_in_range(axis, x.ndim)
axis = axis + x.ndim if axis < 0 else axis axis = axis + x.ndim if axis < 0 else axis
if len(repeats) == 1: if len(repeats) == 1:
@ -2112,7 +2121,8 @@ def repeat(x, repeats, axis=None):
return repeat_elements(x, repeats, axis) return repeat_elements(x, repeats, axis)
size = x.shape[axis] size = x.shape[axis]
if len(repeats) != size: if len(repeats) != size:
const_utils.raise_value_error('operands could not be broadcast together') const_utils.raise_value_error(
'operands could not be broadcast together')
subs = P.Split(axis, size)(x) subs = P.Split(axis, size)(x)
repeated_subs = [] repeated_subs = []
for sub_item, rep in zip(subs, repeats): for sub_item, rep in zip(subs, repeats):
@ -2224,8 +2234,6 @@ def hasnext(it):
@constexpr @constexpr
def constant_abs(x): def constant_abs(x):
"""Returns the absolute value of the constant.""" """Returns the absolute value of the constant."""
if x is None:
raise ValueError("For abs(), the input should be a constant or Tensor type.")
return abs(x) return abs(x)
@ -2241,7 +2249,8 @@ def constant_round(*data):
"""Returns the rounded value of the constant.""" """Returns the rounded value of the constant."""
for x in data: for x in data:
if x is None: if x is None:
raise ValueError("For round(), the input should be a Tensor or 1-2 constants.") raise ValueError(
"For round(), the input should be a Tensor or 1-2 constants.")
return round(*data) return round(*data)
@ -2256,7 +2265,8 @@ def ms_round(*data):
return round_(x) return round_(x)
return constant_round(x) return constant_round(x)
if isinstance(data[0], Tensor) or isinstance(data[1], Tensor): if isinstance(data[0], Tensor) or isinstance(data[1], Tensor):
const_utils.raise_type_error("When applying round() to tensor, only one tensor is supported as input.") const_utils.raise_type_error(
"When applying round() to tensor, only one tensor is supported as input.")
return constant_round(*data) return constant_round(*data)
@ -2274,9 +2284,11 @@ def str_func(*data):
return '' return ''
data = data[0] data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)): if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
const_utils.raise_type_error("str() does not support sparse tensor input.") const_utils.raise_type_error(
"str() does not support sparse tensor input.")
if not F.isconstant(data): if not F.isconstant(data):
const_utils.raise_type_error("str() does not support non-constant input.") const_utils.raise_type_error(
"str() does not support non-constant input.")
return cast_to_str(data) return cast_to_str(data)
@ -2294,13 +2306,15 @@ def bool_func(*data):
return False return False
data = data[0] data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)): if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
const_utils.raise_type_error("bool() does not support sparse tensor input.") const_utils.raise_type_error(
"bool() does not support sparse tensor input.")
if isinstance(data, (Tensor, Tensor_)): if isinstance(data, (Tensor, Tensor_)):
tensor_shape = F.shape(data) tensor_shape = F.shape(data)
tensor_shape_len = len(tensor_shape) tensor_shape_len = len(tensor_shape)
if tensor_shape_len == 0 or (tensor_shape_len == 1 and tensor_shape[0] == 1): if tensor_shape_len == 0 or (tensor_shape_len == 1 and tensor_shape[0] == 1):
return data != 0 return data != 0
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") const_utils.raise_value_error(
"The truth value of an array with more than one element is ambiguous.")
if not F.isconstant(data): if not F.isconstant(data):
if hasattr(data, "__bool__"): if hasattr(data, "__bool__"):
return data.__bool__() return data.__bool__()
@ -2329,9 +2343,11 @@ def int_func(*data):
return 0 return 0
target = data[0] target = data[0]
if not F.isconstant(target): if not F.isconstant(target):
const_utils.raise_type_error("int() does not support non-constant input.") const_utils.raise_type_error(
"int() does not support non-constant input.")
if isinstance(target, (CSRTensor, COOTensor, RowTensorInner)): if isinstance(target, (CSRTensor, COOTensor, RowTensorInner)):
const_utils.raise_type_error("int() does not support sparse tensor input.") const_utils.raise_type_error(
"int() does not support sparse tensor input.")
return cast_to_int(*data) return cast_to_int(*data)
@ -2351,9 +2367,11 @@ def float_func(*data):
return 0.0 return 0.0
data = data[0] data = data[0]
if not F.isconstant(data): if not F.isconstant(data):
const_utils.raise_type_error("float() does not support non-constant input.") const_utils.raise_type_error(
"float() does not support non-constant input.")
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)): if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
const_utils.raise_type_error("float() does not support sparse tensor input.") const_utils.raise_type_error(
"float() does not support sparse tensor input.")
return cast_to_float(data) return cast_to_float(data)
@ -2366,10 +2384,12 @@ def list_func(*data):
return F.make_list() return F.make_list()
data = data[0] data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)): if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
const_utils.raise_type_error("list() does not support single sparse tensor input.") const_utils.raise_type_error(
"list() does not support single sparse tensor input.")
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"): if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
data_type = F.typeof(data) data_type = F.typeof(data)
const_utils.raise_type_error(str(data_type) + " object is not iterable.") const_utils.raise_type_error(
str(data_type) + " object is not iterable.")
if isinstance(data, dict): if isinstance(data, dict):
data = data.keys() data = data.keys()
ret = F.make_list() ret = F.make_list()
@ -2387,10 +2407,12 @@ def tuple_func(*data):
return F.make_tuple() return F.make_tuple()
data = data[0] data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)): if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
const_utils.raise_type_error("tuple() does not support single sparse tensor input.") const_utils.raise_type_error(
"tuple() does not support single sparse tensor input.")
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"): if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
data_type = F.typeof(data) data_type = F.typeof(data)
const_utils.raise_type_error(str(data_type) + " object is not iterable.") const_utils.raise_type_error(
str(data_type) + " object is not iterable.")
if isinstance(data, dict): if isinstance(data, dict):
data = data.keys() data = data.keys()
ret = F.make_tuple() ret = F.make_tuple()
@ -2419,7 +2441,8 @@ def get_max_min_data_len(*data):
if isinstance(data, (dict, list, tuple)): if isinstance(data, (dict, list, tuple)):
len_data = len(data) len_data = len(data)
else: else:
const_utils.raise_type_error("max() or min() does not support the data type.") const_utils.raise_type_error(
"max() or min() does not support the data type.")
return len_data return len_data
@ -2431,7 +2454,8 @@ def get_tensor_num(data):
tensor_shape = F.shape(input_data) tensor_shape = F.shape(input_data)
tensor_shape_len = len(tensor_shape) tensor_shape_len = len(tensor_shape)
if tensor_shape_len != 0 and not (tensor_shape_len == 1 and tensor_shape[0] == 1): if tensor_shape_len != 0 and not (tensor_shape_len == 1 and tensor_shape[0] == 1):
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") const_utils.raise_value_error(
"The truth value of an array with more than one element is ambiguous.")
tensor_num = tensor_num + 1 tensor_num = tensor_num + 1
return tensor_num return tensor_num
@ -2453,9 +2477,11 @@ def ms_max_one_element(x):
tensor_shape = F.shape(x) tensor_shape = F.shape(x)
tensor_shape_len = len(tensor_shape) tensor_shape_len = len(tensor_shape)
if tensor_shape_len == 0: if tensor_shape_len == 0:
const_utils.raise_type_error("Cannot iterate over a scalar tensor.") const_utils.raise_type_error(
"Cannot iterate over a scalar tensor.")
if tensor_shape_len >= 2: if tensor_shape_len >= 2:
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") const_utils.raise_value_error(
"The truth value of an array with more than one element is ambiguous.")
return x.max() return x.max()
# Deal with Tensor in tuple or list # Deal with Tensor in tuple or list
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
@ -2465,9 +2491,11 @@ def ms_max_one_element(x):
if tensor_num == len(x): if tensor_num == len(x):
return max_tensor(x) return max_tensor(x)
if tensor_num != 0: if tensor_num != 0:
const_utils.raise_type_error("max() cannot contain both tensor and non-tensor type.") const_utils.raise_type_error(
"max() cannot contain both tensor and non-tensor type.")
if exist_tensor(x): if exist_tensor(x):
const_utils.raise_type_error("max() cannot support tensor in list or tuple nested now.") const_utils.raise_type_error(
"max() cannot support tensor in list or tuple nested now.")
return max_(x) return max_(x)
@ -2485,10 +2513,12 @@ def ms_max(*data):
if tensor_num == len_data: if tensor_num == len_data:
return max_tensor(*data) return max_tensor(*data)
if tensor_num != 0: if tensor_num != 0:
const_utils.raise_type_error("max() cannot contain both tensor and non-tensor type.") const_utils.raise_type_error(
"max() cannot contain both tensor and non-tensor type.")
# exist tensor in list/tuple # exist tensor in list/tuple
if exist_tensor(data): if exist_tensor(data):
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") const_utils.raise_value_error(
"The truth value of an array with more than one element is ambiguous.")
return max_(*data) return max_(*data)
@ -2525,9 +2555,11 @@ def ms_min_one_element(x):
tensor_shape = F.shape(x) tensor_shape = F.shape(x)
tensor_shape_len = len(tensor_shape) tensor_shape_len = len(tensor_shape)
if tensor_shape_len == 0: if tensor_shape_len == 0:
const_utils.raise_type_error("Cannot iterate over a scalar tensor.") const_utils.raise_type_error(
"Cannot iterate over a scalar tensor.")
if tensor_shape_len >= 2: if tensor_shape_len >= 2:
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") const_utils.raise_value_error(
"The truth value of an array with more than one element is ambiguous.")
return x.min() return x.min()
# Deal with Tensor in tuple or list # Deal with Tensor in tuple or list
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
@ -2537,9 +2569,11 @@ def ms_min_one_element(x):
if tensor_num == len(x): if tensor_num == len(x):
return min_tensor(x) return min_tensor(x)
if tensor_num != 0: if tensor_num != 0:
const_utils.raise_type_error("min() cannot contain both tensor and non-tensor type.") const_utils.raise_type_error(
"min() cannot contain both tensor and non-tensor type.")
if exist_tensor(x): if exist_tensor(x):
const_utils.raise_type_error("min() cannot support tensor in list or tuple nested now.") const_utils.raise_type_error(
"min() cannot support tensor in list or tuple nested now.")
return min_(x) return min_(x)
@ -2557,10 +2591,12 @@ def ms_min(*data):
if tensor_num == len_data: if tensor_num == len_data:
return min_tensor(*data) return min_tensor(*data)
if tensor_num != 0: if tensor_num != 0:
const_utils.raise_type_error("min() cannot contain both tensor and non-tensor type.") const_utils.raise_type_error(
"min() cannot contain both tensor and non-tensor type.")
# exist tensor in list/tuple # exist tensor in list/tuple
if exist_tensor(data): if exist_tensor(data):
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") const_utils.raise_value_error(
"The truth value of an array with more than one element is ambiguous.")
return min_(*data) return min_(*data)
@ -2572,11 +2608,13 @@ def ms_sum(*data):
x = data[0] x = data[0]
if not isinstance(x, Tensor) and not hasattr(x, "__ms_iter__"): if not isinstance(x, Tensor) and not hasattr(x, "__ms_iter__"):
data_type = F.typeof(x) data_type = F.typeof(x)
const_utils.raise_type_error(str(data_type) + " object is not iterable.") const_utils.raise_type_error(
str(data_type) + " object is not iterable.")
if isinstance(x, Tensor): if isinstance(x, Tensor):
tensor_shape = F.shape(x) tensor_shape = F.shape(x)
if len(tensor_shape) == 0: if len(tensor_shape) == 0:
const_utils.raise_type_error("Cannot iterate over a scalar tensor.") const_utils.raise_type_error(
"Cannot iterate over a scalar tensor.")
if isinstance(x, dict): if isinstance(x, dict):
x = x.keys() x = x.keys()
result = 0 result = 0
@ -2607,7 +2645,8 @@ def ms_len(data):
def python_len_with_check(data): def python_len_with_check(data):
"""Return the result of python built-in len function with iterable check""" """Return the result of python built-in len function with iterable check"""
if not hasattr(data, "__iter__"): if not hasattr(data, "__iter__"):
raise TypeError(str(type(data)) + " object is not iterable in graph mode.") raise TypeError(str(type(data)) +
" object is not iterable in graph mode.")
return len(data) return len(data)
@ -2617,7 +2656,8 @@ def ms_len_with_iterable_check(data):
return python_len_with_check(data) return python_len_with_check(data)
if not hasattr(data, "__len__"): if not hasattr(data, "__len__"):
type_str = str(F.typeof(data)) type_str = str(F.typeof(data))
const_utils.raise_type_error(type_str + " object is not iterable in graph mode.") const_utils.raise_type_error(
type_str + " object is not iterable in graph mode.")
return data.__len__() return data.__len__()
@ -2680,7 +2720,6 @@ def expand_dims(x, axis):
""" """
Insert a dimension of shape 1 at the specified axis of Tensor. Insert a dimension of shape 1 at the specified axis of Tensor.
""" """
check_is_int(axis, 'axis')
return P.ExpandDims()(x, axis) return P.ExpandDims()(x, axis)
@ -2688,7 +2727,6 @@ def unsqueeze(x, dim):
""" """
Insert a dimension of shape 1 at the specified axis of Tensor. Insert a dimension of shape 1 at the specified axis of Tensor.
""" """
check_is_int(dim, 'dim')
return P.ExpandDims()(x, dim) return P.ExpandDims()(x, dim)
@ -2740,7 +2778,8 @@ def check_select_condition(cond_type):
""" """
if isinstance(cond_type, mstype.tensor_type): if isinstance(cond_type, mstype.tensor_type):
return return
raise TypeError(f"For select, the argument condition should be Tensor, but got {cond_type}.") raise TypeError(
f"For select, the argument condition should be Tensor, but got {cond_type}.")
@constexpr @constexpr
@ -2893,9 +2932,7 @@ def ge(x, y):
def while_cond(x): def while_cond(x):
"""For while condition, if the condition is a tensor, the loop will not be unrolled""" """For while condition, if the condition is a tensor, the loop will not be unrolled"""
if issubclass_(F.typeof(x), F.typeof(mstype.tensor)): if issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
is_cond = check_is_tensor_bool_cond(F.shape(x)) return F.cast(x, mstype.bool_)
if is_cond:
return F.cast(x, mstype.bool_)
return x return x
@ -3046,14 +3083,16 @@ def coo_to_dense(x):
def coo_coalesce(x): def coo_coalesce(x):
"""Returns the coalesced sparse tensor of the input.""" """Returns the coalesced sparse tensor of the input."""
shape = const_utils.make_tensor(x.shape) shape = const_utils.make_tensor(x.shape)
res_indices, res_values, _ = P.Coalesce()(x.indices.transpose(), x.values, shape) res_indices, res_values, _ = P.Coalesce()(
x.indices.transpose(), x.values, shape)
return COOTensor(res_indices.transpose(), res_values, x.shape) return COOTensor(res_indices.transpose(), res_values, x.shape)
def csr_to_coo(x): def csr_to_coo(x):
"""convert csr to coo.""" """convert csr to coo."""
if x.ndim != 2: if x.ndim != 2:
const_utils.raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.") const_utils.raise_value_error(
"Currently only support 2-D CSRTensor when converting to COOTensor.")
row_indices = F.csr2coo(x.indptr, x.values.shape[0]) row_indices = F.csr2coo(x.indptr, x.values.shape[0])
coo_indices = P.Stack(1)((row_indices, x.indices)) coo_indices = P.Stack(1)((row_indices, x.indices))
return COOTensor(coo_indices, x.values, x.shape) return COOTensor(coo_indices, x.values, x.shape)
@ -3069,8 +3108,6 @@ def random_categorical_(x, num_sample, seed=0, dtype=mstype.int64):
Generates random samples from a given categorical distribution tensor. Generates random samples from a given categorical distribution tensor.
Refer to :func:`mindspore.ops.random_categorical` for more detail. Refer to :func:`mindspore.ops.random_categorical` for more detail.
""" """
validator.check_is_int(num_sample, 'num_sample')
validator.check_is_int(seed, 'seed')
return F.random_categorical(x, num_sample, seed, dtype) return F.random_categorical(x, num_sample, seed, dtype)
@ -3099,33 +3136,31 @@ def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
"""check whether x is list or tuple or tensor.""" """check whether x is list or tuple or tensor."""
if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)): if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)):
return True return True
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.") raise TypeError(
f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")
@constexpr @constexpr
def check_is_const_int(x, op_name, arg_name): def check_is_const_int(x, op_name, arg_name):
"""check whether x is const int.""" """check whether x is const int."""
if x is None: if x is None:
raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.") raise TypeError(
f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
if not isinstance(x, int): if not isinstance(x, int):
raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.") raise TypeError(
f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
return True return True
@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
if shp in ((), (1,)):
return True
if None in shp:
raise ValueError(f"Only tensor which shape is () or (1,) can be converted to bool, but got tensor shape is "
f"None")
raise ValueError(f"Only tensor which shape is () or (1,) can be converted to bool, but got tensor shape is {shp}")
@constexpr @constexpr
def const_tensor_to_bool(x): def const_tensor_to_bool(x):
"""convert bool tensor to bool condition""" """convert bool tensor to bool condition
def const_tensor_to_bool(x):
convert bool tensor to bool condition
if x.shape == (1,):
return bool(x[0])
return bool(x)
"""
if x is None: if x is None:
raise ValueError("Only tensor which shape is () or (1,) can be converted to bool, but got None") raise ValueError("Only tensor which shape is () or (1,) can be converted to bool, but got None")
x = x.asnumpy() x = x.asnumpy()
@ -3153,36 +3188,22 @@ def check_view_shape(x):
return x return x
# convert normal param_check functions to constexpr functions
check_astype_dtype_const = constexpr(validator.check_astype_dtype) check_astype_dtype_const = constexpr(validator.check_astype_dtype)
check_transpose_axis_const = constexpr(validator.check_transpose_axis) check_transpose_axis_const = constexpr(validator.check_transpose_axis)
check_reshape_shp_const = constexpr(validator.check_reshape_shp)
check_flatten_order_const = constexpr(validator.check_flatten_order) check_flatten_order_const = constexpr(validator.check_flatten_order)
check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
check_axis_in_range_const = constexpr(validator.check_axis_in_range)
check_axis_valid = constexpr(validator.check_axis_valid)
max_ = constexpr(validator.max_) max_ = constexpr(validator.max_)
min_ = constexpr(validator.min_) min_ = constexpr(validator.min_)
expanded_shape = constexpr(validator.expanded_shape) expanded_shape = validator.expanded_shape
tuple_slice = constexpr(validator.tuple_slice) tuple_slice = validator.tuple_slice
infer_out_shape = constexpr(validator.infer_out_shape) empty_compile = validator.empty_compile
get_log2_size = constexpr(validator.get_log2_size)
check_axis_type = constexpr(validator.check_axis_type)
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
empty_compile = constexpr(validator.empty_compile)
check_type_support = constexpr(validator.check_type_support) check_type_support = constexpr(validator.check_type_support)
check_is_int = constexpr(validator.check_is_int)
check_type_name = constexpr(validator.check_type_name) check_type_name = constexpr(validator.check_type_name)
check_value_type = constexpr(validator.check_value_type) check_value_type = constexpr(validator.check_value_type)
check_int = constexpr(validator.check_int)
check_bool = constexpr(validator.check_bool)
def tensor_bool(x): def tensor_bool(x):
"""tensor as condition, if is constant, return immediate bool value""" """tensor as condition, if is constant, return immediate bool value"""
is_cond = check_is_tensor_bool_cond(F.shape(x)) if F.isconstant(x):
if is_cond and F.isconstant(x):
return const_tensor_to_bool(x) return const_tensor_to_bool(x)
return F.cast(x, mstype.bool_) return F.cast(x, mstype.bool_)
@ -3340,8 +3361,6 @@ def top_k(input_x, k, sorted=True):
""" """
Finds values and indices of the `k` largest entries along the last dimension. Finds values and indices of the `k` largest entries along the last dimension.
""" """
check_is_int(k, 'k')
check_bool(sorted, 'sorted')
return F.top_k(input_x, k, sorted) return F.top_k(input_x, k, sorted)
@ -3588,7 +3607,6 @@ def bernoulli(x, p=0.5, seed=-1):
""" """
Randomly draws binary numbers from a Bernoulli distribution. Randomly draws binary numbers from a Bernoulli distribution.
""" """
check_is_int(seed, 'bernoulli', 'seed')
return F.bernoulli(x, p, seed) return F.bernoulli(x, p, seed)

View File

@ -23,7 +23,6 @@ import mindspore.ops as ops
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.composite as C import mindspore.ops.composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr
from mindspore import log as logger from mindspore import log as logger
@ -96,17 +95,6 @@ def apply_offload_iterators(data, offload_model):
return data return data
@constexpr
def check_input_dims(x_shape, required_dim, offload_op_name):
"""
Check if input has the required number of dimensions for the operation.
"""
input_dim = len(x_shape)
if input_dim is not required_dim:
raise ValueError("For %s offload operation, the dimension of input should be %d, but got %d." %
(offload_op_name, required_dim, input_dim))
def assign_min_max_params(in_params, center=1): def assign_min_max_params(in_params, center=1):
""" """
Adjust input parameters for ops. Adjust input parameters for ops.
@ -175,7 +163,6 @@ class RandomHorizontalFlip(nn.Cell):
x = self.cast(x, mstype.float32) x = self.cast(x, mstype.float32)
x_shape = self.shape(x) x_shape = self.shape(x)
check_input_dims(x_shape, 4, 'RandomHorizontalFlip')
bs, h, w, c = x_shape bs, h, w, c = x_shape
flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32) flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
@ -208,7 +195,6 @@ class RandomVerticalFlip(nn.Cell):
x = self.cast(x, mstype.float32) x = self.cast(x, mstype.float32)
x_shape = self.shape(x) x_shape = self.shape(x)
check_input_dims(x_shape, 4, 'RandomVerticalFlip')
bs, h, w, c = x_shape bs, h, w, c = x_shape
flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32) flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
@ -293,7 +279,6 @@ class RandomColorAdjust(nn.Cell):
x = self.cast(x, mstype.float32) x = self.cast(x, mstype.float32)
x_shape = self.shape(x) x_shape = self.shape(x)
check_input_dims(x_shape, 4, 'RandomColorAdjust')
bs, h, w, c = x_shape bs, h, w, c = x_shape
br_rand_factor = self.generate_rand_batch(self.br_min, self.br_max, self.check_rand_br, x_shape) br_rand_factor = self.generate_rand_batch(self.br_min, self.br_max, self.check_rand_br, x_shape)
@ -402,7 +387,6 @@ class RandomSharpness(nn.Cell):
def construct(self, x): def construct(self, x):
x = self.cast(x, mstype.float32) x = self.cast(x, mstype.float32)
x_shape = self.shape(x) x_shape = self.shape(x)
check_input_dims(x_shape, 4, 'RandomSharpness')
bs, h, w, c = x_shape bs, h, w, c = x_shape
degree_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32) degree_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
@ -449,11 +433,8 @@ class HwcToChw(nn.Cell):
def __init__(self): def __init__(self):
super(HwcToChw, self).__init__() super(HwcToChw, self).__init__()
self.trans = P.Transpose() self.trans = P.Transpose()
self.shape = P.Shape()
def construct(self, x): def construct(self, x):
x_shape = self.shape(x)
check_input_dims(x_shape, 4, 'HwcToChw')
return self.trans(x, (0, 3, 1, 2)) return self.trans(x, (0, 3, 1, 2))

View File

@ -25,7 +25,6 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import nn_ops as NN_OPS from mindspore.ops.operations import nn_ops as NN_OPS
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore import ops from mindspore import ops
@ -195,18 +194,9 @@ class Softmax2d(Cell):
"""Initialize Softmax2d.""" """Initialize Softmax2d."""
super(Softmax2d, self).__init__() super(Softmax2d, self).__init__()
self.softmax = P.Softmax(axis=-3) self.softmax = P.Softmax(axis=-3)
self.shape = P.Shape()
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim not in (3, 4):
raise ValueError(f"For '{cls_name}', the in_shape must have 3 or 4 dims, but got {dim}.")
def construct(self, x): def construct(self, x):
x_shape = self.shape(x)
self._check_input_dim(x_shape, self.cls_name)
return self.softmax(x) return self.softmax(x)

View File

@ -87,15 +87,18 @@ class L1Regularizer(Cell):
super(L1Regularizer, self).__init__() super(L1Regularizer, self).__init__()
Validator.check_value_type("scale", scale, [int, float], self.cls_name) Validator.check_value_type("scale", scale, [int, float], self.cls_name)
if scale <= 0: if scale <= 0:
raise ValueError(f"For '{self.cls_name}', the 'scale' must be greater than 0, but got {scale}.") raise ValueError(
f"For '{self.cls_name}', the 'scale' must be greater than 0, but got {scale}.")
if math.isinf(scale) or math.isnan(scale): if math.isinf(scale) or math.isnan(scale):
raise ValueError(f"For '{self.cls_name}', the 'scale' can not be INF or NAN, but got {scale}.") raise ValueError(
f"For '{self.cls_name}', the 'scale' can not be INF or NAN, but got {scale}.")
self.abs = P.Abs() self.abs = P.Abs()
self.reduce_sum = P.ReduceSum() self.reduce_sum = P.ReduceSum()
self.scale = Tensor(scale, dtype=mstype.float32) self.scale = Tensor(scale, dtype=mstype.float32)
def construct(self, weights): def construct(self, weights):
const_utils.check_type_valid(F.dtype(weights), mstype.number_type, 'weights') const_utils.check_type_valid(
F.dtype(weights), mstype.number_type, 'weights')
l1_regularization = self.scale * self.reduce_sum(self.abs(weights)) l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
return l1_regularization return l1_regularization
@ -155,13 +158,16 @@ class Dropout(Cell):
def __init__(self, keep_prob=0.5, dtype=mstype.float32): def __init__(self, keep_prob=0.5, dtype=mstype.float32):
"""Initialize Dropout.""" """Initialize Dropout."""
super(Dropout, self).__init__() super(Dropout, self).__init__()
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) Validator.check_value_type('keep_prob', keep_prob, [
float], self.cls_name)
if keep_prob <= 0 or keep_prob > 1: if keep_prob <= 0 or keep_prob > 1:
raise ValueError(f"For '{self.cls_name}', the 'keep_prob' must be a number in range (0, 1], " raise ValueError(f"For '{self.cls_name}', the 'keep_prob' must be a number in range (0, 1], "
f"but got {keep_prob}.") f"but got {keep_prob}.")
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) Validator.check_subclass(
"dtype", dtype, mstype.number_type, self.cls_name)
if dtype != mstype.float32: if dtype != mstype.float32:
logger.info("This parameter `dtype` will be deleted or invisible in the future. Please don't use it.") logger.info(
"This parameter `dtype` will be deleted or invisible in the future. Please don't use it.")
self.keep_prob = keep_prob self.keep_prob = keep_prob
seed0, seed1 = _get_graph_seed(0, "dropout") seed0, seed1 = _get_graph_seed(0, "dropout")
self.seed0 = seed0 self.seed0 = seed0
@ -623,13 +629,6 @@ class Flatten(Cell):
return F.reshape(x, (F.shape(x)[0], -1)) return F.reshape(x, (F.shape(x)[0], -1))
@constexpr
def check_dense_input_shape(x, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if len(x) < 2:
raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 2, but got {len(x)}.")
class Identity(Cell): class Identity(Cell):
""" """
Returns a Tensor with the same shape and contents as input. Returns a Tensor with the same shape and contents as input.
@ -727,9 +726,12 @@ class Dense(Cell):
activation=None): activation=None):
"""Initialize Dense.""" """Initialize Dense."""
super(Dense, self).__init__() super(Dense, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name) self.in_channels = Validator.check_positive_int(
self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name) in_channels, "in_channels", self.cls_name)
self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name) self.out_channels = Validator.check_positive_int(
out_channels, "out_channels", self.cls_name)
self.has_bias = Validator.check_bool(
has_bias, "has_bias", self.cls_name)
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape_op = P.Shape() self.shape_op = P.Shape()
@ -740,7 +742,8 @@ class Dense(Cell):
f"be equal to 2, and the first dim must be equal to 'out_channels', and the " f"be equal to 2, and the first dim must be equal to 'out_channels', and the "
f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, " f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.") f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels]), name="weight")
self.bias = None self.bias = None
if self.has_bias: if self.has_bias:
@ -749,11 +752,13 @@ class Dense(Cell):
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must " raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got " f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.") f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") self.bias = Parameter(initializer(
bias_init, [out_channels]), name="bias")
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.matmul = P.MatMul(transpose_b=True) self.matmul = P.MatMul(transpose_b=True)
self.activation = get_activation(activation) if isinstance(activation, str) else activation self.activation = get_activation(activation) if isinstance(
activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)): if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got " raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
f"{type(activation).__name__}.") f"{type(activation).__name__}.")
@ -761,7 +766,6 @@ class Dense(Cell):
def construct(self, x): def construct(self, x):
x_shape = self.shape_op(x) x_shape = self.shape_op(x)
check_dense_input_shape(x_shape, self.cls_name)
if len(x_shape) != 2: if len(x_shape) != 2:
x = self.reshape(x, (-1, x_shape[-1])) x = self.reshape(x, (-1, x_shape[-1]))
x = self.matmul(x, self.weight) x = self.matmul(x, self.weight)
@ -775,7 +779,8 @@ class Dense(Cell):
return x return x
def extend_repr(self): def extend_repr(self):
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) s = 'input_channels={}, output_channels={}'.format(
self.in_channels, self.out_channels)
if self.has_bias: if self.has_bias:
s += ', has_bias={}'.format(self.has_bias) s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag: if self.activation_flag:
@ -787,14 +792,15 @@ class Dense(Cell):
def _is_equal_one(x): def _is_equal_one(x):
if x is None: if x is None:
return False return False
return bool(x.asnumpy().mean() == 1.0) return F.equal(F.reduce_mean(x), 1.0)
@constexpr @constexpr
def _dtype_check(x_dtype, prim_name=None): def _dtype_check(x_dtype, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The" msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if x_dtype not in [mstype.float32, mstype.float16]: if x_dtype not in [mstype.float32, mstype.float16]:
raise TypeError(f"{msg_prefix} x_dtype must be float32 or float16, but got {x_dtype}.") raise TypeError(
f"{msg_prefix} x_dtype must be float32 or float16, but got {x_dtype}.")
@constexpr @constexpr
@ -923,7 +929,8 @@ class Norm(Cell):
def __init__(self, axis=(), keep_dims=False): def __init__(self, axis=(), keep_dims=False):
"""Initialize Norm.""" """Initialize Norm."""
super(Norm, self).__init__() super(Norm, self).__init__()
Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name) Validator.check_value_type(
"keep_dims", keep_dims, [bool], self.cls_name)
self.axis = axis self.axis = axis
self.keep_dims = keep_dims self.keep_dims = keep_dims
self.reduce_sum = P.ReduceSum(True) self.reduce_sum = P.ReduceSum(True)
@ -1209,7 +1216,8 @@ class Pad(Cell):
super(Pad, self).__init__() super(Pad, self).__init__()
self.mode = mode self.mode = mode
self.paddings = paddings self.paddings = paddings
Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name) Validator.check_string(
self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
if not isinstance(paddings, tuple): if not isinstance(paddings, tuple):
raise TypeError(f"For '{self.cls_name}', the type of 'paddings' must be tuple, " raise TypeError(f"For '{self.cls_name}', the type of 'paddings' must be tuple, "
f"but got {type(paddings).__name__}.") f"but got {type(paddings).__name__}.")
@ -1239,20 +1247,17 @@ def bilinear(shape, size, scale, align_corners, prim_name=None):
"""Check input and calculate shape""" """Check input and calculate shape"""
msg_prefix = f"For '{prim_name}', the" if prim_name else "The" msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if not isinstance(align_corners, bool): if not isinstance(align_corners, bool):
raise TypeError(f"{msg_prefix} type of 'align_corners' must be boolean, " raise TypeError(
f"but got {type(align_corners).__name__}.") f"{msg_prefix} type of 'align_corners' must be bool, but got {type(align_corners).__name__}.")
if size is None and scale is None: if size is None and scale is None:
raise ValueError(f"{msg_prefix} 'size' and 'scale' both none.") raise ValueError(f"{msg_prefix} 'size' and 'scale' both none.")
if size is not None and scale is not None: if size is not None and scale is not None:
raise ValueError(f"{msg_prefix} 'size' and 'scale' both not none.") raise ValueError(f"{msg_prefix} 'size' and 'scale' both not none.")
if size is not None: if size is not None:
if not isinstance(size, (tuple, list)): if not isinstance(size, (tuple, list)):
raise ValueError(f"{msg_prefix} 'size' must be tuple or list or None, but got {type(size).__name__}.") raise ValueError(
Validator.check_int(len(size), 2, Rel.EQ, "size", "bilinear") f"{msg_prefix} 'size' must be tuple or list or None, but got {type(size).__name__}.")
Validator.check_int(size[0], 1, Rel.GE, "size[0]", "bilinear")
Validator.check_int(size[1], 1, Rel.GE, "size[1]", "bilinear")
return size return size
Validator.check_int(scale, 1, Rel.GE, "scale factor", "bilinear")
ret = (scale * shape[2], scale * shape[3]) ret = (scale * shape[2], scale * shape[3])
return ret return ret
@ -1323,8 +1328,10 @@ class ResizeBilinear(Cell):
self.half_pixel_centers = half_pixel_centers self.half_pixel_centers = half_pixel_centers
def construct(self, x, size=None, scale_factor=None, align_corners=False): def construct(self, x, size=None, scale_factor=None, align_corners=False):
shape = bilinear(x.shape, size, scale_factor, align_corners, self.cls_name) shape = bilinear(x.shape, size, scale_factor,
resize_bilinear = P.ResizeBilinear(shape, align_corners, self.half_pixel_centers) align_corners, self.cls_name)
resize_bilinear = P.ResizeBilinear(
shape, align_corners, self.half_pixel_centers)
return resize_bilinear(x) return resize_bilinear(x)
@ -1390,7 +1397,8 @@ class Unfold(Cell):
super(Unfold, self).__init__() super(Unfold, self).__init__()
def _check_tuple_or_list(arg_name, arg_val, prim_name): def _check_tuple_or_list(arg_name, arg_val, prim_name):
Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name) Validator.check_value_type(f"{arg_name}s", ksizes, [
tuple, list], self.cls_name)
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
raise ValueError(f"For '{prim_name}' the format of '{arg_name}s' must be [1, {arg_name}_row, " raise ValueError(f"For '{prim_name}' the format of '{arg_name}s' must be [1, {arg_name}_row, "
f"{arg_name}_col, 1], but got {arg_val}.") f"{arg_name}_col, 1], but got {arg_val}.")
@ -1405,19 +1413,17 @@ class Unfold(Cell):
ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2] ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
strides = strides[0], strides[3], strides[1], strides[2] strides = strides[0], strides[3], strides[1], strides[2]
rates = rates[0], rates[3], rates[1], rates[2] rates = rates[0], rates[3], rates[1], rates[2]
self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) self.extract_image_patches = inner.ExtractImagePatches(
ksizes, strides, rates, padding)
def construct(self, input_x): def construct(self, input_x):
result = self.extract_image_patches(input_x) result = self.extract_image_patches(input_x)
return result return result
@constexpr
def tril(x_shape, x_dtype, k): def tril(x_shape, x_dtype, k):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril") value = F.cast(P.Tril(diagonal=k)(F.ones(x_shape, x_dtype)), x_dtype)
Validator.check_is_int(k, "k value", "tril") return value
mask = np.tril(np.ones(x_shape), k)
return Tensor(mask, x_dtype)
class Tril(Cell): class Tril(Cell):
@ -1510,16 +1516,14 @@ class Tril(Cell):
def construct(self, x, k=0): def construct(self, x, k=0):
assist = tril(x.shape, self.dtype(x), k) assist = tril(x.shape, self.dtype(x), k)
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32)) result = self.mul(self.cast(x, mstype.float32),
self.cast(assist, mstype.float32))
return self.cast(result, self.dtype(x)) return self.cast(result, self.dtype(x))
@constexpr
def triu(x_shape, x_dtype, k): def triu(x_shape, x_dtype, k):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu") value = F.cast(P.Triu(k)(F.ones(x_shape, x_dtype)), x_dtype)
Validator.check_is_int(k, "k value", "triu") return value
mask = np.triu(np.ones(x_shape), k)
return Tensor(mask, x_dtype)
class Triu(Cell): class Triu(Cell):
@ -1603,24 +1607,34 @@ class Triu(Cell):
def construct(self, x, k=0): def construct(self, x, k=0):
assist = triu(x.shape, self.dtype(x), k) assist = triu(x.shape, self.dtype(x), k)
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32)) result = self.mul(self.cast(x, mstype.float32),
self.cast(assist, mstype.float32))
return self.cast(result, self.dtype(x)) return self.cast(result, self.dtype(x))
@constexpr
def _get_matrix_diag_assist(x_shape, x_dtype): def _get_matrix_diag_assist(x_shape, x_dtype):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") """Get matrix diag assist"""
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1) base_eye = F.reshape(
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],)) F.eye(x_shape[-1], x_shape[-1], x_dtype), (x_shape[-1] * x_shape[-1],))
return Tensor(assist, x_dtype) if len(x_shape) == 1:
assist = F.reshape(base_eye, x_shape + (x_shape[-1],))
else:
assist = F.reshape(
F.tile(base_eye, x_shape[:-1]), x_shape + (x_shape[-1],))
value = F.cast(assist, x_dtype)
return value
@constexpr
def _get_matrix_diag_part_assist(x_shape, x_dtype): def _get_matrix_diag_part_assist(x_shape, x_dtype):
Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist") """Get matrix diag part assist"""
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1) base_eye = F.reshape(
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape) F.eye(x_shape[-2], x_shape[-1], x_dtype), (x_shape[-2] * x_shape[-1],))
return Tensor(assist, x_dtype) if len(x_shape) <= 2:
assist = F.reshape(base_eye, x_shape)
else:
assist = F.reshape(F.tile(base_eye, x_shape[:-2]), x_shape)
value = F.cast(assist, x_dtype)
return value
class MatrixDiag(Cell): class MatrixDiag(Cell):
@ -1867,8 +1881,10 @@ class Roll(Cell):
def __init__(self, shift, axis): def __init__(self, shift, axis):
"""Initialize Roll""" """Initialize Roll"""
super(Roll, self).__init__() super(Roll, self).__init__()
Validator.check_value_type("shift", shift, [int, tuple, list], self.cls_name) Validator.check_value_type(
Validator.check_value_type("axis", axis, [int, tuple, list], self.cls_name) "shift", shift, [int, tuple, list], self.cls_name)
Validator.check_value_type(
"axis", axis, [int, tuple, list], self.cls_name)
self.shape_op = P.Shape() self.shape_op = P.Shape()
self.shift = shift self.shift = shift
self.axis = axis self.axis = axis
@ -1894,14 +1910,16 @@ class Roll(Cell):
f"and the length of 'axis' {len(self.axis)}.") f"and the length of 'axis' {len(self.axis)}.")
else: else:
if not isinstance(self.axis, (list, tuple)): if not isinstance(self.axis, (list, tuple)):
self.op_list.append((P.Roll(shift=self.shift, axis=0), self.axis)) self.op_list.append(
(P.Roll(shift=self.shift, axis=0), self.axis))
else: else:
if len(self.shift) != len(self.axis): if len(self.shift) != len(self.axis):
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be " raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
f"the same, but got the length of 'shift' {len(self.shift)} " f"the same, but got the length of 'shift' {len(self.shift)} "
f"and the length of 'axis' {len(self.axis)}.") f"and the length of 'axis' {len(self.axis)}.")
for idx, _ in enumerate(self.axis): for idx, _ in enumerate(self.axis):
self.op_list.append((P.Roll(shift=self.shift[idx], axis=0), self.axis[idx])) self.op_list.append(
(P.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
def construct(self, input_x): def construct(self, input_x):
dim = len(self.shape_op(input_x)) dim = len(self.shape_op(input_x))
@ -1965,7 +1983,8 @@ class Unflatten(Cell):
self.shape = P.Shape() self.shape = P.Shape()
self.reshape = P.Reshape() self.reshape = P.Reshape()
Validator.check_is_int(axis, 'axis', 'Unflatten') Validator.check_is_int(axis, 'axis', 'Unflatten')
Validator.check_value_type('unflattended_size', unflattened_size, (list, tuple), 'Unflatten') Validator.check_value_type(
'unflattended_size', unflattened_size, (list, tuple), 'Unflatten')
self.axis = axis self.axis = axis
if isinstance(unflattened_size, list): if isinstance(unflattened_size, list):
unflattened_size = tuple(unflattened_size) unflattened_size = tuple(unflattened_size)

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""channel shuffle""" """channel shuffle"""
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
@ -82,21 +81,9 @@ class ChannelShuffle(Cell):
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.transpose = P.Transpose() self.transpose = P.Transpose()
@staticmethod
@constexpr
def _check_input_dim(shape, channels, groups, cls_name):
dim = len(shape)
if dim < 3:
raise ValueError(f"For {cls_name}, the in_shape must have more than 2 dims, but got {dim}.")
if channels % groups != 0:
raise ValueError(f"For {cls_name}, number of channels must be divisible by groups, "
f"but got {channels} channels and {groups} groups.")
def construct(self, x): def construct(self, x):
x_shape = self.shape(x) x_shape = self.shape(x)
n, c = x_shape[0], x_shape[1] n, c = x_shape[0], x_shape[1]
self._check_input_dim(x_shape, c, self.groups, self.cls_name)
out = self.reshape(x, (n, self.groups, c // self.groups, -1)) out = self.reshape(x, (n, self.groups, c // self.groups, -1))
out = self.transpose(out, (0, 2, 1, 3)) out = self.transpose(out, (0, 2, 1, 3))
return self.reshape(out, x_shape) return self.reshape(out, x_shape)

View File

@ -315,12 +315,6 @@ class Conv2d(_Conv):
return output return output
@constexpr
def _check_input_3d(input_shape, op_name):
if len(input_shape) != 3:
raise ValueError(f"For '{op_name}', the dimension of input must be 3d, but got {len(input_shape)}.")
class Conv1d(_Conv): class Conv1d(_Conv):
r""" r"""
Calculates the 1D convolution on the input tensor. The input is typically of shape :math:`(N, C_{in}, L_{in})`, Calculates the 1D convolution on the input tensor. The input is typically of shape :math:`(N, C_{in}, L_{in})`,
@ -482,11 +476,8 @@ class Conv1d(_Conv):
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.expand_dims = P.ExpandDims() self.expand_dims = P.ExpandDims()
self.squeeze = P.Squeeze(2) self.squeeze = P.Squeeze(2)
self.shape = P.Shape()
def construct(self, x): def construct(self, x):
x_shape = self.shape(x)
_check_input_3d(x_shape, self.cls_name)
x = self.expand_dims(x, 2) x = self.expand_dims(x, 2)
output = self.conv2d(x, self.weight) output = self.conv2d(x, self.weight)
if self.has_bias: if self.has_bias:
@ -1289,8 +1280,6 @@ class Conv1dTranspose(_Conv):
return self return self
def construct(self, x): def construct(self, x):
x_shape = self.shape(x)
_check_input_3d(x_shape, self.cls_name)
x = self.expand_dims(x, 2) x = self.expand_dims(x, 2)
n, _, h, w = self.shape(x) n, _, h, w = self.shape(x)

View File

@ -30,14 +30,6 @@ from mindspore.nn.cell import Cell
__all__ = ['BiDense'] __all__ = ['BiDense']
@constexpr
def check_dense_inputs_same_shape(input1, input2, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if input1[:-1] != input2[:-1]:
raise ValueError(f"{msg_prefix} dimensions except the last of 'input1' must be same as 'input2', but got "
f"{input1} of 'input1' and {input2} of 'input2'")
@constexpr(check=False) @constexpr(check=False)
def _check_is_tensor(param_name, input_data, cls_name): def _check_is_tensor(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor.""" """Internal function, used to check whether the input data is Tensor."""
@ -46,14 +38,6 @@ def _check_is_tensor(param_name, input_data, cls_name):
f"but got '{P.typeof(input_data)}'") f"but got '{P.typeof(input_data)}'")
@constexpr
def check_last_dimension(input_dim, input_channels, input_name, input_channels_name, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if input_dim != input_channels:
raise ValueError(f"{msg_prefix} last dimension of '{input_name}' must be same as '{input_channels_name}',"
f" but got {input_dim} of '{input_name}' and {input_channels} of '{input_channels_name}'")
class BiDense(Cell): class BiDense(Cell):
r""" r"""
The bilinear dense connected layer. The bilinear dense connected layer.
@ -171,9 +155,6 @@ class BiDense(Cell):
_check_is_tensor("input2", input2, self.cls_name) _check_is_tensor("input2", input2, self.cls_name)
input1_shape = input1.shape input1_shape = input1.shape
input2_shape = input2.shape input2_shape = input2.shape
check_last_dimension(input1_shape[-1], self.in1_channels, "input1", "in1_channels", self.cls_name)
check_last_dimension(input2_shape[-1], self.in2_channels, "input2", "in2_channels", self.cls_name)
check_dense_inputs_same_shape(input1_shape, input2_shape, self.cls_name)
if len(input1_shape) != 2: if len(input1_shape) != 2:
input1 = input1.reshape((-1, input1_shape[-1])) input1 = input1.reshape((-1, input1_shape[-1]))
input2 = input2.reshape((-1, input2_shape[-1])) input2 = input2.reshape((-1, input2_shape[-1]))

View File

@ -39,13 +39,6 @@ from mindspore.nn.cell import Cell
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup'] __all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
@constexpr
def _check_input_2d(input_shape, param_name, func_name):
if len(input_shape) != 2:
raise ValueError(f"For '{func_name}', the dimension of '{param_name}' must be 2d, but got {len(input_shape)}")
return True
@constexpr @constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
@ -623,10 +616,6 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
self.negative_inf_value = -3.402823466E+38 self.negative_inf_value = -3.402823466E+38
def construct(self, input_indices, input_values, field_ids): def construct(self, input_indices, input_values, field_ids):
_check_input_2d(F.shape(input_indices), "input_indices", self.cls_name)
_check_input_2d(F.shape(input_values), "input_values", self.cls_name)
_check_input_2d(F.shape(field_ids), "field_ids", self.cls_name)
_check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name) _check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
_check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name) _check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
_check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name) _check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)

View File

@ -78,7 +78,6 @@ class ImageGradients(Cell):
super(ImageGradients, self).__init__() super(ImageGradients, self).__init__()
def construct(self, images): def construct(self, images):
check = _check_input_4d(F.shape(images), "images", self.cls_name)
images = F.depend(images, check) images = F.depend(images, check)
batch_size, depth, height, width = P.Shape()(images) batch_size, depth, height, width = P.Shape()(images)
if height == 1: if height == 1:
@ -120,21 +119,6 @@ def _get_dtype_max(dtype):
return dtype_max return dtype_max
@constexpr
def _check_input_4d(input_shape, param_name, func_name):
if len(input_shape) != 4:
raise ValueError(f"For '{func_name}', the dimension of '{param_name}' must be 4d, "
f"but got {len(input_shape)}.")
return True
@constexpr
def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
_check_input_4d(input_shape, param_name, func_name)
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
@constexpr @constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
@ -281,7 +265,6 @@ class SSIM(Cell):
def construct(self, img1, img2): def construct(self, img1, img2):
_check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
_check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
inner.SameTypeShape()(img1, img2) inner.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1)) dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1)) max_val = F.scalar_cast(self.max_val, F.dtype(img1))
@ -387,8 +370,6 @@ class MSSSIM(Cell):
self.concat = P.Concat(axis=1) self.concat = P.Concat(axis=1)
def construct(self, img1, img2): def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8] valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8]
_check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name) _check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name)
inner.SameTypeShape()(img1, img2) inner.SameTypeShape()(img1, img2)
@ -466,8 +447,6 @@ class PSNR(Cell):
self.max_val = max_val self.max_val = max_val
def construct(self, img1, img2): def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
inner.SameTypeShape()(img1, img2) inner.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1)) dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1)) max_val = F.scalar_cast(self.max_val, F.dtype(img1))
@ -481,22 +460,17 @@ class PSNR(Cell):
return psnr return psnr
@constexpr
def _raise_dims_rank_error(input_shape, param_name, func_name):
"""raise error if input is not 3d or 4d"""
raise ValueError(f"{func_name} {param_name} must be 3d or 4d, but got shape {input_shape}")
@constexpr @constexpr
def _get_bbox(rank, shape, central_fraction): def _get_bbox(rank, shape, central_fraction):
"""get bbox start and size for slice""" """get bbox start and size for slice"""
n, c, h, w = -1, -1, -1, -1
if rank == 3: if rank == 3:
c, h, w = shape c, h, w = shape
else: else:
n, c, h, w = shape n, c, h, w = shape
bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2) bbox_h_start = int((float(h) - float(h * central_fraction)) / 2)
bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2) bbox_w_start = int((float(w) - float(w * central_fraction)) / 2)
bbox_h_size = h - bbox_h_start * 2 bbox_h_size = h - bbox_h_start * 2
bbox_w_size = w - bbox_w_start * 2 bbox_w_size = w - bbox_w_start * 2
@ -548,8 +522,6 @@ class CentralCrop(Cell):
def construct(self, image): def construct(self, image):
image_shape = F.shape(image) image_shape = F.shape(image)
rank = len(image_shape) rank = len(image_shape)
if rank not in (3, 4):
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
if self.central_fraction == 1.0: if self.central_fraction == 1.0:
return image return image

View File

@ -763,11 +763,6 @@ class LBeta(Cell):
@constexpr @constexpr
def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None): def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
"""get broadcast_matmul shape""" """get broadcast_matmul shape"""
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if (len(x_shape) < 2) or (len(y_shape) < 2):
raise ValueError(f"{msg_prefix} length of 'x_shape' and 'y_shape' must be equal to or greater than 2, "
f"but got the length of 'x_shape': {len(x_shape)} and the length of 'y_shape': "
f"{len(y_shape)}.")
x_shape_batch = x_shape[:-2] x_shape_batch = x_shape[:-2]
y_shape_batch = y_shape[:-2] y_shape_batch = y_shape[:-2]
if x_shape_batch == y_shape_batch: if x_shape_batch == y_shape_batch:
@ -783,10 +778,6 @@ def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
broadcast_shape_back.append(x_shape[i]) broadcast_shape_back.append(x_shape[i])
elif x_shape[i] == y_shape[i]: elif x_shape[i] == y_shape[i]:
broadcast_shape_back.append(x_shape[i]) broadcast_shape_back.append(x_shape[i])
else:
raise ValueError(f"{msg_prefix} 'x_shape[{i}]' must be equal to 1, or the 'y_shape[{i}]' must be equal "
f"to 1, or the 'x_shape[{i}]' must be equal to 'y_shape[{i}]', but got "
f"'x_shape[{i}]': {x_shape[i]}, 'y_shape[{i}]': {y_shape[i]}.")
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:] x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
@ -794,25 +785,6 @@ def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
return x_broadcast_shape, y_broadcast_shape return x_broadcast_shape, y_broadcast_shape
@constexpr
def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_name=None):
"""check col and row equal"""
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if len(x1_shape) == 1:
transpose_x1 = False
x1_shape = (1,) + x1_shape
if len(x2_shape) == 1:
transpose_x2 = False
x2_shape = x2_shape + (1,)
x1_last = x1_shape[-2:]
x2_last = x2_shape[-2:]
x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
if x1_col != x2_row:
raise ValueError(f"{msg_prefix} column of matrix dimensions of 'x1' must be equal to "
f"the row of matrix dimensions of 'x2', but got 'x1_col' {x1_col} and 'x2_row' {x2_row}.")
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2): def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
"""select matmul op""" """select matmul op"""
x1_dim, x2_dim = len(x1_shape), len(x2_shape) x1_dim, x2_dim = len(x1_shape), len(x2_shape)
@ -857,7 +829,6 @@ class MatMul(Cell):
def construct(self, x1, x2): def construct(self, x1, x2):
x1_shape = self.shape_op(x1) x1_shape = self.shape_op(x1)
x2_shape = self.shape_op(x2) x2_shape = self.shape_op(x2)
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2, self.cls_name)
matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2) matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
x1_dim, x2_dim = len(x1_shape), len(x2_shape) x1_dim, x2_dim = len(x1_shape), len(x2_shape)

View File

@ -121,13 +121,8 @@ class _BatchNorm(Cell):
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
raise NotImplementedError
def construct(self, x): def construct(self, x):
self._check_input_dim(self.shape(x), self.cls_name)
if self.use_batch_statistics is None: if self.use_batch_statistics is None:
if self.training: if self.training:
return self.bn_train(x, return self.bn_train(x,
@ -227,13 +222,6 @@ class BatchNorm1d(_BatchNorm):
[ 0.4999975 0.399998 0.59999704 0.89999545 ]] [ 0.4999975 0.399998 0.59999704 0.89999545 ]]
""" """
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 2:
raise ValueError(f"For '{cls_name}', the in_shape must have 2 dims, but got {dim}.")
class BatchNorm2d(_BatchNorm): class BatchNorm2d(_BatchNorm):
r""" r"""
@ -319,13 +307,6 @@ class BatchNorm2d(_BatchNorm):
[ 0.999995 0.999995 ]]]] [ 0.999995 0.999995 ]]]]
""" """
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 4:
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
class BatchNorm3d(Cell): class BatchNorm3d(Cell):
r""" r"""
@ -413,16 +394,9 @@ class BatchNorm3d(Cell):
self.shape = P.Shape() self.shape = P.Shape()
self.reshape = P.Reshape() self.reshape = P.Reshape()
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 5:
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
def construct(self, x): def construct(self, x):
x_shape = self.shape(x) x_shape = self.shape(x)
self._check_input_dim(x_shape, self.cls_name)
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(x) bn2d_out = self.bn2d(x)
bn3d_out = self.reshape(bn2d_out, x_shape) bn3d_out = self.reshape(bn2d_out, x_shape)
@ -586,12 +560,6 @@ class SyncBatchNorm(_BatchNorm):
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim not in (2, 4):
raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
def _check_rank_ids(self, process_groups, rank_size): def _check_rank_ids(self, process_groups, rank_size):
seen = set() seen = set()
@ -729,7 +697,6 @@ class _InstanceNorm(Cell):
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum) self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
def construct(self, x): def construct(self, x):
self._check_input_dim(self.shape(x), self.cls_name)
return self.instance_bn(x, return self.instance_bn(x,
self.gamma, self.gamma,
self.beta, self.beta,
@ -822,13 +789,6 @@ class InstanceNorm1d(_InstanceNorm):
(2, 3, 5) (2, 3, 5)
""" """
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 3:
raise ValueError(f"For '{cls_name}', the in_shape must have 3 dims, but got {dim}.")
class InstanceNorm2d(_InstanceNorm): class InstanceNorm2d(_InstanceNorm):
r""" r"""
@ -901,13 +861,6 @@ class InstanceNorm2d(_InstanceNorm):
(2, 3, 2, 2) (2, 3, 2, 2)
""" """
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 4:
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
class InstanceNorm3d(_InstanceNorm): class InstanceNorm3d(_InstanceNorm):
r""" r"""
@ -979,12 +932,6 @@ class InstanceNorm3d(_InstanceNorm):
>>> print(output.shape) >>> print(output.shape)
(2, 3, 5, 2, 2) (2, 3, 5, 2, 2)
""" """
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 5:
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
class GroupNorm(Cell): class GroupNorm(Cell):
@ -1068,7 +1015,6 @@ class GroupNorm(Cell):
def _cal_output(self, x): def _cal_output(self, x):
"""calculate groupnorm output""" """calculate groupnorm output"""
batch, channel, height, width = self.shape(x) batch, channel, height, width = self.shape(x)
self._channel_check(channel, self.num_channels, self.cls_name)
x = self.reshape(x, (batch, self.num_groups, -1)) x = self.reshape(x, (batch, self.num_groups, -1))
mean = self.reduce_mean(x, 2) mean = self.reduce_mean(x, 2)
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
@ -1078,21 +1024,6 @@ class GroupNorm(Cell):
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1)) output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
return output return output
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim != 4:
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
@staticmethod
@constexpr
def _channel_check(channel, num_channel, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if channel != num_channel:
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
f"but got channel: {channel}, num_channels: {num_channel}.")
@staticmethod @staticmethod
@constexpr @constexpr
def _check_dtype(dtype, valid_dtypes, prim_name=None): def _check_dtype(dtype, valid_dtypes, prim_name=None):
@ -1102,7 +1033,6 @@ class GroupNorm(Cell):
return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels) return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
def construct(self, x): def construct(self, x):
self._check_input_dim(self.shape(x), self.cls_name)
self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name) self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
output = self._cal_output(x) output = self._cal_output(x)
return output return output

View File

@ -31,10 +31,7 @@ def _check_padding_dimension(dimension, padding):
Validate the input padding and add placeholders if needed. Validate the input padding and add placeholders if needed.
Note: the input 'padding' in this function is already converted to list of lists to match MirrorPad Note: the input 'padding' in this function is already converted to list of lists to match MirrorPad
""" """
if dimension < len(padding): # add place holders
raise ValueError(f"For padding with length {len(padding) * 2}, the dimension of the tensor should be at least "
f"{len(padding)}, but got {dimension}")
# add placeholders
if dimension > len(padding): if dimension > len(padding):
padding = [(0, 0) for _ in range(dimension - len(padding))] + [x for x in padding] padding = [(0, 0) for _ in range(dimension - len(padding))] + [x for x in padding]
return padding return padding
@ -56,45 +53,16 @@ def _swap_to_ms_padding_order(padding):
@constexpr @constexpr
def _check(input_shape, padding, name): def _check(input_shape, padding):
""" """
Check relationship between input shape and padding to make sure after negative dimension padding the out is Check relationship between input shape and padding to make sure after negative dimension padding the out is
positive. positive.
""" """
if len(input_shape) < len(padding):
msg = "For '{}', the dimension of input must more than or equal to len(padding)/2, " \
"but got {}".format(name, len(input_shape))
raise ValueError(msg)
if len(input_shape) > len(padding): if len(input_shape) > len(padding):
if len(padding) == 2 and isinstance(padding[0], int): if len(padding) == 2 and isinstance(padding[0], int):
padding = [(0, 0) for i in range(len(input_shape) - 1)] + [padding] padding = [(0, 0) for i in range(len(input_shape) - 1)] + [padding]
else: else:
padding = [(0, 0) for i in range(len(input_shape) - len(padding))] + [x for x in padding] padding = [(0, 0) for i in range(len(input_shape) - len(padding))] + [x for x in padding]
for index, item in enumerate(padding):
if index == 0:
dim_name = '1st'
elif index == 1:
dim_name = '2nd'
elif index == 2:
dim_name = '3rd'
else:
dim_name = str(index + 1) + 'th'
if item[0] < -input_shape[index]:
msg = "For '{}', the shape of input after padding must be positive, the input shape is {}, " \
"value of parameter 'padding' applied to the {} dimension of input must " \
"no less than -{}, but got {}".format(name, input_shape, dim_name, input_shape[index], item[0])
raise ValueError(msg)
if item[1] < -input_shape[index]:
msg = "For '{}', the shape of input after padding must be positive, the input shape is {}, " \
"value of parameter 'padding' applied to the {} dimension of input must " \
"no less than -{}, but got {}".format(name, input_shape, dim_name, input_shape[index], item[1])
raise ValueError(msg)
if input_shape[index] + item[0] + item[1] <= 0:
msg = "For '{}', the shape of input after padding must be positive, the input shape is {}, " \
"but the {} dimension of input shape {} plus padding {} and {} resulted in a non-positive output " \
"shape.".format(name, input_shape, dim_name, input_shape[index], item[0], item[1])
raise ValueError(msg)
return padding return padding
@ -199,7 +167,7 @@ class _ConstantPadNd(Cell):
"""Construct the pad net.""" """Construct the pad net."""
input_shape = x.shape input_shape = x.shape
input_type = x.dtype input_type = x.dtype
padding = _check(input_shape, self.padding, self._name) padding = _check(input_shape, self.padding)
new_padding, start, end = _get_new_padding(padding) new_padding, start, end = _get_new_padding(padding)
mask = ops.Ones()(input_shape, input_type) mask = ops.Ones()(input_shape, input_type)
output = ops.Pad(new_padding)(x) output = ops.Pad(new_padding)(x)
@ -671,10 +639,6 @@ class _ReplicationPadNd(Cell):
self.padding = padding self.padding = padding
self.padv3 = nn_ops.PadV3(mode="edge") self.padv3 = nn_ops.PadV3(mode="edge")
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
raise NotImplementedError
@staticmethod @staticmethod
@constexpr @constexpr
@ -682,7 +646,6 @@ class _ReplicationPadNd(Cell):
raise NotImplementedError raise NotImplementedError
def construct(self, x): def construct(self, x):
self._check_input_dim(x.shape, self.name)
need_expend_dims = self._need_expend_dim(x) need_expend_dims = self._need_expend_dim(x)
if need_expend_dims: if need_expend_dims:
x = x.expand_dims(0) x = x.expand_dims(0)
@ -743,12 +706,6 @@ class ReplicationPad1d(_ReplicationPadNd):
padding = (padding, padding) padding = (padding, padding)
super(ReplicationPad1d, self).__init__(padding, name="ReplicationPad1d") super(ReplicationPad1d, self).__init__(padding, name="ReplicationPad1d")
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim not in (2, 3):
raise ValueError(f"For '{cls_name}', the in_shape must have 2 or 3 dims, but got {dim}.")
def _need_expend_dim(self, x): def _need_expend_dim(self, x):
input_shape = x.shape input_shape = x.shape
@ -814,12 +771,6 @@ class ReplicationPad2d(_ReplicationPadNd):
padding = (padding, padding, padding, padding) padding = (padding, padding, padding, padding)
super(ReplicationPad2d, self).__init__(padding, name="ReplicationPad2d") super(ReplicationPad2d, self).__init__(padding, name="ReplicationPad2d")
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim not in (3, 4):
raise ValueError(f"For '{cls_name}', the in_shape must have 3 or 4 dims, but got {dim}.")
def _need_expend_dim(self, x): def _need_expend_dim(self, x):
input_shape = x.shape input_shape = x.shape
@ -885,12 +836,6 @@ class ReplicationPad3d(_ReplicationPadNd):
padding = (padding, padding, padding, padding, padding, padding) padding = (padding, padding, padding, padding, padding, padding)
super(ReplicationPad3d, self).__init__(padding, name="ReplicationPad3d") super(ReplicationPad3d, self).__init__(padding, name="ReplicationPad3d")
@staticmethod
@constexpr
def _check_input_dim(shape, cls_name):
dim = len(shape)
if dim not in (4, 5):
raise ValueError(f"For '{cls_name}', the in_shape must have 4 or 5 dims, but got {dim}.")
def _need_expend_dim(self, x): def _need_expend_dim(self, x):
input_shape = x.shape input_shape = x.shape

View File

@ -73,13 +73,6 @@ class _PoolNd(Cell):
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__) return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
@constexpr
def _shape_check(in_shape, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if len(in_shape) != 3:
raise ValueError(f"{msg_prefix} input must has 3 dim, but got {len(in_shape)}")
class LPPool1d(Cell): class LPPool1d(Cell):
r""" r"""
Applies a 1D power lp pooling over an input signal composed of several input planes. Applies a 1D power lp pooling over an input signal composed of several input planes.
@ -489,7 +482,6 @@ class MaxPool1d(_PoolNd):
self.squeeze = P.Squeeze(2) self.squeeze = P.Squeeze(2)
def construct(self, x): def construct(self, x):
_shape_check(self.shape(x), self.cls_name)
x = self.expand(x, 2) x = self.expand(x, 2)
output = self.max_pool(x) output = self.max_pool(x)
output = self.squeeze(output) output = self.squeeze(output)
@ -743,7 +735,6 @@ class AvgPool1d(_PoolNd):
self.squeeze = P.Squeeze(2) self.squeeze = P.Squeeze(2)
def construct(self, x): def construct(self, x):
x = F.depend(x, _shape_check(self.shape(x), self.cls_name))
batch, channel, width = self.shape(x) batch, channel, width = self.shape(x)
if width == self.kernel_size[1]: if width == self.kernel_size[1]:
x = self.reduce_mean(x, 2) x = self.reduce_mean(x, 2)
@ -757,20 +748,6 @@ class AvgPool1d(_PoolNd):
return x return x
@constexpr
def _adaptive_shape_check(in_shape, output_size, prim_name):
"""Check shape."""
msg_prefix = "For {}, the".format(prim_name)
if len(in_shape) != 3:
raise ValueError("{} input must has 3 dim, but got {}.".format(msg_prefix, len(in_shape)))
if in_shape[2] < output_size:
raise ValueError("{} input's last dimension must be greater or equal to "
"output size {}, but got {}.".format(msg_prefix, output_size, in_shape[2]))
if in_shape[2] % output_size != 0:
raise ValueError("{} input's last dimension must be divisible by "
"output size {}, but got {}.".format(msg_prefix, output_size, in_shape[2]))
@constexpr @constexpr
def _adaptive_dtype_check(x_dtype, prim_name): def _adaptive_dtype_check(x_dtype, prim_name):
"""Check dtype.""" """Check dtype."""
@ -837,7 +814,6 @@ class AdaptiveAvgPool1d(Cell):
self.dtype = P.DType() self.dtype = P.DType()
def construct(self, x): def construct(self, x):
_adaptive_shape_check(self.shape(x), self.output_size, self.cls_name)
_adaptive_dtype_check(self.dtype(x), self.cls_name) _adaptive_dtype_check(self.dtype(x), self.cls_name)
_, _, width = self.shape(x) _, _, width = self.shape(x)
@ -1052,7 +1028,6 @@ class AdaptiveMaxPool1d(Cell):
self.dtype = P.DType() self.dtype = P.DType()
def construct(self, x): def construct(self, x):
_adaptive_shape_check(self.shape(x), self.output_size, self.cls_name)
_adaptive_dtype_check(self.dtype(x), self.cls_name) _adaptive_dtype_check(self.dtype(x), self.cls_name)
_, _, width = self.shape(x) _, _, width = self.shape(x)

View File

@ -22,6 +22,7 @@ import mindspore.nn as nn
import mindspore.ops as P import mindspore.ops as P
import mindspore.context as context import mindspore.context as context
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import ParameterTuple, Parameter from mindspore.common.parameter import ParameterTuple, Parameter
@ -446,6 +447,17 @@ class _RNNBase(Cell):
self.b_ih_list = ParameterTuple(self.b_ih_list) self.b_ih_list = ParameterTuple(self.b_ih_list)
self.b_hh_list = ParameterTuple(self.b_hh_list) self.b_hh_list = ParameterTuple(self.b_hh_list)
# TODO: remove this func
def _shape_dynamic(self, shape):
"""use this func for dynamic. del it when ShapeOp is supported"""
x = []
for i in shape:
if not F.isconstant(i):
x.append(-1)
else:
x.append(i)
return tuple(x)
def _stacked_bi_dynamic_rnn(self, x, h, seq_length): def _stacked_bi_dynamic_rnn(self, x, h, seq_length):
"""stacked bidirectional dynamic_rnn""" """stacked bidirectional dynamic_rnn"""
pre_layer = x pre_layer = x
@ -491,9 +503,11 @@ class _RNNBase(Cell):
if self.is_lstm: if self.is_lstm:
h_n = P.Concat(0)(h_n) h_n = P.Concat(0)(h_n)
c_n = P.Concat(0)(c_n) c_n = P.Concat(0)(c_n)
h_n = h_n.view(h[0].shape) h0_shape = self._shape_dynamic(h[0].shape)
c_n = c_n.view(h[1].shape) h1_shape = self._shape_dynamic(h[1].shape)
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape)) h_n = h_n.view(h0_shape)
c_n = c_n.view(h1_shape)
return output, (h_n.view(h0_shape), c_n.view(h1_shape))
h_n = P.Concat(0)(h_n) h_n = P.Concat(0)(h_n)
return output, h_n.view(h.shape) return output, h_n.view(h.shape)
@ -523,9 +537,11 @@ class _RNNBase(Cell):
if self.is_lstm: if self.is_lstm:
h_n = P.Concat(0)(h_n) h_n = P.Concat(0)(h_n)
c_n = P.Concat(0)(c_n) c_n = P.Concat(0)(c_n)
h_n = h_n.view(h[0].shape) h0_shape = self._shape_dynamic(h[0].shape)
c_n = c_n.view(h[1].shape) h1_shape = self._shape_dynamic(h[1].shape)
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape)) h_n = h_n.view(h0_shape)
c_n = c_n.view(h1_shape)
return output, (h_n.view(h0_shape), c_n.view(h1_shape))
h_n = P.Concat(0)(h_n) h_n = P.Concat(0)(h_n)
return output, h_n.view(h.shape) return output, h_n.view(h.shape)

View File

@ -33,9 +33,9 @@ from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \
_set_rank_id, _insert_hash_table_size, _set_cache_enable _set_rank_id, _insert_hash_table_size, _set_cache_enable
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.nn.layer.basic import ClipByNorm from mindspore.nn.layer.basic import ClipByNorm
from mindspore.ops.primitive import constexpr
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor'] __all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']

View File

@ -24,25 +24,6 @@ from mindspore.nn.cell import Cell
__all__ = ['TimeDistributed'] __all__ = ['TimeDistributed']
@constexpr
def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]:
raise ValueError(f"{msg_prefix} 'reshape_with_axis' is invalid in the input and output. "
f"The 'reshape_pos' must be less than the length of 'outputs_shape', and the "
f"'inputs_shape[reshape_pos]' must be equal to 'outputs_shape[reshape_pos]', but got "
f"'reshape_pos': {reshape_pos}, 'inputs_shape': {inputs_shape}, 'outputs_shape': "
f"{outputs_shape}. You may try pass parameters without 'reshape_with_axis'.")
@constexpr
def _check_expand_dims_axis(time_axis, ndim, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if time_axis > ndim:
raise ValueError(f"{msg_prefix} value of 'time_axis' must be in range of [{-ndim - 1}, {ndim}], "
f"but got {time_axis}.")
@constexpr @constexpr
def _generate_perm(axis_a, axis_b, length): def _generate_perm(axis_a, axis_b, length):
perm = tuple(range(length)) perm = tuple(range(length))
@ -57,13 +38,6 @@ def _check_data(flag, prim_name=None):
raise TypeError(f"{msg_prefix} inputs and outputs must be a Tensor.") raise TypeError(f"{msg_prefix} inputs and outputs must be a Tensor.")
@constexpr
def _check_inputs_dim(shape, prim_name=None):
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
if len(shape) < 3:
raise ValueError(f"{msg_prefix} inputs shape must be at least 3D, but got {len(shape)}.")
class TimeDistributed(Cell): class TimeDistributed(Cell):
r""" r"""
The time distributed layer. The time distributed layer.
@ -119,7 +93,6 @@ class TimeDistributed(Cell):
def construct(self, inputs): def construct(self, inputs):
_check_data(isinstance(inputs, Tensor), self.cls_name) _check_data(isinstance(inputs, Tensor), self.cls_name)
_check_inputs_dim(inputs.shape, self.cls_name)
time_axis = self.time_axis % len(inputs.shape) time_axis = self.time_axis % len(inputs.shape)
if self.reshape_with_axis is not None: if self.reshape_with_axis is not None:
reshape_with_axis = self.reshape_with_axis % len(inputs.shape) reshape_with_axis = self.reshape_with_axis % len(inputs.shape)
@ -134,7 +107,6 @@ class TimeDistributed(Cell):
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:]) inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
outputs = self.layer(inputs) outputs = self.layer(inputs)
_check_data(isinstance(outputs, Tensor), self.cls_name) _check_data(isinstance(outputs, Tensor), self.cls_name)
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape, self.cls_name)
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2] outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
if reshape_pos + 1 < len(outputs.shape): if reshape_pos + 1 < len(outputs.shape):
outputs_shape_new += outputs.shape[reshape_pos + 1:] outputs_shape_new += outputs.shape[reshape_pos + 1:]
@ -147,7 +119,6 @@ class TimeDistributed(Cell):
for item in inputs: for item in inputs:
outputs = self.layer(item) outputs = self.layer(item)
_check_data(isinstance(outputs, Tensor), self.cls_name) _check_data(isinstance(outputs, Tensor), self.cls_name)
_check_expand_dims_axis(time_axis, outputs.ndim, self.cls_name)
y += (outputs,) y += (outputs,)
y = Stack(time_axis)(y) y = Stack(time_axis)(y)
return y return y

View File

@ -809,7 +809,6 @@ class DiceLoss(LossBase):
def construct(self, logits, label): def construct(self, logits, label):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', label, self.cls_name) _check_is_tensor('labels', label, self.cls_name)
_check_shape(logits.shape, label.shape, self.cls_name)
if logits.dtype == mstype.uint8: if logits.dtype == mstype.uint8:
raise TypeError(f"For '{self.cls_name}', the dtype of 'logits' can not be uint8.") raise TypeError(f"For '{self.cls_name}', the dtype of 'logits' can not be uint8.")
if label.dtype == mstype.uint8: if label.dtype == mstype.uint8:
@ -824,31 +823,6 @@ class DiceLoss(LossBase):
return dice_loss return dice_loss
@constexpr
def _check_shape(logits_shape, label_shape, prim_name=None):
"""Internal function, used to check whether the shape of logits and labels meets the requirements."""
validator.check('logits_shape', logits_shape, 'label_shape', label_shape, prim_name=prim_name)
@constexpr
def _check_ndim_multi(logits_dim, label_dim, prim_name=None):
"""Internal function, used to check whether the dimension of logits and label meets the requirements."""
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
if logits_dim < 2:
raise ValueError(f"{msg_prefix} 'logits' dimension must be greater than 1, but got {logits_dim}.")
if label_dim < 2:
raise ValueError(f"{msg_prefix} 'labels' dimension must be greater than 1, but got {label_dim}.")
@constexpr
def _check_weights(weight_shape, label_shape, prim_name=None):
"""Internal function, used to check whether the reduced shape meets the requirements."""
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
if weight_shape != label_shape:
raise ValueError(f"{msg_prefix} weight_shape[0] must be equal to label_shape[1], "
f"but got weight_shape[0]: {weight_shape} and label_shape[1]: {label_shape}.")
class MultiClassDiceLoss(LossBase): class MultiClassDiceLoss(LossBase):
r""" r"""
When there are multiple classifications, label is transformed into multiple binary classifications by one hot. When there are multiple classifications, label is transformed into multiple binary classifications by one hot.
@ -919,8 +893,6 @@ class MultiClassDiceLoss(LossBase):
def construct(self, logits, label): def construct(self, logits, label):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', label, self.cls_name) _check_is_tensor('labels', label, self.cls_name)
_check_shape(logits.shape, label.shape, self.cls_name)
_check_ndim_multi(logits.ndim, label.ndim, self.cls_name)
total_loss = 0 total_loss = 0
if self.activation is not None: if self.activation is not None:
@ -930,7 +902,6 @@ class MultiClassDiceLoss(LossBase):
if i != self.ignore_indiex: if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i]) dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None: if self.weights is not None:
_check_weights(self.weights.shape[0], label.shape[1], self.cls_name)
dice_loss *= self.weights[i] dice_loss *= self.weights[i]
total_loss += dice_loss total_loss += dice_loss
@ -1462,12 +1433,6 @@ class BCELoss(LossBase):
return loss return loss
@constexpr
def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name, arg_name1, arg_name2):
"""Internal function, used to check whether the reduced shape meets the requirements."""
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name, arg_name1, arg_name2)
class CosineEmbeddingLoss(LossBase): class CosineEmbeddingLoss(LossBase):
r""" r"""
CosineEmbeddingLoss creates a criterion to measure the similarity between two tensors using cosine distance. CosineEmbeddingLoss creates a criterion to measure the similarity between two tensors using cosine distance.
@ -1527,7 +1492,6 @@ class CosineEmbeddingLoss(LossBase):
_check_is_tensor('logits_x2', logits_x2, self.cls_name) _check_is_tensor('logits_x2', logits_x2, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name)
inner.same_type_shape_(logits_x1, logits_x2) inner.same_type_shape_(logits_x1, logits_x2)
_check_reduced_shape_valid(F.shape(logits_x1), F.shape(labels), (1,), self.cls_name, "logits_x1", "labels")
# if labels > 0, 1-cosine(logits_x1, logits_x2) # if labels > 0, 1-cosine(logits_x1, logits_x2)
# else, max(0, cosine(logits_x1, logits_x2)-margin) # else, max(0, cosine(logits_x1, logits_x2)-margin)
prod_sum = self.reduce_sum(logits_x1 * logits_x2, (1,)) prod_sum = self.reduce_sum(logits_x1 * logits_x2, (1,))
@ -1705,32 +1669,6 @@ class BCEWithLogitsLoss(LossBase):
return loss return loss
@constexpr
def _check_ndim(logits_nidm, labels_ndim, prime_name=None):
'''Internal function, used to check whether the dimension of logits and labels meets the requirements.'''
msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The"
if logits_nidm < 2 or logits_nidm > 4:
raise ValueError(f"{msg_prefix} dimensions of 'logits' must be in [2, 4], but got "
f"dimension of 'logits' {logits_nidm}.")
if labels_ndim < 2 or labels_ndim > 4:
raise ValueError(f"{msg_prefix} dimensions of 'labels' must be in [2, 4], but got "
f"dimension of 'labels' {labels_ndim}.")
if logits_nidm != labels_ndim:
raise ValueError(f"{msg_prefix} dimensions of 'logits' and 'labels' must be equal, but got "
f"dimension of 'logits' {logits_nidm} and dimension of 'labels' {labels_ndim}.")
@constexpr
def _check_channel_and_shape(logits, labels, prime_name=None):
'''Internal function, used to check whether the channels or shape of logits and labels meets the requirements.'''
msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The"
if logits == 1:
raise ValueError(f"{msg_prefix} 'logits'.shape[1] cannot be one, but got {logits}.")
if labels not in (1, logits):
raise ValueError(f"{msg_prefix} 'labels'.shape[1] must be one or equal to 'logits'.shape[1]: {logits}, "
f"but got {labels}.")
@constexpr @constexpr
def _check_input_dtype(labels_dtype, cls_name): def _check_input_dtype(labels_dtype, cls_name):
"""Internal function, used to check whether the data type of labels meets the requirements.""" """Internal function, used to check whether the data type of labels meets the requirements."""
@ -1814,8 +1752,6 @@ class FocalLoss(LossBase):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name)
labelss = labels labelss = labels
_check_ndim(logits.ndim, labelss.ndim, self.cls_name)
_check_channel_and_shape(logits.shape[1], labelss.shape[1], self.cls_name)
_check_input_dtype(self.dtype(labelss), self.cls_name) _check_input_dtype(self.dtype(labelss), self.cls_name)
if logits.ndim > 2: if logits.ndim > 2:

View File

@ -150,7 +150,8 @@ def test_cummin():
def construct(self, a): def construct(self, a):
return self.func(a, 0) return self.func(a, 0)
a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220], ms.float32) a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680,
-1.3298, -0.4220], ms.float32)
expect = Tensor(np.array( expect = Tensor(np.array(
[-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298]), ms.float32) [-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298]), ms.float32)
net = Net() net = Net()
@ -235,6 +236,32 @@ def test_calculate_expert_capacity():
assert net(10.1, 2.0, 3.3, 4) == 17 assert net(10.1, 2.0, 3.3, 4) == 17
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_unsqueeze():
"""
Feature: unsqueeze func
Description: Verify the result of unsqueeze
Expectation: success
"""
from mindspore._extends.parse.standard_method import unsqueeze
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = unsqueeze
def construct(self, x, dim):
return self.func(x, dim)
x = Tensor([[4.0, 9.0, 2.0, 10.0]]).astype("float32")
net = Net()
x2 = net(x, 0)
assert x2.shape == (1, 1, 4)
@pytest.mark.level1 @pytest.mark.level1
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@ -248,6 +275,7 @@ def test_infer_out_shape():
""" """
from mindspore.numpy.utils_const import _infer_out_shape from mindspore.numpy.utils_const import _infer_out_shape
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
@ -259,6 +287,32 @@ def test_infer_out_shape():
assert net((5,), (6, 1), (7, 1, 5), (8, 1, 6, 1)) == (8, 7, 6, 5) assert net((5,), (6, 1), (7, 1, 5), (8, 1, 6, 1)) == (8, 7, 6, 5)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_tensor_bool():
"""
Feature: tensor_bool func
Description: Verify the result of tensor_bool
Expectation: success
"""
from mindspore._extends.parse.standard_method import tensor_bool
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = tensor_bool
def construct(self, x):
return self.func(x)
x = Tensor([4.0]).astype("float32")
net = Net()
x2 = net(x)
assert bool(x2) is True
@pytest.mark.level1 @pytest.mark.level1
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@ -281,3 +335,543 @@ def test_canonicalize_axis():
return self.func(axis, ndim) return self.func(axis, ndim)
net = Net() net = Net()
assert net(0, 2) == 0 assert net(0, 2) == 0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_top_k():
"""
Feature: top_k func
Description: Verify the result of top_k
Expectation: success
"""
from mindspore._extends.parse.standard_method import top_k
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = top_k
def construct(self, x, k):
return self.func(x, k)
x = Tensor([4.0, 9.0, 2.0, 10.0]).astype("float32")
net = Net()
output = net(x, 3)
expect = ([10.0, 9.0, 4.0], [3, 1, 0])
assert np.allclose(output[0].asnumpy(), expect[0])
assert np.allclose(output[1].asnumpy(), expect[1])
@pytest.mark.level1
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bernoulli():
"""
Feature: bernoulli func
Description: Verify the result of bernoulli
Expectation: success
"""
from mindspore._extends.parse.standard_method import bernoulli
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = bernoulli
def construct(self, x):
return self.func(x)
x = Tensor(4).astype("float32")
print(x)
net = Net()
output = net(x)
print(output)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_view():
"""
Feature: view func
Description: Verify the result of view
Expectation: success
"""
from mindspore._extends.parse.standard_method import view
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = view
def construct(self, x, y):
return self.func(x, y)
x = Tensor(np.array([[1, 2, 3], [2, 3, 4]], dtype=np.float32))
net = Net()
output = net(x, (3, 2))
expect = [[1., 2.], [3., 2.], [3., 4.]]
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_reshape():
"""
Feature: reshape func
Description: Verify the result of reshape
Expectation: success
"""
from mindspore._extends.parse.standard_method import reshape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = reshape
def construct(self, x, y):
return self.func(x, y)
x = Tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=ms.float32)
expect = [[-0.1, 0.3], [3.6, 0.4], [0.5, -3.2]]
net = Net()
output = net(x, (3, 2))
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_swapaxes():
"""
Feature: swapaxes func
Description: Verify the result of swapaxes
Expectation: success
"""
from mindspore._extends.parse.standard_method import swapaxes
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = swapaxes
def construct(self, x, y, z):
return self.func(x, y, z)
x = Tensor(np.ones((2, 3, 4), dtype=np.float32))
expect = [[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]]]
net = Net()
output = net(x, 0, 2)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_squeeze():
"""
Feature: squeeze func
Description: Verify the result of squeeze
Expectation: success
"""
from mindspore._extends.parse.standard_method import squeeze
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = squeeze
def construct(self, x, y):
return self.func(x, y)
x = Tensor(np.ones((1, 2, 2, 1), dtype=np.float32))
expect = [[1., 1.],
[1., 1.]]
net = Net()
output = net(x, (0, 3))
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_argmax():
"""
Feature: argmax func
Description: Verify the result of argmax
Expectation: success
"""
from mindspore._extends.parse.standard_method import argmax
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = argmax
def construct(self, x, y):
return self.func(x, y)
a = Tensor(np.arange(10, 16).reshape(2, 3).astype("float32"))
net = Net()
output = net(a, None)
expect = 5
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_diagonal():
"""
Feature: diagonal func
Description: Verify the result of diagonal
Expectation: success
"""
from mindspore._extends.parse.standard_method import diagonal
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = diagonal
def construct(self, x):
return self.func(x)
a = Tensor(np.arange(4).reshape(2, 2))
expect = [0, 3]
net = Net()
output = net(a)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_take():
"""
Feature: take func
Description: Verify the result of take
Expectation: success
"""
from mindspore._extends.parse.standard_method import take
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = take
def construct(self, x, y):
return self.func(x, y)
a = Tensor(np.array([4, 3, 5, 7, 6, 8]))
indices = Tensor(np.array([0, 1, 4]))
expect = [4, 3, 6]
net = Net()
output = net(a, indices)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_choose():
"""
Feature: choose func
Description: Verify the result of choose
Expectation: success
"""
from mindspore._extends.parse.standard_method import choose
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = choose
def construct(self, x, y):
return self.func(x, y)
a = Tensor(np.array([2, 3, 1, 0]))
choices = [[0, 1, 2, 3], [10, 11, 12, 13],
[20, 21, 22, 23], [30, 31, 32, 33]]
expect = [20, 31, 12, 3]
net = Net()
output = net(a, choices)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_var():
"""
Feature: var func
Description: Verify the result of var
Expectation: success
"""
from mindspore._extends.parse.standard_method import var
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = var
def construct(self, x,):
return self.func(x,)
a = Tensor(np.array([1., 2., 3., 4.]))
expect = 1.25
net = Net()
output = net(a)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_searchsorted():
"""
Feature: searchsorted func
Description: Verify the result of searchsorted
Expectation: success
"""
from mindspore._extends.parse.standard_method import searchsorted
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = searchsorted
def construct(self, x, y):
return self.func(x, y)
a = Tensor(np.array([1., 2., 3., 4., 5.]))
expect = 2
net = Net()
output = net(a, 3)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_is_equal_one():
"""
Feature: _is_equal_one func
Description: Verify the result of _is_equal_one
Expectation: success
"""
from mindspore.nn.layer.basic import _is_equal_one
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _is_equal_one
def construct(self, x):
return self.func(x)
a = Tensor(np.array([1., 2., 3., 4., 5.]))
expect = False
net = Net()
output = net(a)
assert output == expect
x = Tensor(np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[10, 11, 12, 13],
[14, 15, 16, 17]]))
expect = [[0, 0, 0, 0],
[5, 0, 0, 0],
[10, 11, 0, 0],
[14, 15, 16, 0]]
net = nn.Tril()
output = net(x, -1)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_triu():
"""
Feature: Triu func
Description: Verify the result of Triu
Expectation: success
"""
x = Tensor(np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[10, 11, 12, 13],
[14, 15, 16, 17]]))
expect = [[0, 0, 3, 4],
[0, 0, 0, 8],
[0, 0, 0, 0],
[0, 0, 0, 0]]
net = nn.Triu()
output = net(x, 2)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resizebilinear():
"""
Feature: ResizeBilinear func
Description: Verify the result of ResizeBilinear
Expectation: success
"""
x = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], ms.float32)
expect = [[[[1., 1.8, 2.6, 3.4, 4.],
[2.6, 3.4, 4.2, 5., 5.6],
[4.2, 5., 5.8, 6.6000004, 7.2],
[5., 5.8, 6.6, 7.4, 8.],
[5., 5.8, 6.6, 7.4, 8.]]]]
resize_bilinear = nn.ResizeBilinear()
output = resize_bilinear(x, size=(5, 5))
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_matrixdiag():
"""
Feature: MatrixDiag func
Description: Verify the result of MatrixDiag
Expectation: success
"""
x = Tensor(np.array([[1, -1, 1], [1, -1, 1]]), ms.float32)
matrix_diag = nn.MatrixDiag()
output = matrix_diag(x)
print(output)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_embeddinglookup():
"""
Feature: EmbeddingLookup func
Description: Verify the result of EmbeddingLookup
Expectation: no error
"""
# _check_input_2d, _check_input_dtype
# mindspore/python/mindspore/nn/layer/embedding.py
input_indices = Tensor(np.array([[1, 0], [3, 2]]), ms.int32)
output = nn.EmbeddingLookup(4, 2, max_norm=0.002)(input_indices)
assert output is not None
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_centralcrop():
"""
Feature: CentralCrop func
Description: Verify the result of CentralCrop
Expectation: success
"""
net = nn.CentralCrop(central_fraction=0.5)
image = Tensor(np.random.random((4, 3, 4, 4)), ms.float32)
expect = (4, 3, 2, 2)
output = net(image)
assert np.allclose(output.shape, expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_matmul_2():
"""
Feature: MatMul func
Description: Verify the result of MatMul
Expectation: success
"""
net = nn.MatMul()
a = Tensor(np.arange(1, 17).reshape((4, 4)), ms.float32)
b = Tensor(np.arange(1, 17).reshape((4, 4)), ms.float32)
expect = [[90., 100., 110., 120.],
[202., 228., 254., 280.],
[314., 356., 398., 440.],
[426., 484., 542., 600.]]
output = net(a, b)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_reflectionpad1d():
"""
Feature: ReflectionPad1d func
Description: Verify the result of ReflectionPad1d
Expectation: success
"""
from mindspore.nn import ReflectionPad1d
x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32))
padding = (3, 1)
pad1d = ReflectionPad1d(padding)
expect = [[[3., 2., 1., 0., 1., 2., 3., 2.],
[7., 6., 5., 4., 5., 6., 7., 6.]]]
output = pad1d(x)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_constantpad1d():
"""
Feature: ConstantPad1d func
Description: Verify the result of ConstantPad1d
Expectation: success
"""
from mindspore.nn import ConstantPad1d
x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32)
x = Tensor(x)
expect = [[[[1., 1., 1., 1., 0.5],
[1., 1., 1., 1., 0.5],
[1., 1., 1., 1., 0.5]],
[[1., 1., 1., 1., 0.5],
[1., 1., 1., 1., 0.5],
[1., 1., 1., 1., 0.5]]]]
padding = (0, 1)
value = 0.5
pad1d = ConstantPad1d(padding, value)
output = pad1d(x)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,54 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test control ops """
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import nn
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.common.parameter import Parameter, ParameterTuple
grad_by_list = C.GradOperation(get_by_list=True)
grad_all = C.GradOperation(get_all=True)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_switch_layer_with_single_prim():
"""
Feature: SwitchLayer
Description: run switch layer case
Expectation: success.
"""
class SwitchLayerCell(nn.Cell):
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (nn.ReLU(), nn.ReLU())
self.z3 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
def construct(self, index, x):
ret = self.layers[index](x) * self.z3
return ret
index = Tensor(0, dtype=mstype.int32)
net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))

View File

@ -16,7 +16,7 @@ import numpy as np
import pytest import pytest
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor, context
class Net(nn.Cell): class Net(nn.Cell):
@ -24,6 +24,16 @@ class Net(nn.Cell):
return x.tril(diagonal) return x.tril(diagonal)
class TrilNet(nn.Cell):
def __init__(self):
super(TrilNet, self).__init__()
self.tril = nn.Tril()
def construct(self, value, k):
return self.tril(value, k)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu @pytest.mark.platform_arm_cpu
@ -44,3 +54,99 @@ def test_tril(mode):
output = net(x) output = net(x)
expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32) expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32)
assert np.allclose(output.asnumpy(), expect_output) assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
def test_tril_0(mode):
"""
Feature: test_tril
Description: Verify the result of test_tril
Expectation: success
"""
value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
net = TrilNet()
out = net(value, 0)
assert np.sum(out.asnumpy()) == 34
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
def test_tril_1(mode):
"""
Feature: test_tril_1
Description: Verify the result of test_tril_1
Expectation: success
"""
value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
net = TrilNet()
out = net(value, 1)
assert np.sum(out.asnumpy()) == 42
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
def test_tril_2(mode):
"""
Feature: test_tril_2
Description: Verify the result of test_tril_2
Expectation: success
"""
value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
net = TrilNet()
out = net(value, -1)
assert np.sum(out.asnumpy()) == 19
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
def test_tril_parameter(mode):
"""
Feature: test_tril_parameter
Description: Verify the result of test_tril_parameter
Expectation: success
"""
net = TrilNet()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 0)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
def test_tril_parameter_1(mode):
"""
Feature: test_tril_parameter_1
Description: Verify the result of test_tril_parameter_1
Expectation: success
"""
net = TrilNet()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 0)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
def test_tril_parameter_2(mode):
"""
Feature: test_tril_parameter_2
Description: Verify the result of test_tril_parameter_2
Expectation: success
"""
net = TrilNet()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 0)

View File

@ -1,3 +1,17 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
@ -6,8 +20,8 @@ from mindspore import context
class Net(nn.Cell): class Net(nn.Cell):
def construct(self, x): def construct(self, x, k):
return x.triu(diagonal=1) return x.triu(diagonal=k)
@pytest.mark.level0 @pytest.mark.level0
@ -16,10 +30,10 @@ class Net(nn.Cell):
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) @pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_subtract(mode): def test_triu_0(mode):
""" """
Feature: tensor.subtract() Feature: test_triu_0
Description: Verify the result of tensor.subtract Description: Verify the result of test_triu_0
Expectation: success Expectation: success
""" """
context.set_context(mode=mode) context.set_context(mode=mode)
@ -28,9 +42,66 @@ def test_subtract(mode):
[5, 6, 7, 8], [5, 6, 7, 8],
[10, 11, 12, 13], [10, 11, 12, 13],
[14, 15, 16, 17]])) [14, 15, 16, 17]]))
output = net(x) output = net(x, 1)
expected = np.array([[0, 2, 3, 4], expected = np.array([[0, 2, 3, 4],
[0, 0, 7, 8], [0, 0, 7, 8],
[0, 0, 0, 13], [0, 0, 0, 13],
[0, 0, 0, 0]]) [0, 0, 0, 0]])
assert np.array_equal(output.asnumpy(), expected) assert np.array_equal(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_triu_1(mode):
"""
Feature: test_triu_1
Description: test_triu_1
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
output = net(x, 0)
assert np.sum(output.asnumpy()) == 26
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_triu_2(mode):
"""
Feature: test_triu_2
Description: test_triu_2
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
output = net(x, 1)
assert np.sum(output.asnumpy()) == 11
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_triu_3(mode):
"""
Feature: test_triu_3
Description: test_triu_3
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
output = net(x, -1)
assert np.sum(output.asnumpy()) == 38

View File

@ -384,8 +384,7 @@ def test_offload_dim_check():
dataset = dataset.map(operations=[C.Decode()], input_columns="image") dataset = dataset.map(operations=[C.Decode()], input_columns="image")
dataset = dataset.map(operations=[C.HWC2CHW()], input_columns="image", offload=True) dataset = dataset.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
error_msg = "For HwcToChw offload operation, the dimension of input should be 4, but got 3." with pytest.raises(ValueError):
with pytest.raises(ValueError, match=error_msg):
for (_, _) in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True): for (_, _) in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
continue continue

View File

@ -82,14 +82,14 @@ def test_check_multifield_embedding_false_type_field_id():
@non_graph_engine @non_graph_engine
def test_check_multifield_embedding_false_input_shape(): def test_check_multifield_embedding_false_input_shape():
with pytest.raises(ValueError): with pytest.raises(TypeError):
compile_multi_field_embedding((8,), (8, 200), (8, 200), compile_multi_field_embedding((8,), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int16) dtype.int16, dtype.float32, dtype.int16)
@non_graph_engine @non_graph_engine
def test_check_multifield_embedding_false_value_shape(): def test_check_multifield_embedding_false_value_shape():
with pytest.raises(ValueError): with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8,), (8, 200), compile_multi_field_embedding((8, 200), (8,), (8, 200),
dtype.int16, dtype.float32, dtype.int16) dtype.int16, dtype.float32, dtype.int16)

View File

@ -85,21 +85,3 @@ def test_psnr_different_dtype():
net = PSNRNet() net = PSNRNet()
with pytest.raises(TypeError): with pytest.raises(TypeError):
_cell_graph_executor.compile(net, img1, img2) _cell_graph_executor.compile(net, img1, img2)
def test_psnr_invalid_5d_input():
shape_1 = (8, 3, 16, 16)
shape_2 = (8, 3, 8, 8)
invalid_shape = (8, 3, 16, 16, 1)
img1 = Tensor(np.random.random(shape_1))
invalid_img1 = Tensor(np.random.random(invalid_shape))
img2 = Tensor(np.random.random(shape_2))
invalid_img2 = Tensor(np.random.random(invalid_shape))
net = PSNRNet()
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, invalid_img1, img2)
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, img1, invalid_img2)
with pytest.raises(ValueError):
_cell_graph_executor.compile(net, invalid_img1, invalid_img2)

View File

@ -1,108 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test nn.Tril()
"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_tril():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
tril = nn.Tril()
return tril(self.value, 0)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 34
def test_tril_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
tril = nn.Tril()
return tril(self.value, 1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 42
def test_tril_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
tril = nn.Tril()
return tril(self.value, -1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 19
def test_tril_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
tril = nn.Tril()
return tril(x, 0)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_tril_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
tril = nn.Tril()
return tril(x, 1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_tril_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
tril = nn.Tril()
return tril(x, -1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

View File

@ -1,102 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test nn.Triu()
"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
class TriuNet(nn.Cell):
def __init__(self):
super(TriuNet, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
triu = nn.Triu()
return triu(self.value, 0)
def test_triu():
"""
Feature: None
Description: test TriuNet with vm backend
Expectation: None
"""
net = TriuNet()
out = net()
assert np.sum(out.asnumpy()) == 26
def test_triu_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
triu = nn.Triu()
return triu(self.value, 1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 11
def test_triu_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
triu = nn.Triu()
return triu(self.value, -1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 38
def test_triu_parameter():
class Net(nn.Cell):
def construct(self, x):
triu = nn.Triu()
return triu(x, 0)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_triu_parameter_1():
class Net(nn.Cell):
def construct(self, x):
triu = nn.Triu()
return triu(x, 1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_triu_parameter_2():
class Net(nn.Cell):
def construct(self, x):
triu = nn.Triu()
return triu(x, -1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

View File

@ -533,26 +533,6 @@ def test_parser_switch_layer_inputs_tuple():
back_out = back_net(i, input1, input2, grad) back_out = back_net(i, input1, input2, grad)
def test_switch_layer_with_single_prim():
class SwitchLayerCell(nn.Cell):
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (nn.ReLU(), nn.ReLU())
self.z3 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
def construct(self, index, x):
ret = self.layers[index](x) * self.z3
return ret
index = Tensor(0, dtype=mstype.int32)
net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_switch_layer_env_eliminate(): def test_switch_layer_env_eliminate():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
@ -1003,7 +983,6 @@ def test_recursive_call():
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.fc = nn.Dense(10, 10) # padding=0 self.fc = nn.Dense(10, 10) # padding=0
# self.net2 = Net2()
def construct(self, x): def construct(self, x):
net2 = Net2() net2 = Net2()
@ -1037,8 +1016,6 @@ def test_recursive_call():
# grad for Tensor(Bool) input and eliminate AddN(MakeTuple(Xs, zeros_like(Bool))) # grad for Tensor(Bool) input and eliminate AddN(MakeTuple(Xs, zeros_like(Bool)))
def test_grad_tensor_bool(): def test_grad_tensor_bool():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z): def construct(self, x, y, z):
out = z out = z