forked from mindspore-Ecosystem/mindspore
!47543 modify constexpr
Merge pull request !47543 from huoxinyou/0105_constexpr
This commit is contained in:
commit
386c33298d
|
@ -214,6 +214,100 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
|
||||||
return arg_value
|
return arg_value
|
||||||
|
|
||||||
|
|
||||||
|
def check_reshape_shp(shp):
|
||||||
|
"""Check the shape argument for tensor.reshape"""
|
||||||
|
if len(shp) == 1:
|
||||||
|
new_shape = shp[0]
|
||||||
|
if isinstance(new_shape, int):
|
||||||
|
return shp
|
||||||
|
if isinstance(new_shape, list):
|
||||||
|
new_shape = tuple(new_shape)
|
||||||
|
return new_shape
|
||||||
|
return shp
|
||||||
|
|
||||||
|
|
||||||
|
def check_swapaxes_axis(axes, ndim):
|
||||||
|
"""Check all the axes argument for tensor.swapaxes"""
|
||||||
|
if isinstance(axes, int):
|
||||||
|
return axes % ndim
|
||||||
|
if isinstance(axes, (tuple, list)):
|
||||||
|
tmp = []
|
||||||
|
for x in axes:
|
||||||
|
tmp.append((x + ndim) % ndim)
|
||||||
|
axes = tuple(tmp)
|
||||||
|
return axes
|
||||||
|
return axes
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shape_for_squeeze(shape, axes):
|
||||||
|
"""
|
||||||
|
yield squeezed shape based on the axes
|
||||||
|
"""
|
||||||
|
new_shape = []
|
||||||
|
ndim = len(shape)
|
||||||
|
if isinstance(axes, int):
|
||||||
|
axes = [axes]
|
||||||
|
elif isinstance(axes, (list, tuple)):
|
||||||
|
axes = set(axes)
|
||||||
|
for idx, s in enumerate(shape):
|
||||||
|
if s != 1 or (idx not in axes) and (idx - ndim not in axes):
|
||||||
|
new_shape.append(s)
|
||||||
|
return tuple(new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def check_axis_in_range(axis, ndim):
|
||||||
|
"""Checks axes are with the bounds of ndim"""
|
||||||
|
return (axis + ndim) % ndim
|
||||||
|
|
||||||
|
|
||||||
|
def check_axis_valid(axes, ndim):
|
||||||
|
"""
|
||||||
|
check the validation of axis and return
|
||||||
|
"""
|
||||||
|
if axes is None:
|
||||||
|
axes = tuple(range(ndim))
|
||||||
|
return axes
|
||||||
|
if isinstance(axes, (tuple, list)):
|
||||||
|
tmp = []
|
||||||
|
for x in axes:
|
||||||
|
tmp.append((x + ndim) % ndim)
|
||||||
|
axes = tuple(tmp)
|
||||||
|
return axes
|
||||||
|
return (axes % ndim,)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_out_shape(*shapes):
|
||||||
|
"""
|
||||||
|
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
||||||
|
"""
|
||||||
|
shape_out = list()
|
||||||
|
max_len = ms_max([len(it) for it in shapes])
|
||||||
|
for i in range(max_len):
|
||||||
|
items = [it[i-(max_len-len(it))] if i - (max_len - len(it))
|
||||||
|
>= 0 else 1 for it in shapes]
|
||||||
|
max_size = 0 if 0 in items else ms_max(items)
|
||||||
|
shape_out.append(max_size)
|
||||||
|
return tuple(shape_out)
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_canonicalize_axes(axes, ndim):
|
||||||
|
"""Check whether the types and values of input axes are valid."""
|
||||||
|
axes = axes if isinstance(axes, tuple) else (axes,)
|
||||||
|
new_axes = ()
|
||||||
|
for ax in axes:
|
||||||
|
ax = ax if ax >= 0 else ax + ndim
|
||||||
|
new_axes += (ax,)
|
||||||
|
return new_axes
|
||||||
|
|
||||||
|
|
||||||
|
def get_log2_size(size):
|
||||||
|
"""Get log2 size"""
|
||||||
|
log2_res = F.log2(F.cast(Tensor(size), mstype.float32))
|
||||||
|
ceil_res = F.ceil(log2_res)
|
||||||
|
cast_res = F.cast(ceil_res, mstype.int64)
|
||||||
|
return cast_res
|
||||||
|
|
||||||
|
|
||||||
class Validator:
|
class Validator:
|
||||||
"""validator for checking input parameters"""
|
"""validator for checking input parameters"""
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,8 @@ from mindspore.ops.composite.base import _append, _insert, _pop, _list_clear, _r
|
||||||
_extend, _dict_clear, _haskey, _update, _fromkeys
|
_extend, _dict_clear, _haskey, _update, _fromkeys
|
||||||
|
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import check_is_number
|
from ..._checkparam import check_is_number, check_reshape_shp, check_swapaxes_axis, prepare_shape_for_squeeze, \
|
||||||
|
check_axis_in_range, check_axis_valid, infer_out_shape, check_and_canonicalize_axes, get_log2_size
|
||||||
from ...ops import functional as F
|
from ...ops import functional as F
|
||||||
from ...ops import operations as P
|
from ...ops import operations as P
|
||||||
from ...ops.composite import tail, MultitypeFuncGraph, env_get, hyper_add, \
|
from ...ops.composite import tail, MultitypeFuncGraph, env_get, hyper_add, \
|
||||||
|
@ -41,7 +42,8 @@ from ...ops.primitive import constexpr
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...ops.operations._sequence_ops import ListAppend
|
from ...ops.operations._sequence_ops import ListAppend
|
||||||
|
|
||||||
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
|
__all__ = ['MultitypeFuncGraph', 'env_get',
|
||||||
|
'hyper_add', 'zeros_like', 'ones_like']
|
||||||
|
|
||||||
shape_ = P.Shape()
|
shape_ = P.Shape()
|
||||||
dtype_ = P.DType()
|
dtype_ = P.DType()
|
||||||
|
@ -519,7 +521,7 @@ def reshape(x, *shape):
|
||||||
[ 3.6 0.4]
|
[ 3.6 0.4]
|
||||||
[ 0.5 -3.2]]
|
[ 0.5 -3.2]]
|
||||||
"""
|
"""
|
||||||
new_shape = check_reshape_shp_const(shape)
|
new_shape = check_reshape_shp(shape)
|
||||||
return F.reshape(x, new_shape)
|
return F.reshape(x, new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -709,7 +711,7 @@ def swapaxes(x, axis1, axis2):
|
||||||
>>> print(output.shape)
|
>>> print(output.shape)
|
||||||
(4,3,2)
|
(4,3,2)
|
||||||
"""
|
"""
|
||||||
axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim)
|
axis1, axis2 = check_swapaxes_axis((axis1, axis2), x.ndim)
|
||||||
|
|
||||||
if axis1 == axis2:
|
if axis1 == axis2:
|
||||||
return x
|
return x
|
||||||
|
@ -757,7 +759,7 @@ def squeeze(x, axis=None):
|
||||||
if axis is None:
|
if axis is None:
|
||||||
return F.squeeze(x)
|
return F.squeeze(x)
|
||||||
# yield squeezed shape based on the axes
|
# yield squeezed shape based on the axes
|
||||||
new_shape = prepare_shape_for_squeeze_const(shape, axis)
|
new_shape = prepare_shape_for_squeeze(shape, axis)
|
||||||
return F.reshape(x, new_shape)
|
return F.reshape(x, new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@ -869,7 +871,7 @@ def argmin(x, axis=None, keepdims=False):
|
||||||
axis = 0
|
axis = 0
|
||||||
is_axis_none = True
|
is_axis_none = True
|
||||||
else:
|
else:
|
||||||
axis = check_axis_in_range_const(axis, F.rank(x))
|
axis = check_axis_in_range(axis, F.rank(x))
|
||||||
out = P.Argmin(axis)(x)
|
out = P.Argmin(axis)(x)
|
||||||
if keepdims and not is_axis_none:
|
if keepdims and not is_axis_none:
|
||||||
out = expand_dims(out, axis)
|
out = expand_dims(out, axis)
|
||||||
|
@ -894,7 +896,7 @@ def median(x, global_median, axis=0, keep_dims=False):
|
||||||
When attr `global_median` is True, the second output Tensor value is meaningless.
|
When attr `global_median` is True, the second output Tensor value is meaningless.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
check_axis_in_range_const(axis, x.ndim)
|
check_axis_in_range(axis, x.ndim)
|
||||||
median_ = Median(global_median, axis, keep_dims)
|
median_ = Median(global_median, axis, keep_dims)
|
||||||
return median_(x)
|
return median_(x)
|
||||||
|
|
||||||
|
@ -968,7 +970,7 @@ def cumsum(x, axis=None, dtype=None):
|
||||||
if axis is None:
|
if axis is None:
|
||||||
x = x.ravel()
|
x = x.ravel()
|
||||||
axis = 0
|
axis = 0
|
||||||
check_axis_in_range_const(axis, x.ndim)
|
check_axis_in_range(axis, x.ndim)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
dtype = check_astype_dtype_const(dtype)
|
dtype = check_astype_dtype_const(dtype)
|
||||||
if original_dtype != dtype:
|
if original_dtype != dtype:
|
||||||
|
@ -1350,7 +1352,8 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
|
||||||
"""
|
"""
|
||||||
ndim = x.ndim
|
ndim = x.ndim
|
||||||
if ndim < 2:
|
if ndim < 2:
|
||||||
const_utils.raise_value_error('diagonal requires an array of at least two dimensions')
|
const_utils.raise_value_error(
|
||||||
|
'diagonal requires an array of at least two dimensions')
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
|
|
||||||
axes = check_axis_valid((axis1, axis2), ndim)
|
axes = check_axis_valid((axis1, axis2), ndim)
|
||||||
|
@ -1602,14 +1605,15 @@ def take(x, indices, axis=None, mode='clip'):
|
||||||
[4 3 6]
|
[4 3 6]
|
||||||
"""
|
"""
|
||||||
if mode not in ('raise', 'wrap', 'clip'):
|
if mode not in ('raise', 'wrap', 'clip'):
|
||||||
const_utils.raise_value_error('raise should be one of "raise", "wrap", or "clip"')
|
const_utils.raise_value_error(
|
||||||
|
'raise should be one of "raise", "wrap", or "clip"')
|
||||||
if axis is None:
|
if axis is None:
|
||||||
a = x.ravel()
|
a = x.ravel()
|
||||||
axis = 0
|
axis = 0
|
||||||
else:
|
else:
|
||||||
a = x
|
a = x
|
||||||
ndim = a.ndim
|
ndim = a.ndim
|
||||||
axis = check_axis_in_range_const(axis, ndim)
|
axis = check_axis_in_range(axis, ndim)
|
||||||
|
|
||||||
shape_a = a.shape
|
shape_a = a.shape
|
||||||
shape_indices = indices.shape
|
shape_indices = indices.shape
|
||||||
|
@ -1690,12 +1694,14 @@ def choose(x, choices, mode='clip'):
|
||||||
# adjusts dtype for F.tensor_mul and F.gather_nd
|
# adjusts dtype for F.tensor_mul and F.gather_nd
|
||||||
a = a.astype(mstype.int32)
|
a = a.astype(mstype.int32)
|
||||||
choices = choices.astype(mstype.int32)
|
choices = choices.astype(mstype.int32)
|
||||||
a = compile_utils.check_indices(choices.shape[0], a, mode, allow_negative_index=False)
|
a = compile_utils.check_indices(
|
||||||
|
choices.shape[0], a, mode, allow_negative_index=False)
|
||||||
|
|
||||||
grids = []
|
grids = []
|
||||||
ndim = len(a.shape)
|
ndim = len(a.shape)
|
||||||
for i in range(ndim):
|
for i in range(ndim):
|
||||||
dim_grid = const_utils.make_tensor(F.make_range(a.shape[i]), mstype.int32)
|
dim_grid = const_utils.make_tensor(
|
||||||
|
F.make_range(a.shape[i]), mstype.int32)
|
||||||
dim_shape = expanded_shape(ndim, a.shape[i], i)
|
dim_shape = expanded_shape(ndim, a.shape[i], i)
|
||||||
dim_grid = P.BroadcastTo(a.shape)(dim_grid.reshape(dim_shape))
|
dim_grid = P.BroadcastTo(a.shape)(dim_grid.reshape(dim_shape))
|
||||||
grids.append(dim_grid)
|
grids.append(dim_grid)
|
||||||
|
@ -1740,7 +1746,8 @@ def searchsorted(x, v, side='left', sorter=None):
|
||||||
shape = v.shape
|
shape = v.shape
|
||||||
if sorter is not None:
|
if sorter is not None:
|
||||||
if sorter.ndim != 1 or sorter.size != a.size:
|
if sorter.ndim != 1 or sorter.size != a.size:
|
||||||
const_utils.raise_value_error('sorter must be 1-D array with the same size as `a`')
|
const_utils.raise_value_error(
|
||||||
|
'sorter must be 1-D array with the same size as `a`')
|
||||||
sorter = const_utils.make_tensor(sorter)
|
sorter = const_utils.make_tensor(sorter)
|
||||||
sorter = sorter.reshape(sorter.shape + (1,))
|
sorter = sorter.reshape(sorter.shape + (1,))
|
||||||
a = F.gather_nd(a, sorter)
|
a = F.gather_nd(a, sorter)
|
||||||
|
@ -1748,12 +1755,14 @@ def searchsorted(x, v, side='left', sorter=None):
|
||||||
i = F.fill(mstype.int32, shape, 0)
|
i = F.fill(mstype.int32, shape, 0)
|
||||||
j = F.fill(mstype.int32, shape, a.size)
|
j = F.fill(mstype.int32, shape, a.size)
|
||||||
|
|
||||||
sort_range = F.make_range(get_log2_size(F.shape_mul(a.shape) + 1))
|
loop_num = get_log2_size(F.shape_mul(a.shape) + 1)
|
||||||
for _ in sort_range:
|
index = Tensor([0])
|
||||||
|
while index < loop_num:
|
||||||
mid = (i - F.neg_tensor(j)) // 2
|
mid = (i - F.neg_tensor(j)) // 2
|
||||||
mask = less_op(v, F.gather_nd(a, mid.reshape(mid.shape + (1,))))
|
mask = less_op(v, F.gather_nd(a, mid.reshape(mid.shape + (1,))))
|
||||||
i = F.select(mask, i, mid)
|
i = F.select(mask, i, mid)
|
||||||
j = F.select(mask, mid, j)
|
j = F.select(mask, mid, j)
|
||||||
|
index += 1
|
||||||
return j
|
return j
|
||||||
|
|
||||||
|
|
||||||
|
@ -1788,7 +1797,8 @@ def fill(x, value):
|
||||||
"""
|
"""
|
||||||
if value is None:
|
if value is None:
|
||||||
if x.dtype not in (mstype.float16, mstype.float32, mstype.float64):
|
if x.dtype not in (mstype.float16, mstype.float32, mstype.float64):
|
||||||
const_utils.raise_type_error("If None is used as value, the original Tensor's dtype must be float.")
|
const_utils.raise_type_error(
|
||||||
|
"If None is used as value, the original Tensor's dtype must be float.")
|
||||||
value = nan_tensor
|
value = nan_tensor
|
||||||
return F.tile(value, x.shape).astype(x.dtype)
|
return F.tile(value, x.shape).astype(x.dtype)
|
||||||
if not isinstance(value, (int, float, bool)):
|
if not isinstance(value, (int, float, bool)):
|
||||||
|
@ -1838,7 +1848,6 @@ def ptp(x, axis=None, keepdims=False):
|
||||||
if axis is None:
|
if axis is None:
|
||||||
axis = ()
|
axis = ()
|
||||||
else:
|
else:
|
||||||
check_axis_type(axis, True, True, False)
|
|
||||||
axis = check_axis_valid(axis, x.ndim)
|
axis = check_axis_valid(axis, x.ndim)
|
||||||
|
|
||||||
return x.max(axis, keepdims) - x.min(axis, keepdims)
|
return x.max(axis, keepdims) - x.min(axis, keepdims)
|
||||||
|
@ -2102,7 +2111,7 @@ def repeat(x, repeats, axis=None):
|
||||||
axis = 0
|
axis = 0
|
||||||
if not isinstance(axis, int):
|
if not isinstance(axis, int):
|
||||||
const_utils.raise_type_error('axes should be integers')
|
const_utils.raise_type_error('axes should be integers')
|
||||||
check_axis_in_range_const(axis, x.ndim)
|
check_axis_in_range(axis, x.ndim)
|
||||||
axis = axis + x.ndim if axis < 0 else axis
|
axis = axis + x.ndim if axis < 0 else axis
|
||||||
|
|
||||||
if len(repeats) == 1:
|
if len(repeats) == 1:
|
||||||
|
@ -2112,7 +2121,8 @@ def repeat(x, repeats, axis=None):
|
||||||
return repeat_elements(x, repeats, axis)
|
return repeat_elements(x, repeats, axis)
|
||||||
size = x.shape[axis]
|
size = x.shape[axis]
|
||||||
if len(repeats) != size:
|
if len(repeats) != size:
|
||||||
const_utils.raise_value_error('operands could not be broadcast together')
|
const_utils.raise_value_error(
|
||||||
|
'operands could not be broadcast together')
|
||||||
subs = P.Split(axis, size)(x)
|
subs = P.Split(axis, size)(x)
|
||||||
repeated_subs = []
|
repeated_subs = []
|
||||||
for sub_item, rep in zip(subs, repeats):
|
for sub_item, rep in zip(subs, repeats):
|
||||||
|
@ -2224,8 +2234,6 @@ def hasnext(it):
|
||||||
@constexpr
|
@constexpr
|
||||||
def constant_abs(x):
|
def constant_abs(x):
|
||||||
"""Returns the absolute value of the constant."""
|
"""Returns the absolute value of the constant."""
|
||||||
if x is None:
|
|
||||||
raise ValueError("For abs(), the input should be a constant or Tensor type.")
|
|
||||||
return abs(x)
|
return abs(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2241,7 +2249,8 @@ def constant_round(*data):
|
||||||
"""Returns the rounded value of the constant."""
|
"""Returns the rounded value of the constant."""
|
||||||
for x in data:
|
for x in data:
|
||||||
if x is None:
|
if x is None:
|
||||||
raise ValueError("For round(), the input should be a Tensor or 1-2 constants.")
|
raise ValueError(
|
||||||
|
"For round(), the input should be a Tensor or 1-2 constants.")
|
||||||
return round(*data)
|
return round(*data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2256,7 +2265,8 @@ def ms_round(*data):
|
||||||
return round_(x)
|
return round_(x)
|
||||||
return constant_round(x)
|
return constant_round(x)
|
||||||
if isinstance(data[0], Tensor) or isinstance(data[1], Tensor):
|
if isinstance(data[0], Tensor) or isinstance(data[1], Tensor):
|
||||||
const_utils.raise_type_error("When applying round() to tensor, only one tensor is supported as input.")
|
const_utils.raise_type_error(
|
||||||
|
"When applying round() to tensor, only one tensor is supported as input.")
|
||||||
return constant_round(*data)
|
return constant_round(*data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2274,9 +2284,11 @@ def str_func(*data):
|
||||||
return ''
|
return ''
|
||||||
data = data[0]
|
data = data[0]
|
||||||
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
||||||
const_utils.raise_type_error("str() does not support sparse tensor input.")
|
const_utils.raise_type_error(
|
||||||
|
"str() does not support sparse tensor input.")
|
||||||
if not F.isconstant(data):
|
if not F.isconstant(data):
|
||||||
const_utils.raise_type_error("str() does not support non-constant input.")
|
const_utils.raise_type_error(
|
||||||
|
"str() does not support non-constant input.")
|
||||||
return cast_to_str(data)
|
return cast_to_str(data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2294,13 +2306,15 @@ def bool_func(*data):
|
||||||
return False
|
return False
|
||||||
data = data[0]
|
data = data[0]
|
||||||
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
||||||
const_utils.raise_type_error("bool() does not support sparse tensor input.")
|
const_utils.raise_type_error(
|
||||||
|
"bool() does not support sparse tensor input.")
|
||||||
if isinstance(data, (Tensor, Tensor_)):
|
if isinstance(data, (Tensor, Tensor_)):
|
||||||
tensor_shape = F.shape(data)
|
tensor_shape = F.shape(data)
|
||||||
tensor_shape_len = len(tensor_shape)
|
tensor_shape_len = len(tensor_shape)
|
||||||
if tensor_shape_len == 0 or (tensor_shape_len == 1 and tensor_shape[0] == 1):
|
if tensor_shape_len == 0 or (tensor_shape_len == 1 and tensor_shape[0] == 1):
|
||||||
return data != 0
|
return data != 0
|
||||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
const_utils.raise_value_error(
|
||||||
|
"The truth value of an array with more than one element is ambiguous.")
|
||||||
if not F.isconstant(data):
|
if not F.isconstant(data):
|
||||||
if hasattr(data, "__bool__"):
|
if hasattr(data, "__bool__"):
|
||||||
return data.__bool__()
|
return data.__bool__()
|
||||||
|
@ -2329,9 +2343,11 @@ def int_func(*data):
|
||||||
return 0
|
return 0
|
||||||
target = data[0]
|
target = data[0]
|
||||||
if not F.isconstant(target):
|
if not F.isconstant(target):
|
||||||
const_utils.raise_type_error("int() does not support non-constant input.")
|
const_utils.raise_type_error(
|
||||||
|
"int() does not support non-constant input.")
|
||||||
if isinstance(target, (CSRTensor, COOTensor, RowTensorInner)):
|
if isinstance(target, (CSRTensor, COOTensor, RowTensorInner)):
|
||||||
const_utils.raise_type_error("int() does not support sparse tensor input.")
|
const_utils.raise_type_error(
|
||||||
|
"int() does not support sparse tensor input.")
|
||||||
return cast_to_int(*data)
|
return cast_to_int(*data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2351,9 +2367,11 @@ def float_func(*data):
|
||||||
return 0.0
|
return 0.0
|
||||||
data = data[0]
|
data = data[0]
|
||||||
if not F.isconstant(data):
|
if not F.isconstant(data):
|
||||||
const_utils.raise_type_error("float() does not support non-constant input.")
|
const_utils.raise_type_error(
|
||||||
|
"float() does not support non-constant input.")
|
||||||
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
||||||
const_utils.raise_type_error("float() does not support sparse tensor input.")
|
const_utils.raise_type_error(
|
||||||
|
"float() does not support sparse tensor input.")
|
||||||
return cast_to_float(data)
|
return cast_to_float(data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2366,10 +2384,12 @@ def list_func(*data):
|
||||||
return F.make_list()
|
return F.make_list()
|
||||||
data = data[0]
|
data = data[0]
|
||||||
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
||||||
const_utils.raise_type_error("list() does not support single sparse tensor input.")
|
const_utils.raise_type_error(
|
||||||
|
"list() does not support single sparse tensor input.")
|
||||||
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
|
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
|
||||||
data_type = F.typeof(data)
|
data_type = F.typeof(data)
|
||||||
const_utils.raise_type_error(str(data_type) + " object is not iterable.")
|
const_utils.raise_type_error(
|
||||||
|
str(data_type) + " object is not iterable.")
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
data = data.keys()
|
data = data.keys()
|
||||||
ret = F.make_list()
|
ret = F.make_list()
|
||||||
|
@ -2387,10 +2407,12 @@ def tuple_func(*data):
|
||||||
return F.make_tuple()
|
return F.make_tuple()
|
||||||
data = data[0]
|
data = data[0]
|
||||||
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
if isinstance(data, (CSRTensor, COOTensor, RowTensorInner)):
|
||||||
const_utils.raise_type_error("tuple() does not support single sparse tensor input.")
|
const_utils.raise_type_error(
|
||||||
|
"tuple() does not support single sparse tensor input.")
|
||||||
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
|
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
|
||||||
data_type = F.typeof(data)
|
data_type = F.typeof(data)
|
||||||
const_utils.raise_type_error(str(data_type) + " object is not iterable.")
|
const_utils.raise_type_error(
|
||||||
|
str(data_type) + " object is not iterable.")
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
data = data.keys()
|
data = data.keys()
|
||||||
ret = F.make_tuple()
|
ret = F.make_tuple()
|
||||||
|
@ -2419,7 +2441,8 @@ def get_max_min_data_len(*data):
|
||||||
if isinstance(data, (dict, list, tuple)):
|
if isinstance(data, (dict, list, tuple)):
|
||||||
len_data = len(data)
|
len_data = len(data)
|
||||||
else:
|
else:
|
||||||
const_utils.raise_type_error("max() or min() does not support the data type.")
|
const_utils.raise_type_error(
|
||||||
|
"max() or min() does not support the data type.")
|
||||||
return len_data
|
return len_data
|
||||||
|
|
||||||
|
|
||||||
|
@ -2431,7 +2454,8 @@ def get_tensor_num(data):
|
||||||
tensor_shape = F.shape(input_data)
|
tensor_shape = F.shape(input_data)
|
||||||
tensor_shape_len = len(tensor_shape)
|
tensor_shape_len = len(tensor_shape)
|
||||||
if tensor_shape_len != 0 and not (tensor_shape_len == 1 and tensor_shape[0] == 1):
|
if tensor_shape_len != 0 and not (tensor_shape_len == 1 and tensor_shape[0] == 1):
|
||||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
const_utils.raise_value_error(
|
||||||
|
"The truth value of an array with more than one element is ambiguous.")
|
||||||
tensor_num = tensor_num + 1
|
tensor_num = tensor_num + 1
|
||||||
return tensor_num
|
return tensor_num
|
||||||
|
|
||||||
|
@ -2453,9 +2477,11 @@ def ms_max_one_element(x):
|
||||||
tensor_shape = F.shape(x)
|
tensor_shape = F.shape(x)
|
||||||
tensor_shape_len = len(tensor_shape)
|
tensor_shape_len = len(tensor_shape)
|
||||||
if tensor_shape_len == 0:
|
if tensor_shape_len == 0:
|
||||||
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
const_utils.raise_type_error(
|
||||||
|
"Cannot iterate over a scalar tensor.")
|
||||||
if tensor_shape_len >= 2:
|
if tensor_shape_len >= 2:
|
||||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
const_utils.raise_value_error(
|
||||||
|
"The truth value of an array with more than one element is ambiguous.")
|
||||||
return x.max()
|
return x.max()
|
||||||
# Deal with Tensor in tuple or list
|
# Deal with Tensor in tuple or list
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
|
@ -2465,9 +2491,11 @@ def ms_max_one_element(x):
|
||||||
if tensor_num == len(x):
|
if tensor_num == len(x):
|
||||||
return max_tensor(x)
|
return max_tensor(x)
|
||||||
if tensor_num != 0:
|
if tensor_num != 0:
|
||||||
const_utils.raise_type_error("max() cannot contain both tensor and non-tensor type.")
|
const_utils.raise_type_error(
|
||||||
|
"max() cannot contain both tensor and non-tensor type.")
|
||||||
if exist_tensor(x):
|
if exist_tensor(x):
|
||||||
const_utils.raise_type_error("max() cannot support tensor in list or tuple nested now.")
|
const_utils.raise_type_error(
|
||||||
|
"max() cannot support tensor in list or tuple nested now.")
|
||||||
return max_(x)
|
return max_(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2485,10 +2513,12 @@ def ms_max(*data):
|
||||||
if tensor_num == len_data:
|
if tensor_num == len_data:
|
||||||
return max_tensor(*data)
|
return max_tensor(*data)
|
||||||
if tensor_num != 0:
|
if tensor_num != 0:
|
||||||
const_utils.raise_type_error("max() cannot contain both tensor and non-tensor type.")
|
const_utils.raise_type_error(
|
||||||
|
"max() cannot contain both tensor and non-tensor type.")
|
||||||
# exist tensor in list/tuple
|
# exist tensor in list/tuple
|
||||||
if exist_tensor(data):
|
if exist_tensor(data):
|
||||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
const_utils.raise_value_error(
|
||||||
|
"The truth value of an array with more than one element is ambiguous.")
|
||||||
return max_(*data)
|
return max_(*data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2525,9 +2555,11 @@ def ms_min_one_element(x):
|
||||||
tensor_shape = F.shape(x)
|
tensor_shape = F.shape(x)
|
||||||
tensor_shape_len = len(tensor_shape)
|
tensor_shape_len = len(tensor_shape)
|
||||||
if tensor_shape_len == 0:
|
if tensor_shape_len == 0:
|
||||||
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
const_utils.raise_type_error(
|
||||||
|
"Cannot iterate over a scalar tensor.")
|
||||||
if tensor_shape_len >= 2:
|
if tensor_shape_len >= 2:
|
||||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
const_utils.raise_value_error(
|
||||||
|
"The truth value of an array with more than one element is ambiguous.")
|
||||||
return x.min()
|
return x.min()
|
||||||
# Deal with Tensor in tuple or list
|
# Deal with Tensor in tuple or list
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
|
@ -2537,9 +2569,11 @@ def ms_min_one_element(x):
|
||||||
if tensor_num == len(x):
|
if tensor_num == len(x):
|
||||||
return min_tensor(x)
|
return min_tensor(x)
|
||||||
if tensor_num != 0:
|
if tensor_num != 0:
|
||||||
const_utils.raise_type_error("min() cannot contain both tensor and non-tensor type.")
|
const_utils.raise_type_error(
|
||||||
|
"min() cannot contain both tensor and non-tensor type.")
|
||||||
if exist_tensor(x):
|
if exist_tensor(x):
|
||||||
const_utils.raise_type_error("min() cannot support tensor in list or tuple nested now.")
|
const_utils.raise_type_error(
|
||||||
|
"min() cannot support tensor in list or tuple nested now.")
|
||||||
return min_(x)
|
return min_(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2557,10 +2591,12 @@ def ms_min(*data):
|
||||||
if tensor_num == len_data:
|
if tensor_num == len_data:
|
||||||
return min_tensor(*data)
|
return min_tensor(*data)
|
||||||
if tensor_num != 0:
|
if tensor_num != 0:
|
||||||
const_utils.raise_type_error("min() cannot contain both tensor and non-tensor type.")
|
const_utils.raise_type_error(
|
||||||
|
"min() cannot contain both tensor and non-tensor type.")
|
||||||
# exist tensor in list/tuple
|
# exist tensor in list/tuple
|
||||||
if exist_tensor(data):
|
if exist_tensor(data):
|
||||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
const_utils.raise_value_error(
|
||||||
|
"The truth value of an array with more than one element is ambiguous.")
|
||||||
return min_(*data)
|
return min_(*data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2572,11 +2608,13 @@ def ms_sum(*data):
|
||||||
x = data[0]
|
x = data[0]
|
||||||
if not isinstance(x, Tensor) and not hasattr(x, "__ms_iter__"):
|
if not isinstance(x, Tensor) and not hasattr(x, "__ms_iter__"):
|
||||||
data_type = F.typeof(x)
|
data_type = F.typeof(x)
|
||||||
const_utils.raise_type_error(str(data_type) + " object is not iterable.")
|
const_utils.raise_type_error(
|
||||||
|
str(data_type) + " object is not iterable.")
|
||||||
if isinstance(x, Tensor):
|
if isinstance(x, Tensor):
|
||||||
tensor_shape = F.shape(x)
|
tensor_shape = F.shape(x)
|
||||||
if len(tensor_shape) == 0:
|
if len(tensor_shape) == 0:
|
||||||
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
const_utils.raise_type_error(
|
||||||
|
"Cannot iterate over a scalar tensor.")
|
||||||
if isinstance(x, dict):
|
if isinstance(x, dict):
|
||||||
x = x.keys()
|
x = x.keys()
|
||||||
result = 0
|
result = 0
|
||||||
|
@ -2607,7 +2645,8 @@ def ms_len(data):
|
||||||
def python_len_with_check(data):
|
def python_len_with_check(data):
|
||||||
"""Return the result of python built-in len function with iterable check"""
|
"""Return the result of python built-in len function with iterable check"""
|
||||||
if not hasattr(data, "__iter__"):
|
if not hasattr(data, "__iter__"):
|
||||||
raise TypeError(str(type(data)) + " object is not iterable in graph mode.")
|
raise TypeError(str(type(data)) +
|
||||||
|
" object is not iterable in graph mode.")
|
||||||
return len(data)
|
return len(data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2617,7 +2656,8 @@ def ms_len_with_iterable_check(data):
|
||||||
return python_len_with_check(data)
|
return python_len_with_check(data)
|
||||||
if not hasattr(data, "__len__"):
|
if not hasattr(data, "__len__"):
|
||||||
type_str = str(F.typeof(data))
|
type_str = str(F.typeof(data))
|
||||||
const_utils.raise_type_error(type_str + " object is not iterable in graph mode.")
|
const_utils.raise_type_error(
|
||||||
|
type_str + " object is not iterable in graph mode.")
|
||||||
return data.__len__()
|
return data.__len__()
|
||||||
|
|
||||||
|
|
||||||
|
@ -2680,7 +2720,6 @@ def expand_dims(x, axis):
|
||||||
"""
|
"""
|
||||||
Insert a dimension of shape 1 at the specified axis of Tensor.
|
Insert a dimension of shape 1 at the specified axis of Tensor.
|
||||||
"""
|
"""
|
||||||
check_is_int(axis, 'axis')
|
|
||||||
return P.ExpandDims()(x, axis)
|
return P.ExpandDims()(x, axis)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2688,7 +2727,6 @@ def unsqueeze(x, dim):
|
||||||
"""
|
"""
|
||||||
Insert a dimension of shape 1 at the specified axis of Tensor.
|
Insert a dimension of shape 1 at the specified axis of Tensor.
|
||||||
"""
|
"""
|
||||||
check_is_int(dim, 'dim')
|
|
||||||
return P.ExpandDims()(x, dim)
|
return P.ExpandDims()(x, dim)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2740,7 +2778,8 @@ def check_select_condition(cond_type):
|
||||||
"""
|
"""
|
||||||
if isinstance(cond_type, mstype.tensor_type):
|
if isinstance(cond_type, mstype.tensor_type):
|
||||||
return
|
return
|
||||||
raise TypeError(f"For select, the argument condition should be Tensor, but got {cond_type}.")
|
raise TypeError(
|
||||||
|
f"For select, the argument condition should be Tensor, but got {cond_type}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -2893,8 +2932,6 @@ def ge(x, y):
|
||||||
def while_cond(x):
|
def while_cond(x):
|
||||||
"""For while condition, if the condition is a tensor, the loop will not be unrolled"""
|
"""For while condition, if the condition is a tensor, the loop will not be unrolled"""
|
||||||
if issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
if issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
||||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
|
||||||
if is_cond:
|
|
||||||
return F.cast(x, mstype.bool_)
|
return F.cast(x, mstype.bool_)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -3046,14 +3083,16 @@ def coo_to_dense(x):
|
||||||
def coo_coalesce(x):
|
def coo_coalesce(x):
|
||||||
"""Returns the coalesced sparse tensor of the input."""
|
"""Returns the coalesced sparse tensor of the input."""
|
||||||
shape = const_utils.make_tensor(x.shape)
|
shape = const_utils.make_tensor(x.shape)
|
||||||
res_indices, res_values, _ = P.Coalesce()(x.indices.transpose(), x.values, shape)
|
res_indices, res_values, _ = P.Coalesce()(
|
||||||
|
x.indices.transpose(), x.values, shape)
|
||||||
return COOTensor(res_indices.transpose(), res_values, x.shape)
|
return COOTensor(res_indices.transpose(), res_values, x.shape)
|
||||||
|
|
||||||
|
|
||||||
def csr_to_coo(x):
|
def csr_to_coo(x):
|
||||||
"""convert csr to coo."""
|
"""convert csr to coo."""
|
||||||
if x.ndim != 2:
|
if x.ndim != 2:
|
||||||
const_utils.raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
|
const_utils.raise_value_error(
|
||||||
|
"Currently only support 2-D CSRTensor when converting to COOTensor.")
|
||||||
row_indices = F.csr2coo(x.indptr, x.values.shape[0])
|
row_indices = F.csr2coo(x.indptr, x.values.shape[0])
|
||||||
coo_indices = P.Stack(1)((row_indices, x.indices))
|
coo_indices = P.Stack(1)((row_indices, x.indices))
|
||||||
return COOTensor(coo_indices, x.values, x.shape)
|
return COOTensor(coo_indices, x.values, x.shape)
|
||||||
|
@ -3069,8 +3108,6 @@ def random_categorical_(x, num_sample, seed=0, dtype=mstype.int64):
|
||||||
Generates random samples from a given categorical distribution tensor.
|
Generates random samples from a given categorical distribution tensor.
|
||||||
Refer to :func:`mindspore.ops.random_categorical` for more detail.
|
Refer to :func:`mindspore.ops.random_categorical` for more detail.
|
||||||
"""
|
"""
|
||||||
validator.check_is_int(num_sample, 'num_sample')
|
|
||||||
validator.check_is_int(seed, 'seed')
|
|
||||||
return F.random_categorical(x, num_sample, seed, dtype)
|
return F.random_categorical(x, num_sample, seed, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3099,33 +3136,31 @@ def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
|
||||||
"""check whether x is list or tuple or tensor."""
|
"""check whether x is list or tuple or tensor."""
|
||||||
if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)):
|
if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)):
|
||||||
return True
|
return True
|
||||||
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")
|
raise TypeError(
|
||||||
|
f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def check_is_const_int(x, op_name, arg_name):
|
def check_is_const_int(x, op_name, arg_name):
|
||||||
"""check whether x is const int."""
|
"""check whether x is const int."""
|
||||||
if x is None:
|
if x is None:
|
||||||
raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
|
raise TypeError(
|
||||||
|
f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
|
||||||
if not isinstance(x, int):
|
if not isinstance(x, int):
|
||||||
raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
|
raise TypeError(
|
||||||
|
f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_is_tensor_bool_cond(shp):
|
|
||||||
"""check if tensor is a bool condition"""
|
|
||||||
if shp in ((), (1,)):
|
|
||||||
return True
|
|
||||||
if None in shp:
|
|
||||||
raise ValueError(f"Only tensor which shape is () or (1,) can be converted to bool, but got tensor shape is "
|
|
||||||
f"None")
|
|
||||||
raise ValueError(f"Only tensor which shape is () or (1,) can be converted to bool, but got tensor shape is {shp}")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def const_tensor_to_bool(x):
|
def const_tensor_to_bool(x):
|
||||||
"""convert bool tensor to bool condition"""
|
"""convert bool tensor to bool condition
|
||||||
|
def const_tensor_to_bool(x):
|
||||||
|
convert bool tensor to bool condition
|
||||||
|
if x.shape == (1,):
|
||||||
|
return bool(x[0])
|
||||||
|
return bool(x)
|
||||||
|
"""
|
||||||
if x is None:
|
if x is None:
|
||||||
raise ValueError("Only tensor which shape is () or (1,) can be converted to bool, but got None")
|
raise ValueError("Only tensor which shape is () or (1,) can be converted to bool, but got None")
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
|
@ -3153,36 +3188,22 @@ def check_view_shape(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# convert normal param_check functions to constexpr functions
|
|
||||||
check_astype_dtype_const = constexpr(validator.check_astype_dtype)
|
check_astype_dtype_const = constexpr(validator.check_astype_dtype)
|
||||||
check_transpose_axis_const = constexpr(validator.check_transpose_axis)
|
check_transpose_axis_const = constexpr(validator.check_transpose_axis)
|
||||||
check_reshape_shp_const = constexpr(validator.check_reshape_shp)
|
|
||||||
check_flatten_order_const = constexpr(validator.check_flatten_order)
|
check_flatten_order_const = constexpr(validator.check_flatten_order)
|
||||||
check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
|
|
||||||
prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
|
|
||||||
check_axis_in_range_const = constexpr(validator.check_axis_in_range)
|
|
||||||
check_axis_valid = constexpr(validator.check_axis_valid)
|
|
||||||
max_ = constexpr(validator.max_)
|
max_ = constexpr(validator.max_)
|
||||||
min_ = constexpr(validator.min_)
|
min_ = constexpr(validator.min_)
|
||||||
expanded_shape = constexpr(validator.expanded_shape)
|
expanded_shape = validator.expanded_shape
|
||||||
tuple_slice = constexpr(validator.tuple_slice)
|
tuple_slice = validator.tuple_slice
|
||||||
infer_out_shape = constexpr(validator.infer_out_shape)
|
empty_compile = validator.empty_compile
|
||||||
get_log2_size = constexpr(validator.get_log2_size)
|
|
||||||
check_axis_type = constexpr(validator.check_axis_type)
|
|
||||||
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
|
|
||||||
empty_compile = constexpr(validator.empty_compile)
|
|
||||||
check_type_support = constexpr(validator.check_type_support)
|
check_type_support = constexpr(validator.check_type_support)
|
||||||
check_is_int = constexpr(validator.check_is_int)
|
|
||||||
check_type_name = constexpr(validator.check_type_name)
|
check_type_name = constexpr(validator.check_type_name)
|
||||||
check_value_type = constexpr(validator.check_value_type)
|
check_value_type = constexpr(validator.check_value_type)
|
||||||
check_int = constexpr(validator.check_int)
|
|
||||||
check_bool = constexpr(validator.check_bool)
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_bool(x):
|
def tensor_bool(x):
|
||||||
"""tensor as condition, if is constant, return immediate bool value"""
|
"""tensor as condition, if is constant, return immediate bool value"""
|
||||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
if F.isconstant(x):
|
||||||
if is_cond and F.isconstant(x):
|
|
||||||
return const_tensor_to_bool(x)
|
return const_tensor_to_bool(x)
|
||||||
return F.cast(x, mstype.bool_)
|
return F.cast(x, mstype.bool_)
|
||||||
|
|
||||||
|
@ -3340,8 +3361,6 @@ def top_k(input_x, k, sorted=True):
|
||||||
"""
|
"""
|
||||||
Finds values and indices of the `k` largest entries along the last dimension.
|
Finds values and indices of the `k` largest entries along the last dimension.
|
||||||
"""
|
"""
|
||||||
check_is_int(k, 'k')
|
|
||||||
check_bool(sorted, 'sorted')
|
|
||||||
return F.top_k(input_x, k, sorted)
|
return F.top_k(input_x, k, sorted)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3588,7 +3607,6 @@ def bernoulli(x, p=0.5, seed=-1):
|
||||||
"""
|
"""
|
||||||
Randomly draws binary numbers from a Bernoulli distribution.
|
Randomly draws binary numbers from a Bernoulli distribution.
|
||||||
"""
|
"""
|
||||||
check_is_int(seed, 'bernoulli', 'seed')
|
|
||||||
return F.bernoulli(x, p, seed)
|
return F.bernoulli(x, p, seed)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ import mindspore.ops as ops
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.ops.composite as C
|
import mindspore.ops.composite as C
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.primitive import constexpr
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,17 +95,6 @@ def apply_offload_iterators(data, offload_model):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_input_dims(x_shape, required_dim, offload_op_name):
|
|
||||||
"""
|
|
||||||
Check if input has the required number of dimensions for the operation.
|
|
||||||
"""
|
|
||||||
input_dim = len(x_shape)
|
|
||||||
if input_dim is not required_dim:
|
|
||||||
raise ValueError("For %s offload operation, the dimension of input should be %d, but got %d." %
|
|
||||||
(offload_op_name, required_dim, input_dim))
|
|
||||||
|
|
||||||
|
|
||||||
def assign_min_max_params(in_params, center=1):
|
def assign_min_max_params(in_params, center=1):
|
||||||
"""
|
"""
|
||||||
Adjust input parameters for ops.
|
Adjust input parameters for ops.
|
||||||
|
@ -175,7 +163,6 @@ class RandomHorizontalFlip(nn.Cell):
|
||||||
|
|
||||||
x = self.cast(x, mstype.float32)
|
x = self.cast(x, mstype.float32)
|
||||||
x_shape = self.shape(x)
|
x_shape = self.shape(x)
|
||||||
check_input_dims(x_shape, 4, 'RandomHorizontalFlip')
|
|
||||||
bs, h, w, c = x_shape
|
bs, h, w, c = x_shape
|
||||||
|
|
||||||
flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||||
|
@ -208,7 +195,6 @@ class RandomVerticalFlip(nn.Cell):
|
||||||
|
|
||||||
x = self.cast(x, mstype.float32)
|
x = self.cast(x, mstype.float32)
|
||||||
x_shape = self.shape(x)
|
x_shape = self.shape(x)
|
||||||
check_input_dims(x_shape, 4, 'RandomVerticalFlip')
|
|
||||||
bs, h, w, c = x_shape
|
bs, h, w, c = x_shape
|
||||||
|
|
||||||
flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
flip_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||||
|
@ -293,7 +279,6 @@ class RandomColorAdjust(nn.Cell):
|
||||||
|
|
||||||
x = self.cast(x, mstype.float32)
|
x = self.cast(x, mstype.float32)
|
||||||
x_shape = self.shape(x)
|
x_shape = self.shape(x)
|
||||||
check_input_dims(x_shape, 4, 'RandomColorAdjust')
|
|
||||||
bs, h, w, c = x_shape
|
bs, h, w, c = x_shape
|
||||||
|
|
||||||
br_rand_factor = self.generate_rand_batch(self.br_min, self.br_max, self.check_rand_br, x_shape)
|
br_rand_factor = self.generate_rand_batch(self.br_min, self.br_max, self.check_rand_br, x_shape)
|
||||||
|
@ -402,7 +387,6 @@ class RandomSharpness(nn.Cell):
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.cast(x, mstype.float32)
|
x = self.cast(x, mstype.float32)
|
||||||
x_shape = self.shape(x)
|
x_shape = self.shape(x)
|
||||||
check_input_dims(x_shape, 4, 'RandomSharpness')
|
|
||||||
bs, h, w, c = x_shape
|
bs, h, w, c = x_shape
|
||||||
|
|
||||||
degree_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
degree_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||||
|
@ -449,11 +433,8 @@ class HwcToChw(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(HwcToChw, self).__init__()
|
super(HwcToChw, self).__init__()
|
||||||
self.trans = P.Transpose()
|
self.trans = P.Transpose()
|
||||||
self.shape = P.Shape()
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape(x)
|
|
||||||
check_input_dims(x_shape, 4, 'HwcToChw')
|
|
||||||
return self.trans(x, (0, 3, 1, 2))
|
return self.trans(x, (0, 3, 1, 2))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops.operations import nn_ops as NN_OPS
|
from mindspore.ops.operations import nn_ops as NN_OPS
|
||||||
from mindspore.ops.primitive import constexpr
|
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore import ops
|
from mindspore import ops
|
||||||
|
|
||||||
|
@ -195,18 +194,9 @@ class Softmax2d(Cell):
|
||||||
"""Initialize Softmax2d."""
|
"""Initialize Softmax2d."""
|
||||||
super(Softmax2d, self).__init__()
|
super(Softmax2d, self).__init__()
|
||||||
self.softmax = P.Softmax(axis=-3)
|
self.softmax = P.Softmax(axis=-3)
|
||||||
self.shape = P.Shape()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim not in (3, 4):
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 3 or 4 dims, but got {dim}.")
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape(x)
|
|
||||||
self._check_input_dim(x_shape, self.cls_name)
|
|
||||||
return self.softmax(x)
|
return self.softmax(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -87,15 +87,18 @@ class L1Regularizer(Cell):
|
||||||
super(L1Regularizer, self).__init__()
|
super(L1Regularizer, self).__init__()
|
||||||
Validator.check_value_type("scale", scale, [int, float], self.cls_name)
|
Validator.check_value_type("scale", scale, [int, float], self.cls_name)
|
||||||
if scale <= 0:
|
if scale <= 0:
|
||||||
raise ValueError(f"For '{self.cls_name}', the 'scale' must be greater than 0, but got {scale}.")
|
raise ValueError(
|
||||||
|
f"For '{self.cls_name}', the 'scale' must be greater than 0, but got {scale}.")
|
||||||
if math.isinf(scale) or math.isnan(scale):
|
if math.isinf(scale) or math.isnan(scale):
|
||||||
raise ValueError(f"For '{self.cls_name}', the 'scale' can not be INF or NAN, but got {scale}.")
|
raise ValueError(
|
||||||
|
f"For '{self.cls_name}', the 'scale' can not be INF or NAN, but got {scale}.")
|
||||||
self.abs = P.Abs()
|
self.abs = P.Abs()
|
||||||
self.reduce_sum = P.ReduceSum()
|
self.reduce_sum = P.ReduceSum()
|
||||||
self.scale = Tensor(scale, dtype=mstype.float32)
|
self.scale = Tensor(scale, dtype=mstype.float32)
|
||||||
|
|
||||||
def construct(self, weights):
|
def construct(self, weights):
|
||||||
const_utils.check_type_valid(F.dtype(weights), mstype.number_type, 'weights')
|
const_utils.check_type_valid(
|
||||||
|
F.dtype(weights), mstype.number_type, 'weights')
|
||||||
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
|
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
|
||||||
return l1_regularization
|
return l1_regularization
|
||||||
|
|
||||||
|
@ -155,13 +158,16 @@ class Dropout(Cell):
|
||||||
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
||||||
"""Initialize Dropout."""
|
"""Initialize Dropout."""
|
||||||
super(Dropout, self).__init__()
|
super(Dropout, self).__init__()
|
||||||
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
Validator.check_value_type('keep_prob', keep_prob, [
|
||||||
|
float], self.cls_name)
|
||||||
if keep_prob <= 0 or keep_prob > 1:
|
if keep_prob <= 0 or keep_prob > 1:
|
||||||
raise ValueError(f"For '{self.cls_name}', the 'keep_prob' must be a number in range (0, 1], "
|
raise ValueError(f"For '{self.cls_name}', the 'keep_prob' must be a number in range (0, 1], "
|
||||||
f"but got {keep_prob}.")
|
f"but got {keep_prob}.")
|
||||||
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
Validator.check_subclass(
|
||||||
|
"dtype", dtype, mstype.number_type, self.cls_name)
|
||||||
if dtype != mstype.float32:
|
if dtype != mstype.float32:
|
||||||
logger.info("This parameter `dtype` will be deleted or invisible in the future. Please don't use it.")
|
logger.info(
|
||||||
|
"This parameter `dtype` will be deleted or invisible in the future. Please don't use it.")
|
||||||
self.keep_prob = keep_prob
|
self.keep_prob = keep_prob
|
||||||
seed0, seed1 = _get_graph_seed(0, "dropout")
|
seed0, seed1 = _get_graph_seed(0, "dropout")
|
||||||
self.seed0 = seed0
|
self.seed0 = seed0
|
||||||
|
@ -623,13 +629,6 @@ class Flatten(Cell):
|
||||||
return F.reshape(x, (F.shape(x)[0], -1))
|
return F.reshape(x, (F.shape(x)[0], -1))
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_dense_input_shape(x, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if len(x) < 2:
|
|
||||||
raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 2, but got {len(x)}.")
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(Cell):
|
class Identity(Cell):
|
||||||
"""
|
"""
|
||||||
Returns a Tensor with the same shape and contents as input.
|
Returns a Tensor with the same shape and contents as input.
|
||||||
|
@ -727,9 +726,12 @@ class Dense(Cell):
|
||||||
activation=None):
|
activation=None):
|
||||||
"""Initialize Dense."""
|
"""Initialize Dense."""
|
||||||
super(Dense, self).__init__()
|
super(Dense, self).__init__()
|
||||||
self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
|
self.in_channels = Validator.check_positive_int(
|
||||||
self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
|
in_channels, "in_channels", self.cls_name)
|
||||||
self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
|
self.out_channels = Validator.check_positive_int(
|
||||||
|
out_channels, "out_channels", self.cls_name)
|
||||||
|
self.has_bias = Validator.check_bool(
|
||||||
|
has_bias, "has_bias", self.cls_name)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape_op = P.Shape()
|
self.shape_op = P.Shape()
|
||||||
|
|
||||||
|
@ -740,7 +742,8 @@ class Dense(Cell):
|
||||||
f"be equal to 2, and the first dim must be equal to 'out_channels', and the "
|
f"be equal to 2, and the first dim must be equal to 'out_channels', and the "
|
||||||
f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
|
f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
|
||||||
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
|
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
|
||||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
self.weight = Parameter(initializer(
|
||||||
|
weight_init, [out_channels, in_channels]), name="weight")
|
||||||
|
|
||||||
self.bias = None
|
self.bias = None
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
|
@ -749,11 +752,13 @@ class Dense(Cell):
|
||||||
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
||||||
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
||||||
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
||||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
self.bias = Parameter(initializer(
|
||||||
|
bias_init, [out_channels]), name="bias")
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
self.matmul = P.MatMul(transpose_b=True)
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
self.activation = get_activation(activation) if isinstance(
|
||||||
|
activation, str) else activation
|
||||||
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
|
||||||
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
|
raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
|
||||||
f"{type(activation).__name__}.")
|
f"{type(activation).__name__}.")
|
||||||
|
@ -761,7 +766,6 @@ class Dense(Cell):
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape_op(x)
|
x_shape = self.shape_op(x)
|
||||||
check_dense_input_shape(x_shape, self.cls_name)
|
|
||||||
if len(x_shape) != 2:
|
if len(x_shape) != 2:
|
||||||
x = self.reshape(x, (-1, x_shape[-1]))
|
x = self.reshape(x, (-1, x_shape[-1]))
|
||||||
x = self.matmul(x, self.weight)
|
x = self.matmul(x, self.weight)
|
||||||
|
@ -775,7 +779,8 @@ class Dense(Cell):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
s = 'input_channels={}, output_channels={}'.format(
|
||||||
|
self.in_channels, self.out_channels)
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
s += ', has_bias={}'.format(self.has_bias)
|
s += ', has_bias={}'.format(self.has_bias)
|
||||||
if self.activation_flag:
|
if self.activation_flag:
|
||||||
|
@ -787,14 +792,15 @@ class Dense(Cell):
|
||||||
def _is_equal_one(x):
|
def _is_equal_one(x):
|
||||||
if x is None:
|
if x is None:
|
||||||
return False
|
return False
|
||||||
return bool(x.asnumpy().mean() == 1.0)
|
return F.equal(F.reduce_mean(x), 1.0)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _dtype_check(x_dtype, prim_name=None):
|
def _dtype_check(x_dtype, prim_name=None):
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if x_dtype not in [mstype.float32, mstype.float16]:
|
if x_dtype not in [mstype.float32, mstype.float16]:
|
||||||
raise TypeError(f"{msg_prefix} x_dtype must be float32 or float16, but got {x_dtype}.")
|
raise TypeError(
|
||||||
|
f"{msg_prefix} x_dtype must be float32 or float16, but got {x_dtype}.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -923,7 +929,8 @@ class Norm(Cell):
|
||||||
def __init__(self, axis=(), keep_dims=False):
|
def __init__(self, axis=(), keep_dims=False):
|
||||||
"""Initialize Norm."""
|
"""Initialize Norm."""
|
||||||
super(Norm, self).__init__()
|
super(Norm, self).__init__()
|
||||||
Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
|
Validator.check_value_type(
|
||||||
|
"keep_dims", keep_dims, [bool], self.cls_name)
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
self.keep_dims = keep_dims
|
self.keep_dims = keep_dims
|
||||||
self.reduce_sum = P.ReduceSum(True)
|
self.reduce_sum = P.ReduceSum(True)
|
||||||
|
@ -1209,7 +1216,8 @@ class Pad(Cell):
|
||||||
super(Pad, self).__init__()
|
super(Pad, self).__init__()
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.paddings = paddings
|
self.paddings = paddings
|
||||||
Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
|
Validator.check_string(
|
||||||
|
self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
|
||||||
if not isinstance(paddings, tuple):
|
if not isinstance(paddings, tuple):
|
||||||
raise TypeError(f"For '{self.cls_name}', the type of 'paddings' must be tuple, "
|
raise TypeError(f"For '{self.cls_name}', the type of 'paddings' must be tuple, "
|
||||||
f"but got {type(paddings).__name__}.")
|
f"but got {type(paddings).__name__}.")
|
||||||
|
@ -1239,20 +1247,17 @@ def bilinear(shape, size, scale, align_corners, prim_name=None):
|
||||||
"""Check input and calculate shape"""
|
"""Check input and calculate shape"""
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
||||||
if not isinstance(align_corners, bool):
|
if not isinstance(align_corners, bool):
|
||||||
raise TypeError(f"{msg_prefix} type of 'align_corners' must be boolean, "
|
raise TypeError(
|
||||||
f"but got {type(align_corners).__name__}.")
|
f"{msg_prefix} type of 'align_corners' must be bool, but got {type(align_corners).__name__}.")
|
||||||
if size is None and scale is None:
|
if size is None and scale is None:
|
||||||
raise ValueError(f"{msg_prefix} 'size' and 'scale' both none.")
|
raise ValueError(f"{msg_prefix} 'size' and 'scale' both none.")
|
||||||
if size is not None and scale is not None:
|
if size is not None and scale is not None:
|
||||||
raise ValueError(f"{msg_prefix} 'size' and 'scale' both not none.")
|
raise ValueError(f"{msg_prefix} 'size' and 'scale' both not none.")
|
||||||
if size is not None:
|
if size is not None:
|
||||||
if not isinstance(size, (tuple, list)):
|
if not isinstance(size, (tuple, list)):
|
||||||
raise ValueError(f"{msg_prefix} 'size' must be tuple or list or None, but got {type(size).__name__}.")
|
raise ValueError(
|
||||||
Validator.check_int(len(size), 2, Rel.EQ, "size", "bilinear")
|
f"{msg_prefix} 'size' must be tuple or list or None, but got {type(size).__name__}.")
|
||||||
Validator.check_int(size[0], 1, Rel.GE, "size[0]", "bilinear")
|
|
||||||
Validator.check_int(size[1], 1, Rel.GE, "size[1]", "bilinear")
|
|
||||||
return size
|
return size
|
||||||
Validator.check_int(scale, 1, Rel.GE, "scale factor", "bilinear")
|
|
||||||
ret = (scale * shape[2], scale * shape[3])
|
ret = (scale * shape[2], scale * shape[3])
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@ -1323,8 +1328,10 @@ class ResizeBilinear(Cell):
|
||||||
self.half_pixel_centers = half_pixel_centers
|
self.half_pixel_centers = half_pixel_centers
|
||||||
|
|
||||||
def construct(self, x, size=None, scale_factor=None, align_corners=False):
|
def construct(self, x, size=None, scale_factor=None, align_corners=False):
|
||||||
shape = bilinear(x.shape, size, scale_factor, align_corners, self.cls_name)
|
shape = bilinear(x.shape, size, scale_factor,
|
||||||
resize_bilinear = P.ResizeBilinear(shape, align_corners, self.half_pixel_centers)
|
align_corners, self.cls_name)
|
||||||
|
resize_bilinear = P.ResizeBilinear(
|
||||||
|
shape, align_corners, self.half_pixel_centers)
|
||||||
return resize_bilinear(x)
|
return resize_bilinear(x)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1390,7 +1397,8 @@ class Unfold(Cell):
|
||||||
super(Unfold, self).__init__()
|
super(Unfold, self).__init__()
|
||||||
|
|
||||||
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
||||||
Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name)
|
Validator.check_value_type(f"{arg_name}s", ksizes, [
|
||||||
|
tuple, list], self.cls_name)
|
||||||
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
|
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
|
||||||
raise ValueError(f"For '{prim_name}' the format of '{arg_name}s' must be [1, {arg_name}_row, "
|
raise ValueError(f"For '{prim_name}' the format of '{arg_name}s' must be [1, {arg_name}_row, "
|
||||||
f"{arg_name}_col, 1], but got {arg_val}.")
|
f"{arg_name}_col, 1], but got {arg_val}.")
|
||||||
|
@ -1405,19 +1413,17 @@ class Unfold(Cell):
|
||||||
ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
|
ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
|
||||||
strides = strides[0], strides[3], strides[1], strides[2]
|
strides = strides[0], strides[3], strides[1], strides[2]
|
||||||
rates = rates[0], rates[3], rates[1], rates[2]
|
rates = rates[0], rates[3], rates[1], rates[2]
|
||||||
self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
|
self.extract_image_patches = inner.ExtractImagePatches(
|
||||||
|
ksizes, strides, rates, padding)
|
||||||
|
|
||||||
def construct(self, input_x):
|
def construct(self, input_x):
|
||||||
result = self.extract_image_patches(input_x)
|
result = self.extract_image_patches(input_x)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def tril(x_shape, x_dtype, k):
|
def tril(x_shape, x_dtype, k):
|
||||||
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril")
|
value = F.cast(P.Tril(diagonal=k)(F.ones(x_shape, x_dtype)), x_dtype)
|
||||||
Validator.check_is_int(k, "k value", "tril")
|
return value
|
||||||
mask = np.tril(np.ones(x_shape), k)
|
|
||||||
return Tensor(mask, x_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class Tril(Cell):
|
class Tril(Cell):
|
||||||
|
@ -1510,16 +1516,14 @@ class Tril(Cell):
|
||||||
|
|
||||||
def construct(self, x, k=0):
|
def construct(self, x, k=0):
|
||||||
assist = tril(x.shape, self.dtype(x), k)
|
assist = tril(x.shape, self.dtype(x), k)
|
||||||
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
result = self.mul(self.cast(x, mstype.float32),
|
||||||
|
self.cast(assist, mstype.float32))
|
||||||
return self.cast(result, self.dtype(x))
|
return self.cast(result, self.dtype(x))
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def triu(x_shape, x_dtype, k):
|
def triu(x_shape, x_dtype, k):
|
||||||
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu")
|
value = F.cast(P.Triu(k)(F.ones(x_shape, x_dtype)), x_dtype)
|
||||||
Validator.check_is_int(k, "k value", "triu")
|
return value
|
||||||
mask = np.triu(np.ones(x_shape), k)
|
|
||||||
return Tensor(mask, x_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class Triu(Cell):
|
class Triu(Cell):
|
||||||
|
@ -1603,24 +1607,34 @@ class Triu(Cell):
|
||||||
|
|
||||||
def construct(self, x, k=0):
|
def construct(self, x, k=0):
|
||||||
assist = triu(x.shape, self.dtype(x), k)
|
assist = triu(x.shape, self.dtype(x), k)
|
||||||
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
result = self.mul(self.cast(x, mstype.float32),
|
||||||
|
self.cast(assist, mstype.float32))
|
||||||
return self.cast(result, self.dtype(x))
|
return self.cast(result, self.dtype(x))
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _get_matrix_diag_assist(x_shape, x_dtype):
|
def _get_matrix_diag_assist(x_shape, x_dtype):
|
||||||
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")
|
"""Get matrix diag assist"""
|
||||||
base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
|
base_eye = F.reshape(
|
||||||
assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
|
F.eye(x_shape[-1], x_shape[-1], x_dtype), (x_shape[-1] * x_shape[-1],))
|
||||||
return Tensor(assist, x_dtype)
|
if len(x_shape) == 1:
|
||||||
|
assist = F.reshape(base_eye, x_shape + (x_shape[-1],))
|
||||||
|
else:
|
||||||
|
assist = F.reshape(
|
||||||
|
F.tile(base_eye, x_shape[:-1]), x_shape + (x_shape[-1],))
|
||||||
|
value = F.cast(assist, x_dtype)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _get_matrix_diag_part_assist(x_shape, x_dtype):
|
def _get_matrix_diag_part_assist(x_shape, x_dtype):
|
||||||
Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist")
|
"""Get matrix diag part assist"""
|
||||||
base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
|
base_eye = F.reshape(
|
||||||
assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
|
F.eye(x_shape[-2], x_shape[-1], x_dtype), (x_shape[-2] * x_shape[-1],))
|
||||||
return Tensor(assist, x_dtype)
|
if len(x_shape) <= 2:
|
||||||
|
assist = F.reshape(base_eye, x_shape)
|
||||||
|
else:
|
||||||
|
assist = F.reshape(F.tile(base_eye, x_shape[:-2]), x_shape)
|
||||||
|
value = F.cast(assist, x_dtype)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class MatrixDiag(Cell):
|
class MatrixDiag(Cell):
|
||||||
|
@ -1867,8 +1881,10 @@ class Roll(Cell):
|
||||||
def __init__(self, shift, axis):
|
def __init__(self, shift, axis):
|
||||||
"""Initialize Roll"""
|
"""Initialize Roll"""
|
||||||
super(Roll, self).__init__()
|
super(Roll, self).__init__()
|
||||||
Validator.check_value_type("shift", shift, [int, tuple, list], self.cls_name)
|
Validator.check_value_type(
|
||||||
Validator.check_value_type("axis", axis, [int, tuple, list], self.cls_name)
|
"shift", shift, [int, tuple, list], self.cls_name)
|
||||||
|
Validator.check_value_type(
|
||||||
|
"axis", axis, [int, tuple, list], self.cls_name)
|
||||||
self.shape_op = P.Shape()
|
self.shape_op = P.Shape()
|
||||||
self.shift = shift
|
self.shift = shift
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
|
@ -1894,14 +1910,16 @@ class Roll(Cell):
|
||||||
f"and the length of 'axis' {len(self.axis)}.")
|
f"and the length of 'axis' {len(self.axis)}.")
|
||||||
else:
|
else:
|
||||||
if not isinstance(self.axis, (list, tuple)):
|
if not isinstance(self.axis, (list, tuple)):
|
||||||
self.op_list.append((P.Roll(shift=self.shift, axis=0), self.axis))
|
self.op_list.append(
|
||||||
|
(P.Roll(shift=self.shift, axis=0), self.axis))
|
||||||
else:
|
else:
|
||||||
if len(self.shift) != len(self.axis):
|
if len(self.shift) != len(self.axis):
|
||||||
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
|
raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
|
||||||
f"the same, but got the length of 'shift' {len(self.shift)} "
|
f"the same, but got the length of 'shift' {len(self.shift)} "
|
||||||
f"and the length of 'axis' {len(self.axis)}.")
|
f"and the length of 'axis' {len(self.axis)}.")
|
||||||
for idx, _ in enumerate(self.axis):
|
for idx, _ in enumerate(self.axis):
|
||||||
self.op_list.append((P.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
|
self.op_list.append(
|
||||||
|
(P.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
|
||||||
|
|
||||||
def construct(self, input_x):
|
def construct(self, input_x):
|
||||||
dim = len(self.shape_op(input_x))
|
dim = len(self.shape_op(input_x))
|
||||||
|
@ -1965,7 +1983,8 @@ class Unflatten(Cell):
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
Validator.check_is_int(axis, 'axis', 'Unflatten')
|
Validator.check_is_int(axis, 'axis', 'Unflatten')
|
||||||
Validator.check_value_type('unflattended_size', unflattened_size, (list, tuple), 'Unflatten')
|
Validator.check_value_type(
|
||||||
|
'unflattended_size', unflattened_size, (list, tuple), 'Unflatten')
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
if isinstance(unflattened_size, list):
|
if isinstance(unflattened_size, list):
|
||||||
unflattened_size = tuple(unflattened_size)
|
unflattened_size = tuple(unflattened_size)
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""channel shuffle"""
|
"""channel shuffle"""
|
||||||
from mindspore.ops.primitive import constexpr
|
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
|
|
||||||
|
@ -82,21 +81,9 @@ class ChannelShuffle(Cell):
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.transpose = P.Transpose()
|
self.transpose = P.Transpose()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, channels, groups, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim < 3:
|
|
||||||
raise ValueError(f"For {cls_name}, the in_shape must have more than 2 dims, but got {dim}.")
|
|
||||||
|
|
||||||
if channels % groups != 0:
|
|
||||||
raise ValueError(f"For {cls_name}, number of channels must be divisible by groups, "
|
|
||||||
f"but got {channels} channels and {groups} groups.")
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape(x)
|
x_shape = self.shape(x)
|
||||||
n, c = x_shape[0], x_shape[1]
|
n, c = x_shape[0], x_shape[1]
|
||||||
self._check_input_dim(x_shape, c, self.groups, self.cls_name)
|
|
||||||
out = self.reshape(x, (n, self.groups, c // self.groups, -1))
|
out = self.reshape(x, (n, self.groups, c // self.groups, -1))
|
||||||
out = self.transpose(out, (0, 2, 1, 3))
|
out = self.transpose(out, (0, 2, 1, 3))
|
||||||
return self.reshape(out, x_shape)
|
return self.reshape(out, x_shape)
|
||||||
|
|
|
@ -315,12 +315,6 @@ class Conv2d(_Conv):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_input_3d(input_shape, op_name):
|
|
||||||
if len(input_shape) != 3:
|
|
||||||
raise ValueError(f"For '{op_name}', the dimension of input must be 3d, but got {len(input_shape)}.")
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(_Conv):
|
class Conv1d(_Conv):
|
||||||
r"""
|
r"""
|
||||||
Calculates the 1D convolution on the input tensor. The input is typically of shape :math:`(N, C_{in}, L_{in})`,
|
Calculates the 1D convolution on the input tensor. The input is typically of shape :math:`(N, C_{in}, L_{in})`,
|
||||||
|
@ -482,11 +476,8 @@ class Conv1d(_Conv):
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
self.expand_dims = P.ExpandDims()
|
self.expand_dims = P.ExpandDims()
|
||||||
self.squeeze = P.Squeeze(2)
|
self.squeeze = P.Squeeze(2)
|
||||||
self.shape = P.Shape()
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape(x)
|
|
||||||
_check_input_3d(x_shape, self.cls_name)
|
|
||||||
x = self.expand_dims(x, 2)
|
x = self.expand_dims(x, 2)
|
||||||
output = self.conv2d(x, self.weight)
|
output = self.conv2d(x, self.weight)
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
|
@ -1289,8 +1280,6 @@ class Conv1dTranspose(_Conv):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape(x)
|
|
||||||
_check_input_3d(x_shape, self.cls_name)
|
|
||||||
x = self.expand_dims(x, 2)
|
x = self.expand_dims(x, 2)
|
||||||
|
|
||||||
n, _, h, w = self.shape(x)
|
n, _, h, w = self.shape(x)
|
||||||
|
|
|
@ -30,14 +30,6 @@ from mindspore.nn.cell import Cell
|
||||||
__all__ = ['BiDense']
|
__all__ = ['BiDense']
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_dense_inputs_same_shape(input1, input2, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if input1[:-1] != input2[:-1]:
|
|
||||||
raise ValueError(f"{msg_prefix} dimensions except the last of 'input1' must be same as 'input2', but got "
|
|
||||||
f"{input1} of 'input1' and {input2} of 'input2'")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr(check=False)
|
@constexpr(check=False)
|
||||||
def _check_is_tensor(param_name, input_data, cls_name):
|
def _check_is_tensor(param_name, input_data, cls_name):
|
||||||
"""Internal function, used to check whether the input data is Tensor."""
|
"""Internal function, used to check whether the input data is Tensor."""
|
||||||
|
@ -46,14 +38,6 @@ def _check_is_tensor(param_name, input_data, cls_name):
|
||||||
f"but got '{P.typeof(input_data)}'")
|
f"but got '{P.typeof(input_data)}'")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_last_dimension(input_dim, input_channels, input_name, input_channels_name, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if input_dim != input_channels:
|
|
||||||
raise ValueError(f"{msg_prefix} last dimension of '{input_name}' must be same as '{input_channels_name}',"
|
|
||||||
f" but got {input_dim} of '{input_name}' and {input_channels} of '{input_channels_name}'")
|
|
||||||
|
|
||||||
|
|
||||||
class BiDense(Cell):
|
class BiDense(Cell):
|
||||||
r"""
|
r"""
|
||||||
The bilinear dense connected layer.
|
The bilinear dense connected layer.
|
||||||
|
@ -171,9 +155,6 @@ class BiDense(Cell):
|
||||||
_check_is_tensor("input2", input2, self.cls_name)
|
_check_is_tensor("input2", input2, self.cls_name)
|
||||||
input1_shape = input1.shape
|
input1_shape = input1.shape
|
||||||
input2_shape = input2.shape
|
input2_shape = input2.shape
|
||||||
check_last_dimension(input1_shape[-1], self.in1_channels, "input1", "in1_channels", self.cls_name)
|
|
||||||
check_last_dimension(input2_shape[-1], self.in2_channels, "input2", "in2_channels", self.cls_name)
|
|
||||||
check_dense_inputs_same_shape(input1_shape, input2_shape, self.cls_name)
|
|
||||||
if len(input1_shape) != 2:
|
if len(input1_shape) != 2:
|
||||||
input1 = input1.reshape((-1, input1_shape[-1]))
|
input1 = input1.reshape((-1, input1_shape[-1]))
|
||||||
input2 = input2.reshape((-1, input2_shape[-1]))
|
input2 = input2.reshape((-1, input2_shape[-1]))
|
||||||
|
|
|
@ -39,13 +39,6 @@ from mindspore.nn.cell import Cell
|
||||||
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
|
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_input_2d(input_shape, param_name, func_name):
|
|
||||||
if len(input_shape) != 2:
|
|
||||||
raise ValueError(f"For '{func_name}', the dimension of '{param_name}' must be 2d, but got {len(input_shape)}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||||
|
@ -623,10 +616,6 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
||||||
self.negative_inf_value = -3.402823466E+38
|
self.negative_inf_value = -3.402823466E+38
|
||||||
|
|
||||||
def construct(self, input_indices, input_values, field_ids):
|
def construct(self, input_indices, input_values, field_ids):
|
||||||
|
|
||||||
_check_input_2d(F.shape(input_indices), "input_indices", self.cls_name)
|
|
||||||
_check_input_2d(F.shape(input_values), "input_values", self.cls_name)
|
|
||||||
_check_input_2d(F.shape(field_ids), "field_ids", self.cls_name)
|
|
||||||
_check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
|
_check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
|
||||||
_check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
|
_check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
|
||||||
_check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)
|
_check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)
|
||||||
|
|
|
@ -78,7 +78,6 @@ class ImageGradients(Cell):
|
||||||
super(ImageGradients, self).__init__()
|
super(ImageGradients, self).__init__()
|
||||||
|
|
||||||
def construct(self, images):
|
def construct(self, images):
|
||||||
check = _check_input_4d(F.shape(images), "images", self.cls_name)
|
|
||||||
images = F.depend(images, check)
|
images = F.depend(images, check)
|
||||||
batch_size, depth, height, width = P.Shape()(images)
|
batch_size, depth, height, width = P.Shape()(images)
|
||||||
if height == 1:
|
if height == 1:
|
||||||
|
@ -120,21 +119,6 @@ def _get_dtype_max(dtype):
|
||||||
return dtype_max
|
return dtype_max
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_input_4d(input_shape, param_name, func_name):
|
|
||||||
if len(input_shape) != 4:
|
|
||||||
raise ValueError(f"For '{func_name}', the dimension of '{param_name}' must be 4d, "
|
|
||||||
f"but got {len(input_shape)}.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
|
|
||||||
_check_input_4d(input_shape, param_name, func_name)
|
|
||||||
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
|
|
||||||
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||||
|
@ -281,7 +265,6 @@ class SSIM(Cell):
|
||||||
|
|
||||||
def construct(self, img1, img2):
|
def construct(self, img1, img2):
|
||||||
_check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
|
_check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
_check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
|
|
||||||
inner.SameTypeShape()(img1, img2)
|
inner.SameTypeShape()(img1, img2)
|
||||||
dtype_max_val = _get_dtype_max(F.dtype(img1))
|
dtype_max_val = _get_dtype_max(F.dtype(img1))
|
||||||
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
|
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
|
||||||
|
@ -387,8 +370,6 @@ class MSSSIM(Cell):
|
||||||
self.concat = P.Concat(axis=1)
|
self.concat = P.Concat(axis=1)
|
||||||
|
|
||||||
def construct(self, img1, img2):
|
def construct(self, img1, img2):
|
||||||
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
|
||||||
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
|
||||||
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8]
|
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8]
|
||||||
_check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name)
|
_check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name)
|
||||||
inner.SameTypeShape()(img1, img2)
|
inner.SameTypeShape()(img1, img2)
|
||||||
|
@ -466,8 +447,6 @@ class PSNR(Cell):
|
||||||
self.max_val = max_val
|
self.max_val = max_val
|
||||||
|
|
||||||
def construct(self, img1, img2):
|
def construct(self, img1, img2):
|
||||||
_check_input_4d(F.shape(img1), "img1", self.cls_name)
|
|
||||||
_check_input_4d(F.shape(img2), "img2", self.cls_name)
|
|
||||||
inner.SameTypeShape()(img1, img2)
|
inner.SameTypeShape()(img1, img2)
|
||||||
dtype_max_val = _get_dtype_max(F.dtype(img1))
|
dtype_max_val = _get_dtype_max(F.dtype(img1))
|
||||||
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
|
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
|
||||||
|
@ -481,22 +460,17 @@ class PSNR(Cell):
|
||||||
return psnr
|
return psnr
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _raise_dims_rank_error(input_shape, param_name, func_name):
|
|
||||||
"""raise error if input is not 3d or 4d"""
|
|
||||||
raise ValueError(f"{func_name} {param_name} must be 3d or 4d, but got shape {input_shape}")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _get_bbox(rank, shape, central_fraction):
|
def _get_bbox(rank, shape, central_fraction):
|
||||||
"""get bbox start and size for slice"""
|
"""get bbox start and size for slice"""
|
||||||
|
n, c, h, w = -1, -1, -1, -1
|
||||||
if rank == 3:
|
if rank == 3:
|
||||||
c, h, w = shape
|
c, h, w = shape
|
||||||
else:
|
else:
|
||||||
n, c, h, w = shape
|
n, c, h, w = shape
|
||||||
|
|
||||||
bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2)
|
bbox_h_start = int((float(h) - float(h * central_fraction)) / 2)
|
||||||
bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2)
|
bbox_w_start = int((float(w) - float(w * central_fraction)) / 2)
|
||||||
bbox_h_size = h - bbox_h_start * 2
|
bbox_h_size = h - bbox_h_start * 2
|
||||||
bbox_w_size = w - bbox_w_start * 2
|
bbox_w_size = w - bbox_w_start * 2
|
||||||
|
|
||||||
|
@ -548,8 +522,6 @@ class CentralCrop(Cell):
|
||||||
def construct(self, image):
|
def construct(self, image):
|
||||||
image_shape = F.shape(image)
|
image_shape = F.shape(image)
|
||||||
rank = len(image_shape)
|
rank = len(image_shape)
|
||||||
if rank not in (3, 4):
|
|
||||||
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
|
|
||||||
if self.central_fraction == 1.0:
|
if self.central_fraction == 1.0:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
|
@ -763,11 +763,6 @@ class LBeta(Cell):
|
||||||
@constexpr
|
@constexpr
|
||||||
def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
||||||
"""get broadcast_matmul shape"""
|
"""get broadcast_matmul shape"""
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if (len(x_shape) < 2) or (len(y_shape) < 2):
|
|
||||||
raise ValueError(f"{msg_prefix} length of 'x_shape' and 'y_shape' must be equal to or greater than 2, "
|
|
||||||
f"but got the length of 'x_shape': {len(x_shape)} and the length of 'y_shape': "
|
|
||||||
f"{len(y_shape)}.")
|
|
||||||
x_shape_batch = x_shape[:-2]
|
x_shape_batch = x_shape[:-2]
|
||||||
y_shape_batch = y_shape[:-2]
|
y_shape_batch = y_shape[:-2]
|
||||||
if x_shape_batch == y_shape_batch:
|
if x_shape_batch == y_shape_batch:
|
||||||
|
@ -783,10 +778,6 @@ def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
||||||
broadcast_shape_back.append(x_shape[i])
|
broadcast_shape_back.append(x_shape[i])
|
||||||
elif x_shape[i] == y_shape[i]:
|
elif x_shape[i] == y_shape[i]:
|
||||||
broadcast_shape_back.append(x_shape[i])
|
broadcast_shape_back.append(x_shape[i])
|
||||||
else:
|
|
||||||
raise ValueError(f"{msg_prefix} 'x_shape[{i}]' must be equal to 1, or the 'y_shape[{i}]' must be equal "
|
|
||||||
f"to 1, or the 'x_shape[{i}]' must be equal to 'y_shape[{i}]', but got "
|
|
||||||
f"'x_shape[{i}]': {x_shape[i]}, 'y_shape[{i}]': {y_shape[i]}.")
|
|
||||||
|
|
||||||
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
||||||
x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
|
x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
|
||||||
|
@ -794,25 +785,6 @@ def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
||||||
return x_broadcast_shape, y_broadcast_shape
|
return x_broadcast_shape, y_broadcast_shape
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_name=None):
|
|
||||||
"""check col and row equal"""
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if len(x1_shape) == 1:
|
|
||||||
transpose_x1 = False
|
|
||||||
x1_shape = (1,) + x1_shape
|
|
||||||
if len(x2_shape) == 1:
|
|
||||||
transpose_x2 = False
|
|
||||||
x2_shape = x2_shape + (1,)
|
|
||||||
x1_last = x1_shape[-2:]
|
|
||||||
x2_last = x2_shape[-2:]
|
|
||||||
x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
|
|
||||||
x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
|
|
||||||
if x1_col != x2_row:
|
|
||||||
raise ValueError(f"{msg_prefix} column of matrix dimensions of 'x1' must be equal to "
|
|
||||||
f"the row of matrix dimensions of 'x2', but got 'x1_col' {x1_col} and 'x2_row' {x2_row}.")
|
|
||||||
|
|
||||||
|
|
||||||
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
|
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
|
||||||
"""select matmul op"""
|
"""select matmul op"""
|
||||||
x1_dim, x2_dim = len(x1_shape), len(x2_shape)
|
x1_dim, x2_dim = len(x1_shape), len(x2_shape)
|
||||||
|
@ -857,7 +829,6 @@ class MatMul(Cell):
|
||||||
def construct(self, x1, x2):
|
def construct(self, x1, x2):
|
||||||
x1_shape = self.shape_op(x1)
|
x1_shape = self.shape_op(x1)
|
||||||
x2_shape = self.shape_op(x2)
|
x2_shape = self.shape_op(x2)
|
||||||
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2, self.cls_name)
|
|
||||||
matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
|
matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
|
||||||
|
|
||||||
x1_dim, x2_dim = len(x1_shape), len(x2_shape)
|
x1_dim, x2_dim = len(x1_shape), len(x2_shape)
|
||||||
|
|
|
@ -121,13 +121,8 @@ class _BatchNorm(Cell):
|
||||||
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
|
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
|
||||||
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
|
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
self._check_input_dim(self.shape(x), self.cls_name)
|
|
||||||
if self.use_batch_statistics is None:
|
if self.use_batch_statistics is None:
|
||||||
if self.training:
|
if self.training:
|
||||||
return self.bn_train(x,
|
return self.bn_train(x,
|
||||||
|
@ -227,13 +222,6 @@ class BatchNorm1d(_BatchNorm):
|
||||||
[ 0.4999975 0.399998 0.59999704 0.89999545 ]]
|
[ 0.4999975 0.399998 0.59999704 0.89999545 ]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim != 2:
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 2 dims, but got {dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm2d(_BatchNorm):
|
class BatchNorm2d(_BatchNorm):
|
||||||
r"""
|
r"""
|
||||||
|
@ -319,13 +307,6 @@ class BatchNorm2d(_BatchNorm):
|
||||||
[ 0.999995 0.999995 ]]]]
|
[ 0.999995 0.999995 ]]]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim != 4:
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm3d(Cell):
|
class BatchNorm3d(Cell):
|
||||||
r"""
|
r"""
|
||||||
|
@ -413,16 +394,9 @@ class BatchNorm3d(Cell):
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim != 5:
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x_shape = self.shape(x)
|
x_shape = self.shape(x)
|
||||||
self._check_input_dim(x_shape, self.cls_name)
|
|
||||||
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
||||||
bn2d_out = self.bn2d(x)
|
bn2d_out = self.bn2d(x)
|
||||||
bn3d_out = self.reshape(bn2d_out, x_shape)
|
bn3d_out = self.reshape(bn2d_out, x_shape)
|
||||||
|
@ -586,12 +560,6 @@ class SyncBatchNorm(_BatchNorm):
|
||||||
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
||||||
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim not in (2, 4):
|
|
||||||
raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
|
|
||||||
|
|
||||||
def _check_rank_ids(self, process_groups, rank_size):
|
def _check_rank_ids(self, process_groups, rank_size):
|
||||||
seen = set()
|
seen = set()
|
||||||
|
@ -729,7 +697,6 @@ class _InstanceNorm(Cell):
|
||||||
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
|
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
self._check_input_dim(self.shape(x), self.cls_name)
|
|
||||||
return self.instance_bn(x,
|
return self.instance_bn(x,
|
||||||
self.gamma,
|
self.gamma,
|
||||||
self.beta,
|
self.beta,
|
||||||
|
@ -822,13 +789,6 @@ class InstanceNorm1d(_InstanceNorm):
|
||||||
(2, 3, 5)
|
(2, 3, 5)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim != 3:
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 3 dims, but got {dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceNorm2d(_InstanceNorm):
|
class InstanceNorm2d(_InstanceNorm):
|
||||||
r"""
|
r"""
|
||||||
|
@ -901,13 +861,6 @@ class InstanceNorm2d(_InstanceNorm):
|
||||||
(2, 3, 2, 2)
|
(2, 3, 2, 2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim != 4:
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class InstanceNorm3d(_InstanceNorm):
|
class InstanceNorm3d(_InstanceNorm):
|
||||||
r"""
|
r"""
|
||||||
|
@ -979,12 +932,6 @@ class InstanceNorm3d(_InstanceNorm):
|
||||||
>>> print(output.shape)
|
>>> print(output.shape)
|
||||||
(2, 3, 5, 2, 2)
|
(2, 3, 5, 2, 2)
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim != 5:
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(Cell):
|
class GroupNorm(Cell):
|
||||||
|
@ -1068,7 +1015,6 @@ class GroupNorm(Cell):
|
||||||
def _cal_output(self, x):
|
def _cal_output(self, x):
|
||||||
"""calculate groupnorm output"""
|
"""calculate groupnorm output"""
|
||||||
batch, channel, height, width = self.shape(x)
|
batch, channel, height, width = self.shape(x)
|
||||||
self._channel_check(channel, self.num_channels, self.cls_name)
|
|
||||||
x = self.reshape(x, (batch, self.num_groups, -1))
|
x = self.reshape(x, (batch, self.num_groups, -1))
|
||||||
mean = self.reduce_mean(x, 2)
|
mean = self.reduce_mean(x, 2)
|
||||||
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
|
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
|
||||||
|
@ -1078,21 +1024,6 @@ class GroupNorm(Cell):
|
||||||
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
|
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim != 4:
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _channel_check(channel, num_channel, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if channel != num_channel:
|
|
||||||
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
|
|
||||||
f"but got channel: {channel}, num_channels: {num_channel}.")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_dtype(dtype, valid_dtypes, prim_name=None):
|
def _check_dtype(dtype, valid_dtypes, prim_name=None):
|
||||||
|
@ -1102,7 +1033,6 @@ class GroupNorm(Cell):
|
||||||
return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
|
return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
self._check_input_dim(self.shape(x), self.cls_name)
|
|
||||||
self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
|
self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||||
output = self._cal_output(x)
|
output = self._cal_output(x)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -31,9 +31,6 @@ def _check_padding_dimension(dimension, padding):
|
||||||
Validate the input padding and add placeholders if needed.
|
Validate the input padding and add placeholders if needed.
|
||||||
Note: the input 'padding' in this function is already converted to list of lists to match MirrorPad
|
Note: the input 'padding' in this function is already converted to list of lists to match MirrorPad
|
||||||
"""
|
"""
|
||||||
if dimension < len(padding):
|
|
||||||
raise ValueError(f"For padding with length {len(padding) * 2}, the dimension of the tensor should be at least "
|
|
||||||
f"{len(padding)}, but got {dimension}")
|
|
||||||
# add place holders
|
# add place holders
|
||||||
if dimension > len(padding):
|
if dimension > len(padding):
|
||||||
padding = [(0, 0) for _ in range(dimension - len(padding))] + [x for x in padding]
|
padding = [(0, 0) for _ in range(dimension - len(padding))] + [x for x in padding]
|
||||||
|
@ -56,45 +53,16 @@ def _swap_to_ms_padding_order(padding):
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check(input_shape, padding, name):
|
def _check(input_shape, padding):
|
||||||
"""
|
"""
|
||||||
Check relationship between input shape and padding to make sure after negative dimension padding the out is
|
Check relationship between input shape and padding to make sure after negative dimension padding the out is
|
||||||
positive.
|
positive.
|
||||||
"""
|
"""
|
||||||
if len(input_shape) < len(padding):
|
|
||||||
msg = "For '{}', the dimension of input must more than or equal to len(padding)/2, " \
|
|
||||||
"but got {}".format(name, len(input_shape))
|
|
||||||
raise ValueError(msg)
|
|
||||||
if len(input_shape) > len(padding):
|
if len(input_shape) > len(padding):
|
||||||
if len(padding) == 2 and isinstance(padding[0], int):
|
if len(padding) == 2 and isinstance(padding[0], int):
|
||||||
padding = [(0, 0) for i in range(len(input_shape) - 1)] + [padding]
|
padding = [(0, 0) for i in range(len(input_shape) - 1)] + [padding]
|
||||||
else:
|
else:
|
||||||
padding = [(0, 0) for i in range(len(input_shape) - len(padding))] + [x for x in padding]
|
padding = [(0, 0) for i in range(len(input_shape) - len(padding))] + [x for x in padding]
|
||||||
for index, item in enumerate(padding):
|
|
||||||
if index == 0:
|
|
||||||
dim_name = '1st'
|
|
||||||
elif index == 1:
|
|
||||||
dim_name = '2nd'
|
|
||||||
elif index == 2:
|
|
||||||
dim_name = '3rd'
|
|
||||||
else:
|
|
||||||
dim_name = str(index + 1) + 'th'
|
|
||||||
|
|
||||||
if item[0] < -input_shape[index]:
|
|
||||||
msg = "For '{}', the shape of input after padding must be positive, the input shape is {}, " \
|
|
||||||
"value of parameter 'padding' applied to the {} dimension of input must " \
|
|
||||||
"no less than -{}, but got {}".format(name, input_shape, dim_name, input_shape[index], item[0])
|
|
||||||
raise ValueError(msg)
|
|
||||||
if item[1] < -input_shape[index]:
|
|
||||||
msg = "For '{}', the shape of input after padding must be positive, the input shape is {}, " \
|
|
||||||
"value of parameter 'padding' applied to the {} dimension of input must " \
|
|
||||||
"no less than -{}, but got {}".format(name, input_shape, dim_name, input_shape[index], item[1])
|
|
||||||
raise ValueError(msg)
|
|
||||||
if input_shape[index] + item[0] + item[1] <= 0:
|
|
||||||
msg = "For '{}', the shape of input after padding must be positive, the input shape is {}, " \
|
|
||||||
"but the {} dimension of input shape {} plus padding {} and {} resulted in a non-positive output " \
|
|
||||||
"shape.".format(name, input_shape, dim_name, input_shape[index], item[0], item[1])
|
|
||||||
raise ValueError(msg)
|
|
||||||
return padding
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,7 +167,7 @@ class _ConstantPadNd(Cell):
|
||||||
"""Construct the pad net."""
|
"""Construct the pad net."""
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
input_type = x.dtype
|
input_type = x.dtype
|
||||||
padding = _check(input_shape, self.padding, self._name)
|
padding = _check(input_shape, self.padding)
|
||||||
new_padding, start, end = _get_new_padding(padding)
|
new_padding, start, end = _get_new_padding(padding)
|
||||||
mask = ops.Ones()(input_shape, input_type)
|
mask = ops.Ones()(input_shape, input_type)
|
||||||
output = ops.Pad(new_padding)(x)
|
output = ops.Pad(new_padding)(x)
|
||||||
|
@ -671,10 +639,6 @@ class _ReplicationPadNd(Cell):
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.padv3 = nn_ops.PadV3(mode="edge")
|
self.padv3 = nn_ops.PadV3(mode="edge")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -682,7 +646,6 @@ class _ReplicationPadNd(Cell):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
self._check_input_dim(x.shape, self.name)
|
|
||||||
need_expend_dims = self._need_expend_dim(x)
|
need_expend_dims = self._need_expend_dim(x)
|
||||||
if need_expend_dims:
|
if need_expend_dims:
|
||||||
x = x.expand_dims(0)
|
x = x.expand_dims(0)
|
||||||
|
@ -743,12 +706,6 @@ class ReplicationPad1d(_ReplicationPadNd):
|
||||||
padding = (padding, padding)
|
padding = (padding, padding)
|
||||||
super(ReplicationPad1d, self).__init__(padding, name="ReplicationPad1d")
|
super(ReplicationPad1d, self).__init__(padding, name="ReplicationPad1d")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim not in (2, 3):
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 2 or 3 dims, but got {dim}.")
|
|
||||||
|
|
||||||
def _need_expend_dim(self, x):
|
def _need_expend_dim(self, x):
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
|
@ -814,12 +771,6 @@ class ReplicationPad2d(_ReplicationPadNd):
|
||||||
padding = (padding, padding, padding, padding)
|
padding = (padding, padding, padding, padding)
|
||||||
super(ReplicationPad2d, self).__init__(padding, name="ReplicationPad2d")
|
super(ReplicationPad2d, self).__init__(padding, name="ReplicationPad2d")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim not in (3, 4):
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 3 or 4 dims, but got {dim}.")
|
|
||||||
|
|
||||||
def _need_expend_dim(self, x):
|
def _need_expend_dim(self, x):
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
|
@ -885,12 +836,6 @@ class ReplicationPad3d(_ReplicationPadNd):
|
||||||
padding = (padding, padding, padding, padding, padding, padding)
|
padding = (padding, padding, padding, padding, padding, padding)
|
||||||
super(ReplicationPad3d, self).__init__(padding, name="ReplicationPad3d")
|
super(ReplicationPad3d, self).__init__(padding, name="ReplicationPad3d")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@constexpr
|
|
||||||
def _check_input_dim(shape, cls_name):
|
|
||||||
dim = len(shape)
|
|
||||||
if dim not in (4, 5):
|
|
||||||
raise ValueError(f"For '{cls_name}', the in_shape must have 4 or 5 dims, but got {dim}.")
|
|
||||||
|
|
||||||
def _need_expend_dim(self, x):
|
def _need_expend_dim(self, x):
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
|
|
|
@ -73,13 +73,6 @@ class _PoolNd(Cell):
|
||||||
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
|
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _shape_check(in_shape, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if len(in_shape) != 3:
|
|
||||||
raise ValueError(f"{msg_prefix} input must has 3 dim, but got {len(in_shape)}")
|
|
||||||
|
|
||||||
|
|
||||||
class LPPool1d(Cell):
|
class LPPool1d(Cell):
|
||||||
r"""
|
r"""
|
||||||
Applies a 1D power lp pooling over an input signal composed of several input planes.
|
Applies a 1D power lp pooling over an input signal composed of several input planes.
|
||||||
|
@ -489,7 +482,6 @@ class MaxPool1d(_PoolNd):
|
||||||
self.squeeze = P.Squeeze(2)
|
self.squeeze = P.Squeeze(2)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_shape_check(self.shape(x), self.cls_name)
|
|
||||||
x = self.expand(x, 2)
|
x = self.expand(x, 2)
|
||||||
output = self.max_pool(x)
|
output = self.max_pool(x)
|
||||||
output = self.squeeze(output)
|
output = self.squeeze(output)
|
||||||
|
@ -743,7 +735,6 @@ class AvgPool1d(_PoolNd):
|
||||||
self.squeeze = P.Squeeze(2)
|
self.squeeze = P.Squeeze(2)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = F.depend(x, _shape_check(self.shape(x), self.cls_name))
|
|
||||||
batch, channel, width = self.shape(x)
|
batch, channel, width = self.shape(x)
|
||||||
if width == self.kernel_size[1]:
|
if width == self.kernel_size[1]:
|
||||||
x = self.reduce_mean(x, 2)
|
x = self.reduce_mean(x, 2)
|
||||||
|
@ -757,20 +748,6 @@ class AvgPool1d(_PoolNd):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _adaptive_shape_check(in_shape, output_size, prim_name):
|
|
||||||
"""Check shape."""
|
|
||||||
msg_prefix = "For {}, the".format(prim_name)
|
|
||||||
if len(in_shape) != 3:
|
|
||||||
raise ValueError("{} input must has 3 dim, but got {}.".format(msg_prefix, len(in_shape)))
|
|
||||||
if in_shape[2] < output_size:
|
|
||||||
raise ValueError("{} input's last dimension must be greater or equal to "
|
|
||||||
"output size {}, but got {}.".format(msg_prefix, output_size, in_shape[2]))
|
|
||||||
if in_shape[2] % output_size != 0:
|
|
||||||
raise ValueError("{} input's last dimension must be divisible by "
|
|
||||||
"output size {}, but got {}.".format(msg_prefix, output_size, in_shape[2]))
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _adaptive_dtype_check(x_dtype, prim_name):
|
def _adaptive_dtype_check(x_dtype, prim_name):
|
||||||
"""Check dtype."""
|
"""Check dtype."""
|
||||||
|
@ -837,7 +814,6 @@ class AdaptiveAvgPool1d(Cell):
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_adaptive_shape_check(self.shape(x), self.output_size, self.cls_name)
|
|
||||||
_adaptive_dtype_check(self.dtype(x), self.cls_name)
|
_adaptive_dtype_check(self.dtype(x), self.cls_name)
|
||||||
|
|
||||||
_, _, width = self.shape(x)
|
_, _, width = self.shape(x)
|
||||||
|
@ -1052,7 +1028,6 @@ class AdaptiveMaxPool1d(Cell):
|
||||||
self.dtype = P.DType()
|
self.dtype = P.DType()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_adaptive_shape_check(self.shape(x), self.output_size, self.cls_name)
|
|
||||||
_adaptive_dtype_check(self.dtype(x), self.cls_name)
|
_adaptive_dtype_check(self.dtype(x), self.cls_name)
|
||||||
|
|
||||||
_, _, width = self.shape(x)
|
_, _, width = self.shape(x)
|
||||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.nn as nn
|
||||||
import mindspore.ops as P
|
import mindspore.ops as P
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.parameter import ParameterTuple, Parameter
|
from mindspore.common.parameter import ParameterTuple, Parameter
|
||||||
|
@ -446,6 +447,17 @@ class _RNNBase(Cell):
|
||||||
self.b_ih_list = ParameterTuple(self.b_ih_list)
|
self.b_ih_list = ParameterTuple(self.b_ih_list)
|
||||||
self.b_hh_list = ParameterTuple(self.b_hh_list)
|
self.b_hh_list = ParameterTuple(self.b_hh_list)
|
||||||
|
|
||||||
|
# TODO: remove this func
|
||||||
|
def _shape_dynamic(self, shape):
|
||||||
|
"""use this func for dynamic. del it when ShapeOp is supported"""
|
||||||
|
x = []
|
||||||
|
for i in shape:
|
||||||
|
if not F.isconstant(i):
|
||||||
|
x.append(-1)
|
||||||
|
else:
|
||||||
|
x.append(i)
|
||||||
|
return tuple(x)
|
||||||
|
|
||||||
def _stacked_bi_dynamic_rnn(self, x, h, seq_length):
|
def _stacked_bi_dynamic_rnn(self, x, h, seq_length):
|
||||||
"""stacked bidirectional dynamic_rnn"""
|
"""stacked bidirectional dynamic_rnn"""
|
||||||
pre_layer = x
|
pre_layer = x
|
||||||
|
@ -491,9 +503,11 @@ class _RNNBase(Cell):
|
||||||
if self.is_lstm:
|
if self.is_lstm:
|
||||||
h_n = P.Concat(0)(h_n)
|
h_n = P.Concat(0)(h_n)
|
||||||
c_n = P.Concat(0)(c_n)
|
c_n = P.Concat(0)(c_n)
|
||||||
h_n = h_n.view(h[0].shape)
|
h0_shape = self._shape_dynamic(h[0].shape)
|
||||||
c_n = c_n.view(h[1].shape)
|
h1_shape = self._shape_dynamic(h[1].shape)
|
||||||
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
|
h_n = h_n.view(h0_shape)
|
||||||
|
c_n = c_n.view(h1_shape)
|
||||||
|
return output, (h_n.view(h0_shape), c_n.view(h1_shape))
|
||||||
h_n = P.Concat(0)(h_n)
|
h_n = P.Concat(0)(h_n)
|
||||||
return output, h_n.view(h.shape)
|
return output, h_n.view(h.shape)
|
||||||
|
|
||||||
|
@ -523,9 +537,11 @@ class _RNNBase(Cell):
|
||||||
if self.is_lstm:
|
if self.is_lstm:
|
||||||
h_n = P.Concat(0)(h_n)
|
h_n = P.Concat(0)(h_n)
|
||||||
c_n = P.Concat(0)(c_n)
|
c_n = P.Concat(0)(c_n)
|
||||||
h_n = h_n.view(h[0].shape)
|
h0_shape = self._shape_dynamic(h[0].shape)
|
||||||
c_n = c_n.view(h[1].shape)
|
h1_shape = self._shape_dynamic(h[1].shape)
|
||||||
return output, (h_n.view(h[0].shape), c_n.view(h[1].shape))
|
h_n = h_n.view(h0_shape)
|
||||||
|
c_n = c_n.view(h1_shape)
|
||||||
|
return output, (h_n.view(h0_shape), c_n.view(h1_shape))
|
||||||
h_n = P.Concat(0)(h_n)
|
h_n = P.Concat(0)(h_n)
|
||||||
return output, h_n.view(h.shape)
|
return output, h_n.view(h.shape)
|
||||||
|
|
||||||
|
|
|
@ -33,9 +33,9 @@ from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \
|
||||||
_set_rank_id, _insert_hash_table_size, _set_cache_enable
|
_set_rank_id, _insert_hash_table_size, _set_cache_enable
|
||||||
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.ops.primitive import constexpr
|
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.nn.layer.basic import ClipByNorm
|
from mindspore.nn.layer.basic import ClipByNorm
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
|
|
||||||
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']
|
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']
|
||||||
|
|
||||||
|
|
|
@ -24,25 +24,6 @@ from mindspore.nn.cell import Cell
|
||||||
__all__ = ['TimeDistributed']
|
__all__ = ['TimeDistributed']
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]:
|
|
||||||
raise ValueError(f"{msg_prefix} 'reshape_with_axis' is invalid in the input and output. "
|
|
||||||
f"The 'reshape_pos' must be less than the length of 'outputs_shape', and the "
|
|
||||||
f"'inputs_shape[reshape_pos]' must be equal to 'outputs_shape[reshape_pos]', but got "
|
|
||||||
f"'reshape_pos': {reshape_pos}, 'inputs_shape': {inputs_shape}, 'outputs_shape': "
|
|
||||||
f"{outputs_shape}. You may try pass parameters without 'reshape_with_axis'.")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_expand_dims_axis(time_axis, ndim, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if time_axis > ndim:
|
|
||||||
raise ValueError(f"{msg_prefix} value of 'time_axis' must be in range of [{-ndim - 1}, {ndim}], "
|
|
||||||
f"but got {time_axis}.")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _generate_perm(axis_a, axis_b, length):
|
def _generate_perm(axis_a, axis_b, length):
|
||||||
perm = tuple(range(length))
|
perm = tuple(range(length))
|
||||||
|
@ -57,13 +38,6 @@ def _check_data(flag, prim_name=None):
|
||||||
raise TypeError(f"{msg_prefix} inputs and outputs must be a Tensor.")
|
raise TypeError(f"{msg_prefix} inputs and outputs must be a Tensor.")
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_inputs_dim(shape, prim_name=None):
|
|
||||||
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
||||||
if len(shape) < 3:
|
|
||||||
raise ValueError(f"{msg_prefix} inputs shape must be at least 3D, but got {len(shape)}.")
|
|
||||||
|
|
||||||
|
|
||||||
class TimeDistributed(Cell):
|
class TimeDistributed(Cell):
|
||||||
r"""
|
r"""
|
||||||
The time distributed layer.
|
The time distributed layer.
|
||||||
|
@ -119,7 +93,6 @@ class TimeDistributed(Cell):
|
||||||
|
|
||||||
def construct(self, inputs):
|
def construct(self, inputs):
|
||||||
_check_data(isinstance(inputs, Tensor), self.cls_name)
|
_check_data(isinstance(inputs, Tensor), self.cls_name)
|
||||||
_check_inputs_dim(inputs.shape, self.cls_name)
|
|
||||||
time_axis = self.time_axis % len(inputs.shape)
|
time_axis = self.time_axis % len(inputs.shape)
|
||||||
if self.reshape_with_axis is not None:
|
if self.reshape_with_axis is not None:
|
||||||
reshape_with_axis = self.reshape_with_axis % len(inputs.shape)
|
reshape_with_axis = self.reshape_with_axis % len(inputs.shape)
|
||||||
|
@ -134,7 +107,6 @@ class TimeDistributed(Cell):
|
||||||
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
|
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
|
||||||
outputs = self.layer(inputs)
|
outputs = self.layer(inputs)
|
||||||
_check_data(isinstance(outputs, Tensor), self.cls_name)
|
_check_data(isinstance(outputs, Tensor), self.cls_name)
|
||||||
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape, self.cls_name)
|
|
||||||
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
|
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
|
||||||
if reshape_pos + 1 < len(outputs.shape):
|
if reshape_pos + 1 < len(outputs.shape):
|
||||||
outputs_shape_new += outputs.shape[reshape_pos + 1:]
|
outputs_shape_new += outputs.shape[reshape_pos + 1:]
|
||||||
|
@ -147,7 +119,6 @@ class TimeDistributed(Cell):
|
||||||
for item in inputs:
|
for item in inputs:
|
||||||
outputs = self.layer(item)
|
outputs = self.layer(item)
|
||||||
_check_data(isinstance(outputs, Tensor), self.cls_name)
|
_check_data(isinstance(outputs, Tensor), self.cls_name)
|
||||||
_check_expand_dims_axis(time_axis, outputs.ndim, self.cls_name)
|
|
||||||
y += (outputs,)
|
y += (outputs,)
|
||||||
y = Stack(time_axis)(y)
|
y = Stack(time_axis)(y)
|
||||||
return y
|
return y
|
||||||
|
|
|
@ -809,7 +809,6 @@ class DiceLoss(LossBase):
|
||||||
def construct(self, logits, label):
|
def construct(self, logits, label):
|
||||||
_check_is_tensor('logits', logits, self.cls_name)
|
_check_is_tensor('logits', logits, self.cls_name)
|
||||||
_check_is_tensor('labels', label, self.cls_name)
|
_check_is_tensor('labels', label, self.cls_name)
|
||||||
_check_shape(logits.shape, label.shape, self.cls_name)
|
|
||||||
if logits.dtype == mstype.uint8:
|
if logits.dtype == mstype.uint8:
|
||||||
raise TypeError(f"For '{self.cls_name}', the dtype of 'logits' can not be uint8.")
|
raise TypeError(f"For '{self.cls_name}', the dtype of 'logits' can not be uint8.")
|
||||||
if label.dtype == mstype.uint8:
|
if label.dtype == mstype.uint8:
|
||||||
|
@ -824,31 +823,6 @@ class DiceLoss(LossBase):
|
||||||
return dice_loss
|
return dice_loss
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_shape(logits_shape, label_shape, prim_name=None):
|
|
||||||
"""Internal function, used to check whether the shape of logits and labels meets the requirements."""
|
|
||||||
validator.check('logits_shape', logits_shape, 'label_shape', label_shape, prim_name=prim_name)
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_ndim_multi(logits_dim, label_dim, prim_name=None):
|
|
||||||
"""Internal function, used to check whether the dimension of logits and label meets the requirements."""
|
|
||||||
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
|
||||||
if logits_dim < 2:
|
|
||||||
raise ValueError(f"{msg_prefix} 'logits' dimension must be greater than 1, but got {logits_dim}.")
|
|
||||||
if label_dim < 2:
|
|
||||||
raise ValueError(f"{msg_prefix} 'labels' dimension must be greater than 1, but got {label_dim}.")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_weights(weight_shape, label_shape, prim_name=None):
|
|
||||||
"""Internal function, used to check whether the reduced shape meets the requirements."""
|
|
||||||
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
|
||||||
if weight_shape != label_shape:
|
|
||||||
raise ValueError(f"{msg_prefix} weight_shape[0] must be equal to label_shape[1], "
|
|
||||||
f"but got weight_shape[0]: {weight_shape} and label_shape[1]: {label_shape}.")
|
|
||||||
|
|
||||||
|
|
||||||
class MultiClassDiceLoss(LossBase):
|
class MultiClassDiceLoss(LossBase):
|
||||||
r"""
|
r"""
|
||||||
When there are multiple classifications, label is transformed into multiple binary classifications by one hot.
|
When there are multiple classifications, label is transformed into multiple binary classifications by one hot.
|
||||||
|
@ -919,8 +893,6 @@ class MultiClassDiceLoss(LossBase):
|
||||||
def construct(self, logits, label):
|
def construct(self, logits, label):
|
||||||
_check_is_tensor('logits', logits, self.cls_name)
|
_check_is_tensor('logits', logits, self.cls_name)
|
||||||
_check_is_tensor('labels', label, self.cls_name)
|
_check_is_tensor('labels', label, self.cls_name)
|
||||||
_check_shape(logits.shape, label.shape, self.cls_name)
|
|
||||||
_check_ndim_multi(logits.ndim, label.ndim, self.cls_name)
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
if self.activation is not None:
|
if self.activation is not None:
|
||||||
|
@ -930,7 +902,6 @@ class MultiClassDiceLoss(LossBase):
|
||||||
if i != self.ignore_indiex:
|
if i != self.ignore_indiex:
|
||||||
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
|
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
|
||||||
if self.weights is not None:
|
if self.weights is not None:
|
||||||
_check_weights(self.weights.shape[0], label.shape[1], self.cls_name)
|
|
||||||
dice_loss *= self.weights[i]
|
dice_loss *= self.weights[i]
|
||||||
total_loss += dice_loss
|
total_loss += dice_loss
|
||||||
|
|
||||||
|
@ -1462,12 +1433,6 @@ class BCELoss(LossBase):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name, arg_name1, arg_name2):
|
|
||||||
"""Internal function, used to check whether the reduced shape meets the requirements."""
|
|
||||||
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name, arg_name1, arg_name2)
|
|
||||||
|
|
||||||
|
|
||||||
class CosineEmbeddingLoss(LossBase):
|
class CosineEmbeddingLoss(LossBase):
|
||||||
r"""
|
r"""
|
||||||
CosineEmbeddingLoss creates a criterion to measure the similarity between two tensors using cosine distance.
|
CosineEmbeddingLoss creates a criterion to measure the similarity between two tensors using cosine distance.
|
||||||
|
@ -1527,7 +1492,6 @@ class CosineEmbeddingLoss(LossBase):
|
||||||
_check_is_tensor('logits_x2', logits_x2, self.cls_name)
|
_check_is_tensor('logits_x2', logits_x2, self.cls_name)
|
||||||
_check_is_tensor('labels', labels, self.cls_name)
|
_check_is_tensor('labels', labels, self.cls_name)
|
||||||
inner.same_type_shape_(logits_x1, logits_x2)
|
inner.same_type_shape_(logits_x1, logits_x2)
|
||||||
_check_reduced_shape_valid(F.shape(logits_x1), F.shape(labels), (1,), self.cls_name, "logits_x1", "labels")
|
|
||||||
# if labels > 0, 1-cosine(logits_x1, logits_x2)
|
# if labels > 0, 1-cosine(logits_x1, logits_x2)
|
||||||
# else, max(0, cosine(logits_x1, logits_x2)-margin)
|
# else, max(0, cosine(logits_x1, logits_x2)-margin)
|
||||||
prod_sum = self.reduce_sum(logits_x1 * logits_x2, (1,))
|
prod_sum = self.reduce_sum(logits_x1 * logits_x2, (1,))
|
||||||
|
@ -1705,32 +1669,6 @@ class BCEWithLogitsLoss(LossBase):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_ndim(logits_nidm, labels_ndim, prime_name=None):
|
|
||||||
'''Internal function, used to check whether the dimension of logits and labels meets the requirements.'''
|
|
||||||
msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The"
|
|
||||||
if logits_nidm < 2 or logits_nidm > 4:
|
|
||||||
raise ValueError(f"{msg_prefix} dimensions of 'logits' must be in [2, 4], but got "
|
|
||||||
f"dimension of 'logits' {logits_nidm}.")
|
|
||||||
if labels_ndim < 2 or labels_ndim > 4:
|
|
||||||
raise ValueError(f"{msg_prefix} dimensions of 'labels' must be in [2, 4], but got "
|
|
||||||
f"dimension of 'labels' {labels_ndim}.")
|
|
||||||
if logits_nidm != labels_ndim:
|
|
||||||
raise ValueError(f"{msg_prefix} dimensions of 'logits' and 'labels' must be equal, but got "
|
|
||||||
f"dimension of 'logits' {logits_nidm} and dimension of 'labels' {labels_ndim}.")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_channel_and_shape(logits, labels, prime_name=None):
|
|
||||||
'''Internal function, used to check whether the channels or shape of logits and labels meets the requirements.'''
|
|
||||||
msg_prefix = f'For \'{prime_name}\', the' if prime_name else "The"
|
|
||||||
if logits == 1:
|
|
||||||
raise ValueError(f"{msg_prefix} 'logits'.shape[1] cannot be one, but got {logits}.")
|
|
||||||
if labels not in (1, logits):
|
|
||||||
raise ValueError(f"{msg_prefix} 'labels'.shape[1] must be one or equal to 'logits'.shape[1]: {logits}, "
|
|
||||||
f"but got {labels}.")
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_dtype(labels_dtype, cls_name):
|
def _check_input_dtype(labels_dtype, cls_name):
|
||||||
"""Internal function, used to check whether the data type of labels meets the requirements."""
|
"""Internal function, used to check whether the data type of labels meets the requirements."""
|
||||||
|
@ -1814,8 +1752,6 @@ class FocalLoss(LossBase):
|
||||||
_check_is_tensor('logits', logits, self.cls_name)
|
_check_is_tensor('logits', logits, self.cls_name)
|
||||||
_check_is_tensor('labels', labels, self.cls_name)
|
_check_is_tensor('labels', labels, self.cls_name)
|
||||||
labelss = labels
|
labelss = labels
|
||||||
_check_ndim(logits.ndim, labelss.ndim, self.cls_name)
|
|
||||||
_check_channel_and_shape(logits.shape[1], labelss.shape[1], self.cls_name)
|
|
||||||
_check_input_dtype(self.dtype(labelss), self.cls_name)
|
_check_input_dtype(self.dtype(labelss), self.cls_name)
|
||||||
|
|
||||||
if logits.ndim > 2:
|
if logits.ndim > 2:
|
||||||
|
|
|
@ -150,7 +150,8 @@ def test_cummin():
|
||||||
def construct(self, a):
|
def construct(self, a):
|
||||||
return self.func(a, 0)
|
return self.func(a, 0)
|
||||||
|
|
||||||
a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220], ms.float32)
|
a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680,
|
||||||
|
-1.3298, -0.4220], ms.float32)
|
||||||
expect = Tensor(np.array(
|
expect = Tensor(np.array(
|
||||||
[-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298]), ms.float32)
|
[-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298]), ms.float32)
|
||||||
net = Net()
|
net = Net()
|
||||||
|
@ -235,6 +236,32 @@ def test_calculate_expert_capacity():
|
||||||
assert net(10.1, 2.0, 3.3, 4) == 17
|
assert net(10.1, 2.0, 3.3, 4) == 17
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_unsqueeze():
|
||||||
|
"""
|
||||||
|
Feature: unsqueeze func
|
||||||
|
Description: Verify the result of unsqueeze
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import unsqueeze
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = unsqueeze
|
||||||
|
|
||||||
|
def construct(self, x, dim):
|
||||||
|
return self.func(x, dim)
|
||||||
|
x = Tensor([[4.0, 9.0, 2.0, 10.0]]).astype("float32")
|
||||||
|
net = Net()
|
||||||
|
x2 = net(x, 0)
|
||||||
|
assert x2.shape == (1, 1, 4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level1
|
@pytest.mark.level1
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@ -248,6 +275,7 @@ def test_infer_out_shape():
|
||||||
"""
|
"""
|
||||||
from mindspore.numpy.utils_const import _infer_out_shape
|
from mindspore.numpy.utils_const import _infer_out_shape
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
@ -259,6 +287,32 @@ def test_infer_out_shape():
|
||||||
assert net((5,), (6, 1), (7, 1, 5), (8, 1, 6, 1)) == (8, 7, 6, 5)
|
assert net((5,), (6, 1), (7, 1, 5), (8, 1, 6, 1)) == (8, 7, 6, 5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_tensor_bool():
|
||||||
|
"""
|
||||||
|
Feature: tensor_bool func
|
||||||
|
Description: Verify the result of tensor_bool
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import tensor_bool
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = tensor_bool
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
x = Tensor([4.0]).astype("float32")
|
||||||
|
net = Net()
|
||||||
|
x2 = net(x)
|
||||||
|
assert bool(x2) is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level1
|
@pytest.mark.level1
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@ -281,3 +335,543 @@ def test_canonicalize_axis():
|
||||||
return self.func(axis, ndim)
|
return self.func(axis, ndim)
|
||||||
net = Net()
|
net = Net()
|
||||||
assert net(0, 2) == 0
|
assert net(0, 2) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_top_k():
|
||||||
|
"""
|
||||||
|
Feature: top_k func
|
||||||
|
Description: Verify the result of top_k
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import top_k
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = top_k
|
||||||
|
|
||||||
|
def construct(self, x, k):
|
||||||
|
return self.func(x, k)
|
||||||
|
x = Tensor([4.0, 9.0, 2.0, 10.0]).astype("float32")
|
||||||
|
net = Net()
|
||||||
|
output = net(x, 3)
|
||||||
|
expect = ([10.0, 9.0, 4.0], [3, 1, 0])
|
||||||
|
assert np.allclose(output[0].asnumpy(), expect[0])
|
||||||
|
assert np.allclose(output[1].asnumpy(), expect[1])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_bernoulli():
|
||||||
|
"""
|
||||||
|
Feature: bernoulli func
|
||||||
|
Description: Verify the result of bernoulli
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import bernoulli
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = bernoulli
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
x = Tensor(4).astype("float32")
|
||||||
|
print(x)
|
||||||
|
net = Net()
|
||||||
|
output = net(x)
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_view():
|
||||||
|
"""
|
||||||
|
Feature: view func
|
||||||
|
Description: Verify the result of view
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import view
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = view
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.func(x, y)
|
||||||
|
x = Tensor(np.array([[1, 2, 3], [2, 3, 4]], dtype=np.float32))
|
||||||
|
net = Net()
|
||||||
|
output = net(x, (3, 2))
|
||||||
|
expect = [[1., 2.], [3., 2.], [3., 4.]]
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_reshape():
|
||||||
|
"""
|
||||||
|
Feature: reshape func
|
||||||
|
Description: Verify the result of reshape
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import reshape
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = reshape
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.func(x, y)
|
||||||
|
x = Tensor([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=ms.float32)
|
||||||
|
expect = [[-0.1, 0.3], [3.6, 0.4], [0.5, -3.2]]
|
||||||
|
net = Net()
|
||||||
|
output = net(x, (3, 2))
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_swapaxes():
|
||||||
|
"""
|
||||||
|
Feature: swapaxes func
|
||||||
|
Description: Verify the result of swapaxes
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import swapaxes
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = swapaxes
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
return self.func(x, y, z)
|
||||||
|
x = Tensor(np.ones((2, 3, 4), dtype=np.float32))
|
||||||
|
expect = [[[1., 1.], [1., 1.], [1., 1.]],
|
||||||
|
|
||||||
|
[[1., 1.], [1., 1.], [1., 1.]],
|
||||||
|
|
||||||
|
[[1., 1.], [1., 1.], [1., 1.]],
|
||||||
|
|
||||||
|
[[1., 1.], [1., 1.], [1., 1.]]]
|
||||||
|
net = Net()
|
||||||
|
output = net(x, 0, 2)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_squeeze():
|
||||||
|
"""
|
||||||
|
Feature: squeeze func
|
||||||
|
Description: Verify the result of squeeze
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import squeeze
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = squeeze
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.func(x, y)
|
||||||
|
x = Tensor(np.ones((1, 2, 2, 1), dtype=np.float32))
|
||||||
|
expect = [[1., 1.],
|
||||||
|
[1., 1.]]
|
||||||
|
net = Net()
|
||||||
|
output = net(x, (0, 3))
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_argmax():
|
||||||
|
"""
|
||||||
|
Feature: argmax func
|
||||||
|
Description: Verify the result of argmax
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import argmax
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = argmax
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.func(x, y)
|
||||||
|
a = Tensor(np.arange(10, 16).reshape(2, 3).astype("float32"))
|
||||||
|
net = Net()
|
||||||
|
output = net(a, None)
|
||||||
|
expect = 5
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_diagonal():
|
||||||
|
"""
|
||||||
|
Feature: diagonal func
|
||||||
|
Description: Verify the result of diagonal
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import diagonal
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = diagonal
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
a = Tensor(np.arange(4).reshape(2, 2))
|
||||||
|
expect = [0, 3]
|
||||||
|
net = Net()
|
||||||
|
output = net(a)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_take():
|
||||||
|
"""
|
||||||
|
Feature: take func
|
||||||
|
Description: Verify the result of take
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import take
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = take
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.func(x, y)
|
||||||
|
a = Tensor(np.array([4, 3, 5, 7, 6, 8]))
|
||||||
|
indices = Tensor(np.array([0, 1, 4]))
|
||||||
|
expect = [4, 3, 6]
|
||||||
|
net = Net()
|
||||||
|
output = net(a, indices)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_choose():
|
||||||
|
"""
|
||||||
|
Feature: choose func
|
||||||
|
Description: Verify the result of choose
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import choose
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = choose
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.func(x, y)
|
||||||
|
a = Tensor(np.array([2, 3, 1, 0]))
|
||||||
|
choices = [[0, 1, 2, 3], [10, 11, 12, 13],
|
||||||
|
[20, 21, 22, 23], [30, 31, 32, 33]]
|
||||||
|
expect = [20, 31, 12, 3]
|
||||||
|
net = Net()
|
||||||
|
output = net(a, choices)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_var():
|
||||||
|
"""
|
||||||
|
Feature: var func
|
||||||
|
Description: Verify the result of var
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import var
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = var
|
||||||
|
|
||||||
|
def construct(self, x,):
|
||||||
|
return self.func(x,)
|
||||||
|
a = Tensor(np.array([1., 2., 3., 4.]))
|
||||||
|
expect = 1.25
|
||||||
|
net = Net()
|
||||||
|
output = net(a)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_searchsorted():
|
||||||
|
"""
|
||||||
|
Feature: searchsorted func
|
||||||
|
Description: Verify the result of searchsorted
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore._extends.parse.standard_method import searchsorted
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = searchsorted
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.func(x, y)
|
||||||
|
a = Tensor(np.array([1., 2., 3., 4., 5.]))
|
||||||
|
expect = 2
|
||||||
|
net = Net()
|
||||||
|
output = net(a, 3)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_is_equal_one():
|
||||||
|
"""
|
||||||
|
Feature: _is_equal_one func
|
||||||
|
Description: Verify the result of _is_equal_one
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.nn.layer.basic import _is_equal_one
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = _is_equal_one
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.func(x)
|
||||||
|
a = Tensor(np.array([1., 2., 3., 4., 5.]))
|
||||||
|
expect = False
|
||||||
|
net = Net()
|
||||||
|
output = net(a)
|
||||||
|
assert output == expect
|
||||||
|
|
||||||
|
x = Tensor(np.array([[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[10, 11, 12, 13],
|
||||||
|
[14, 15, 16, 17]]))
|
||||||
|
expect = [[0, 0, 0, 0],
|
||||||
|
[5, 0, 0, 0],
|
||||||
|
[10, 11, 0, 0],
|
||||||
|
[14, 15, 16, 0]]
|
||||||
|
net = nn.Tril()
|
||||||
|
output = net(x, -1)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_triu():
|
||||||
|
"""
|
||||||
|
Feature: Triu func
|
||||||
|
Description: Verify the result of Triu
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
x = Tensor(np.array([[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[10, 11, 12, 13],
|
||||||
|
[14, 15, 16, 17]]))
|
||||||
|
expect = [[0, 0, 3, 4],
|
||||||
|
[0, 0, 0, 8],
|
||||||
|
[0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0]]
|
||||||
|
net = nn.Triu()
|
||||||
|
output = net(x, 2)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_resizebilinear():
|
||||||
|
"""
|
||||||
|
Feature: ResizeBilinear func
|
||||||
|
Description: Verify the result of ResizeBilinear
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
x = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], ms.float32)
|
||||||
|
expect = [[[[1., 1.8, 2.6, 3.4, 4.],
|
||||||
|
[2.6, 3.4, 4.2, 5., 5.6],
|
||||||
|
[4.2, 5., 5.8, 6.6000004, 7.2],
|
||||||
|
[5., 5.8, 6.6, 7.4, 8.],
|
||||||
|
[5., 5.8, 6.6, 7.4, 8.]]]]
|
||||||
|
resize_bilinear = nn.ResizeBilinear()
|
||||||
|
output = resize_bilinear(x, size=(5, 5))
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_matrixdiag():
|
||||||
|
"""
|
||||||
|
Feature: MatrixDiag func
|
||||||
|
Description: Verify the result of MatrixDiag
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
x = Tensor(np.array([[1, -1, 1], [1, -1, 1]]), ms.float32)
|
||||||
|
matrix_diag = nn.MatrixDiag()
|
||||||
|
output = matrix_diag(x)
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_embeddinglookup():
|
||||||
|
"""
|
||||||
|
Feature: EmbeddingLookup func
|
||||||
|
Description: Verify the result of EmbeddingLookup
|
||||||
|
Expectation: no error
|
||||||
|
"""
|
||||||
|
# _check_input_2d, _check_input_dtype
|
||||||
|
# mindspore/python/mindspore/nn/layer/embedding.py
|
||||||
|
input_indices = Tensor(np.array([[1, 0], [3, 2]]), ms.int32)
|
||||||
|
output = nn.EmbeddingLookup(4, 2, max_norm=0.002)(input_indices)
|
||||||
|
assert output is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_centralcrop():
|
||||||
|
"""
|
||||||
|
Feature: CentralCrop func
|
||||||
|
Description: Verify the result of CentralCrop
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
net = nn.CentralCrop(central_fraction=0.5)
|
||||||
|
image = Tensor(np.random.random((4, 3, 4, 4)), ms.float32)
|
||||||
|
expect = (4, 3, 2, 2)
|
||||||
|
output = net(image)
|
||||||
|
assert np.allclose(output.shape, expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_matmul_2():
|
||||||
|
"""
|
||||||
|
Feature: MatMul func
|
||||||
|
Description: Verify the result of MatMul
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
net = nn.MatMul()
|
||||||
|
a = Tensor(np.arange(1, 17).reshape((4, 4)), ms.float32)
|
||||||
|
b = Tensor(np.arange(1, 17).reshape((4, 4)), ms.float32)
|
||||||
|
expect = [[90., 100., 110., 120.],
|
||||||
|
[202., 228., 254., 280.],
|
||||||
|
[314., 356., 398., 440.],
|
||||||
|
[426., 484., 542., 600.]]
|
||||||
|
output = net(a, b)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_reflectionpad1d():
|
||||||
|
"""
|
||||||
|
Feature: ReflectionPad1d func
|
||||||
|
Description: Verify the result of ReflectionPad1d
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.nn import ReflectionPad1d
|
||||||
|
x = Tensor(np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]).astype(np.float32))
|
||||||
|
padding = (3, 1)
|
||||||
|
pad1d = ReflectionPad1d(padding)
|
||||||
|
expect = [[[3., 2., 1., 0., 1., 2., 3., 2.],
|
||||||
|
[7., 6., 5., 4., 5., 6., 7., 6.]]]
|
||||||
|
output = pad1d(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_constantpad1d():
|
||||||
|
"""
|
||||||
|
Feature: ConstantPad1d func
|
||||||
|
Description: Verify the result of ConstantPad1d
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.nn import ConstantPad1d
|
||||||
|
x = np.ones(shape=(1, 2, 3, 4)).astype(np.float32)
|
||||||
|
x = Tensor(x)
|
||||||
|
expect = [[[[1., 1., 1., 1., 0.5],
|
||||||
|
[1., 1., 1., 1., 0.5],
|
||||||
|
[1., 1., 1., 1., 0.5]],
|
||||||
|
|
||||||
|
[[1., 1., 1., 1., 0.5],
|
||||||
|
[1., 1., 1., 1., 0.5],
|
||||||
|
[1., 1., 1., 1., 0.5]]]]
|
||||||
|
padding = (0, 1)
|
||||||
|
value = 0.5
|
||||||
|
pad1d = ConstantPad1d(padding, value)
|
||||||
|
output = pad1d(x)
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test control ops """
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
|
|
||||||
|
grad_by_list = C.GradOperation(get_by_list=True)
|
||||||
|
grad_all = C.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_switch_layer_with_single_prim():
|
||||||
|
"""
|
||||||
|
Feature: SwitchLayer
|
||||||
|
Description: run switch layer case
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
class SwitchLayerCell(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SwitchLayerCell, self).__init__()
|
||||||
|
self.layers = (nn.ReLU(), nn.ReLU())
|
||||||
|
self.z3 = Parameter(
|
||||||
|
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||||
|
|
||||||
|
def construct(self, index, x):
|
||||||
|
ret = self.layers[index](x) * self.z3
|
||||||
|
return ret
|
||||||
|
|
||||||
|
index = Tensor(0, dtype=mstype.int32)
|
||||||
|
net = SwitchLayerCell()
|
||||||
|
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||||
|
grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
|
||||||
|
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||||
|
grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
@ -16,7 +16,7 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor, context
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
|
@ -24,6 +24,16 @@ class Net(nn.Cell):
|
||||||
return x.tril(diagonal)
|
return x.tril(diagonal)
|
||||||
|
|
||||||
|
|
||||||
|
class TrilNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TrilNet, self).__init__()
|
||||||
|
self.tril = nn.Tril()
|
||||||
|
|
||||||
|
def construct(self, value, k):
|
||||||
|
|
||||||
|
return self.tril(value, k)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
@pytest.mark.platform_arm_cpu
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@ -44,3 +54,99 @@ def test_tril(mode):
|
||||||
output = net(x)
|
output = net(x)
|
||||||
expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32)
|
expect_output = np.array([[-1.8297, 0., 0.], [-1.2167, 0.5574, 0.], [-0.6702, 0.2276, 1.2421]], dtype=np.float32)
|
||||||
assert np.allclose(output.asnumpy(), expect_output)
|
assert np.allclose(output.asnumpy(), expect_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
|
||||||
|
def test_tril_0(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_tril
|
||||||
|
Description: Verify the result of test_tril
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
net = TrilNet()
|
||||||
|
out = net(value, 0)
|
||||||
|
assert np.sum(out.asnumpy()) == 34
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
|
||||||
|
def test_tril_1(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_tril_1
|
||||||
|
Description: Verify the result of test_tril_1
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
net = TrilNet()
|
||||||
|
out = net(value, 1)
|
||||||
|
assert np.sum(out.asnumpy()) == 42
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
|
||||||
|
def test_tril_2(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_tril_2
|
||||||
|
Description: Verify the result of test_tril_2
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
net = TrilNet()
|
||||||
|
out = net(value, -1)
|
||||||
|
assert np.sum(out.asnumpy()) == 19
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
|
||||||
|
def test_tril_parameter(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_tril_parameter
|
||||||
|
Description: Verify the result of test_tril_parameter
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
net = TrilNet()
|
||||||
|
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
|
||||||
|
def test_tril_parameter_1(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_tril_parameter_1
|
||||||
|
Description: Verify the result of test_tril_parameter_1
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
net = TrilNet()
|
||||||
|
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE,])
|
||||||
|
def test_tril_parameter_2(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_tril_parameter_2
|
||||||
|
Description: Verify the result of test_tril_parameter_2
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
net = TrilNet()
|
||||||
|
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 0)
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
@ -6,8 +20,8 @@ from mindspore import context
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def construct(self, x):
|
def construct(self, x, k):
|
||||||
return x.triu(diagonal=1)
|
return x.triu(diagonal=k)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -16,10 +30,10 @@ class Net(nn.Cell):
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
def test_subtract(mode):
|
def test_triu_0(mode):
|
||||||
"""
|
"""
|
||||||
Feature: tensor.subtract()
|
Feature: test_triu_0
|
||||||
Description: Verify the result of tensor.subtract
|
Description: Verify the result of test_triu_0
|
||||||
Expectation: success
|
Expectation: success
|
||||||
"""
|
"""
|
||||||
context.set_context(mode=mode)
|
context.set_context(mode=mode)
|
||||||
|
@ -28,9 +42,66 @@ def test_subtract(mode):
|
||||||
[5, 6, 7, 8],
|
[5, 6, 7, 8],
|
||||||
[10, 11, 12, 13],
|
[10, 11, 12, 13],
|
||||||
[14, 15, 16, 17]]))
|
[14, 15, 16, 17]]))
|
||||||
output = net(x)
|
output = net(x, 1)
|
||||||
expected = np.array([[0, 2, 3, 4],
|
expected = np.array([[0, 2, 3, 4],
|
||||||
[0, 0, 7, 8],
|
[0, 0, 7, 8],
|
||||||
[0, 0, 0, 13],
|
[0, 0, 0, 13],
|
||||||
[0, 0, 0, 0]])
|
[0, 0, 0, 0]])
|
||||||
assert np.array_equal(output.asnumpy(), expected)
|
assert np.array_equal(output.asnumpy(), expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_triu_1(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_triu_1
|
||||||
|
Description: test_triu_1
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
net = Net()
|
||||||
|
x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
output = net(x, 0)
|
||||||
|
assert np.sum(output.asnumpy()) == 26
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_triu_2(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_triu_2
|
||||||
|
Description: test_triu_2
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
net = Net()
|
||||||
|
x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
output = net(x, 1)
|
||||||
|
assert np.sum(output.asnumpy()) == 11
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_triu_3(mode):
|
||||||
|
"""
|
||||||
|
Feature: test_triu_3
|
||||||
|
Description: test_triu_3
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
net = Net()
|
||||||
|
x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
output = net(x, -1)
|
||||||
|
assert np.sum(output.asnumpy()) == 38
|
||||||
|
|
|
@ -384,8 +384,7 @@ def test_offload_dim_check():
|
||||||
dataset = dataset.map(operations=[C.Decode()], input_columns="image")
|
dataset = dataset.map(operations=[C.Decode()], input_columns="image")
|
||||||
dataset = dataset.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
|
dataset = dataset.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
|
||||||
|
|
||||||
error_msg = "For HwcToChw offload operation, the dimension of input should be 4, but got 3."
|
with pytest.raises(ValueError):
|
||||||
with pytest.raises(ValueError, match=error_msg):
|
|
||||||
for (_, _) in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
for (_, _) in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -82,14 +82,14 @@ def test_check_multifield_embedding_false_type_field_id():
|
||||||
|
|
||||||
@non_graph_engine
|
@non_graph_engine
|
||||||
def test_check_multifield_embedding_false_input_shape():
|
def test_check_multifield_embedding_false_input_shape():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(TypeError):
|
||||||
compile_multi_field_embedding((8,), (8, 200), (8, 200),
|
compile_multi_field_embedding((8,), (8, 200), (8, 200),
|
||||||
dtype.int16, dtype.float32, dtype.int16)
|
dtype.int16, dtype.float32, dtype.int16)
|
||||||
|
|
||||||
|
|
||||||
@non_graph_engine
|
@non_graph_engine
|
||||||
def test_check_multifield_embedding_false_value_shape():
|
def test_check_multifield_embedding_false_value_shape():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(TypeError):
|
||||||
compile_multi_field_embedding((8, 200), (8,), (8, 200),
|
compile_multi_field_embedding((8, 200), (8,), (8, 200),
|
||||||
dtype.int16, dtype.float32, dtype.int16)
|
dtype.int16, dtype.float32, dtype.int16)
|
||||||
|
|
||||||
|
|
|
@ -85,21 +85,3 @@ def test_psnr_different_dtype():
|
||||||
net = PSNRNet()
|
net = PSNRNet()
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_cell_graph_executor.compile(net, img1, img2)
|
_cell_graph_executor.compile(net, img1, img2)
|
||||||
|
|
||||||
|
|
||||||
def test_psnr_invalid_5d_input():
|
|
||||||
shape_1 = (8, 3, 16, 16)
|
|
||||||
shape_2 = (8, 3, 8, 8)
|
|
||||||
invalid_shape = (8, 3, 16, 16, 1)
|
|
||||||
img1 = Tensor(np.random.random(shape_1))
|
|
||||||
invalid_img1 = Tensor(np.random.random(invalid_shape))
|
|
||||||
img2 = Tensor(np.random.random(shape_2))
|
|
||||||
invalid_img2 = Tensor(np.random.random(invalid_shape))
|
|
||||||
|
|
||||||
net = PSNRNet()
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_cell_graph_executor.compile(net, invalid_img1, img2)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_cell_graph_executor.compile(net, img1, invalid_img2)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_cell_graph_executor.compile(net, invalid_img1, invalid_img2)
|
|
||||||
|
|
|
@ -1,108 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""
|
|
||||||
test nn.Tril()
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore import context
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tril():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
tril = nn.Tril()
|
|
||||||
return tril(self.value, 0)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert np.sum(out.asnumpy()) == 34
|
|
||||||
|
|
||||||
|
|
||||||
def test_tril_1():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
tril = nn.Tril()
|
|
||||||
return tril(self.value, 1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert np.sum(out.asnumpy()) == 42
|
|
||||||
|
|
||||||
|
|
||||||
def test_tril_2():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
tril = nn.Tril()
|
|
||||||
return tril(self.value, -1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert np.sum(out.asnumpy()) == 19
|
|
||||||
|
|
||||||
|
|
||||||
def test_tril_parameter():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
tril = nn.Tril()
|
|
||||||
return tril(x, 0)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_tril_parameter_1():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
tril = nn.Tril()
|
|
||||||
return tril(x, 1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_tril_parameter_2():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
tril = nn.Tril()
|
|
||||||
return tril(x, -1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
|
|
@ -1,102 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""
|
|
||||||
test nn.Triu()
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore import context
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
|
|
||||||
class TriuNet(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(TriuNet, self).__init__()
|
|
||||||
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
triu = nn.Triu()
|
|
||||||
return triu(self.value, 0)
|
|
||||||
|
|
||||||
def test_triu():
|
|
||||||
"""
|
|
||||||
Feature: None
|
|
||||||
Description: test TriuNet with vm backend
|
|
||||||
Expectation: None
|
|
||||||
"""
|
|
||||||
net = TriuNet()
|
|
||||||
out = net()
|
|
||||||
assert np.sum(out.asnumpy()) == 26
|
|
||||||
|
|
||||||
def test_triu_1():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
triu = nn.Triu()
|
|
||||||
return triu(self.value, 1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert np.sum(out.asnumpy()) == 11
|
|
||||||
|
|
||||||
|
|
||||||
def test_triu_2():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
triu = nn.Triu()
|
|
||||||
return triu(self.value, -1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert np.sum(out.asnumpy()) == 38
|
|
||||||
|
|
||||||
|
|
||||||
def test_triu_parameter():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def construct(self, x):
|
|
||||||
triu = nn.Triu()
|
|
||||||
return triu(x, 0)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_triu_parameter_1():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def construct(self, x):
|
|
||||||
triu = nn.Triu()
|
|
||||||
return triu(x, 1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_triu_parameter_2():
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def construct(self, x):
|
|
||||||
triu = nn.Triu()
|
|
||||||
return triu(x, -1)
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
|
|
|
@ -533,26 +533,6 @@ def test_parser_switch_layer_inputs_tuple():
|
||||||
back_out = back_net(i, input1, input2, grad)
|
back_out = back_net(i, input1, input2, grad)
|
||||||
|
|
||||||
|
|
||||||
def test_switch_layer_with_single_prim():
|
|
||||||
class SwitchLayerCell(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(SwitchLayerCell, self).__init__()
|
|
||||||
self.layers = (nn.ReLU(), nn.ReLU())
|
|
||||||
self.z3 = Parameter(
|
|
||||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
|
||||||
|
|
||||||
def construct(self, index, x):
|
|
||||||
ret = self.layers[index](x) * self.z3
|
|
||||||
return ret
|
|
||||||
|
|
||||||
index = Tensor(0, dtype=mstype.int32)
|
|
||||||
net = SwitchLayerCell()
|
|
||||||
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
||||||
grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
|
|
||||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
||||||
grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
|
||||||
|
|
||||||
|
|
||||||
def test_switch_layer_env_eliminate():
|
def test_switch_layer_env_eliminate():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -1003,7 +983,6 @@ def test_recursive_call():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.fc = nn.Dense(10, 10) # padding=0
|
self.fc = nn.Dense(10, 10) # padding=0
|
||||||
# self.net2 = Net2()
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
net2 = Net2()
|
net2 = Net2()
|
||||||
|
@ -1037,8 +1016,6 @@ def test_recursive_call():
|
||||||
# grad for Tensor(Bool) input and eliminate AddN(MakeTuple(Xs, zeros_like(Bool)))
|
# grad for Tensor(Bool) input and eliminate AddN(MakeTuple(Xs, zeros_like(Bool)))
|
||||||
def test_grad_tensor_bool():
|
def test_grad_tensor_bool():
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
|
|
||||||
def construct(self, x, y, z):
|
def construct(self, x, y, z):
|
||||||
out = z
|
out = z
|
||||||
|
|
Loading…
Reference in New Issue