forked from mindspore-Ecosystem/mindspore
add array_ops, math_ops and tensor ops
This commit is contained in:
parent
b48320d9eb
commit
72903c11c8
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -595,6 +595,123 @@ class Validator:
|
|||
raise ValueError(f'For {prim_name}, {ori_shape} reduce on {axis} should be '
|
||||
f'{tuple(exp_shape)}, but got {shape}.')
|
||||
|
||||
@staticmethod
|
||||
def check_astype_dtype(dtype):
|
||||
"""Check whether dtype is a valid input, and convert to mstype"""
|
||||
all_types = mstype.__dtype__ + ["int", "float", "bool"]
|
||||
if isinstance(dtype, str):
|
||||
if dtype.lower() not in all_types:
|
||||
raise TypeError(f"`{dtype}` not understood.")
|
||||
dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
|
||||
elif isinstance(dtype, type):
|
||||
dtype = mstype.pytype_to_dtype(dtype)
|
||||
elif not dtype in mstype.number_type + (mstype.bool_,):
|
||||
raise TypeError(f"`{dtype}` not understood.")
|
||||
return dtype
|
||||
|
||||
@staticmethod
|
||||
def check_transpose_axis(axes, ndim):
|
||||
"""Check the axis argument for tensor.transpose"""
|
||||
if not axes or (len(axes) == 1 and axes[0] is None):
|
||||
return tuple(range(ndim-1, -1, -1))
|
||||
|
||||
if len(axes) == 1:
|
||||
perm = axes[0]
|
||||
# if only one argument provided, it must be tuple or list
|
||||
if isinstance(perm, list):
|
||||
perm = tuple(perm)
|
||||
else:
|
||||
if not isinstance(perm, tuple):
|
||||
raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}")
|
||||
return perm
|
||||
|
||||
# if multiple arguments provided, it must be `ndim` number of ints
|
||||
if len(axes) != ndim:
|
||||
raise ValueError("The number of axes must equal to the dimension of tensor.")
|
||||
return axes
|
||||
|
||||
@staticmethod
|
||||
def check_reshape_shp(shp):
|
||||
"""Check the shape argument for tensor.reshape"""
|
||||
|
||||
if len(shp) == 1:
|
||||
new_shape = shp[0]
|
||||
# if only one argument provided, it must be int, tuple or list
|
||||
if isinstance(new_shape, int):
|
||||
return shp
|
||||
if isinstance(new_shape, list):
|
||||
new_shape = tuple(new_shape)
|
||||
else:
|
||||
if not isinstance(new_shape, tuple):
|
||||
raise TypeError(
|
||||
f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}")
|
||||
return new_shape
|
||||
|
||||
return shp
|
||||
|
||||
@staticmethod
|
||||
def check_flatten_order(order):
|
||||
"""Check flatten function input order"""
|
||||
if not isinstance(order, str):
|
||||
raise TypeError(f"The order variable should be a string, but got {type(order)}")
|
||||
if order not in ('C', 'F'):
|
||||
raise ValueError(f"only `C` and `F` are supported as order, but got {order}")
|
||||
return order
|
||||
|
||||
@staticmethod
|
||||
def check_swapaxes_axis(axes, ndim):
|
||||
"""Check all the axes argument for tensor.swapaxes"""
|
||||
if isinstance(axes, int):
|
||||
check_axis_in_range(axes, ndim)
|
||||
return axes % ndim
|
||||
if isinstance(axes, (tuple, list)):
|
||||
for axis in axes:
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError(f"axis argument should be integer, but got {type(axis)}.")
|
||||
check_axis_in_range(axis, ndim)
|
||||
axes = tuple(map(lambda x: x % ndim, axes))
|
||||
return axes
|
||||
raise TypeError(f"axes should be integer, list or tuple for check, but got {type(axes)}.")
|
||||
|
||||
@staticmethod
|
||||
def prepare_shape_for_squeeze(shape, axes):
|
||||
"""
|
||||
Creates the squeezed new shape based on the tensor and given axes.
|
||||
|
||||
Args:
|
||||
shape (tuple): the shape of the tensor
|
||||
axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
|
||||
be squeezed.
|
||||
|
||||
Returns:
|
||||
new_shape(tuple): the shape with dimensions squeezed.
|
||||
"""
|
||||
new_shape = []
|
||||
ndim = len(shape)
|
||||
|
||||
# Convert to set
|
||||
if isinstance(axes, int):
|
||||
if axes >= ndim or axes < -ndim:
|
||||
raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}")
|
||||
axes = {axes}
|
||||
|
||||
elif isinstance(axes, (list, tuple)):
|
||||
for axis in axes:
|
||||
if axis >= ndim or axis < -ndim:
|
||||
raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {ndim}")
|
||||
axes = set(axes)
|
||||
|
||||
else:
|
||||
raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}")
|
||||
|
||||
for idx, s in enumerate(shape):
|
||||
if s != 1 or (idx not in axes) and (idx - ndim not in axes):
|
||||
new_shape.append(s)
|
||||
# if an axis is selected with shape entry greater than one, an error is raised.
|
||||
if s != 1 and ((idx in axes) or (idx - ndim in axes)):
|
||||
raise ValueError(f"axis {axes} has shape entry {s} > 1, cannot be squeezed.")
|
||||
return tuple(new_shape)
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
|
@ -623,6 +740,13 @@ def _expand_tuple(n_dimensions):
|
|||
return convert
|
||||
|
||||
|
||||
def check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
if -ndim <= axis < ndim:
|
||||
return True
|
||||
raise ValueError(f'axis {axis} is out of bounds for tensor of dimension {ndim}')
|
||||
|
||||
|
||||
def _check_data_type_valid(data, valid_type):
|
||||
"""Check data type valid."""
|
||||
if valid_type is None:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,10 +15,13 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""standard_method"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
|
@ -28,14 +31,17 @@ from ...ops.primitive import constexpr
|
|||
|
||||
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
|
||||
|
||||
trans = P.Transpose()
|
||||
shape_ = P.Shape()
|
||||
reshape_ = P.Reshape()
|
||||
dtype_ = P.DType()
|
||||
abs_ = P.Abs()
|
||||
ndim_ = P.Rank()
|
||||
size_ = P.Size()
|
||||
|
||||
itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
|
||||
mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2,
|
||||
mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4,
|
||||
mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8}
|
||||
|
||||
def mean(x, axis=(), keep_dims=False):
|
||||
"""
|
||||
Reduces a dimension of a tensor by averaging all elements in the dimension.
|
||||
|
@ -93,23 +99,150 @@ def any_(x, axis=(), keep_dims=False):
|
|||
return reduce_any(x, axis)
|
||||
|
||||
|
||||
def itemsize_(x):
|
||||
"""
|
||||
Return length of one tensor element in bytes.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
itemsize(int).
|
||||
"""
|
||||
return get_itemsize(x.dtype)
|
||||
|
||||
|
||||
def nbytes_(x):
|
||||
"""
|
||||
Return total number of bytes taken by the tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
nbytes(int).
|
||||
"""
|
||||
return itemsize_(x) * F.shape_mul(shape_(x))
|
||||
|
||||
|
||||
def strides_(x):
|
||||
"""
|
||||
Return the tuple of bytes to step in each dimension when traversing a tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
strides (tuple[int]).
|
||||
"""
|
||||
strides = ()
|
||||
ndim = P.Rank()(x)
|
||||
tensor_shape = shape_(x)
|
||||
for i in F.make_range(0, ndim):
|
||||
stride = itemsize_(x)
|
||||
for j in F.make_range(i + 1, ndim):
|
||||
stride *= tensor_shape[j]
|
||||
strides += (stride,)
|
||||
return strides
|
||||
|
||||
|
||||
def astype(x, dtype, copy=True):
|
||||
"""Implementation of `astype`."""
|
||||
dtype = check_astype_dtype_const(dtype)
|
||||
if not copy and dtype == x.dtype:
|
||||
return x
|
||||
return F.cast(x, dtype)
|
||||
|
||||
|
||||
def transpose(x, *axis):
|
||||
"""Implementation of `transpose`."""
|
||||
new_order = None
|
||||
ndim = F.rank(x)
|
||||
perm = check_transpose_axis_const(axis, ndim)
|
||||
return F.transpose(x, perm)
|
||||
|
||||
|
||||
# `tensor.T` is used as a property in graph mode
|
||||
T_ = transpose
|
||||
|
||||
|
||||
def reshape(x, *shape):
|
||||
"""Implementation of `reshape`."""
|
||||
new_shape = check_reshape_shp_const(shape)
|
||||
return F.reshape(x, new_shape)
|
||||
|
||||
|
||||
def ravel(x):
|
||||
"""Implementation of `ravel`."""
|
||||
return reshape(x, (-1,))
|
||||
|
||||
|
||||
def flatten(x, order='C'):
|
||||
"""
|
||||
Returns a copy of the array collapsed into one dimension.
|
||||
|
||||
Args:
|
||||
order (str, optional): Can choose between `C` and `F`. `C` means to
|
||||
flatten in row-major (C-style) order. ‘F’ means to flatten in column-major
|
||||
(Fortran- style) order. Only `C` and `F` are supported.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
order = check_flatten_order_const(order)
|
||||
if order == 'C':
|
||||
return F.reshape(x, (-1,))
|
||||
|
||||
perm = F.make_range(0, F.rank(x))
|
||||
new_order = F.tuple_reversed(perm)
|
||||
return F.reshape(F.transpose(x, new_order), (-1,))
|
||||
|
||||
|
||||
def swapaxes(x, axis1, axis2):
|
||||
"""
|
||||
Interchanges two axes of a tensor.
|
||||
|
||||
Args:
|
||||
axis1 (int): First axis.
|
||||
axis2 (int): Second axis.
|
||||
|
||||
Returns:
|
||||
Transposed tensor, has the same data type as the original tensor x.
|
||||
"""
|
||||
axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim)
|
||||
|
||||
if axis1 == axis2:
|
||||
return x
|
||||
if axis1 > axis2:
|
||||
axis1, axis2 = axis2, axis1
|
||||
|
||||
perm = F.make_range(0, x.ndim)
|
||||
new_perm = None
|
||||
if axis2 + 1 < x.ndim:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
|
||||
else:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1]
|
||||
|
||||
return F.transpose(x, new_perm)
|
||||
|
||||
|
||||
def squeeze(x, axis=None):
|
||||
"""
|
||||
Removes single-dimensional entries from the shape of an tensor.
|
||||
|
||||
Args:
|
||||
axis: Union[None, int, list(int), tuple(list)]. Default is None.
|
||||
|
||||
Returns:
|
||||
Tensor, with all or a subset of the dimensions of length 1 removed.
|
||||
"""
|
||||
shape = F.shape(x)
|
||||
length = F.tuple_len(shape)
|
||||
if not axis:
|
||||
perm = F.make_range(0, length)
|
||||
new_order = F.tuple_reversed(perm)
|
||||
|
||||
elif len(axis) == 1:
|
||||
new_order = convert_list_to_tuple(axis[0])
|
||||
|
||||
elif len(axis) == length:
|
||||
new_order = axis
|
||||
|
||||
out = trans(x, new_order)
|
||||
return out
|
||||
if axis is None:
|
||||
return F.squeeze(x)
|
||||
# yield squeezed shape based on the axes
|
||||
new_shape = prepare_shape_for_squeeze_const(shape, axis)
|
||||
return F.reshape(x, new_shape)
|
||||
|
||||
|
||||
def getitem(data, item):
|
||||
|
@ -200,7 +333,7 @@ def expand_tensor_as(x, y):
|
|||
def view(x, *shape):
|
||||
"""Reshape tensor, if shape is -1, reshape tensor into one dimension"""
|
||||
shape = check_view_shape(shape)
|
||||
return reshape_(x, shape)
|
||||
return F.reshape(x, shape)
|
||||
|
||||
|
||||
def isinstance_(x, base_type):
|
||||
|
@ -240,6 +373,12 @@ def check_type_same(x_type, base_type):
|
|||
raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'")
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_itemsize(x_type):
|
||||
"""get itemsize from tensor's dtype."""
|
||||
return itemsize_map[x_type]
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tensor(x):
|
||||
"""check whether x is tensor."""
|
||||
|
@ -298,14 +437,14 @@ def check_view_shape(x):
|
|||
x = x[0]
|
||||
return x
|
||||
|
||||
@constexpr
|
||||
def convert_list_to_tuple(shp):
|
||||
"""Check the type of the shape, if is list, convert to tuple"""
|
||||
if not isinstance(shp, (list, tuple)):
|
||||
raise ValueError(f"The shape variable should be a list or tuple, but got {type(shp)}")
|
||||
if isinstance(shp, list):
|
||||
shp = tuple(shp)
|
||||
return shp
|
||||
|
||||
# convert noraml param_check functions to constexpr functions
|
||||
check_astype_dtype_const = constexpr(validator.check_astype_dtype)
|
||||
check_transpose_axis_const = constexpr(validator.check_transpose_axis)
|
||||
check_reshape_shp_const = constexpr(validator.check_reshape_shp)
|
||||
check_flatten_order_const = constexpr(validator.check_flatten_order)
|
||||
check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
|
||||
prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
|
||||
|
||||
def tensor_bool(x):
|
||||
"""tensor as conditon, if is constant, return immediate bool value"""
|
||||
|
|
|
@ -178,6 +178,12 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
|
||||
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
|
||||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"flatten", std::string("flatten")}, // P.reshape(,-1)
|
||||
{"reshape", std::string("reshape")}, // P.reshape()
|
||||
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
|
||||
{"swapaxes", std::string("swapaxes")}, // P.transpose()
|
||||
{"squeeze", std::string("squeeze")}, // P.squeeze()
|
||||
{"astype", std::string("astype")}, // P.cast()
|
||||
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
|
@ -190,10 +196,14 @@ BuiltInTypeMap &GetAttrMap() {
|
|||
static BuiltInTypeMap attr_map = {
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
{"shape", std::string("shape_")}, // C.shape_
|
||||
{"dtype", std::string("dtype_")}, // C.dtype_
|
||||
{"size", std::string("size_")}, // C.size_
|
||||
{"ndim", std::string("ndim_")}, // C.ndim_
|
||||
{"shape", std::string("shape_")}, // C.shape_
|
||||
{"dtype", std::string("dtype_")}, // C.dtype_
|
||||
{"size", std::string("size_")}, // C.size_
|
||||
{"ndim", std::string("ndim_")}, // C.ndim_
|
||||
{"T", std::string("T_")}, // C.T_
|
||||
{"itemsize", std::string("itemsize_")}, // C.itemsize_
|
||||
{"nbytes", std::string("nbytes_")}, // C.nbytes_
|
||||
{"strides", std::string("strides_")}, // C.strides_
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -258,6 +258,20 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
|
|||
return dims;
|
||||
}
|
||||
|
||||
py::tuple TensorPy::GetPyTupleStrides(const Tensor &tensor) {
|
||||
std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
|
||||
std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
|
||||
py::tuple py_strides(strides.size());
|
||||
for (size_t i = 0; i < strides.size(); ++i) {
|
||||
py_strides[i] = py::int_(strides[i]);
|
||||
}
|
||||
return py_strides;
|
||||
}
|
||||
|
||||
py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().itemsize(); }
|
||||
|
||||
py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); }
|
||||
|
||||
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
|
||||
if (tensor.NeedWait()) {
|
||||
py::gil_scoped_release gil_release;
|
||||
|
@ -381,6 +395,40 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
>>> data.size
|
||||
6
|
||||
)mydelimiter")
|
||||
.def_property_readonly("_itemsize", TensorPy::GetPyItemSize, R"mydelimiter(
|
||||
Get the tensor's length of one element in bytes.
|
||||
|
||||
Returns:
|
||||
itemsize, length of one element in bytes.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
|
||||
>>> data.itemsize
|
||||
4
|
||||
)mydelimiter")
|
||||
.def_property_readonly("_nbytes", TensorPy::GetPyNBytes, R"mydelimiter(
|
||||
Get the tensor's total number of bytes.
|
||||
|
||||
Returns:
|
||||
nbytes, total number of bytes taken by the tensor.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
|
||||
>>> data.nbytes
|
||||
4
|
||||
)mydelimiter")
|
||||
.def_property_readonly("_strides", TensorPy::GetPyTupleStrides, R"mydelimiter(
|
||||
Get the tensor's tuple of bytes to step in each dimension
|
||||
when traversing an array.
|
||||
|
||||
Returns:
|
||||
tuple[int], the strides of the tensor.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
|
||||
>>> data.strides
|
||||
(4, 4)
|
||||
)mydelimiter")
|
||||
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter(
|
||||
Creates a Tensor from a numpy.ndarray without copy.
|
||||
|
||||
|
|
|
@ -109,6 +109,12 @@ class TensorPy {
|
|||
static py::array AsNumpy(const Tensor &tensor);
|
||||
|
||||
static py::tuple GetPyTupleShape(const Tensor &tensor);
|
||||
|
||||
static py::tuple GetPyTupleStrides(const Tensor &tensor);
|
||||
|
||||
static py::int_ GetPyItemSize(const Tensor &tensor);
|
||||
|
||||
static py::int_ GetPyNBytes(const Tensor &tensor);
|
||||
};
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -267,6 +267,26 @@ class Tensor(Tensor_):
|
|||
"""tensor is inited."""
|
||||
return self.init is not None
|
||||
|
||||
@property
|
||||
def itemsize(self):
|
||||
"""The length of one tensor element in bytes."""
|
||||
return self._itemsize
|
||||
|
||||
@property
|
||||
def strides(self):
|
||||
"""The tuple of bytes to step in each dimension when traversing a tensor."""
|
||||
return self._strides
|
||||
|
||||
@property
|
||||
def nbytes(self):
|
||||
"""The total number of bytes taken by the tensor."""
|
||||
return self._nbytes
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
"""The transposed tensor."""
|
||||
return self.transpose()
|
||||
|
||||
@property
|
||||
def virtual_flag(self):
|
||||
"""Mark tensor is virtual."""
|
||||
|
@ -384,6 +404,145 @@ class Tensor(Tensor_):
|
|||
axis = ()
|
||||
return tensor_operator_registry.get('mean')(keep_dims)(self, axis)
|
||||
|
||||
def transpose(self, *axes):
|
||||
"""
|
||||
Returns a view of the array with axes transposed.
|
||||
|
||||
For a 1-D array this has no effect, as a transposed vector is simply the
|
||||
same vector. For a 2-D array, this is a standard matrix transpose. For an
|
||||
n-D array, if axes are given, their order indicates how the axes are permuted
|
||||
(see Examples). If axes are not provided and a.shape = (i[0], i[1],...
|
||||
i[n-2], i[n-1]), then a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0]).
|
||||
|
||||
Args:
|
||||
axes(Union[None, tuple(int), list(int), n ints], optional):
|
||||
None or no argument: reverses the order of the axes.
|
||||
Tuple of ints: i in the j-th place in the tuple means a’s i-th
|
||||
axis becomes a.transpose()’s j-th axis.
|
||||
n ints: this form is intended simply as a `convenience alternative
|
||||
to the tuple form.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dimension as input tensor, with axes suitably permuted.
|
||||
"""
|
||||
perm = validator.check_transpose_axis(axes, self.ndim)
|
||||
return tensor_operator_registry.get('transpose')()(self, perm)
|
||||
|
||||
def reshape(self, *shape):
|
||||
"""
|
||||
Gives a new shape to an array without changing its data.
|
||||
|
||||
Args:
|
||||
shape(Union[int, tuple(int), list(int)]): The new shape should be compatible
|
||||
with the original shape. If an integer, then the result will be a 1-D
|
||||
array of that length. One shape dimension can be -1. In this case, the
|
||||
value is inferred from the length of the array and remaining dimensions.
|
||||
|
||||
Returns:
|
||||
reshaped_tensor(Tensor): This will be a new view object if possible;
|
||||
otherwise, it will be a copy.
|
||||
"""
|
||||
new_shape = validator.check_reshape_shp(shape)
|
||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||
|
||||
def ravel(self):
|
||||
"""
|
||||
Returns a contiguous flattened tensor.
|
||||
A 1-D tensor, containing the elements of the input, is returned.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
reshape_op = tensor_operator_registry.get('reshape')()
|
||||
return reshape_op(self, (-1,))
|
||||
|
||||
def flatten(self, order='C'):
|
||||
"""
|
||||
Returns a copy of the array collapsed into one dimension.
|
||||
|
||||
Args:
|
||||
order (str, optional): Can choose between `C` and `F`. `C` means to
|
||||
flatten in row-major (C-style) order. ‘F’ means to flatten in column-major
|
||||
(Fortran- style) order. Only `C` and `F` are supported.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same data type as x.
|
||||
"""
|
||||
reshape_op = tensor_operator_registry.get('reshape')()
|
||||
trans_op = tensor_operator_registry.get('transpose')()
|
||||
|
||||
order = validator.check_flatten_order(order)
|
||||
if order == 'C':
|
||||
return reshape_op(self, (-1,))
|
||||
|
||||
perm = tuple(range(self.ndim-1, -1, -1))
|
||||
return reshape_op(trans_op(self, perm), (-1,))
|
||||
|
||||
def swapaxes(self, axis1, axis2):
|
||||
"""
|
||||
Interchanges two axes of a tensor.
|
||||
|
||||
Args:
|
||||
axis1 (int): First axis.
|
||||
axis2 (int): Second axis.
|
||||
|
||||
Returns:
|
||||
Transposed tensor, has the same data type as the original tensor x.
|
||||
"""
|
||||
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
|
||||
|
||||
if axis1 == axis2:
|
||||
return self
|
||||
if axis1 > axis2:
|
||||
axis1, axis2 = axis2, axis1
|
||||
|
||||
perm = tuple(range(0, self.ndim))
|
||||
new_perm = None
|
||||
if axis2 + 1 < self.ndim:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
|
||||
else:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
|
||||
perm[axis1+1:axis2] + perm[axis1:axis1+1]
|
||||
|
||||
return tensor_operator_registry.get('transpose')()(self, new_perm)
|
||||
|
||||
def squeeze(self, axis=None):
|
||||
"""
|
||||
Removes single-dimensional entries from the shape of an tensor.
|
||||
|
||||
Args:
|
||||
axis: Union[None, int, list(int), tuple(list)]. Default is None.
|
||||
|
||||
Returns:
|
||||
Tensor, with all or a subset of the dimensions of length 1 removed.
|
||||
"""
|
||||
if axis is None:
|
||||
return tensor_operator_registry.get('squeeze')(self)
|
||||
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
|
||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||
|
||||
def astype(self, dtype, copy=True):
|
||||
"""
|
||||
Returns a copy of the array, cast to a specified type.
|
||||
|
||||
Args:
|
||||
dtype(Union[mstype.dtype, numpy.dtype, str]): Designated tensor dtype,
|
||||
can be in format of np.float32, mstype.float32 or `float32`. Default
|
||||
is mstype.float32.
|
||||
|
||||
copy(bool, optional): By default, astype always returns a newly allocated
|
||||
tensor. If this is set to false, the input tensor is returned instead
|
||||
of a copy if possible.
|
||||
|
||||
Returns:
|
||||
Tensor, with the designated dtype.
|
||||
"""
|
||||
dtype = validator.check_astype_dtype(dtype)
|
||||
if not copy and dtype == self.dtype:
|
||||
return self
|
||||
return tensor_operator_registry.get('cast')(self, dtype)
|
||||
|
||||
|
||||
def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -25,22 +25,33 @@ Note:
|
|||
- random/ defines all the random operations.
|
||||
"""
|
||||
|
||||
from .array_ops import (array, asarray, asfarray, ones, zeros, full, arange,
|
||||
linspace, logspace, eye, identity, transpose, expand_dims,
|
||||
squeeze, rollaxis, swapaxes, reshape, ravel, concatenate)
|
||||
from .array_ops import copy_ as copy
|
||||
from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape,
|
||||
ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d,
|
||||
column_stack, hstack, dstack, vstack, stack, unique)
|
||||
from .array_creations import copy_ as copy
|
||||
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
|
||||
linspace, logspace, eye, identity, empty, empty_like,
|
||||
ones_like, zeros_like, full_like, diagonal, tril, triu,
|
||||
tri, trace)
|
||||
from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16,
|
||||
uint32, uint64, float_, float16, float32, float64, bool_, inf,
|
||||
numeric_types)
|
||||
from .math_ops import mean, inner
|
||||
from .math_ops import (mean, inner, add, subtract, multiply, divide, power,
|
||||
dot, outer, tensordot, absolute)
|
||||
|
||||
|
||||
array_ops_module = ['array', 'asarray', 'asfarray', 'copy', 'ones', 'zeros', 'arange',
|
||||
'linspace', 'logspace', 'eye', 'identity', 'transpose', 'expand_dims',
|
||||
'squeeze', 'rollaxis', 'swapaxes', 'reshape', 'ravel', 'concatenate']
|
||||
array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape',
|
||||
'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d',
|
||||
'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique']
|
||||
|
||||
math_module = ['mean', 'inner']
|
||||
array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange',
|
||||
'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like',
|
||||
'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu',
|
||||
'tri', 'trace']
|
||||
|
||||
__all__ = array_ops_module + math_module + numeric_types
|
||||
math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'power',
|
||||
'dot', 'outer', 'tensordot', 'absolute']
|
||||
|
||||
__all__ = array_ops_module + array_creations_module + math_module + numeric_types
|
||||
|
||||
__all__.sort()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,14 +14,15 @@
|
|||
# ============================================================================
|
||||
"""Dtypes and utilities"""
|
||||
|
||||
from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, uint64, \
|
||||
float16, float32, float64, bool_)
|
||||
from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, uint64,
|
||||
float16, float32, float64, bool_)
|
||||
|
||||
# original numpy has int->int64, float->float64, uint->uint64 mapping. we map
|
||||
# them to 32 bit, since 64 bit calculation is not supported from mindspore
|
||||
# backend for now.
|
||||
|
||||
inf = float('inf')
|
||||
nan = float('nan')
|
||||
|
||||
int_ = int32
|
||||
uint = uint32
|
||||
|
@ -94,3 +95,69 @@ all_types = [
|
|||
'np.float32',
|
||||
'np.float64',
|
||||
'np.bool']
|
||||
|
||||
promotion_rule = {
|
||||
(uint8, uint16): uint16,
|
||||
(uint8, uint32): uint32,
|
||||
(uint8, uint64): uint64,
|
||||
(uint16, uint32): uint32,
|
||||
(uint16, uint64): uint64,
|
||||
(uint32, uint64): uint64,
|
||||
(uint8, int8): int16,
|
||||
(uint8, int16): int16,
|
||||
(uint8, int32): int32,
|
||||
(uint8, int64): int64,
|
||||
(uint16, int8): int32,
|
||||
(uint16, int16): int32,
|
||||
(uint16, int32): int32,
|
||||
(uint16, int64): int64,
|
||||
(uint32, int8): int64,
|
||||
(uint32, int16): int64,
|
||||
(uint32, int32): int64,
|
||||
(uint32, int64): int64,
|
||||
(uint64, int8): float64,
|
||||
(uint64, int16): float64,
|
||||
(uint64, int32): float64,
|
||||
(uint64, int64): float64,
|
||||
(uint8, float16): float16,
|
||||
(uint8, float32): float32,
|
||||
(uint8, float64): float64,
|
||||
(uint16, float16): float16,
|
||||
(uint16, float32): float32,
|
||||
(uint16, float64): float32,
|
||||
(uint32, float16): float16,
|
||||
(uint32, float32): float32,
|
||||
(uint32, float64): float64,
|
||||
(uint64, float16): float16,
|
||||
(uint64, float32): float32,
|
||||
(uint64, float64): float64,
|
||||
(int8, int16): int16,
|
||||
(int8, int32): int32,
|
||||
(int8, int64): int64,
|
||||
(int16, int32): int32,
|
||||
(int16, int64): int64,
|
||||
(int32, int64): int64,
|
||||
(int8, float16): float16,
|
||||
(int8, float32): float32,
|
||||
(int8, float64): float64,
|
||||
(int16, float16): float16,
|
||||
(int16, float32): float32,
|
||||
(int16, float64): float64,
|
||||
(int32, float16): float16,
|
||||
(int32, float32): float32,
|
||||
(int32, float64): float64,
|
||||
(float16, float32): float32,
|
||||
(float16, float64): float64,
|
||||
(float32, float64): float64,
|
||||
(bool_, uint8): uint8,
|
||||
(bool_, uint16): uint16,
|
||||
(bool_, uint32): uint32,
|
||||
(bool_, uint64): uint64,
|
||||
(bool_, int8): int8,
|
||||
(bool_, int16): int16,
|
||||
(bool_, int32): int32,
|
||||
(bool_, int64): int64,
|
||||
(bool_, float16): float16,
|
||||
(bool_, float32): float32,
|
||||
(bool_, float64): float64,
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,10 +15,354 @@
|
|||
"""math operations, the function docs are adapted from Numpy API."""
|
||||
from ..ops import operations as P
|
||||
from ..ops import functional as F
|
||||
from ..ops import composite as C
|
||||
from ..ops.primitive import constexpr
|
||||
from .array_ops import squeeze, asarray
|
||||
from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \
|
||||
_check_shape_aligned
|
||||
from ..common import dtype as mstype
|
||||
from .array_ops import ravel
|
||||
from .array_ops import where as where_
|
||||
from .array_creations import asarray, full
|
||||
from .utils import _is_scalar, _expand, _broadcast_to, _is_empty
|
||||
from .utils_const import _infer_out_shape, _check_axis_valid, _get_device_compile, \
|
||||
_check_shape_aligned, _empty, _check_is_tensor, _raise_type_error, _check_same_type, \
|
||||
_check_is_float, _check_input_tensor
|
||||
from .dtypes import nan
|
||||
|
||||
|
||||
_mean_default = P.ReduceMean()
|
||||
_mean_keepdims = P.ReduceMean(True)
|
||||
_matmul = P.MatMul(False, False)
|
||||
_matmul_T = P.MatMul(False, True)
|
||||
|
||||
|
||||
def absolute(x, out=None, where=True, dtype=None):
|
||||
"""
|
||||
Calculates the absolute value element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments casting, order, dtype, subok, signature, and extobj are
|
||||
not supported.
|
||||
When argument where is provided, argument out must have a tensor value.
|
||||
Argument out is not supported for storing the result, however it can be
|
||||
used in combination with argument where to set the value at indices for
|
||||
which where is set to False.
|
||||
Currently the backend kernel only supports float calculation, if the input
|
||||
is not a float, then it will be casted to float32 and casted back.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor to be used for calculation.
|
||||
out (Tensor or None): optional, defaults to None.
|
||||
where (Tensor or None): optional. For any non-default value of type other
|
||||
than Tensor or None, the output retains its original value.
|
||||
This condition is broadcasted over the input. At locations where the
|
||||
condition is True, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default out=None,
|
||||
locations within it where the condition is False will remain
|
||||
uninitialized.
|
||||
dtype (data type): optional, defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = np.asarray([1, 2, 3, -4, -5], np.float64)
|
||||
>>> output = np.absolute(x)
|
||||
>>> print(output)
|
||||
[1. 2. 3. 4. 5.]
|
||||
"""
|
||||
if not _check_is_tensor(F.typeof(x)):
|
||||
_raise_type_error("Input is expected to be a tensor, but got ", x)
|
||||
original_dtype = x.dtype
|
||||
if not _check_is_float(original_dtype) and dtype is None:
|
||||
x = x.astype(mstype.float32)
|
||||
return _apply_tensor_op(F.absolute, x, out=out, where=where, dtype=dtype).astype(original_dtype)
|
||||
return _apply_tensor_op(F.absolute, x, out=out, where=where, dtype=dtype)
|
||||
|
||||
|
||||
def add(x1, x2, out=None, where=True, dtype=None):
|
||||
"""
|
||||
Adds arguments element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments casting, order, dtype, subok, signature, and extobj are
|
||||
not supported.
|
||||
When argument where is provided, argument out must have a tensor value.
|
||||
Argument out is not supported for storing the result, however it can be
|
||||
used in combination with argument where to set the value at indices for
|
||||
which where is set to False.
|
||||
On GPU, the supported dtypes are np.float16, np.float32, np.int32,
|
||||
and np.int64.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, np.float64,
|
||||
np.int16, np.int32, and np.int64.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): input to be added.
|
||||
x2 (Tensor): input to be added.
|
||||
out (Tensor or None): optional, defaults to None.
|
||||
where (Tensor or None): optional. For any non-default value of type other
|
||||
than Tensor or None, the output retains its original value.
|
||||
This condition is broadcast over the input. At locations where the
|
||||
condition is True, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default out=None,
|
||||
locations within it where the condition is False will remain
|
||||
uninitialized.
|
||||
dtype (data type): optional, defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, the sum of x1 and x2, element-wise. This is a scalar
|
||||
if both x1 and x2 are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = np.full((3, 2), [1, 2])
|
||||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.add(x1, x2)
|
||||
>>> print(output)
|
||||
[[4, 6],
|
||||
[4, 6],
|
||||
[4, 6]]
|
||||
"""
|
||||
# broadcast is not fully supported in tensor_add on CPU,
|
||||
# so we use tensor_sub as a substitute solution
|
||||
if _get_device_compile() == 'CPU':
|
||||
return subtract(x1, F.neg_tensor(x2), out=out, where=where, dtype=dtype)
|
||||
return _apply_tensor_op(F.tensor_add, x1, x2, out=out, where=where, dtype=dtype)
|
||||
|
||||
|
||||
def subtract(x1, x2, out=None, where=True, dtype=None):
|
||||
"""
|
||||
Subtracts arguments, element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments casting, order, dtype, subok, signature, and extobj are
|
||||
not supported.
|
||||
When argument where is provided, argument out must have a tensor value.
|
||||
Argument out is not supported for storing the result, however it can be
|
||||
used in combination with argument where to set the value at indices for
|
||||
which where is set to False.
|
||||
On GPU, the supported dtypes are np.float16, np.float32, np.int32,
|
||||
and np.int64.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, np.float64,
|
||||
np.int16, np.int32, and np.int64.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): the input to be subtracted from.
|
||||
x2 (Tensor): the input to be subtracted by.
|
||||
out (Tensor or None): optional, defaults to None.
|
||||
where (Tensor or None): optional. For any non-default value of type other
|
||||
than Tensor or None, the output retains its original value.
|
||||
This condition is broadcast over the input. At locations where the
|
||||
condition is True, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default out=None,
|
||||
locations within it where the condition is False will remain
|
||||
uninitialized.
|
||||
dtype (data type): optional, defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, the difference of x1 and x2, element-wise. This is a
|
||||
scalar if both x1 and x2 are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = np.full((3, 2), [1, 2])
|
||||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.subtract(x1, x2)
|
||||
>>> print(output)
|
||||
[[-2, -2],
|
||||
[-2, -2],
|
||||
[-2, -2]]
|
||||
"""
|
||||
return _apply_tensor_op(F.tensor_sub, x1, x2, out=out, where=where, dtype=dtype)
|
||||
|
||||
|
||||
def multiply(x1, x2, out=None, where=True, dtype=None):
|
||||
"""
|
||||
Multiplies arguments element-wise.
|
||||
|
||||
Note:
|
||||
Numpy arguments casting, order, dtype, subok, signature, and extobj are
|
||||
not supported.
|
||||
When argument where is provided, argument out must have a tensor value.
|
||||
Argument out is not supported for storing the result, however it can be
|
||||
used in combination with argument where to set the value at indices for
|
||||
which where is set to False.
|
||||
On GPU, the supported dtypes are np.float16, np.float32, np.int32,
|
||||
and np.int64.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, np.float64,
|
||||
np.int16, np.int32, and np.int64.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): input tensor to be multiplied.
|
||||
x2 (Tensor): input tensor to be multiplied.
|
||||
out (Tensor or None): optional, defaults to None.
|
||||
where (Tensor or None): optional. For any non-default value of type other
|
||||
than Tensor or None, the output retains its original value.
|
||||
This condition is broadcast over the input. At locations where the
|
||||
condition is True, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default out=None,
|
||||
locations within it where the condition is False will remain
|
||||
uninitialized.
|
||||
dtype (data type): optional, defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, the product of x1 and x2, element-wise. This is a scalar
|
||||
if both x1 and x2 are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = np.full((3, 2), [1, 2])
|
||||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.multiply(x1, x2)
|
||||
>>> print(output)
|
||||
[[3, 8],
|
||||
[3, 8],
|
||||
[3, 8]]
|
||||
"""
|
||||
if _get_device_compile() == 'CPU':
|
||||
# broadcast is not fully supported on CPU backend,
|
||||
# and explicit broadcasting is performed
|
||||
shape_out = _infer_out_shape(F.shape(x1), F.shape(x2))
|
||||
ndim_out = F.tuple_len(shape_out)
|
||||
x1 = _expand(x1, ndim_out)
|
||||
x2 = _expand(x2, ndim_out)
|
||||
x1 = _broadcast_to(x1, F.shape(x1), shape_out, ndim_out)
|
||||
x2 = _broadcast_to(x2, F.shape(x2), shape_out, ndim_out)
|
||||
return _apply_tensor_op(F.tensor_mul, x1, x2, out=out, where=where, dtype=dtype)
|
||||
|
||||
|
||||
def divide(x1, x2, out=None, where=True, dtype=None):
|
||||
"""
|
||||
Returns a true division of the inputs, element-wise.
|
||||
|
||||
Instead of the Python traditional ‘floor division’, this returns a true
|
||||
division.
|
||||
|
||||
Note:
|
||||
Numpy arguments casting, order, dtype, subok, signature, and extobj are
|
||||
not supported.
|
||||
When argument where is provided, argument out must have a tensor value.
|
||||
Argument out is not supported for storing the result, however it can be
|
||||
used in combination with argument where to set the value at indices for
|
||||
which where is set to False.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, np.float64,
|
||||
np.int16, np.int32, and np.int64.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): the divident.
|
||||
x2 (Tensor): the divisor.
|
||||
out (Tensor or None): optional, defaults to None.
|
||||
where (Tensor or None): optional. For any non-default value of type other
|
||||
than Tensor or None, the output retains its original value.
|
||||
This condition is broadcast over the input. At locations where the
|
||||
condition is True, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default out=None,
|
||||
locations within it where the condition is False will remain
|
||||
uninitialized.
|
||||
dtype (data type): optional, defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, this is a scalar if both x1 and x2 are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = np.full((3, 2), [1, 2])
|
||||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.divide(x1, x2)
|
||||
>>> print(output)
|
||||
[[0.33333333, 0.5],
|
||||
[0.33333333, 0.5],
|
||||
[0.33333333, 0.5]]
|
||||
"""
|
||||
if not _check_is_float(F.dtype(x1)) and not _check_is_float(F.dtype(x2)):
|
||||
x1 = F.cast(x1, mstype.float32)
|
||||
x2 = F.cast(x2, mstype.float32)
|
||||
return _apply_tensor_op(F.tensor_div, x1, x2, out=out, where=where, dtype=dtype)
|
||||
|
||||
|
||||
def power(x1, x2, out=None, where=True, dtype=None):
|
||||
"""
|
||||
First array elements raised to powers from second array, element-wise.
|
||||
|
||||
Raises each base in x1 to the positionally-corresponding power in x2.
|
||||
|
||||
Note:
|
||||
Numpy arguments casting, order, dtype, subok, signature, and extobj are
|
||||
not supported.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, np.float64,
|
||||
np.int16, np.int32, and np.int64.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): the bases.
|
||||
x2 (Tensor): the exponenets.
|
||||
out (Tensor or None): optional, defaults to None.
|
||||
where (Tensor or None): optional. For any non-default value of type other
|
||||
than Tensor or None, the output retains its original value.
|
||||
This condition is broadcast over the input. At locations where the
|
||||
condition is True, the out array will be set to the ufunc result.
|
||||
Elsewhere, the out array will retain its original value. Note that
|
||||
if an uninitialized out array is created via the default out=None,
|
||||
locations within it where the condition is False will remain
|
||||
uninitialized.
|
||||
dtype (data type): optional, defaults to None. Overrides the dtype of the
|
||||
output Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, the bases in x1 raised to the exponents in x2. This
|
||||
is a scalarif both x1 and x2 are scalars.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = np.full((3, 2), [1, 2])
|
||||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.power(x1, x2)
|
||||
>>> print(output)
|
||||
[[ 1, 16],
|
||||
[ 1, 16],
|
||||
[ 1, 16]]
|
||||
"""
|
||||
return _apply_tensor_op(F.tensor_pow, x1, x2, out=out, where=where, dtype=dtype)
|
||||
|
||||
|
||||
def mean(a, axis=None, keepdims=False):
|
||||
|
@ -31,8 +375,9 @@ def mean(a, axis=None, keepdims=False):
|
|||
|
||||
Note:
|
||||
Numpy arguments dtype and out are not supported.
|
||||
On GPU, the supported dtypes are mstype.float16, and mstype.float32.
|
||||
On CPU, the supported dtypes are mstype.float16, and mstype.float32.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, and
|
||||
np.float64.
|
||||
|
||||
Args:
|
||||
a (Tensor): input tensor containing numbers whose mean is desired.
|
||||
|
@ -63,17 +408,29 @@ def mean(a, axis=None, keepdims=False):
|
|||
>>> print(output)
|
||||
2.5
|
||||
"""
|
||||
axis = _check_axis_valid(axis, P.Rank()(a))
|
||||
if _is_empty(F.shape(a)):
|
||||
return _nan()
|
||||
if _is_scalar(a.shape):
|
||||
|
||||
axis = _check_axis_valid(axis, F.rank(a))
|
||||
shape_a = F.shape(a)
|
||||
|
||||
if _is_empty(shape_a):
|
||||
if keepdims:
|
||||
shape_out = _shape_reduced_keepdims(shape_a, axis)
|
||||
else:
|
||||
shape_out = _shape_reduced(shape_a, axis)
|
||||
if _is_empty(shape_out):
|
||||
return _empty(F.dtype(a), shape_out)
|
||||
return _full_compile(shape_out, nan)
|
||||
|
||||
if _is_scalar(shape_a):
|
||||
if keepdims:
|
||||
return a
|
||||
return squeeze(a)
|
||||
shape_out = _shape_reduced(shape_a, axis)
|
||||
return F.reshape(a, shape_out)
|
||||
|
||||
if keepdims:
|
||||
res = P.ReduceMean(True)(a, axis)
|
||||
res = _mean_keepdims(a, axis)
|
||||
else:
|
||||
res = P.ReduceMean(False)(a, axis)
|
||||
res = _mean_default(a, axis)
|
||||
return res
|
||||
|
||||
|
||||
|
@ -87,8 +444,9 @@ def inner(a, b):
|
|||
|
||||
Note:
|
||||
Numpy argument out is not supported.
|
||||
On GPU, the supported dtypes are mstype.float16, and mstype.float32.
|
||||
On CPU, the supported dtype is mstype.float32.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, and
|
||||
np.float64.
|
||||
|
||||
Args:
|
||||
a (Tensor): input tensor. If a and b are nonscalar, their last
|
||||
|
@ -127,19 +485,21 @@ def inner(a, b):
|
|||
[[3. 3. 3. 3. 3. 3. 3.]
|
||||
[3. 3. 3. 3. 3. 3. 3.]]]
|
||||
"""
|
||||
if P.Rank()(a) == 0 or P.Rank()(b) == 0:
|
||||
if _is_scalar(a.shape):
|
||||
if F.rank(a) == 0 or F.rank(b) == 0:
|
||||
a = _expand(a, 1)
|
||||
b = _expand(b, 1)
|
||||
if F.rank(a) < F.rank(b):
|
||||
a, b = b, a
|
||||
return _apply_bin_op(P.Mul(), a, b)
|
||||
return F.tensor_mul(a, b)
|
||||
|
||||
_ = _check_shape_aligned(a.shape, b.shape)
|
||||
aligned_shape_a = (F.shape_mul(a.shape[:-1]), a.shape[-1])
|
||||
aligned_shape_b = (F.shape_mul(b.shape[:-1]), a.shape[-1])
|
||||
a_aligned = P.Reshape()(a, aligned_shape_a)
|
||||
b_aligned = P.Reshape()(b, aligned_shape_b)
|
||||
_ = _check_shape_aligned(F.shape(a), F.shape(b))
|
||||
aligned_shape_a = (F.shape_mul(F.shape(a)[:-1]), F.shape(a)[-1])
|
||||
aligned_shape_b = (F.shape_mul(F.shape(b)[:-1]), F.shape(a)[-1])
|
||||
a_aligned = F.reshape(a, aligned_shape_a)
|
||||
b_aligned = F.reshape(b, aligned_shape_b)
|
||||
|
||||
res = P.MatMul(False, True)(a_aligned, b_aligned)
|
||||
res = P.Reshape()(res, a.shape[:-1] + b.shape[:-1])
|
||||
res = _matmul_T(a_aligned, b_aligned)
|
||||
res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1])
|
||||
return res
|
||||
|
||||
|
||||
|
@ -149,28 +509,245 @@ def _nan():
|
|||
return asarray(float('nan'))
|
||||
|
||||
|
||||
def _is_empty(shape):
|
||||
"""Checks if the shape is empty"""
|
||||
return F.shape_mul(shape) == 0
|
||||
def dot(a, b):
|
||||
"""
|
||||
Dot product of two arrays.
|
||||
|
||||
Specifically,
|
||||
If both a and b are 1-D arrays, it is inner product of vectors
|
||||
(without complex conjugation).
|
||||
If both a and b are 2-D arrays, it is matrix multiplication.
|
||||
If either a or b is 0-D (scalar), it is equivalent to multiply.
|
||||
If a is an N-D array and b is a 1-D array, it is a sum product
|
||||
over the last axis of a and b.
|
||||
If a is an N-D array and b is an M-D array (where M>=2), it is a
|
||||
sum product over the last axis of a and the second-to-last axis of b:
|
||||
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
|
||||
|
||||
Note:
|
||||
Numpy argument out is not supported.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, and
|
||||
np.float64.
|
||||
|
||||
Args:
|
||||
a (Tensor): input tensor
|
||||
b (Tensor): input tensor
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, the dot product of a and b. If a and b are
|
||||
both scalars or both 1-D arrays then a scalar is returned;
|
||||
otherwise an array is returned
|
||||
|
||||
Raises:
|
||||
ValueError: If the last dimension of a is not the same size
|
||||
as the second-to-last dimension of b.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.full((1, 3), 7)
|
||||
>>> b = np.full((2, 3, 4), 5)
|
||||
>>> output = np.dot(a, b)
|
||||
>>> print(output)
|
||||
[[[105, 105, 105, 105],
|
||||
[105, 105, 105, 105]]]
|
||||
"""
|
||||
ndim_a, ndim_b = F.rank(a), F.rank(b)
|
||||
if ndim_a > 0 and ndim_b >= 2:
|
||||
perm = F.make_range(ndim_b)
|
||||
perm = perm[:-2] + (perm[-1],) + (perm[-2],)
|
||||
b = F.transpose(b, perm)
|
||||
return inner(a, b)
|
||||
|
||||
|
||||
def _expand(x, ndim):
|
||||
"""Expand x to ndim"""
|
||||
while P.Rank()(x) < ndim:
|
||||
x = P.ExpandDims()(x, 0)
|
||||
return x
|
||||
def outer(a, b):
|
||||
"""
|
||||
Computes the outer product of two vectors.
|
||||
|
||||
Given two vectors, a = [a0, a1, ..., aM] and b = [b0, b1, ..., bN],
|
||||
the outer product [1] is:
|
||||
[[a0*b0 a0*b1 ... a0*bN ]
|
||||
[a1*b0 .
|
||||
[ ... .
|
||||
[aM*b0 aM*bN ]]
|
||||
|
||||
Note:
|
||||
Numpy argument out is not supported.
|
||||
On GPU, the supported dtypes are np.float16, and np.float32.
|
||||
On CPU, the supported dtypes are np.float16, np.float32, and
|
||||
np.float64.
|
||||
|
||||
Args:
|
||||
a (Tensor): first input vector. Input is flattened if not
|
||||
already 1-dimensional.
|
||||
b (Tensor): second input vector. Input is flattened if not
|
||||
already 1-dimensional.
|
||||
|
||||
Returns:
|
||||
Tensor or scalar, out[i, j] = a[i] * b[j].
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> a = np.full(7, 2)
|
||||
>>> b = np.full(4, 3)
|
||||
>>> output = np.outer(a, b)
|
||||
>>> print(output)
|
||||
[[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6]]
|
||||
"""
|
||||
_check_input_tensor(F.typeof(a))
|
||||
_check_input_tensor(F.typeof(b))
|
||||
|
||||
if F.rank(a) != 1:
|
||||
a = ravel(a)
|
||||
if F.rank(b) != 1:
|
||||
b = ravel(b)
|
||||
a = F.reshape(a, (F.shape(a)[0], 1))
|
||||
b = _expand(b, 2)
|
||||
return _matmul(a, b)
|
||||
|
||||
|
||||
def _apply_bin_op(fn, x1, x2):
|
||||
"""apply binary operations based on fn."""
|
||||
device = _get_device_compile()
|
||||
out_shape = _infer_out_shape(device, x1.shape, x2.shape)
|
||||
if device == 'CPU':
|
||||
# built-in operations on CPU does not support operands with
|
||||
# dimensions of size 1 or with shape 0, therefore squeeze
|
||||
# and scalar promotion is performed
|
||||
x1, x2 = squeeze(x1), squeeze(x2)
|
||||
x1, x2 = _expand(x1, 1), _expand(x2, 1)
|
||||
res = fn(x1, x2)
|
||||
res = P.Reshape()(res, out_shape)
|
||||
return res
|
||||
def tensordot(a, b, axes=2):
|
||||
"""
|
||||
Computes tensor dot product along specified axes.
|
||||
|
||||
Given two tensors, a and b, and an array_like object containing two array_like
|
||||
objects, (a_axes, b_axes), sum the products of a’s and b’s elements (components)
|
||||
over the axes specified by a_axes and b_axes. The third argument can be a single
|
||||
non-negative integer_like scalar, N; if it is such, then the last N dimensions of
|
||||
a and the first N dimensions of b are summed over.
|
||||
Three common use cases are:
|
||||
axes = 0 : tensor product
|
||||
axes = 1 : tensor dot product
|
||||
axes = 2 : (default) tensor double contraction
|
||||
When axes is integer_like, the sequence for evaluation will be: first the -Nth
|
||||
axis in a and 0th axis in b, and the -1th axis in a and Nth axis in b last.
|
||||
When there is more than one axis to sum over - and they are not the last (first)
|
||||
axes of a (b) - the argument axes should consist of two sequences of the same
|
||||
length, with the first axis to sum over given first in both sequences, the second
|
||||
axis second, and so forth.
|
||||
The shape of the result consists of the non-contracted axes of the first tensor,
|
||||
followed by the non-contracted axes of the second.
|
||||
|
||||
Note:
|
||||
On CPU, the supported dypes are np.float16 and np.float32.
|
||||
On GPU, the supported dypes are np.float16 and np.float32.
|
||||
|
||||
Args:
|
||||
a, b (Tensor): Tensors to “dot”.
|
||||
axes (int or (2,) array_like):
|
||||
integer_like: If an int N, sum over the last N axes of a and the first N
|
||||
axes of b in order. The sizes of the corresponding axes must match.
|
||||
(2,) array_like: Or, a list of axes to be summed over, first sequence
|
||||
applying to a, second to b. Both elements array_like must be of the same
|
||||
length.
|
||||
|
||||
Returns:
|
||||
Tensor, or list of tensors, the tensor dot product of the input.
|
||||
|
||||
Raises:
|
||||
TypeError: if the input is not a tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> a = np.ones((3, 4, 5))
|
||||
>>> b = np.ones((4, 3, 2))
|
||||
>>> output = np.tensordot(a, b, axes=([1,0],[0,1]))
|
||||
>>> print(output.shape)
|
||||
(5, 2)
|
||||
"""
|
||||
_check_input_tensor(F.typeof(a))
|
||||
_check_input_tensor(F.typeof(b))
|
||||
|
||||
if F.rank(a)*F.rank(b) == 0 and axes == 0:
|
||||
return F.tensor_mul(a, b)
|
||||
return C.tensor_dot(a, b, axes)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _full_compile(shape, value):
|
||||
return full(shape, value)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_reduced_keepdims(shape, axes):
|
||||
"""
|
||||
Reduces dimensions corresponding to argument axes while
|
||||
keeping the number of dimensions unchanged.
|
||||
"""
|
||||
ndim_out = F.tuple_len(shape)
|
||||
shape_out = [1]*ndim_out
|
||||
for i in range(ndim_out):
|
||||
if not i in axes:
|
||||
shape_out[i] = shape[i]
|
||||
return tuple(shape_out)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_reduced(shape, axes):
|
||||
"""Removes dimensions corresponding to argument axes"""
|
||||
ndim_orig = F.tuple_len(shape)
|
||||
ndim_out = ndim_orig - F.tuple_len(axes)
|
||||
shape_out = [0]*ndim_out
|
||||
idx_out = 0
|
||||
for i in range(ndim_orig):
|
||||
if not i in axes:
|
||||
shape_out[idx_out] = shape[i]
|
||||
idx_out += 1
|
||||
return tuple(shape_out)
|
||||
|
||||
|
||||
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
||||
"""Infers the shape of the last two dimensions after performing matmul."""
|
||||
shape_rem = ()
|
||||
if ndim1 >= 2:
|
||||
shape_rem += (shape1[-2],)
|
||||
if transpose_b:
|
||||
if ndim2 >= 2:
|
||||
shape_rem += (shape2[-2],)
|
||||
else:
|
||||
if ndim1 >= 1:
|
||||
shape_rem += (shape2[-1],)
|
||||
return shape_rem
|
||||
|
||||
|
||||
def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None):
|
||||
"""applies tensor operations based on fn"""
|
||||
for arg in args:
|
||||
_check_input_tensor(F.typeof(arg))
|
||||
res = fn(*args)
|
||||
|
||||
# if out is set to a non-default value, return tensor will have the same
|
||||
# dtype as out, which overrides the dtype passed into the keyword argument
|
||||
if _check_is_tensor(F.typeof(out)):
|
||||
dtype_out = F.dtype(out)
|
||||
elif dtype is not None:
|
||||
dtype_out = dtype
|
||||
else:
|
||||
dtype_out = F.dtype(res)
|
||||
|
||||
if _check_is_tensor(F.typeof(out)) and _check_is_tensor(F.typeof(where)):
|
||||
out = where_(where, res, out)
|
||||
elif out is None or where is not None:
|
||||
out = res
|
||||
|
||||
if not _check_same_type(F.dtype(out), dtype_out):
|
||||
out = F.cast(out, dtype_out)
|
||||
|
||||
return out
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,126 +13,14 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""internal utility functions"""
|
||||
from functools import partial
|
||||
|
||||
import numpy as onp
|
||||
|
||||
import mindspore.context as context
|
||||
from ..common import Tensor
|
||||
from ..ops import operations as P
|
||||
from ..ops import functional as F
|
||||
from ..ops.primitive import constexpr
|
||||
from ..common import dtype as mstype
|
||||
|
||||
from .dtypes import dtype_tuple, all_types, dtype_map
|
||||
|
||||
@constexpr
|
||||
def _check_shape_compile(shape):
|
||||
"""check the shape param to match the numpy style inside the graph"""
|
||||
if not isinstance(shape, (int, tuple, list)):
|
||||
raise TypeError(
|
||||
f"only int, tuple and list are allowed for shape, but got {type(shape)}")
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, list):
|
||||
shape = tuple(shape)
|
||||
return shape
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_int(x):
|
||||
"""Check the type of x is int."""
|
||||
if isinstance(x, int):
|
||||
return True
|
||||
raise TypeError(f"integer argument expected, but got {type(x)}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_start_normalize(start, ndim):
|
||||
"""check and normalize start argument for rollaxis."""
|
||||
if start < -ndim or start > ndim:
|
||||
raise ValueError(
|
||||
f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.")
|
||||
if start < 0:
|
||||
start = start + ndim
|
||||
return start
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axes_range(axes, ndim):
|
||||
"""
|
||||
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
|
||||
Args:
|
||||
axes (Union[int, tuple(int), list(int)]): Axes of the tensor.
|
||||
ndim (int): The number of dimensions of the tensor.
|
||||
Return:
|
||||
Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
|
||||
"""
|
||||
if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list):
|
||||
raise TypeError(
|
||||
f"int, tuple(int) or list(int) expected, but got {type(axes)}.")
|
||||
low = -ndim
|
||||
up = ndim - 1
|
||||
if low > up:
|
||||
raise ValueError(
|
||||
f"Lower bound {low} and upper bound {up} of axes are not allowed.")
|
||||
if isinstance(axes, int):
|
||||
if axes < low or axes > up:
|
||||
raise ValueError(
|
||||
f"axis {axes} is out of bounds for tensor of dimension {ndim}.")
|
||||
return axes if axes >= 0 else axes + ndim
|
||||
new_axes = []
|
||||
for item in axes:
|
||||
if not isinstance(item, int):
|
||||
raise TypeError(
|
||||
f"int in tuple or list expected, but got {type(item)}.")
|
||||
if item < low or item > up:
|
||||
raise ValueError(
|
||||
f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.")
|
||||
new_axes.append(item if item >= 0 else item + ndim)
|
||||
return tuple(new_axes)
|
||||
|
||||
|
||||
def _check_shape_contain_zero(shp):
|
||||
"""Check whether shape contains 0"""
|
||||
if isinstance(shp, int):
|
||||
return shp == 0
|
||||
if isinstance(shp, (list, tuple)):
|
||||
for s in shp:
|
||||
if s == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_shape(shape):
|
||||
"""check the shape param to match the numpy style outside the graph"""
|
||||
if not isinstance(shape, (int, tuple, list)):
|
||||
raise TypeError(
|
||||
f"only int, tuple and list are allowed for shape, but got {type(shape)}")
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, list):
|
||||
shape = tuple(shape)
|
||||
return shape
|
||||
|
||||
|
||||
def _check_dtype(dtype):
|
||||
"""check the input dtype and make conversions"""
|
||||
# convert the string dtype to mstype.dtype
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
dtype = dtype_map[dtype]
|
||||
elif isinstance(dtype, type):
|
||||
if dtype is int:
|
||||
dtype = mstype.int32
|
||||
if dtype is float:
|
||||
dtype = mstype.float32
|
||||
if dtype is bool:
|
||||
dtype = mstype.bool_
|
||||
if dtype not in dtype_tuple:
|
||||
raise TypeError(
|
||||
f"only {all_types} are allowed for dtype, but got {type(dtype)}")
|
||||
return dtype
|
||||
from .utils_const import _tile_size
|
||||
|
||||
|
||||
def _deep_list(array_like):
|
||||
|
@ -170,15 +58,8 @@ def _check_input_for_asarray(array_like):
|
|||
"""check whether array_like argument is a valid type for np.asarray conversion"""
|
||||
if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
|
||||
return True
|
||||
raise TypeError(
|
||||
"input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \
|
||||
f"or numpy.ndarray, but got {type(array_like)}")
|
||||
|
||||
|
||||
def _cast_to(array, dtype):
|
||||
"""cast the input to specified dtype"""
|
||||
cast = P.Cast()
|
||||
return cast(array, dtype)
|
||||
raise TypeError("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \
|
||||
f"or numpy.ndarray, but got {type(array_like)}")
|
||||
|
||||
|
||||
def _is_scalar(shape):
|
||||
|
@ -186,10 +67,9 @@ def _is_scalar(shape):
|
|||
return F.shape_mul(shape) == 1
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_device_compile():
|
||||
"""Get the current device (`GPU`, `CPU`, `Ascend`)"""
|
||||
return context.get_context('device_target')
|
||||
def _is_empty(shape):
|
||||
"""Checks if the shape is empty"""
|
||||
return F.shape_mul(shape) == 0
|
||||
|
||||
|
||||
def _get_device():
|
||||
|
@ -199,10 +79,12 @@ def _get_device():
|
|||
|
||||
def _covert_list_tensor_to_tuple_tensor(list_of_tensor):
|
||||
"""Convert a list of tensor to a tuple of tensor"""
|
||||
tuple_of_tensor = ()
|
||||
for tensor in list_of_tensor:
|
||||
tuple_of_tensor += (tensor,)
|
||||
return tuple_of_tensor
|
||||
if isinstance(list_of_tensor, list):
|
||||
tuple_of_tensor = ()
|
||||
for tensor in list_of_tensor:
|
||||
tuple_of_tensor += (tensor,)
|
||||
return tuple_of_tensor
|
||||
return list_of_tensor
|
||||
|
||||
|
||||
def _get_mode():
|
||||
|
@ -210,145 +92,14 @@ def _get_mode():
|
|||
return context.get_context('mode')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _reverse_index(idx, arr):
|
||||
"""
|
||||
Returns 1 if shape[idx:] is broadcastable to shape_out[idx:],
|
||||
2 situations if the function returns 1:
|
||||
- 1. Tensor's shape has 1 at the designated dimension.
|
||||
- 2. Tensor's dimension is less than the designated idx. (The Tensor shape
|
||||
has been reversed)
|
||||
For both cases, 2 tensors are broadcastable.
|
||||
otherwise returns the element at position of shape
|
||||
"""
|
||||
if len(arr) <= idx:
|
||||
return 1
|
||||
return arr[-1 - idx]
|
||||
def _expand(x, ndim, axis=0):
|
||||
"""Expand x to ndim."""
|
||||
while F.rank(x) < ndim:
|
||||
x = F.expand_dims(x, axis)
|
||||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def _infer_out_shape(device, *shapes):
|
||||
"""
|
||||
Returns shape of output after broadcasting
|
||||
Raises ValueError if shape1 and shape2 cannot be broadcast
|
||||
"""
|
||||
shapes_unbroadcastable = False
|
||||
cpu_shapes_different = False
|
||||
contains_scalar = any(_is_scalar(shape) for shape in shapes)
|
||||
ndim_max = max(map(len, shapes))
|
||||
shape_out = [0]*ndim_max
|
||||
i = 0
|
||||
for i in range(ndim_max):
|
||||
shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes))
|
||||
for shape in shapes:
|
||||
if _reverse_index(i, shape) != shape_out[-1 - i]:
|
||||
if _reverse_index(i, shape) != 1:
|
||||
shapes_unbroadcastable = True
|
||||
if device == 'CPU' and not contains_scalar:
|
||||
cpu_shapes_different = True
|
||||
if not shapes_unbroadcastable and not cpu_shapes_different:
|
||||
return tuple(shape_out)
|
||||
if shapes_unbroadcastable:
|
||||
raise ValueError(
|
||||
f'operands could not be broadcast together with shapes {*shapes,}')
|
||||
raise ValueError('broadcasting is currently not supported on CPU. Non-scalar' + \
|
||||
f'operands must have the same shape, but got {*shapes,}')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
if -ndim <= axis < ndim:
|
||||
return True
|
||||
raise ValueError(
|
||||
f'axis {axis} is out of bounds for array of dimension {ndim}')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axis_valid(axes, ndim):
|
||||
"""
|
||||
Checks axes are valid given ndim, and returns axes that can be passed
|
||||
to the built-in operator (non-negative, int or tuple)
|
||||
"""
|
||||
if isinstance(axes, int):
|
||||
_ = _check_axis_in_range(axes, ndim)
|
||||
return (axes % ndim,)
|
||||
if isinstance(axes, tuple):
|
||||
for axis in axes:
|
||||
_ = _check_axis_in_range(axis, ndim)
|
||||
axes = tuple(map(lambda x: x % ndim, axes))
|
||||
if all(axes.count(el) <= 1 for el in axes):
|
||||
return axes
|
||||
if axes is None:
|
||||
axes = F.make_range(ndim)
|
||||
return axes
|
||||
raise ValueError('duplicate value in \'axis\'')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_aligned(shape1, shape2):
|
||||
"""Checks shape1 and shape2 are valid shapes to perform inner product"""
|
||||
if shape1[-1] == shape2[-1]:
|
||||
return True
|
||||
raise ValueError(
|
||||
f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_dim_cpu(shape, bound):
|
||||
"""Checks input shape is upper-bounded by parameter bound"""
|
||||
ndim = len(shape)
|
||||
if _is_scalar(shape):
|
||||
return True
|
||||
if ndim <= bound:
|
||||
return True
|
||||
raise ValueError(
|
||||
f'dimension {ndim} larger than {bound} is not supported on CPU')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _tile_size(shape, out_shape, ndim):
|
||||
"""Returns tile_size such that shape*tile_size = out_shape"""
|
||||
size = [1]*ndim
|
||||
for idx, (i, j) in enumerate(zip(shape, out_shape)):
|
||||
if i != j:
|
||||
size[idx] = j
|
||||
return tuple(size)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_core_match(shape1, shape2):
|
||||
"""Checks shape1 and shape2 are valid shapes to perform matmul"""
|
||||
ndim1, ndim2 = len(shape1), len(shape2)
|
||||
if ndim1 < 1 or ndim2 < 2:
|
||||
return True
|
||||
if shape1[-1] == shape2[-2]:
|
||||
return True
|
||||
raise ValueError(f'mismatch in core dimension of input operands (size {shape1[-1]} ' +
|
||||
f'is different from {shape2[-2]})')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _cpu_not_support(name):
|
||||
"""Checks if a function not supported on cpu is executed on cpu device"""
|
||||
if _get_device() != 'CPU':
|
||||
return True
|
||||
raise ValueError(f'{name} is not supported on CPU')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_tuple(obj):
|
||||
"""Check whether obj is a tuple"""
|
||||
return isinstance(obj, mstype.Tuple)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_list(obj):
|
||||
"""Check whether obj is a list"""
|
||||
return isinstance(obj, mstype.List)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_tensor(obj):
|
||||
"""Check whether obj is a tensor"""
|
||||
return isinstance(obj, mstype.tensor_type)
|
||||
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
|
||||
"""Broadcasts x from shape_cur to shape_to."""
|
||||
size = _tile_size(shape_cur, shape_to, ndim_to)
|
||||
return F.tile(x, size)
|
||||
|
|
|
@ -0,0 +1,364 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""internal graph-compatible utility functions"""
|
||||
from functools import partial
|
||||
|
||||
import numpy as onp
|
||||
|
||||
import mindspore.context as context
|
||||
from ..common import Tensor
|
||||
from ..ops import functional as F
|
||||
from ..ops.primitive import constexpr
|
||||
from ..common import dtype as mstype
|
||||
from .._c_expression import Tensor as Tensor_
|
||||
from .._c_expression.typing import Tuple, List
|
||||
|
||||
from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape(shape):
|
||||
"""check the shape param to match the numpy style"""
|
||||
if not isinstance(shape, (int, tuple, list, Tuple, List)):
|
||||
raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}")
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, (list, List)):
|
||||
shape = tuple(shape)
|
||||
return shape
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_dtype(dtype):
|
||||
"""check the input dtype and make conversions"""
|
||||
# convert the string dtype to mstype.dtype
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
dtype = dtype_map[dtype]
|
||||
elif isinstance(dtype, type):
|
||||
if dtype is int:
|
||||
dtype = mstype.int32
|
||||
elif dtype is float:
|
||||
dtype = mstype.float32
|
||||
else:
|
||||
dtype = mstype.pytype_to_dtype(dtype)
|
||||
if dtype not in dtype_tuple:
|
||||
raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}")
|
||||
return dtype
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_contain_zero(shp):
|
||||
"""Check whether shape contains zero"""
|
||||
if isinstance(shp, int):
|
||||
return shp == 0
|
||||
return F.shape_mul(shp) == 0
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_start_normalize(start, ndim):
|
||||
"""check and normalize start argument for rollaxis."""
|
||||
if start < -ndim or start > ndim:
|
||||
raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.")
|
||||
if start < 0:
|
||||
start = start + ndim
|
||||
return start
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axes_range(axes, ndim):
|
||||
"""
|
||||
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
|
||||
Args:
|
||||
axes (Union[int, tuple(int), list(int)]): Axes of the tensor.
|
||||
ndim (int): The number of dimensions of the tensor.
|
||||
Return:
|
||||
Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
|
||||
"""
|
||||
if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list):
|
||||
raise TypeError(f"int, tuple(int) or list(int) expected, but got {type(axes)}.")
|
||||
low = -ndim
|
||||
up = ndim - 1
|
||||
if low > up:
|
||||
raise ValueError(f"Lower bound {low} and upper bound {up} of axes are not allowed.")
|
||||
if isinstance(axes, int):
|
||||
if axes < low or axes > up:
|
||||
raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}.")
|
||||
return axes if axes >= 0 else axes + ndim
|
||||
new_axes = []
|
||||
for item in axes:
|
||||
if not isinstance(item, int):
|
||||
raise TypeError(f"int in tuple or list expected, but got {type(item)}.")
|
||||
if item < low or item > up:
|
||||
raise ValueError(f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.")
|
||||
new_axes.append(item if item >= 0 else item + ndim)
|
||||
return tuple(new_axes)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_device_compile():
|
||||
"""Get the current device (`GPU`, `CPU`, `Ascend`)"""
|
||||
return context.get_context('device_target')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _reverse_index(idx, arr):
|
||||
"""
|
||||
Returns 1 if shape[idx:] is broadcastable to shape_out[idx:],
|
||||
2 situations if the function returns 1:
|
||||
- 1. Tensor's shape has 1 at the designated dimension.
|
||||
- 2. Tensor's dimension is less than the designated idx. (The Tensor shape
|
||||
has been reversed)
|
||||
For both cases, 2 tensors are broadcastable.
|
||||
otherwise returns the element at position of shape
|
||||
"""
|
||||
if len(arr) <= idx:
|
||||
return 1
|
||||
return arr[-1 - idx]
|
||||
|
||||
|
||||
@constexpr
|
||||
def _infer_out_shape(*shapes):
|
||||
"""
|
||||
Returns shape of output after broadcasting
|
||||
Raises ValueError if shape1 and shape2 cannot be broadcast
|
||||
"""
|
||||
shapes_unbroadcastable = False
|
||||
ndim_max = max(map(len, shapes))
|
||||
shape_out = [0]*ndim_max
|
||||
i = 0
|
||||
for i in range(ndim_max):
|
||||
shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes))
|
||||
for shape in shapes:
|
||||
if _reverse_index(i, shape) != shape_out[-1 - i]:
|
||||
if _reverse_index(i, shape) != 1:
|
||||
shapes_unbroadcastable = True
|
||||
break
|
||||
if shapes_unbroadcastable:
|
||||
break
|
||||
if not shapes_unbroadcastable:
|
||||
return tuple(shape_out)
|
||||
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
if -ndim <= axis < ndim:
|
||||
return True
|
||||
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axis_valid(axes, ndim):
|
||||
"""
|
||||
Checks axes are valid given ndim, and returns axes that can be passed
|
||||
to the built-in operator (non-negative, int or tuple)
|
||||
"""
|
||||
if isinstance(axes, int):
|
||||
_ = _check_axis_in_range(axes, ndim)
|
||||
return (axes % ndim,)
|
||||
if isinstance(axes, tuple):
|
||||
for axis in axes:
|
||||
_ = _check_axis_in_range(axis, ndim)
|
||||
axes = tuple(map(lambda x: x % ndim, axes))
|
||||
if all(axes.count(el) <= 1 for el in axes):
|
||||
return axes
|
||||
if axes is None:
|
||||
axes = F.make_range(ndim)
|
||||
return axes
|
||||
raise ValueError('duplicate value in \'axis\'')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_aligned(shape1, shape2):
|
||||
"""Checks shape1 and shape2 are valid shapes to perform inner product"""
|
||||
if shape1[-1] == shape2[-1]:
|
||||
return True
|
||||
raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _tile_size(shape, out_shape, ndim):
|
||||
"""Returns tile_size such that shape*tile_size = out_shape"""
|
||||
size = [1]*ndim
|
||||
for idx, (i, j) in enumerate(zip(shape, out_shape)):
|
||||
if i != j:
|
||||
size[idx] = j
|
||||
return tuple(size)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_int(obj):
|
||||
"""Check whether obj is an integer."""
|
||||
return isinstance(obj, int)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_tuple(obj):
|
||||
"""Check whether obj is a tuple"""
|
||||
return isinstance(obj, (tuple, Tuple))
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_list(obj):
|
||||
"""Check whether obj is a list"""
|
||||
return isinstance(obj, (list, List))
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_tensor(obj):
|
||||
"""Check whether obj is a tensor"""
|
||||
return isinstance(obj, mstype.tensor_type)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _raise_type_error(info, param=None):
|
||||
"""
|
||||
Raise TypeError in both graph/pynative mode
|
||||
|
||||
Args:
|
||||
info(str): info string to display
|
||||
param(python obj): any object that can be recognized by graph mode. If is
|
||||
not None, then param's type information will be extracted and displayed.
|
||||
Default is None.
|
||||
"""
|
||||
if param is None:
|
||||
raise TypeError(info)
|
||||
raise TypeError(info + f"{type(param)}")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _raise_value_error(info, param=None):
|
||||
"""
|
||||
Raise TypeError in both graph/pynative mode
|
||||
|
||||
Args:
|
||||
info(str): info string to display
|
||||
param(python obj): any object that can be recognized by graph mode. If is
|
||||
not None, then param's value information will be extracted and displayed.
|
||||
Default is None.
|
||||
"""
|
||||
if param is None:
|
||||
raise ValueError(info)
|
||||
raise ValueError(info + f"{param}")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _empty(dtype, shape):
|
||||
"""Returns an uninitialized array with dtype and shape."""
|
||||
return Tensor_(dtype, shape)
|
||||
|
||||
|
||||
def _get_index_for_unique(input_x, unique_x):
|
||||
"""
|
||||
Return the indices of the first occurrences of the unique values in the original array.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`.
|
||||
unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order.
|
||||
|
||||
Returns:
|
||||
Tensor. The indices of the unique values in the original array. Has the same shape as `unique_x`.
|
||||
"""
|
||||
o_array = input_x.asnumpy()
|
||||
dic = {}
|
||||
for idx in range(o_array.size):
|
||||
val = o_array[idx]
|
||||
if val not in dic:
|
||||
dic[val] = idx
|
||||
|
||||
index_lst = []
|
||||
u_array = unique_x.asnumpy()
|
||||
for idx in range(u_array.size):
|
||||
index_lst.append(dic[u_array[idx]])
|
||||
|
||||
return Tensor(onp.array(index_lst), input_x.dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_counts_for_unique(input_x, unique_x):
|
||||
"""
|
||||
Return the number of times each of the unique values comes up in the original tensor.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`.
|
||||
unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order.
|
||||
|
||||
Returns:
|
||||
Tensor. The number of times each of the unique values comes up in the original tensor.
|
||||
"""
|
||||
dic = {}
|
||||
o_array = input_x.asnumpy()
|
||||
for idx in range(o_array.size):
|
||||
val = o_array[idx]
|
||||
if val not in dic:
|
||||
dic[val] = 1
|
||||
else:
|
||||
dic[val] += 1
|
||||
|
||||
u_array = unique_x.asnumpy()
|
||||
counts_lst = [dic[val] for val in u_array]
|
||||
|
||||
return Tensor(onp.array(counts_lst), input_x.dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_max_value(x):
|
||||
"""Returns the maximum value of the input tensor `x`. """
|
||||
return int(max(x.asnumpy()))
|
||||
|
||||
|
||||
@constexpr
|
||||
def _promote(dtype1, dtype2):
|
||||
if dtype1 == dtype2:
|
||||
return dtype1
|
||||
if (dtype1, dtype2) in promotion_rule:
|
||||
return promotion_rule[dtype1, dtype2]
|
||||
return promotion_rule[dtype2, dtype1]
|
||||
|
||||
|
||||
@constexpr
|
||||
def _max(*args):
|
||||
"""Returns the maximum value."""
|
||||
return max(*args)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _min(*args):
|
||||
""""Returns the minimum value."""
|
||||
return min(*args)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _abs(arg):
|
||||
"""Returns the absolute value."""
|
||||
return abs(arg)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_same_type(dtype1, dtype2):
|
||||
return dtype1 == dtype2
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_float(dtype):
|
||||
return dtype in mstype.float_type
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_tensor(input_type):
|
||||
if not _check_is_tensor(input_type):
|
||||
raise TypeError(f'expect Tensor, but got {input_type}')
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -32,6 +32,7 @@ isconstant.set_const_prim(True)
|
|||
|
||||
issubclass_ = P.IsSubClass()
|
||||
isinstance_ = P.IsInstance()
|
||||
eye = P.Eye()
|
||||
fill = P.Fill()
|
||||
tile = P.Tile()
|
||||
select = P.Select()
|
||||
|
@ -45,6 +46,7 @@ control_depend = P.ControlDepend()
|
|||
merge = P.Merge()
|
||||
geswitch = P.GeSwitch()
|
||||
addn = P.AddN()
|
||||
absolute = P.Abs()
|
||||
tensor_add = P.TensorAdd()
|
||||
neg_tensor = P.Neg()
|
||||
tensor_lt = P.Less()
|
||||
|
@ -67,6 +69,8 @@ assign_add = P.AssignAdd()
|
|||
assign = P.Assign()
|
||||
square = P.Square()
|
||||
sqrt = P.Sqrt()
|
||||
reduce_sum = P.ReduceSum()
|
||||
tensor_slice = P.Slice()
|
||||
|
||||
scalar_to_array = P.ScalarToArray()
|
||||
scalar_to_tensor = P.ScalarToTensor()
|
||||
|
@ -74,6 +78,8 @@ tuple_to_array = P.TupleToArray()
|
|||
scalar_cast = P.ScalarCast()
|
||||
print_ = P.Print()
|
||||
expand_dims = P.ExpandDims()
|
||||
transpose = P.Transpose()
|
||||
squeeze = P.Squeeze()
|
||||
scatter_nd = P.ScatterNd()
|
||||
gather = P.GatherV2()
|
||||
gather_nd = P.GatherNd()
|
||||
|
@ -177,6 +183,7 @@ tensor_operator_registry.register('any', P.ReduceAny)
|
|||
tensor_operator_registry.register('abs', P.Abs)
|
||||
tensor_operator_registry.register('mean', P.ReduceMean)
|
||||
tensor_operator_registry.register('reshape', P.Reshape)
|
||||
tensor_operator_registry.register('transpose', P.Transpose)
|
||||
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
|
@ -187,6 +194,7 @@ tensor_operator_registry.register('__le__', tensor_le)
|
|||
tensor_operator_registry.register('__gt__', tensor_gt)
|
||||
tensor_operator_registry.register('__ge__', tensor_ge)
|
||||
tensor_operator_registry.register('shape', shape)
|
||||
tensor_operator_registry.register('squeeze', squeeze)
|
||||
# support GE backend for no compare operators
|
||||
tensor_operator_registry.register('cast', cast)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
|
@ -0,0 +1,730 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""unit tests for numpy array operations"""
|
||||
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
import numpy as onp
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.numpy as mnp
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class Cases():
|
||||
def __init__(self):
|
||||
self.all_shapes = [
|
||||
0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]
|
||||
]
|
||||
self.onp_dtypes = [onp.int32, 'int32', int,
|
||||
onp.float32, 'float32', float,
|
||||
onp.uint32, 'uint32',
|
||||
onp.bool_, 'bool', bool]
|
||||
|
||||
self.mnp_dtypes = [mnp.int32, 'int32', int,
|
||||
mnp.float32, 'float32', float,
|
||||
mnp.uint32, 'uint32',
|
||||
mnp.bool_, 'bool', bool]
|
||||
|
||||
self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,),
|
||||
[(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member
|
||||
(100, 100)).astype(onp.float32),
|
||||
onp.random.random((100, 100)).astype(onp.bool)]
|
||||
|
||||
self.arrs = [
|
||||
rand_int(2),
|
||||
rand_int(2, 3),
|
||||
rand_int(2, 3, 4),
|
||||
rand_int(2, 3, 4, 5),
|
||||
]
|
||||
|
||||
# scalars expanded across the 0th dimension
|
||||
self.scalars = [
|
||||
rand_int(),
|
||||
rand_int(1),
|
||||
rand_int(1, 1),
|
||||
rand_int(1, 1, 1),
|
||||
]
|
||||
|
||||
# arrays of the same size expanded across the 0th dimension
|
||||
self.expanded_arrs = [
|
||||
rand_int(2, 3),
|
||||
rand_int(1, 2, 3),
|
||||
rand_int(1, 1, 2, 3),
|
||||
rand_int(1, 1, 1, 2, 3),
|
||||
]
|
||||
|
||||
# arrays with dimensions of size 1
|
||||
self.nested_arrs = [
|
||||
rand_int(1),
|
||||
rand_int(1, 2),
|
||||
rand_int(3, 1, 8),
|
||||
rand_int(1, 3, 9, 1),
|
||||
]
|
||||
|
||||
# arrays which can be broadcast
|
||||
self.broadcastables = [
|
||||
rand_int(5),
|
||||
rand_int(6, 1),
|
||||
rand_int(7, 1, 5),
|
||||
rand_int(8, 1, 6, 1)
|
||||
]
|
||||
|
||||
# boolean arrays which can be broadcast
|
||||
self.bool_broadcastables = [
|
||||
rand_bool(),
|
||||
rand_bool(1),
|
||||
rand_bool(5),
|
||||
rand_bool(6, 1),
|
||||
rand_bool(7, 1, 5),
|
||||
rand_bool(8, 1, 6, 1),
|
||||
]
|
||||
|
||||
self.mnp_prototypes = [
|
||||
mnp.ones((2, 3, 4)),
|
||||
mnp.ones((0, 3, 0, 2, 5)),
|
||||
onp.ones((2, 7, 0)),
|
||||
onp.ones(()),
|
||||
[mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
|
||||
([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])),
|
||||
]
|
||||
|
||||
self.onp_prototypes = [
|
||||
onp.ones((2, 3, 4)),
|
||||
onp.ones((0, 3, 0, 2, 5)),
|
||||
onp.ones((2, 7, 0)),
|
||||
onp.ones(()),
|
||||
[onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
|
||||
([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])),
|
||||
]
|
||||
|
||||
|
||||
def match_array(actual, expected, error=0):
|
||||
if error > 0:
|
||||
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
|
||||
decimal=error)
|
||||
else:
|
||||
onp.testing.assert_equal(actual.tolist(), expected.tolist())
|
||||
|
||||
|
||||
def check_all_results(onp_results, mnp_results, error=0):
|
||||
"""Check all results from numpy and mindspore.numpy"""
|
||||
for i, _ in enumerate(onp_results):
|
||||
match_array(onp_results[i], mnp_results[i].asnumpy())
|
||||
|
||||
|
||||
def check_all_unique_results(onp_results, mnp_results):
|
||||
"""
|
||||
Check all results from numpy and mindspore.numpy.
|
||||
|
||||
Args:
|
||||
onp_results (Union[tuple of numpy.arrays, numpy.array])
|
||||
mnp_results (Union[tuple of Tensors, Tensor])
|
||||
"""
|
||||
for i, _ in enumerate(onp_results):
|
||||
if isinstance(onp_results[i], tuple):
|
||||
for j in range(len(onp_results[i])):
|
||||
match_array(onp_results[i][j],
|
||||
mnp_results[i][j].asnumpy(), error=7)
|
||||
else:
|
||||
match_array(onp_results[i], mnp_results[i].asnumpy(), error=7)
|
||||
|
||||
|
||||
def run_non_kw_test(mnp_fn, onp_fn):
|
||||
"""Run tests on functions with non keyword arguments"""
|
||||
test_case = Cases()
|
||||
for i in range(len(test_case.arrs)):
|
||||
arrs = test_case.arrs[:i]
|
||||
match_res(mnp_fn, onp_fn, *arrs)
|
||||
|
||||
for i in range(len(test_case.scalars)):
|
||||
arrs = test_case.scalars[:i]
|
||||
match_res(mnp_fn, onp_fn, *arrs)
|
||||
|
||||
for i in range(len(test_case.expanded_arrs)):
|
||||
arrs = test_case.expanded_arrs[:i]
|
||||
match_res(mnp_fn, onp_fn, *arrs)
|
||||
|
||||
for i in range(len(test_case.nested_arrs)):
|
||||
arrs = test_case.nested_arrs[:i]
|
||||
match_res(mnp_fn, onp_fn, *arrs)
|
||||
|
||||
|
||||
def rand_int(*shape):
|
||||
"""return an random integer array with parameter shape"""
|
||||
res = onp.random.randint(low=1, high=5, size=shape)
|
||||
if isinstance(res, onp.ndarray):
|
||||
return res.astype(onp.float32)
|
||||
return float(res)
|
||||
|
||||
|
||||
# return an random boolean array
|
||||
def rand_bool(*shape):
|
||||
return onp.random.rand(*shape) > 0.5
|
||||
|
||||
|
||||
def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
|
||||
"""Checks results from applying mnp_fn and onp_fn on arrs respectively"""
|
||||
mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs)
|
||||
mnp_res = mnp_fn(*mnp_arrs, **kwargs)
|
||||
onp_res = onp_fn(*arrs, **kwargs)
|
||||
match_all_arrays(mnp_res, onp_res)
|
||||
|
||||
|
||||
def match_all_arrays(mnp_res, onp_res, error=0):
|
||||
if isinstance(mnp_res, (tuple, list)):
|
||||
for actual, expected in zip(mnp_res, onp_res):
|
||||
match_array(actual.asnumpy(), expected, error)
|
||||
else:
|
||||
match_array(mnp_res.asnumpy(), onp_res, error)
|
||||
|
||||
|
||||
def match_meta(actual, expected):
|
||||
# float64 and int64 are not supported, and the defualt type for
|
||||
# float and int are float32 and int32, respectively
|
||||
if expected.dtype == onp.float64:
|
||||
expected = expected.astype(onp.float32)
|
||||
elif expected.dtype == onp.int64:
|
||||
expected = expected.astype(onp.int32)
|
||||
assert actual.shape == expected.shape
|
||||
assert actual.dtype == expected.dtype
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_asarray():
|
||||
test_case = Cases()
|
||||
for array in test_case.array_sets:
|
||||
# Check for dtype matching
|
||||
actual = onp.asarray(array)
|
||||
expected = mnp.asarray(array).asnumpy()
|
||||
# Since we set float32/int32 as the default dtype in mindspore, we need
|
||||
# to make a conversion between numpy.asarray and mindspore.numpy.asarray
|
||||
if actual.dtype is onp.dtype('float64'):
|
||||
assert expected.dtype == onp.dtype('float32')
|
||||
elif actual.dtype is onp.dtype('int64'):
|
||||
assert expected.dtype == onp.dtype('int32')
|
||||
else:
|
||||
assert actual.dtype == expected.dtype
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.asarray(array, test_case.onp_dtypes[i])
|
||||
expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_array():
|
||||
# array's function is very similar to asarray, so we mainly test the
|
||||
# `copy` argument.
|
||||
test_case = Cases()
|
||||
for array in test_case.array_sets:
|
||||
arr1 = mnp.asarray(array)
|
||||
arr2 = mnp.array(arr1, copy=False)
|
||||
arr3 = mnp.array(arr1)
|
||||
arr4 = mnp.asarray(array, dtype='int32')
|
||||
arr5 = mnp.asarray(arr4, dtype=mnp.int32)
|
||||
assert arr1 is arr2
|
||||
assert arr1 is not arr3
|
||||
assert arr4 is arr5
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_asfarray():
|
||||
test_case = Cases()
|
||||
for array in test_case.array_sets:
|
||||
# Check for dtype matching
|
||||
actual = onp.asfarray(array)
|
||||
expected = mnp.asfarray(array).asnumpy()
|
||||
# Since we set float32/int32 as the default dtype in mindspore, we need
|
||||
# to make a conversion between numpy.asarray and mindspore.numpy.asarray
|
||||
if actual.dtype is onp.dtype('float64'):
|
||||
assert expected.dtype == onp.dtype('float32')
|
||||
else:
|
||||
assert actual.dtype == expected.dtype
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.asfarray(array, test_case.onp_dtypes[i])
|
||||
expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_zeros():
|
||||
test_case = Cases()
|
||||
for shape in test_case.all_shapes:
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.zeros(shape, test_case.onp_dtypes[i])
|
||||
expected = mnp.zeros(shape, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
actual = onp.zeros(shape)
|
||||
expected = mnp.zeros(shape).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ones():
|
||||
test_case = Cases()
|
||||
for shape in test_case.all_shapes:
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.ones(shape, test_case.onp_dtypes[i])
|
||||
expected = mnp.ones(shape, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
actual = onp.ones(shape)
|
||||
expected = mnp.ones(shape).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_full():
|
||||
actual = onp.full((2, 2), [1, 2])
|
||||
expected = mnp.full((2, 2), [1, 2]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((2, 0), onp.inf)
|
||||
expected = mnp.full((2, 0), mnp.inf).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((2, 3), True)
|
||||
expected = mnp.full((2, 3), True).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((3, 4, 5), 7.5)
|
||||
expected = mnp.full((3, 4, 5), 7.5).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_eye():
|
||||
test_case = Cases()
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
for m in range(1, 5):
|
||||
actual = onp.eye(m, dtype=test_case.onp_dtypes[i])
|
||||
expected = mnp.eye(m, dtype=test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
for n in range(1, 5):
|
||||
for k in range(0, 5):
|
||||
actual = onp.eye(m, n, k, dtype=test_case.onp_dtypes[i])
|
||||
expected = mnp.eye(
|
||||
m, n, k, dtype=test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_identity():
|
||||
test_case = Cases()
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
for m in range(1, 5):
|
||||
actual = onp.identity(m, dtype=test_case.onp_dtypes[i])
|
||||
expected = mnp.identity(m, dtype=test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_arange():
|
||||
actual = onp.arange(10)
|
||||
expected = mnp.arange(10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(0, 10)
|
||||
expected = mnp.arange(0, 10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(start=10)
|
||||
expected = mnp.arange(start=10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(start=10, step=0.1)
|
||||
expected = mnp.arange(start=10, step=0.1).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
actual = onp.arange(10, step=0.1)
|
||||
expected = mnp.arange(10, step=0.1).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
actual = onp.arange(0.1, 9.9)
|
||||
expected = mnp.arange(0.1, 9.9).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_linspace():
|
||||
actual = onp.linspace(2.0, 3.0, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0, num=5).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
actual = onp.linspace(
|
||||
2.0, 3.0, num=5, endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0, num=5, retstep=True)
|
||||
match_array(actual[0], expected[0].asnumpy())
|
||||
assert actual[1] == expected[1]
|
||||
|
||||
actual = onp.linspace(2.0, [3, 4, 5], num=5,
|
||||
endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.linspace(
|
||||
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_logspace():
|
||||
actual = onp.logspace(2.0, 3.0, dtype=onp.float32)
|
||||
expected = mnp.logspace(2.0, 3.0).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32)
|
||||
expected = mnp.logspace(2.0, 3.0, num=5).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.logspace(
|
||||
2.0, 3.0, num=5, endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.logspace(2.0, [3, 4, 5], num=5,
|
||||
endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.logspace(
|
||||
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_empty():
|
||||
test_case = Cases()
|
||||
for shape in test_case.all_shapes:
|
||||
for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes,
|
||||
test_case.onp_dtypes):
|
||||
actual = mnp.empty(shape, mnp_dtype).asnumpy()
|
||||
expected = onp.empty(shape, onp_dtype)
|
||||
match_meta(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_empty_like():
|
||||
test_case = Cases()
|
||||
for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes):
|
||||
actual = mnp.empty_like(mnp_proto).asnumpy()
|
||||
expected = onp.empty_like(onp_proto)
|
||||
assert actual.shape == expected.shape
|
||||
|
||||
for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes,
|
||||
test_case.onp_dtypes):
|
||||
actual = mnp.empty_like(mnp_proto, dtype=mnp_dtype).asnumpy()
|
||||
expected = onp.empty_like(onp_proto, dtype=onp_dtype)
|
||||
match_meta(actual, expected)
|
||||
|
||||
|
||||
def run_x_like(mnp_fn, onp_fn):
|
||||
test_case = Cases()
|
||||
for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes):
|
||||
actual = mnp_fn(mnp_proto).asnumpy()
|
||||
expected = onp_fn(onp_proto)
|
||||
match_array(actual, expected)
|
||||
|
||||
for shape in test_case.all_shapes:
|
||||
actual = mnp_fn(mnp_proto, shape=shape).asnumpy()
|
||||
expected = onp_fn(onp_proto, shape=shape)
|
||||
match_array(actual, expected)
|
||||
|
||||
for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes,
|
||||
test_case.onp_dtypes):
|
||||
actual = mnp_fn(mnp_proto, dtype=mnp_dtype).asnumpy()
|
||||
expected = onp_fn(onp_proto, dtype=onp_dtype)
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = mnp_fn(mnp_proto, dtype=mnp_dtype,
|
||||
shape=shape).asnumpy()
|
||||
expected = onp_fn(onp_proto, dtype=onp_dtype, shape=shape)
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_ones_like():
|
||||
run_x_like(mnp.ones_like, onp.ones_like)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_zeros_like():
|
||||
run_x_like(mnp.zeros_like, onp.zeros_like)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_full_like():
|
||||
test_case = Cases()
|
||||
for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes):
|
||||
shape = onp.zeros_like(onp_proto).shape
|
||||
fill_value = rand_int()
|
||||
actual = mnp.full_like(mnp_proto, fill_value).asnumpy()
|
||||
expected = onp.full_like(onp_proto, fill_value)
|
||||
match_array(actual, expected)
|
||||
|
||||
for i in range(len(shape) - 1, 0, -1):
|
||||
fill_value = rand_int(*shape[i:])
|
||||
actual = mnp.full_like(mnp_proto, fill_value).asnumpy()
|
||||
expected = onp.full_like(onp_proto, fill_value)
|
||||
match_array(actual, expected)
|
||||
|
||||
fill_value = rand_int(1, *shape[i + 1:])
|
||||
actual = mnp.full_like(mnp_proto, fill_value).asnumpy()
|
||||
expected = onp.full_like(onp_proto, fill_value)
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_tri_triu_tril():
|
||||
x = mnp.ones((16, 32), dtype="bool")
|
||||
match_array(mnp.tril(x).asnumpy(), onp.tril(x.asnumpy()))
|
||||
match_array(mnp.tril(x, -1).asnumpy(), onp.tril(x.asnumpy(), -1))
|
||||
match_array(mnp.triu(x).asnumpy(), onp.triu(x.asnumpy()))
|
||||
match_array(mnp.triu(x, -1).asnumpy(), onp.triu(x.asnumpy(), -1))
|
||||
|
||||
x = mnp.ones((64, 64), dtype="uint8")
|
||||
match_array(mnp.tril(x).asnumpy(), onp.tril(x.asnumpy()))
|
||||
match_array(mnp.tril(x, 25).asnumpy(), onp.tril(x.asnumpy(), 25))
|
||||
match_array(mnp.triu(x).asnumpy(), onp.triu(x.asnumpy()))
|
||||
match_array(mnp.triu(x, 25).asnumpy(), onp.triu(x.asnumpy(), 25))
|
||||
|
||||
match_array(mnp.tri(64, 64).asnumpy(), onp.tri(64, 64))
|
||||
match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10))
|
||||
|
||||
|
||||
def mnp_diagonal(arr):
|
||||
return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0)
|
||||
|
||||
|
||||
def onp_diagonal(arr):
|
||||
return onp.diagonal(arr, offset=2, axis1=-1, axis2=0)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_diagonal():
|
||||
arr = rand_int(0, 0)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=1)
|
||||
|
||||
arr = rand_int(3, 5)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=1)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=0)
|
||||
|
||||
arr = rand_int(7, 4, 9)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-1)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-2, axis2=2)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr,
|
||||
offset=i, axis1=-1, axis2=-2)
|
||||
|
||||
arr = rand_int(2, 5, 8, 1)
|
||||
match_res(mnp_diagonal, onp_diagonal, arr)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-3, axis2=2)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=3)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-2)
|
||||
match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=2, axis2=-1)
|
||||
|
||||
|
||||
def mnp_trace(arr):
|
||||
return mnp.trace(arr, offset=4, axis1=1, axis2=2)
|
||||
|
||||
|
||||
def onp_trace(arr):
|
||||
return onp.trace(arr, offset=4, axis1=1, axis2=2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_trace():
|
||||
arr = rand_int(0, 0)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=1)
|
||||
|
||||
arr = rand_int(3, 5)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=1)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=0)
|
||||
|
||||
arr = rand_int(7, 4, 9)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-1)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-2, axis2=2)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-1, axis2=-2)
|
||||
|
||||
arr = rand_int(2, 5, 8, 1)
|
||||
match_res(mnp_trace, onp_trace, arr)
|
||||
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-3, axis2=2)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=3)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-2)
|
||||
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_asarray_exception():
|
||||
with pytest.raises(TypeError):
|
||||
mnp.asarray({1, 2, 3})
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_linspace_exception():
|
||||
with pytest.raises(TypeError):
|
||||
mnp.linspace(0, 1, num=2.5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_empty_like_exception():
|
||||
with pytest.raises(ValueError):
|
||||
mnp.empty_like([[1, 2, 3], [4, 5]])
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,567 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""unit tests for numpy math operations"""
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import numpy as onp
|
||||
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore import context
|
||||
|
||||
|
||||
def rand_int(*shape):
|
||||
"""return an random integer array with parameter shape"""
|
||||
res = onp.random.randint(low=1, high=5, size=shape)
|
||||
if isinstance(res, onp.ndarray):
|
||||
return res.astype(onp.float32)
|
||||
return float(res)
|
||||
|
||||
|
||||
# return an random boolean array
|
||||
def rand_bool(*shape):
|
||||
return onp.random.rand(*shape) > 0.5
|
||||
|
||||
|
||||
class Cases():
|
||||
def __init__(self):
|
||||
self.device_cpu = context.get_context('device_target')
|
||||
|
||||
self.arrs = [
|
||||
rand_int(2),
|
||||
rand_int(2, 3),
|
||||
rand_int(2, 3, 4),
|
||||
rand_int(2, 3, 4, 5),
|
||||
]
|
||||
|
||||
# scalars expanded across the 0th dimension
|
||||
self.scalars = [
|
||||
rand_int(),
|
||||
rand_int(1),
|
||||
rand_int(1, 1),
|
||||
rand_int(1, 1, 1, 1),
|
||||
]
|
||||
|
||||
# empty arrays
|
||||
self.empty_arrs = [
|
||||
rand_int(0),
|
||||
rand_int(4, 0),
|
||||
rand_int(2, 0, 2),
|
||||
rand_int(5, 0, 7, 0),
|
||||
]
|
||||
|
||||
# arrays of the same size expanded across the 0th dimension
|
||||
self.expanded_arrs = [
|
||||
rand_int(2, 3),
|
||||
rand_int(1, 2, 3),
|
||||
rand_int(1, 1, 2, 3),
|
||||
rand_int(1, 1, 1, 2, 3),
|
||||
]
|
||||
|
||||
# arrays of the same size expanded across the 0th dimension
|
||||
self.expanded_arrs = [
|
||||
rand_int(2, 3),
|
||||
rand_int(1, 2, 3),
|
||||
rand_int(1, 1, 2, 3),
|
||||
rand_int(1, 1, 1, 2, 3),
|
||||
]
|
||||
|
||||
# arrays with last dimension aligned
|
||||
self.aligned_arrs = [
|
||||
rand_int(2, 3),
|
||||
rand_int(1, 4, 3),
|
||||
rand_int(5, 1, 2, 3),
|
||||
rand_int(4, 2, 1, 1, 3),
|
||||
]
|
||||
|
||||
# arrays which can be broadcast
|
||||
self.broadcastables = [
|
||||
rand_int(5),
|
||||
rand_int(6, 1),
|
||||
rand_int(7, 1, 5),
|
||||
rand_int(8, 1, 6, 1)
|
||||
]
|
||||
|
||||
# boolean arrays which can be broadcast
|
||||
self.bool_broadcastables = [
|
||||
rand_bool(),
|
||||
rand_bool(1),
|
||||
rand_bool(5),
|
||||
rand_bool(6, 1),
|
||||
rand_bool(7, 1, 5),
|
||||
rand_bool(8, 1, 6, 1),
|
||||
]
|
||||
|
||||
# core dimension 0 is matched for each
|
||||
# pair of array[i] and array[i + 1]
|
||||
self.core_broadcastables = [
|
||||
rand_int(3),
|
||||
rand_int(3),
|
||||
rand_int(6),
|
||||
rand_int(6, 4),
|
||||
rand_int(5, 2),
|
||||
rand_int(2),
|
||||
rand_int(2, 9),
|
||||
rand_int(9, 8),
|
||||
rand_int(6),
|
||||
rand_int(2, 6, 5),
|
||||
rand_int(9, 2, 7),
|
||||
rand_int(7),
|
||||
rand_int(5, 2, 4),
|
||||
rand_int(6, 1, 4, 9),
|
||||
rand_int(7, 1, 5, 3, 2),
|
||||
rand_int(8, 1, 6, 1, 2, 9),
|
||||
]
|
||||
|
||||
# arrays with dimensions of size 1
|
||||
self.nested_arrs = [
|
||||
rand_int(1),
|
||||
rand_int(1, 2),
|
||||
rand_int(3, 1, 8),
|
||||
rand_int(1, 3, 9, 1),
|
||||
]
|
||||
|
||||
|
||||
test_case = Cases()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
def mnp_add(x1, x2):
|
||||
return mnp.add(x1, x2)
|
||||
|
||||
|
||||
def onp_add(x1, x2):
|
||||
return onp.add(x1, x2)
|
||||
|
||||
|
||||
def mnp_subtract(x1, x2):
|
||||
return mnp.subtract(x1, x2)
|
||||
|
||||
|
||||
def onp_subtract(x1, x2):
|
||||
return onp.subtract(x1, x2)
|
||||
|
||||
|
||||
def mnp_mutiply(x1, x2):
|
||||
return mnp.multiply(x1, x2)
|
||||
|
||||
|
||||
def onp_multiply(x1, x2):
|
||||
return onp.multiply(x1, x2)
|
||||
|
||||
|
||||
def mnp_divide(x1, x2):
|
||||
return mnp.divide(x1, x2)
|
||||
|
||||
|
||||
def onp_divide(x1, x2):
|
||||
return onp.divide(x1, x2)
|
||||
|
||||
|
||||
def mnp_power(x1, x2):
|
||||
return mnp.power(x1, x2)
|
||||
|
||||
|
||||
def onp_power(x1, x2):
|
||||
return onp.power(x1, x2)
|
||||
|
||||
|
||||
def mnp_inner(a, b):
|
||||
return mnp.inner(a, b)
|
||||
|
||||
|
||||
def onp_inner(a, b):
|
||||
return onp.inner(a, b)
|
||||
|
||||
|
||||
def mnp_dot(a, b):
|
||||
return mnp.dot(a, b)
|
||||
|
||||
|
||||
def onp_dot(a, b):
|
||||
return onp.dot(a, b)
|
||||
|
||||
|
||||
def mnp_outer(a, b):
|
||||
return mnp.outer(a, b)
|
||||
|
||||
|
||||
def onp_outer(a, b):
|
||||
return onp.outer(a, b)
|
||||
|
||||
|
||||
def mnp_add_kwargs(x, y, where=None, out=None):
|
||||
return mnp.add(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def onp_add_kwargs(x, y, where=None, out=None):
|
||||
return onp.add(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def mnp_subtract_kwargs(x, y, where=None, out=None):
|
||||
return mnp.subtract(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def onp_subtract_kwargs(x, y, where=None, out=None):
|
||||
return onp.subtract(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def mnp_multiply_kwargs(x, y, where=None, out=None):
|
||||
return mnp.multiply(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def onp_multiply_kwargs(x, y, where=None, out=None):
|
||||
return onp.multiply(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def mnp_divide_kwargs(x, y, where=None, out=None):
|
||||
return mnp.divide(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def onp_divide_kwargs(x, y, where=None, out=None):
|
||||
return onp.divide(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def mnp_power_kwargs(x, y, where=None, out=None):
|
||||
return mnp.power(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def onp_power_kwargs(x, y, where=None, out=None):
|
||||
return onp.power(x, y, where=where, out=out)
|
||||
|
||||
|
||||
def mnp_tensordot(x, y):
|
||||
a = mnp.tensordot(x, y)
|
||||
b = mnp.tensordot(x, y, axes=0)
|
||||
c = mnp.tensordot(x, y, axes=1)
|
||||
d = mnp.tensordot(x, y, axes=2)
|
||||
e = mnp.tensordot(x, y, axes=(3, 0))
|
||||
f = mnp.tensordot(x, y, axes=[2, 1])
|
||||
g = mnp.tensordot(x, y, axes=((2, 3), (0, 1)))
|
||||
h = mnp.tensordot(x, y, axes=[[3, 2], [1, 0]])
|
||||
return a, b, c, d, e, f, g, h
|
||||
|
||||
|
||||
def onp_tensordot(x, y):
|
||||
a = onp.tensordot(x, y)
|
||||
b = onp.tensordot(x, y, axes=0)
|
||||
c = onp.tensordot(x, y, axes=1)
|
||||
d = onp.tensordot(x, y, axes=2)
|
||||
e = onp.tensordot(x, y, axes=(3, 0))
|
||||
f = onp.tensordot(x, y, axes=[2, 1])
|
||||
g = onp.tensordot(x, y, axes=((2, 3), (0, 1)))
|
||||
h = onp.tensordot(x, y, axes=[[3, 2], [1, 0]])
|
||||
return a, b, c, d, e, f, g, h
|
||||
|
||||
|
||||
def run_binop_test(mnp_fn, onp_fn):
|
||||
for arr in test_case.arrs:
|
||||
match_res(mnp_fn, onp_fn, arr, arr)
|
||||
|
||||
for scalar in test_case.scalars:
|
||||
match_res(mnp_fn, onp_fn, arr, scalar)
|
||||
match_res(mnp_fn, onp_fn, scalar, arr)
|
||||
|
||||
for scalar1 in test_case.scalars:
|
||||
for scalar2 in test_case.scalars:
|
||||
match_res(mnp_fn, onp_fn, scalar1, scalar2)
|
||||
|
||||
for expanded_arr1 in test_case.expanded_arrs:
|
||||
for expanded_arr2 in test_case.expanded_arrs:
|
||||
match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2)
|
||||
|
||||
for broadcastable1 in test_case.broadcastables:
|
||||
for broadcastable2 in test_case.broadcastables:
|
||||
match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2)
|
||||
|
||||
|
||||
def run_multi_test(mnp_fn, onp_fn, arrs):
|
||||
mnp_arrs = map(mnp.asarray, arrs)
|
||||
for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)):
|
||||
match_array(actual.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_add():
|
||||
run_binop_test(mnp_add, onp_add)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_subtract():
|
||||
run_binop_test(mnp_subtract, onp_subtract)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_multiply():
|
||||
run_binop_test(mnp_mutiply, onp_multiply)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_divide():
|
||||
run_binop_test(mnp_divide, onp_divide)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_power():
|
||||
run_binop_test(mnp_power, onp_power)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_inner():
|
||||
for arr1 in test_case.aligned_arrs:
|
||||
for arr2 in test_case.aligned_arrs:
|
||||
match_res(mnp_inner, onp_inner, arr1, arr2)
|
||||
|
||||
for scalar1 in test_case.scalars:
|
||||
for scalar2 in test_case.scalars:
|
||||
match_res(mnp_inner, onp_inner,
|
||||
scalar1, scalar2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_dot():
|
||||
# test case (1D, 1D)
|
||||
match_res(mnp_dot, onp_dot, rand_int(3), rand_int(3))
|
||||
|
||||
# test case (2D, 2D)
|
||||
match_res(mnp_dot, onp_dot, rand_int(4, 7), rand_int(7, 2))
|
||||
|
||||
# test case (0D, _) (_, 0D)
|
||||
match_res(mnp_dot, onp_dot, rand_int(), rand_int(1, 9, 3))
|
||||
match_res(mnp_dot, onp_dot, rand_int(8, 5, 6, 3), rand_int())
|
||||
|
||||
# test case (ND, 1D)
|
||||
match_res(mnp_dot, onp_dot, rand_int(2, 4, 5), rand_int(5))
|
||||
|
||||
# test case (ND, MD)
|
||||
match_res(mnp_dot, onp_dot, rand_int(5, 4, 1, 8), rand_int(8, 3))
|
||||
|
||||
for i in range(8):
|
||||
match_res(mnp_dot, onp_dot,
|
||||
test_case.core_broadcastables[2*i], test_case.core_broadcastables[2*i + 1])
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_outer():
|
||||
run_binop_test(mnp_outer, onp_outer)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_add_kwargs():
|
||||
for where in test_case.bool_broadcastables[:2]:
|
||||
for x in test_case.broadcastables[:2]:
|
||||
for y in test_case.broadcastables[:2]:
|
||||
shape_out = onp.broadcast(where, x, y).shape
|
||||
out = rand_int(*shape_out)
|
||||
match_res(mnp_add_kwargs, onp_add_kwargs, x, y, where, out)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensordot():
|
||||
x = rand_int(4, 2, 7, 7)
|
||||
y = rand_int(7, 7, 6)
|
||||
run_multi_test(mnp_tensordot, onp_tensordot, (x, y))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_type_promotion():
|
||||
arr = rand_int(2, 3)
|
||||
onp_sum = onp_add(arr, arr)
|
||||
|
||||
a = mnp.asarray(arr, dtype='float16')
|
||||
b = mnp.asarray(arr, dtype='float32')
|
||||
c = mnp.asarray(arr, dtype='int32')
|
||||
|
||||
match_array(mnp_add(a, b).asnumpy(), onp_sum)
|
||||
match_array(mnp_add(b, c).asnumpy(), onp_sum)
|
||||
|
||||
|
||||
def mnp_absolute(x):
|
||||
return mnp.absolute(x)
|
||||
|
||||
|
||||
def onp_absolute(x):
|
||||
return onp.absolute(x)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_absolute():
|
||||
arr = rand_int(2, 3)
|
||||
|
||||
a = mnp.asarray(arr, dtype='float16')
|
||||
b = mnp.asarray(arr, dtype='float32')
|
||||
c = mnp.asarray(arr, dtype='uint8')
|
||||
d = mnp.asarray(arr, dtype='bool')
|
||||
|
||||
match_array(mnp_absolute(a).asnumpy(), onp_absolute(a.asnumpy()))
|
||||
match_array(mnp_absolute(b).asnumpy(), onp_absolute(b.asnumpy()))
|
||||
match_array(mnp_absolute(c).asnumpy(), onp_absolute(c.asnumpy()))
|
||||
match_array(mnp_absolute(d).asnumpy(), onp_absolute(d.asnumpy()))
|
||||
|
||||
where = rand_int(2, 3).astype('bool')
|
||||
out = rand_int(2, 3)
|
||||
match_array(mnp.absolute(a, out=mnp.asarray(out), where=mnp.asarray(where)).asnumpy(),
|
||||
onp.absolute(a.asnumpy(), out=out, where=where))
|
||||
|
||||
|
||||
def mnp_add_dtype(x1, x2, out, where):
|
||||
a = mnp.add(x1, x2, dtype=mnp.float16)
|
||||
b = mnp.add(x1, x2, out=out, dtype=mnp.float16)
|
||||
c = mnp.add(x1, x2, where=where, dtype=mnp.float16)
|
||||
d = mnp.add(x1, x2, out=out, where=where, dtype=mnp.float16)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
def onp_add_dtype(x1, x2, out, where):
|
||||
a = onp.add(x1, x2, dtype=onp.float16)
|
||||
b = onp.add(x1, x2, out=out, dtype=onp.float16)
|
||||
c = onp.add(x1, x2, where=where, dtype=onp.float16)
|
||||
d = onp.add(x1, x2, out=out, where=where, dtype=onp.float16)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_add_dtype():
|
||||
x1 = rand_int(2, 3).astype('int32')
|
||||
x2 = rand_int(2, 3).astype('int32')
|
||||
out = rand_int(2, 3).astype('float32')
|
||||
where = rand_bool(2, 3)
|
||||
arrs = (x1, x2, out, where)
|
||||
mnp_arrs = map(mnp.array, arrs)
|
||||
mnp_res = mnp_add_dtype(*mnp_arrs)
|
||||
onp_res = onp_add_dtype(*arrs)
|
||||
for actual, expected in zip(mnp_res, onp_res):
|
||||
assert actual.asnumpy().dtype == expected.dtype
|
||||
|
||||
|
||||
# check if the output from mnp function and onp function applied on the arrays are matched
|
||||
|
||||
|
||||
def match_res(mnp_fn, onp_fn, *arrs):
|
||||
mnp_arrs = map(partial(mnp.asarray, dtype='float32'), arrs)
|
||||
mnp_res = mnp_fn(*mnp_arrs)
|
||||
onp_res = onp_fn(*arrs)
|
||||
if isinstance(mnp_res, (tuple, list)):
|
||||
for actual, expected in zip(mnp_res, onp_res):
|
||||
match_array(actual.asnumpy(), expected)
|
||||
else:
|
||||
match_array(mnp_res.asnumpy(), onp_res)
|
||||
|
||||
|
||||
def match_array(actual, expected, error=5):
|
||||
if error > 0:
|
||||
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
|
||||
decimal=error)
|
||||
else:
|
||||
onp.testing.assert_equal(actual.tolist(), expected.tolist())
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_exception_innner():
|
||||
with pytest.raises(ValueError):
|
||||
mnp.inner(mnp.asarray(test_case.arrs[0]),
|
||||
mnp.asarray(test_case.arrs[1]))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_exception_add():
|
||||
with pytest.raises(ValueError):
|
||||
mnp.add(mnp.asarray(test_case.arrs[1]), mnp.asarray(test_case.arrs[2]))
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_exception_mean():
|
||||
with pytest.raises(ValueError):
|
||||
mnp.mean(mnp.asarray(test_case.arrs[0]), (-1, 0))
|
|
@ -69,6 +69,24 @@ def test_tensor_size():
|
|||
assert arr.size == b.size
|
||||
|
||||
|
||||
def test_tensor_itemsize():
|
||||
arr = np.ones((1, 2, 3))
|
||||
b = ms.Tensor(arr)
|
||||
assert arr.itemsize == b.itemsize
|
||||
|
||||
|
||||
def test_tensor_strides():
|
||||
arr = np.ones((3, 4, 5, 6))
|
||||
b = ms.Tensor(arr)
|
||||
assert arr.strides == b.strides
|
||||
|
||||
|
||||
def test_tensor_nbytes():
|
||||
arr = np.ones((3, 4, 5, 6))
|
||||
b = ms.Tensor(arr)
|
||||
assert arr.nbytes == b.nbytes
|
||||
|
||||
|
||||
def test_dtype():
|
||||
a = ms.Tensor(np.ones((2, 3), dtype=np.int32))
|
||||
assert a.dtype == ms.int32
|
||||
|
|
|
@ -1,591 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""unit tests for numpy array operations"""
|
||||
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
import numpy as onp
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
|
||||
|
||||
class Cases():
|
||||
def __init__(self):
|
||||
self.all_shapes = [
|
||||
0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]
|
||||
]
|
||||
self.onp_dtypes = [onp.int32, 'int32', int,
|
||||
onp.float32, 'float32', float,
|
||||
onp.uint32, 'uint32',
|
||||
onp.bool_, 'bool', bool]
|
||||
|
||||
self.mnp_dtypes = [mnp.int32, 'int32', int,
|
||||
mnp.float32, 'float32', float,
|
||||
mnp.uint32, 'uint32',
|
||||
mnp.bool_, 'bool', bool]
|
||||
|
||||
self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,),
|
||||
[(1, 2, 3), (4, 5, 6)], onp.random.random(
|
||||
(100, 100)).astype(onp.float32),
|
||||
onp.random.random((100, 100)).astype(onp.bool)]
|
||||
|
||||
|
||||
def match_array(actual, expected, error=0):
|
||||
if error > 0:
|
||||
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
|
||||
decimal=error)
|
||||
else:
|
||||
onp.testing.assert_equal(actual.tolist(), expected.tolist())
|
||||
|
||||
|
||||
def check_all_results(onp_results, mnp_results):
|
||||
"""Check all results from numpy and mindspore.numpy"""
|
||||
for i, _ in enumerate(onp_results):
|
||||
match_array(onp_results[i], mnp_results[i].asnumpy())
|
||||
|
||||
|
||||
def test_asarray():
|
||||
test_case = Cases()
|
||||
for array in test_case.array_sets:
|
||||
# Check for dtype matching
|
||||
actual = onp.asarray(array)
|
||||
expected = mnp.asarray(array).asnumpy()
|
||||
# Since we set float32/int32 as the default dtype in mindspore, we need
|
||||
# to make a conversion between numpy.asarray and mindspore.numpy.asarray
|
||||
if actual.dtype is onp.dtype('float64'):
|
||||
assert expected.dtype == onp.dtype('float32')
|
||||
elif actual.dtype is onp.dtype('int64'):
|
||||
assert expected.dtype == onp.dtype('int32')
|
||||
else:
|
||||
assert actual.dtype == expected.dtype
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.asarray(array, test_case.onp_dtypes[i])
|
||||
expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
def test_array():
|
||||
# array's function is very similar to asarray, so we mainly test the
|
||||
# `copy` argument.
|
||||
test_case = Cases()
|
||||
for array in test_case.array_sets:
|
||||
arr1 = mnp.asarray(array)
|
||||
arr2 = mnp.array(arr1, copy=False)
|
||||
arr3 = mnp.array(arr1)
|
||||
arr4 = mnp.asarray(array, dtype='int32')
|
||||
arr5 = mnp.asarray(arr4, dtype=mnp.int32)
|
||||
assert arr1 is arr2
|
||||
assert arr1 is not arr3
|
||||
assert arr4 is arr5
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
def test_asfarray():
|
||||
test_case = Cases()
|
||||
for array in test_case.array_sets:
|
||||
# Check for dtype matching
|
||||
actual = onp.asfarray(array)
|
||||
expected = mnp.asfarray(array).asnumpy()
|
||||
# Since we set float32/int32 as the default dtype in mindspore, we need
|
||||
# to make a conversion between numpy.asarray and mindspore.numpy.asarray
|
||||
if actual.dtype is onp.dtype('float64'):
|
||||
assert expected.dtype == onp.dtype('float32')
|
||||
else:
|
||||
assert actual.dtype == expected.dtype
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.asfarray(array, test_case.onp_dtypes[i])
|
||||
expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
# Additional tests for nested tensor/numpy_array mixture
|
||||
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
|
||||
|
||||
actual = onp.asarray(onp_input)
|
||||
expected = mnp.asarray(mnp_input).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
|
||||
def test_zeros():
|
||||
test_case = Cases()
|
||||
for shape in test_case.all_shapes:
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.zeros(shape, test_case.onp_dtypes[i])
|
||||
expected = mnp.zeros(shape, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
actual = onp.zeros(shape)
|
||||
expected = mnp.zeros(shape).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
def test_ones():
|
||||
test_case = Cases()
|
||||
for shape in test_case.all_shapes:
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
actual = onp.ones(shape, test_case.onp_dtypes[i])
|
||||
expected = mnp.ones(shape, test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
actual = onp.ones(shape)
|
||||
expected = mnp.ones(shape).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
def test_full():
|
||||
actual = onp.full((2, 2), [1, 2])
|
||||
expected = mnp.full((2, 2), [1, 2]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((2, 0), onp.inf)
|
||||
expected = mnp.full((2, 0), mnp.inf).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((2, 3), True)
|
||||
expected = mnp.full((2, 3), True).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.full((3, 4, 5), 7.5)
|
||||
expected = mnp.full((3, 4, 5), 7.5).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
def test_eye():
|
||||
test_case = Cases()
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
for m in range(1, 5):
|
||||
actual = onp.eye(m, dtype=test_case.onp_dtypes[i])
|
||||
expected = mnp.eye(m, dtype=test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
for n in range(1, 5):
|
||||
for k in range(0, 5):
|
||||
actual = onp.eye(m, n, k, dtype=test_case.onp_dtypes[i])
|
||||
expected = mnp.eye(
|
||||
m, n, k, dtype=test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
def test_identity():
|
||||
test_case = Cases()
|
||||
for i in range(len(test_case.onp_dtypes)):
|
||||
for m in range(1, 5):
|
||||
actual = onp.identity(m, dtype=test_case.onp_dtypes[i])
|
||||
expected = mnp.identity(m, dtype=test_case.mnp_dtypes[i]).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
def test_arange():
|
||||
actual = onp.arange(10)
|
||||
expected = mnp.arange(10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(0, 10)
|
||||
expected = mnp.arange(0, 10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(start=10)
|
||||
expected = mnp.arange(start=10).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.arange(start=10, step=0.1)
|
||||
expected = mnp.arange(start=10, step=0.1).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
actual = onp.arange(10, step=0.1)
|
||||
expected = mnp.arange(10, step=0.1).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
actual = onp.arange(0.1, 9.9)
|
||||
expected = mnp.arange(0.1, 9.9).asnumpy()
|
||||
match_array(actual, expected, error=6)
|
||||
|
||||
|
||||
def test_linspace():
|
||||
actual = onp.linspace(2.0, 3.0, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0, num=5).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
actual = onp.linspace(
|
||||
2.0, 3.0, num=5, endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected, error=7)
|
||||
|
||||
actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32)
|
||||
expected = mnp.linspace(2.0, 3.0, num=5, retstep=True)
|
||||
match_array(actual[0], expected[0].asnumpy())
|
||||
assert actual[1] == expected[1]
|
||||
|
||||
actual = onp.linspace(2.0, [3, 4, 5], num=5,
|
||||
endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.linspace(
|
||||
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
def test_logspace():
|
||||
actual = onp.logspace(2.0, 3.0, dtype=onp.float32)
|
||||
expected = mnp.logspace(2.0, 3.0).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32)
|
||||
expected = mnp.logspace(2.0, 3.0, num=5).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.logspace(
|
||||
2.0, 3.0, num=5, endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
actual = onp.logspace(2.0, [3, 4, 5], num=5,
|
||||
endpoint=False, dtype=onp.float32)
|
||||
expected = mnp.logspace(
|
||||
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
# Test np.transpose and np.ndarray.transpose
|
||||
|
||||
|
||||
def mnp_transpose(input_tensor):
|
||||
a = mnp.transpose(input_tensor, (0, 2, 1))
|
||||
b = mnp.transpose(input_tensor, [2, 1, 0])
|
||||
c = mnp.transpose(input_tensor, (1, 0, 2))
|
||||
d = mnp.transpose(input_tensor)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
def onp_transpose(input_array):
|
||||
a = onp.transpose(input_array, (0, 2, 1))
|
||||
b = onp.transpose(input_array, [2, 1, 0])
|
||||
c = onp.transpose(input_array, (1, 0, 2))
|
||||
d = onp.transpose(input_array)
|
||||
return a, b, c, d
|
||||
|
||||
# Test np.expand_dims
|
||||
|
||||
|
||||
def mnp_expand_dims(input_tensor):
|
||||
a = mnp.expand_dims(input_tensor, 0)
|
||||
b = mnp.expand_dims(input_tensor, -1)
|
||||
c = mnp.expand_dims(input_tensor, axis=2)
|
||||
d = mnp.expand_dims(input_tensor, axis=-2)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
def onp_expand_dims(input_array):
|
||||
a = onp.expand_dims(input_array, 0)
|
||||
b = onp.expand_dims(input_array, -1)
|
||||
c = onp.expand_dims(input_array, axis=2)
|
||||
d = onp.expand_dims(input_array, axis=-2)
|
||||
return a, b, c, d
|
||||
|
||||
# Test np.squeeze
|
||||
|
||||
|
||||
def mnp_squeeze(input_tensor):
|
||||
a = mnp.squeeze(input_tensor)
|
||||
b = mnp.squeeze(input_tensor, 0)
|
||||
c = mnp.squeeze(input_tensor, axis=None)
|
||||
d = mnp.squeeze(input_tensor, axis=-3)
|
||||
e = mnp.squeeze(input_tensor, (2,))
|
||||
f = mnp.squeeze(input_tensor, (0, 2))
|
||||
return a, b, c, d, e, f
|
||||
|
||||
|
||||
def onp_squeeze(input_array):
|
||||
a = onp.squeeze(input_array)
|
||||
b = onp.squeeze(input_array, 0)
|
||||
c = onp.squeeze(input_array, axis=None)
|
||||
d = onp.squeeze(input_array, axis=-3)
|
||||
e = onp.squeeze(input_array, (2,))
|
||||
f = onp.squeeze(input_array, (0, 2))
|
||||
return a, b, c, d, e, f
|
||||
|
||||
# Test np.rollaxis
|
||||
|
||||
|
||||
def mnp_rollaxis(input_tensor):
|
||||
a = mnp.rollaxis(input_tensor, 0, 1)
|
||||
b = mnp.rollaxis(input_tensor, 0, 2)
|
||||
c = mnp.rollaxis(input_tensor, 2, 1)
|
||||
d = mnp.rollaxis(input_tensor, 2, 2)
|
||||
e = mnp.rollaxis(input_tensor, 0)
|
||||
f = mnp.rollaxis(input_tensor, 1)
|
||||
return a, b, c, d, e, f
|
||||
|
||||
|
||||
def onp_rollaxis(input_array):
|
||||
a = onp.rollaxis(input_array, 0, 1)
|
||||
b = onp.rollaxis(input_array, 0, 2)
|
||||
c = onp.rollaxis(input_array, 2, 1)
|
||||
d = onp.rollaxis(input_array, 2, 2)
|
||||
e = onp.rollaxis(input_array, 0)
|
||||
f = onp.rollaxis(input_array, 1)
|
||||
return a, b, c, d, e, f
|
||||
|
||||
# Test np.swapaxes
|
||||
|
||||
|
||||
def mnp_swapaxes(input_tensor):
|
||||
a = mnp.swapaxes(input_tensor, 0, 1)
|
||||
b = mnp.swapaxes(input_tensor, 1, 0)
|
||||
c = mnp.swapaxes(input_tensor, 1, 1)
|
||||
d = mnp.swapaxes(input_tensor, 2, 1)
|
||||
e = mnp.swapaxes(input_tensor, 1, 2)
|
||||
f = mnp.swapaxes(input_tensor, 2, 2)
|
||||
return a, b, c, d, e, f
|
||||
|
||||
|
||||
def onp_swapaxes(input_array):
|
||||
a = onp.swapaxes(input_array, 0, 1)
|
||||
b = onp.swapaxes(input_array, 1, 0)
|
||||
c = onp.swapaxes(input_array, 1, 1)
|
||||
d = onp.swapaxes(input_array, 2, 1)
|
||||
e = onp.swapaxes(input_array, 1, 2)
|
||||
f = onp.swapaxes(input_array, 2, 2)
|
||||
return a, b, c, d, e, f
|
||||
|
||||
# Test np.reshape
|
||||
|
||||
|
||||
def mnp_reshape(input_tensor):
|
||||
a = mnp.reshape(input_tensor, (3, 8))
|
||||
b = mnp.reshape(input_tensor, [3, -1])
|
||||
c = mnp.reshape(input_tensor, (-1, 12))
|
||||
d = mnp.reshape(input_tensor, (-1,))
|
||||
e = mnp.reshape(input_tensor, 24)
|
||||
f = mnp.reshape(input_tensor, [2, 4, -1])
|
||||
return a, b, c, d, e, f
|
||||
|
||||
|
||||
def onp_reshape(input_array):
|
||||
a = onp.reshape(input_array, (3, 8))
|
||||
b = onp.reshape(input_array, [3, -1])
|
||||
c = onp.reshape(input_array, (-1, 12))
|
||||
d = onp.reshape(input_array, (-1,))
|
||||
e = onp.reshape(input_array, 24)
|
||||
f = onp.reshape(input_array, [2, 4, -1])
|
||||
return a, b, c, d, e, f
|
||||
|
||||
# Test np.ravel
|
||||
|
||||
|
||||
def mnp_ravel(input_tensor):
|
||||
a = mnp.ravel(input_tensor)
|
||||
return a
|
||||
|
||||
|
||||
def onp_ravel(input_array):
|
||||
a = onp.ravel(input_array)
|
||||
return a
|
||||
|
||||
# Test np.concatenate
|
||||
|
||||
|
||||
def mnp_concatenate(input_tensor):
|
||||
a = mnp.concatenate(input_tensor, None)
|
||||
b = mnp.concatenate(input_tensor, 0)
|
||||
c = mnp.concatenate(input_tensor, 1)
|
||||
d = mnp.concatenate(input_tensor, 2)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
def onp_concatenate(input_array):
|
||||
a = onp.concatenate(input_array, None)
|
||||
b = onp.concatenate(input_array, 0)
|
||||
c = onp.concatenate(input_array, 1)
|
||||
d = onp.concatenate(input_array, 2)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
def test_transpose():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_transposed = onp_transpose(onp_array)
|
||||
m_transposed = mnp_transpose(mnp_array)
|
||||
check_all_results(o_transposed, m_transposed)
|
||||
|
||||
|
||||
def test_expand_dims():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_expanded = onp_expand_dims(onp_array)
|
||||
m_expanded = mnp_expand_dims(mnp_array)
|
||||
check_all_results(o_expanded, m_expanded)
|
||||
|
||||
|
||||
def test_squeeze():
|
||||
onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_squeezed = onp_squeeze(onp_array)
|
||||
m_squeezed = mnp_squeeze(mnp_array)
|
||||
check_all_results(o_squeezed, m_squeezed)
|
||||
|
||||
onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_squeezed = onp_squeeze(onp_array)
|
||||
m_squeezed = mnp_squeeze(mnp_array)
|
||||
check_all_results(o_squeezed, m_squeezed)
|
||||
|
||||
|
||||
def test_rollaxis():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_rolled = onp_rollaxis(onp_array)
|
||||
m_rolled = mnp_rollaxis(mnp_array)
|
||||
check_all_results(o_rolled, m_rolled)
|
||||
|
||||
|
||||
def test_swapaxes():
|
||||
onp_array = onp.random.random((3, 4, 5)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_swaped = onp_swapaxes(onp_array)
|
||||
m_swaped = mnp_swapaxes(mnp_array)
|
||||
check_all_results(o_swaped, m_swaped)
|
||||
|
||||
|
||||
def test_reshape():
|
||||
onp_array = onp.random.random((2, 3, 4)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_reshaped = onp_reshape(onp_array)
|
||||
m_reshaped = mnp_reshape(mnp_array)
|
||||
check_all_results(o_reshaped, m_reshaped)
|
||||
|
||||
|
||||
def test_ravel():
|
||||
onp_array = onp.random.random((2, 3, 4)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_ravel = onp_ravel(onp_array)
|
||||
m_ravel = mnp_ravel(mnp_array).asnumpy()
|
||||
match_array(o_ravel, m_ravel)
|
||||
|
||||
|
||||
def test_concatenate():
|
||||
onp_array = onp.random.random((5, 4, 3, 2)).astype('float32')
|
||||
mnp_array = mnp.asarray(onp_array)
|
||||
o_concatenate = onp_concatenate(onp_array)
|
||||
m_concatenate = mnp_concatenate(mnp_array)
|
||||
check_all_results(o_concatenate, m_concatenate)
|
||||
|
||||
|
||||
class ReshapeExpandSqueeze(Cell):
|
||||
def __init__(self):
|
||||
super(ReshapeExpandSqueeze, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
x = mnp.expand_dims(x, 2)
|
||||
x = mnp.reshape(x, (1, 2, 3, 4, 1, 1))
|
||||
x = mnp.squeeze(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransposeConcatRavel(Cell):
|
||||
def __init__(self):
|
||||
super(TransposeConcatRavel, self).__init__()
|
||||
|
||||
def construct(self, x1, x2, x3):
|
||||
x1 = mnp.transpose(x1, [0, 2, 1])
|
||||
x2 = x2.transpose(0, 2, 1)
|
||||
x = mnp.concatenate((x1, x2, x3), -1)
|
||||
x = mnp.ravel(x)
|
||||
return x
|
||||
|
||||
|
||||
class RollSwap(Cell):
|
||||
def __init__(self):
|
||||
super(RollSwap, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
x = mnp.rollaxis(x, 2)
|
||||
x = mnp.swapaxes(x, 0, 1)
|
||||
return x
|
||||
|
||||
|
||||
test_case_array_ops = [
|
||||
('ReshapeExpandSqueeze', {
|
||||
'block': ReshapeExpandSqueeze(),
|
||||
'desc_inputs': [mnp.ones((2, 3, 4))]}),
|
||||
|
||||
('TransposeConcatRavel', {
|
||||
'block': TransposeConcatRavel(),
|
||||
'desc_inputs': [mnp.ones((2, 3, 4)),
|
||||
mnp.ones((2, 3, 4)),
|
||||
mnp.ones((2, 4, 1))]}),
|
||||
|
||||
('RollSwap', {
|
||||
'block': RollSwap(),
|
||||
'desc_inputs': [mnp.ones((2, 3, 4))]})
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_array_ops]
|
||||
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
|
||||
# use -k to select certain testcast
|
||||
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_exec():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
return test_exec_case
|
||||
|
||||
|
||||
def test_expand_dims_exception():
|
||||
with pytest.raises(TypeError):
|
||||
mnp.expand_dims(mnp.ones((3, 3)), 1.2)
|
||||
|
||||
|
||||
def test_asarray_exception():
|
||||
with pytest.raises(TypeError):
|
||||
mnp.asarray({1, 2, 3})
|
||||
|
||||
|
||||
def test_swapaxes_exception():
|
||||
with pytest.raises(ValueError):
|
||||
mnp.swapaxes(mnp.ones((3, 3)), 1, 10)
|
||||
|
||||
|
||||
def test_linspace_exception():
|
||||
with pytest.raises(TypeError):
|
||||
mnp.linspace(0, 1, num=2.5)
|
|
@ -1,102 +0,0 @@
|
|||
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""unit tests for numpy math operations"""
|
||||
|
||||
import pytest
|
||||
import numpy as onp
|
||||
|
||||
import mindspore.numpy as mnp
|
||||
|
||||
|
||||
def rand_int(*shape):
|
||||
"""return an random integer array with parameter shape"""
|
||||
res = onp.random.randint(low=1, high=5, size=shape)
|
||||
if isinstance(res, onp.ndarray):
|
||||
res = res.astype(onp.float32)
|
||||
return res
|
||||
|
||||
|
||||
class Cases():
|
||||
def __init__(self):
|
||||
|
||||
self.arrs = [
|
||||
rand_int(2),
|
||||
rand_int(2, 3),
|
||||
rand_int(2, 3, 4),
|
||||
rand_int(2, 3, 4, 5),
|
||||
]
|
||||
|
||||
# scalars expanded across the 0th dimension
|
||||
self.scalars = [
|
||||
rand_int(),
|
||||
rand_int(1),
|
||||
rand_int(1, 1),
|
||||
rand_int(1, 1, 1),
|
||||
]
|
||||
|
||||
# arrays with last dimension aligned
|
||||
self.aligned_arrs = [
|
||||
rand_int(2, 3),
|
||||
rand_int(1, 4, 3),
|
||||
rand_int(5, 1, 2, 3),
|
||||
rand_int(4, 2, 1, 1, 3),
|
||||
]
|
||||
|
||||
|
||||
test_case = Cases()
|
||||
|
||||
|
||||
def mnp_inner(a, b):
|
||||
return mnp.inner(a, b)
|
||||
|
||||
|
||||
def onp_inner(a, b):
|
||||
return onp.inner(a, b)
|
||||
|
||||
|
||||
def test_inner():
|
||||
for arr1 in test_case.aligned_arrs:
|
||||
for arr2 in test_case.aligned_arrs:
|
||||
match_res(mnp_inner, onp_inner, arr1, arr2)
|
||||
|
||||
for scalar1 in test_case.scalars:
|
||||
for scalar2 in test_case.scalars:
|
||||
match_res(mnp_inner, onp_inner,
|
||||
scalar1, scalar2)
|
||||
|
||||
|
||||
# check if the output from mnp function and onp function applied on the arrays are matched
|
||||
|
||||
|
||||
def match_res(mnp_fn, onp_fn, arr1, arr2):
|
||||
actual = mnp_fn(mnp.asarray(arr1, dtype='float32'),
|
||||
mnp.asarray(arr2, dtype='float32')).asnumpy()
|
||||
expected = onp_fn(arr1, arr2)
|
||||
match_array(actual, expected)
|
||||
|
||||
|
||||
def match_array(actual, expected, error=5):
|
||||
if error > 0:
|
||||
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
|
||||
decimal=error)
|
||||
else:
|
||||
onp.testing.assert_equal(actual.tolist(), expected.tolist())
|
||||
|
||||
|
||||
def test_exception_innner():
|
||||
with pytest.raises(ValueError):
|
||||
mnp.inner(mnp.asarray(test_case.arrs[0]),
|
||||
mnp.asarray(test_case.arrs[1]))
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test astype"""
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_astype():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.astype("float16")
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res.dtype == mstype.float16
|
||||
|
||||
|
||||
def test_astype_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.int64)
|
||||
|
||||
def construct(self):
|
||||
return self.value.astype(mstype.bool_)
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res.dtype == mstype.bool_
|
||||
|
||||
|
||||
def test_astype_2():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float64)
|
||||
|
||||
def construct(self):
|
||||
return self.value.astype(mstype.uint64)
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res.dtype == mstype.uint64
|
||||
|
||||
|
||||
def test_astype_error_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.astype("float88")
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
net()
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test flatten"""
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_flatten():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.flatten()
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_flatten_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.flatten(order='F')
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_flatten_error():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.flatten(order='X')
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
net()
|
||||
|
||||
|
||||
def test_flatten_error_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.flatten(order=123)
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
net()
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test tensor properties in graph mode"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_ndim():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor(np.random.random(
|
||||
(2, 3, 4, 5)), dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.ndim
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res == 4
|
||||
|
||||
|
||||
def test_nbytes():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor(np.random.random(
|
||||
(2, 3, 4, 5)), dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.nbytes
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res == 480
|
||||
|
||||
|
||||
def test_size():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor(np.random.random(
|
||||
(2, 3, 4, 5)), dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.size
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res == 120
|
||||
|
||||
|
||||
def test_strides():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor(np.random.random(
|
||||
(2, 3, 4, 5)), dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.strides
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res == (240, 80, 20, 4)
|
||||
|
||||
|
||||
def test_itemsize():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value1 = Tensor(np.random.random(
|
||||
(2, 3, 4, 5)), dtype=mstype.float64)
|
||||
self.value2 = Tensor(np.random.random(
|
||||
(2, 3, 4, 5)), dtype=mstype.int32)
|
||||
self.value3 = Tensor(np.random.random(
|
||||
(2, 3, 4, 5)), dtype=mstype.bool_)
|
||||
|
||||
def construct(self):
|
||||
return (self.value1.itemsize, self.value2.itemsize, self.value3.itemsize)
|
||||
|
||||
net = Net()
|
||||
res = net()
|
||||
assert res == (8, 4, 1)
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test reshape"""
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_reshape():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.reshape(-1)
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_reshape_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.reshape([3, 2, 1])
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_reshape_2():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.reshape((-1, 2))
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_reshape_error():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.reshape(1, 2, 4)
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
net()
|
||||
|
||||
|
||||
def test_reshape_error_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.reshape((1, 2, 3.5))
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
net()
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test transpose"""
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_transpose():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.transpose()
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_transpose_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.transpose(1, 0)
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_transpose_2():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.transpose([1, 0])
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_transpose_3():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.transpose((1, 0))
|
||||
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_transpose_error():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.transpose(0, 2, 1)
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
net()
|
||||
|
||||
|
||||
def test_transpose_error_1():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
|
||||
|
||||
def construct(self):
|
||||
return self.value.transpose(1.0, 0)
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError):
|
||||
net()
|
Loading…
Reference in New Issue