diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 1bf9c5bd632..f247be9c887 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -595,6 +595,123 @@ class Validator: raise ValueError(f'For {prim_name}, {ori_shape} reduce on {axis} should be ' f'{tuple(exp_shape)}, but got {shape}.') + @staticmethod + def check_astype_dtype(dtype): + """Check whether dtype is a valid input, and convert to mstype""" + all_types = mstype.__dtype__ + ["int", "float", "bool"] + if isinstance(dtype, str): + if dtype.lower() not in all_types: + raise TypeError(f"`{dtype}` not understood.") + dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower())) + elif isinstance(dtype, type): + dtype = mstype.pytype_to_dtype(dtype) + elif not dtype in mstype.number_type + (mstype.bool_,): + raise TypeError(f"`{dtype}` not understood.") + return dtype + + @staticmethod + def check_transpose_axis(axes, ndim): + """Check the axis argument for tensor.transpose""" + if not axes or (len(axes) == 1 and axes[0] is None): + return tuple(range(ndim-1, -1, -1)) + + if len(axes) == 1: + perm = axes[0] + # if only one argument provided, it must be tuple or list + if isinstance(perm, list): + perm = tuple(perm) + else: + if not isinstance(perm, tuple): + raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}") + return perm + + # if multiple arguments provided, it must be `ndim` number of ints + if len(axes) != ndim: + raise ValueError("The number of axes must equal to the dimension of tensor.") + return axes + + @staticmethod + def check_reshape_shp(shp): + """Check the shape argument for tensor.reshape""" + + if len(shp) == 1: + new_shape = shp[0] + # if only one argument provided, it must be int, tuple or list + if isinstance(new_shape, int): + return shp + if isinstance(new_shape, list): + new_shape = tuple(new_shape) + else: + if not isinstance(new_shape, tuple): + raise TypeError( + f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}") + return new_shape + + return shp + + @staticmethod + def check_flatten_order(order): + """Check flatten function input order""" + if not isinstance(order, str): + raise TypeError(f"The order variable should be a string, but got {type(order)}") + if order not in ('C', 'F'): + raise ValueError(f"only `C` and `F` are supported as order, but got {order}") + return order + + @staticmethod + def check_swapaxes_axis(axes, ndim): + """Check all the axes argument for tensor.swapaxes""" + if isinstance(axes, int): + check_axis_in_range(axes, ndim) + return axes % ndim + if isinstance(axes, (tuple, list)): + for axis in axes: + if not isinstance(axis, int): + raise TypeError(f"axis argument should be integer, but got {type(axis)}.") + check_axis_in_range(axis, ndim) + axes = tuple(map(lambda x: x % ndim, axes)) + return axes + raise TypeError(f"axes should be integer, list or tuple for check, but got {type(axes)}.") + + @staticmethod + def prepare_shape_for_squeeze(shape, axes): + """ + Creates the squeezed new shape based on the tensor and given axes. + + Args: + shape (tuple): the shape of the tensor + axes Union[int, tuple(int), list(int)]: the axes with dimensions need to + be squeezed. + + Returns: + new_shape(tuple): the shape with dimensions squeezed. + """ + new_shape = [] + ndim = len(shape) + + # Convert to set + if isinstance(axes, int): + if axes >= ndim or axes < -ndim: + raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}") + axes = {axes} + + elif isinstance(axes, (list, tuple)): + for axis in axes: + if axis >= ndim or axis < -ndim: + raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {ndim}") + axes = set(axes) + + else: + raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}") + + for idx, s in enumerate(shape): + if s != 1 or (idx not in axes) and (idx - ndim not in axes): + new_shape.append(s) + # if an axis is selected with shape entry greater than one, an error is raised. + if s != 1 and ((idx in axes) or (idx - ndim in axes)): + raise ValueError(f"axis {axes} has shape entry {s} > 1, cannot be squeezed.") + return tuple(new_shape) + def check_input_format(input_param): """Judge input format.""" @@ -623,6 +740,13 @@ def _expand_tuple(n_dimensions): return convert +def check_axis_in_range(axis, ndim): + """Checks axes are with the bounds of ndim""" + if -ndim <= axis < ndim: + return True + raise ValueError(f'axis {axis} is out of bounds for tensor of dimension {ndim}') + + def _check_data_type_valid(data, valid_type): """Check data type valid.""" if valid_type is None: diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index e30777b7f7b..000d411d18b 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -1,6 +1,6 @@ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,13 @@ # limitations under the License. # ============================================================================ """standard_method""" + from dataclasses import dataclass from mindspore import Tensor from mindspore import dtype as mstype + +from ..._checkparam import Validator as validator from ...ops import functional as F from ...ops import operations as P from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \ @@ -28,14 +31,17 @@ from ...ops.primitive import constexpr __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] -trans = P.Transpose() shape_ = P.Shape() -reshape_ = P.Reshape() dtype_ = P.DType() abs_ = P.Abs() ndim_ = P.Rank() size_ = P.Size() +itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1, + mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2, + mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4, + mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8} + def mean(x, axis=(), keep_dims=False): """ Reduces a dimension of a tensor by averaging all elements in the dimension. @@ -93,23 +99,150 @@ def any_(x, axis=(), keep_dims=False): return reduce_any(x, axis) +def itemsize_(x): + """ + Return length of one tensor element in bytes. + + Args: + x (Tensor): Input tensor. + + Returns: + itemsize(int). + """ + return get_itemsize(x.dtype) + + +def nbytes_(x): + """ + Return total number of bytes taken by the tensor. + + Args: + x (Tensor): Input tensor. + + Returns: + nbytes(int). + """ + return itemsize_(x) * F.shape_mul(shape_(x)) + + +def strides_(x): + """ + Return the tuple of bytes to step in each dimension when traversing a tensor. + + Args: + x (Tensor): Input tensor. + + Returns: + strides (tuple[int]). + """ + strides = () + ndim = P.Rank()(x) + tensor_shape = shape_(x) + for i in F.make_range(0, ndim): + stride = itemsize_(x) + for j in F.make_range(i + 1, ndim): + stride *= tensor_shape[j] + strides += (stride,) + return strides + + +def astype(x, dtype, copy=True): + """Implementation of `astype`.""" + dtype = check_astype_dtype_const(dtype) + if not copy and dtype == x.dtype: + return x + return F.cast(x, dtype) + + def transpose(x, *axis): """Implementation of `transpose`.""" - new_order = None + ndim = F.rank(x) + perm = check_transpose_axis_const(axis, ndim) + return F.transpose(x, perm) + + +# `tensor.T` is used as a property in graph mode +T_ = transpose + + +def reshape(x, *shape): + """Implementation of `reshape`.""" + new_shape = check_reshape_shp_const(shape) + return F.reshape(x, new_shape) + + +def ravel(x): + """Implementation of `ravel`.""" + return reshape(x, (-1,)) + + +def flatten(x, order='C'): + """ + Returns a copy of the array collapsed into one dimension. + + Args: + order (str, optional): Can choose between `C` and `F`. `C` means to + flatten in row-major (C-style) order. ‘F’ means to flatten in column-major + (Fortran- style) order. Only `C` and `F` are supported. + + Returns: + Tensor, has the same data type as x. + """ + order = check_flatten_order_const(order) + if order == 'C': + return F.reshape(x, (-1,)) + + perm = F.make_range(0, F.rank(x)) + new_order = F.tuple_reversed(perm) + return F.reshape(F.transpose(x, new_order), (-1,)) + + +def swapaxes(x, axis1, axis2): + """ + Interchanges two axes of a tensor. + + Args: + axis1 (int): First axis. + axis2 (int): Second axis. + + Returns: + Transposed tensor, has the same data type as the original tensor x. + """ + axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim) + + if axis1 == axis2: + return x + if axis1 > axis2: + axis1, axis2 = axis2, axis1 + + perm = F.make_range(0, x.ndim) + new_perm = None + if axis2 + 1 < x.ndim: + new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ + perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] + else: + new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ + perm[axis1+1:axis2] + perm[axis1:axis1+1] + + return F.transpose(x, new_perm) + + +def squeeze(x, axis=None): + """ + Removes single-dimensional entries from the shape of an tensor. + + Args: + axis: Union[None, int, list(int), tuple(list)]. Default is None. + + Returns: + Tensor, with all or a subset of the dimensions of length 1 removed. + """ shape = F.shape(x) - length = F.tuple_len(shape) - if not axis: - perm = F.make_range(0, length) - new_order = F.tuple_reversed(perm) - - elif len(axis) == 1: - new_order = convert_list_to_tuple(axis[0]) - - elif len(axis) == length: - new_order = axis - - out = trans(x, new_order) - return out + if axis is None: + return F.squeeze(x) + # yield squeezed shape based on the axes + new_shape = prepare_shape_for_squeeze_const(shape, axis) + return F.reshape(x, new_shape) def getitem(data, item): @@ -200,7 +333,7 @@ def expand_tensor_as(x, y): def view(x, *shape): """Reshape tensor, if shape is -1, reshape tensor into one dimension""" shape = check_view_shape(shape) - return reshape_(x, shape) + return F.reshape(x, shape) def isinstance_(x, base_type): @@ -240,6 +373,12 @@ def check_type_same(x_type, base_type): raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'") +@constexpr +def get_itemsize(x_type): + """get itemsize from tensor's dtype.""" + return itemsize_map[x_type] + + @constexpr def check_is_tensor(x): """check whether x is tensor.""" @@ -298,14 +437,14 @@ def check_view_shape(x): x = x[0] return x -@constexpr -def convert_list_to_tuple(shp): - """Check the type of the shape, if is list, convert to tuple""" - if not isinstance(shp, (list, tuple)): - raise ValueError(f"The shape variable should be a list or tuple, but got {type(shp)}") - if isinstance(shp, list): - shp = tuple(shp) - return shp + +# convert noraml param_check functions to constexpr functions +check_astype_dtype_const = constexpr(validator.check_astype_dtype) +check_transpose_axis_const = constexpr(validator.check_transpose_axis) +check_reshape_shp_const = constexpr(validator.check_reshape_shp) +check_flatten_order_const = constexpr(validator.check_flatten_order) +check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis) +prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze) def tensor_bool(x): """tensor as conditon, if is constant, return immediate bool value""" diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 8ad5b228b4f..9d39f62f013 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -178,6 +178,12 @@ BuiltInTypeMap &GetMethodMap() { {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, {"transpose", std::string("transpose")}, // P.transpose + {"flatten", std::string("flatten")}, // P.reshape(,-1) + {"reshape", std::string("reshape")}, // P.reshape() + {"ravel", std::string("ravel")}, // P.reshape(,(-1,)) + {"swapaxes", std::string("swapaxes")}, // P.transpose() + {"squeeze", std::string("squeeze")}, // P.squeeze() + {"astype", std::string("astype")}, // P.cast() {"__bool__", std::string("tensor_bool")}, // C.tensor_bool }}, {kObjectTypeJTagged, {}}, @@ -190,10 +196,14 @@ BuiltInTypeMap &GetAttrMap() { static BuiltInTypeMap attr_map = { {kObjectTypeTensorType, { - {"shape", std::string("shape_")}, // C.shape_ - {"dtype", std::string("dtype_")}, // C.dtype_ - {"size", std::string("size_")}, // C.size_ - {"ndim", std::string("ndim_")}, // C.ndim_ + {"shape", std::string("shape_")}, // C.shape_ + {"dtype", std::string("dtype_")}, // C.dtype_ + {"size", std::string("size_")}, // C.size_ + {"ndim", std::string("ndim_")}, // C.ndim_ + {"T", std::string("T_")}, // C.T_ + {"itemsize", std::string("itemsize_")}, // C.itemsize_ + {"nbytes", std::string("nbytes_")}, // C.nbytes_ + {"strides", std::string("strides_")}, // C.strides_ }}, {kObjectTypeRowTensorType, { diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index c8a4a2a9492..3dd5be92cdd 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -258,6 +258,20 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { return dims; } +py::tuple TensorPy::GetPyTupleStrides(const Tensor &tensor) { + std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end()); + std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize()); + py::tuple py_strides(strides.size()); + for (size_t i = 0; i < strides.size(); ++i) { + py_strides[i] = py::int_(strides[i]); + } + return py_strides; +} + +py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().itemsize(); } + +py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); } + py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { { py::gil_scoped_release gil_release; @@ -383,6 +397,40 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { >>> data.size 6 )mydelimiter") + .def_property_readonly("_itemsize", TensorPy::GetPyItemSize, R"mydelimiter( + Get the tensor's length of one element in bytes. + + Returns: + itemsize, length of one element in bytes. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) + >>> data.itemsize + 4 + )mydelimiter") + .def_property_readonly("_nbytes", TensorPy::GetPyNBytes, R"mydelimiter( + Get the tensor's total number of bytes. + + Returns: + nbytes, total number of bytes taken by the tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) + >>> data.nbytes + 4 + )mydelimiter") + .def_property_readonly("_strides", TensorPy::GetPyTupleStrides, R"mydelimiter( + Get the tensor's tuple of bytes to step in each dimension + when traversing an array. + + Returns: + tuple[int], the strides of the tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) + >>> data.strides + (4, 4) + )mydelimiter") .def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter( Creates a Tensor from a numpy.ndarray without copy. diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.h b/mindspore/ccsrc/pybind_api/ir/tensor_py.h index a091edc3922..647ea282bbd 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.h +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.h @@ -109,6 +109,12 @@ class TensorPy { static py::array AsNumpy(const Tensor &tensor); static py::tuple GetPyTupleShape(const Tensor &tensor); + + static py::tuple GetPyTupleStrides(const Tensor &tensor); + + static py::int_ GetPyItemSize(const Tensor &tensor); + + static py::int_ GetPyNBytes(const Tensor &tensor); }; } // namespace tensor } // namespace mindspore diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 2add3237ce4..5e8cbbc857c 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -267,6 +267,26 @@ class Tensor(Tensor_): """tensor is inited.""" return self.init is not None + @property + def itemsize(self): + """The length of one tensor element in bytes.""" + return self._itemsize + + @property + def strides(self): + """The tuple of bytes to step in each dimension when traversing a tensor.""" + return self._strides + + @property + def nbytes(self): + """The total number of bytes taken by the tensor.""" + return self._nbytes + + @property + def T(self): + """The transposed tensor.""" + return self.transpose() + @property def virtual_flag(self): """Mark tensor is virtual.""" @@ -384,6 +404,145 @@ class Tensor(Tensor_): axis = () return tensor_operator_registry.get('mean')(keep_dims)(self, axis) + def transpose(self, *axes): + """ + Returns a view of the array with axes transposed. + + For a 1-D array this has no effect, as a transposed vector is simply the + same vector. For a 2-D array, this is a standard matrix transpose. For an + n-D array, if axes are given, their order indicates how the axes are permuted + (see Examples). If axes are not provided and a.shape = (i[0], i[1],... + i[n-2], i[n-1]), then a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0]). + + Args: + axes(Union[None, tuple(int), list(int), n ints], optional): + None or no argument: reverses the order of the axes. + Tuple of ints: i in the j-th place in the tuple means a’s i-th + axis becomes a.transpose()’s j-th axis. + n ints: this form is intended simply as a `convenience alternative + to the tuple form. + + Returns: + Tensor, has the same dimension as input tensor, with axes suitably permuted. + """ + perm = validator.check_transpose_axis(axes, self.ndim) + return tensor_operator_registry.get('transpose')()(self, perm) + + def reshape(self, *shape): + """ + Gives a new shape to an array without changing its data. + + Args: + shape(Union[int, tuple(int), list(int)]): The new shape should be compatible + with the original shape. If an integer, then the result will be a 1-D + array of that length. One shape dimension can be -1. In this case, the + value is inferred from the length of the array and remaining dimensions. + + Returns: + reshaped_tensor(Tensor): This will be a new view object if possible; + otherwise, it will be a copy. + """ + new_shape = validator.check_reshape_shp(shape) + return tensor_operator_registry.get('reshape')()(self, new_shape) + + def ravel(self): + """ + Returns a contiguous flattened tensor. + A 1-D tensor, containing the elements of the input, is returned. + + Returns: + Tensor, has the same data type as x. + """ + reshape_op = tensor_operator_registry.get('reshape')() + return reshape_op(self, (-1,)) + + def flatten(self, order='C'): + """ + Returns a copy of the array collapsed into one dimension. + + Args: + order (str, optional): Can choose between `C` and `F`. `C` means to + flatten in row-major (C-style) order. ‘F’ means to flatten in column-major + (Fortran- style) order. Only `C` and `F` are supported. + + Returns: + Tensor, has the same data type as x. + """ + reshape_op = tensor_operator_registry.get('reshape')() + trans_op = tensor_operator_registry.get('transpose')() + + order = validator.check_flatten_order(order) + if order == 'C': + return reshape_op(self, (-1,)) + + perm = tuple(range(self.ndim-1, -1, -1)) + return reshape_op(trans_op(self, perm), (-1,)) + + def swapaxes(self, axis1, axis2): + """ + Interchanges two axes of a tensor. + + Args: + axis1 (int): First axis. + axis2 (int): Second axis. + + Returns: + Transposed tensor, has the same data type as the original tensor x. + """ + axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim) + + if axis1 == axis2: + return self + if axis1 > axis2: + axis1, axis2 = axis2, axis1 + + perm = tuple(range(0, self.ndim)) + new_perm = None + if axis2 + 1 < self.ndim: + new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ + perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] + else: + new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ + perm[axis1+1:axis2] + perm[axis1:axis1+1] + + return tensor_operator_registry.get('transpose')()(self, new_perm) + + def squeeze(self, axis=None): + """ + Removes single-dimensional entries from the shape of an tensor. + + Args: + axis: Union[None, int, list(int), tuple(list)]. Default is None. + + Returns: + Tensor, with all or a subset of the dimensions of length 1 removed. + """ + if axis is None: + return tensor_operator_registry.get('squeeze')(self) + new_shape = validator.prepare_shape_for_squeeze(self.shape, axis) + return tensor_operator_registry.get('reshape')()(self, new_shape) + + def astype(self, dtype, copy=True): + """ + Returns a copy of the array, cast to a specified type. + + Args: + dtype(Union[mstype.dtype, numpy.dtype, str]): Designated tensor dtype, + can be in format of np.float32, mstype.float32 or `float32`. Default + is mstype.float32. + + copy(bool, optional): By default, astype always returns a newly allocated + tensor. If this is set to false, the input tensor is returned instead + of a copy if possible. + + Returns: + Tensor, with the designated dtype. + """ + dtype = validator.check_astype_dtype(dtype) + if not copy and dtype == self.dtype: + return self + return tensor_operator_registry.get('cast')(self, dtype) + def init_data(self, slice_index=None, shape=None, opt_shard_group=None): """ diff --git a/mindspore/numpy/__init__.py b/mindspore/numpy/__init__.py index 95bae77a727..add6582f530 100644 --- a/mindspore/numpy/__init__.py +++ b/mindspore/numpy/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,22 +25,33 @@ Note: - random/ defines all the random operations. """ -from .array_ops import (array, asarray, asfarray, ones, zeros, full, arange, - linspace, logspace, eye, identity, transpose, expand_dims, - squeeze, rollaxis, swapaxes, reshape, ravel, concatenate) -from .array_ops import copy_ as copy +from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape, + ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d, + column_stack, hstack, dstack, vstack, stack, unique) +from .array_creations import copy_ as copy +from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, + linspace, logspace, eye, identity, empty, empty_like, + ones_like, zeros_like, full_like, diagonal, tril, triu, + tri, trace) from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float_, float16, float32, float64, bool_, inf, numeric_types) -from .math_ops import mean, inner +from .math_ops import (mean, inner, add, subtract, multiply, divide, power, + dot, outer, tensordot, absolute) -array_ops_module = ['array', 'asarray', 'asfarray', 'copy', 'ones', 'zeros', 'arange', - 'linspace', 'logspace', 'eye', 'identity', 'transpose', 'expand_dims', - 'squeeze', 'rollaxis', 'swapaxes', 'reshape', 'ravel', 'concatenate'] +array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape', + 'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d', + 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique'] -math_module = ['mean', 'inner'] +array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange', + 'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like', + 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu', + 'tri', 'trace'] -__all__ = array_ops_module + math_module + numeric_types +math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'power', + 'dot', 'outer', 'tensordot', 'absolute'] + +__all__ = array_ops_module + array_creations_module + math_module + numeric_types __all__.sort() diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py new file mode 100644 index 00000000000..0fe486dbd4b --- /dev/null +++ b/mindspore/numpy/array_creations.py @@ -0,0 +1,1196 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""array operations, the function docs are adapted from Numpy API.""" +from copy import copy as py_copy +from itertools import groupby + +import numpy as onp + +from ..common import Tensor +from ..common import dtype as mstype +from ..ops import functional as F +from ..ops.primitive import constexpr +from ..nn.layer.basic import tril as nn_tril +from ..nn.layer.basic import triu as nn_triu +from .._c_expression import Tensor as Tensor_ +from .._c_expression.typing import Float + +from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \ + _expand, _broadcast_to, _is_empty +from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, _check_same_type, \ + _check_shape_contain_zero, _check_shape, _check_dtype +from .array_ops import transpose + +# According to official numpy reference, the dimension of a numpy array must be less +# than 32 +MAX_NUMPY_DIMS = 32 + + +def array(obj, dtype=None, copy=True, ndmin=0): + """ + Creates a tensor. + + This function creates tensors from an array-like object. + + Args: + obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in + any form that can be converted to a tensor. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.int32, or `int32`. If dtype is None, the data type + of the new tensor will be inferred from obj. Default is None. + copy (bool): If true, then the object is copied. Otherwise, a copy will + only be made if necessary. Default: True. + ndmin (int): Specifies the minimum number of dimensions that the resulting + tensor should have. Ones will be pre-pended to the shape as needed to + meet this requirement. Default: 0 + + Returns: + Tensor, generated tensor with the specified dtype. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input `obj` has different sizes at different dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.array([1,2,3])) + [1 2 3] + """ + if ndmin > 0: + # Fall back to original numpy creation. + if isinstance(obj, Tensor): + obj = obj.asnumpy() + return asarray(onp.array(obj, dtype, copy=copy, ndmin=ndmin)) + + if not copy: + return asarray(obj, dtype=dtype) + + obj = py_copy(obj) + return asarray(obj, dtype=dtype) + + +def asarray(a, dtype=None): + """ + Converts the input to tensor. + + This function converts tensors from an array-like object. + + Args: + a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in + any form that can be converted to a tensor. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and ndarrays. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.int32, or `int32`. If dtype is None, the data type + of the new tensor will be inferred from a. Default is None. + + Returns: + Tensor, generated tensor with the specified dtype. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input `a` has different sizes at different dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.asarray([1,2,3])) + [1 2 3] + """ + + if dtype is not None: + dtype = _check_dtype(dtype) + + _ = _check_input_for_asarray(a) + + if isinstance(a, float) and (dtype is None): + dtype = mstype.float32 + + if isinstance(a, int) and not isinstance(a, bool) and (dtype is None): + dtype = mstype.int32 + + if isinstance(a, bool) and (dtype is None): + dtype = mstype.bool_ + + if isinstance(a, (list, tuple)): + # Convert all tuple/nested tuples to lists + a = _deep_list(a) + # Convert all tensor sub-elements to numpy arrays + a = _deep_tensor_to_nparray(a) + a = onp.asarray(a) + if a.dtype is onp.dtype('object'): + raise ValueError('Input array must have the same size across all dimensions.') + # If dtype is not specified, we keep consistent with numpy decision + # only exceptions are: we use int/float32 + if dtype is None: + if a.dtype is onp.dtype('int64'): + dtype = mstype.int32 + elif a.dtype is onp.dtype('float64'): + dtype = mstype.float32 + + if isinstance(a, onp.ndarray) and dtype is None: + if a.dtype is onp.dtype('bool'): + dtype = mstype.bool_ + elif a.dtype is onp.dtype('int'): + dtype = mstype.int32 + elif a.dtype is onp.dtype('float'): + dtype = mstype.float32 + elif a.dtype is onp.dtype('object'): + raise TypeError(f"For Tensor convertion, the input_data is {a} that contains unsupported element.") + a = Tensor.from_numpy(a) + + # If a is already a tensor and we don't need to cast dtype, return a + if isinstance(a, Tensor): + if dtype is None: + return a + dtype = _check_dtype(dtype) + if dtype == a.dtype: + return a + + return Tensor(a, dtype=dtype) + + +def asfarray(a, dtype=mstype.float32): + """ + Similar to asarray, converts the input to a float tensor. + + If non-float dtype is defined, this function will return a float32 tensor instead. + + Args: + a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in + any form that can be converted to a tensor. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`. Default is mstype.float32. + + Returns: + Tensor, generated tensor with the specified float dtype. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input `a` has different sizes at different dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.asfarray([1,2,3])) + [1. 2. 3.] + """ + dtype = _check_dtype(dtype) + _ = _check_input_for_asarray(a) + + if dtype not in (mstype.float16, mstype.float32, mstype.float64): + dtype = mstype.float32 + + if isinstance(a, (list, tuple)): + # Convert all tuple/nested tuples to lists + a = _deep_list(a) + # Convert all tensor sub-elements to numpy arrays + a = _deep_tensor_to_nparray(a) + a = onp.asarray(a) + if a.dtype is onp.dtype('object'): + raise TypeError(f"For Tensor convertion, the input_data is {a} that contains unsupported element.") + if isinstance(a, onp.ndarray): + a = Tensor.from_numpy(a) + + return Tensor(a, dtype) + + +def copy_(a): + """ + Returns a tensor copy of the given object. + + Args: + a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in + any form that can be converted to a tensor. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + + Returns: + Tensor, has the same data as `a`. + + Raises: + TypeError: If input `a` has type not specified above. + ValueError: If input `a` has different sizes at different dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.ones((2,2)) + >>> print(np.copy(x)) + [[1. 1.] + [1. 1.]] + """ + if not isinstance(a, Tensor): + a = asarray(a) + return py_copy(a) + + +@constexpr +def _fill(shape, value, dtype): + """Original numpy.full function.""" + return Tensor(onp.full(shape, value), dtype) + + +def ones(shape, dtype=mstype.float32): + """ + Returns a new tensor of given shape and type, filled with ones. + + Args: + shape (Union[int, tuple, list]): the shape of the new tensor. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`. Default is mstype.float32. + + Returns: + Tensor, with the designated shape and dtype, filled with ones. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.ones((2,2))) + [[1. 1.] + [1. 1.]] + """ + shape = _check_shape(shape) + dtype = _check_dtype(dtype) + if _check_shape_contain_zero(shape): + return _fill(shape, 1.0, dtype) + output = F.fill(dtype, shape, 1) + return output + + +def zeros(shape, dtype=mstype.float32): + """ + Returns a new tensor of given shape and type, filled with zeros. + + Args: + shape (Union[int, tuple, list]): the shape of the new tensor. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`. Default is mstype.float32. + + Returns: + Tensor, with the designated shape and dtype, filled with zeros. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.zeros((2,2))) + [[0. 0.] + [0. 0.]] + """ + shape = _check_shape(shape) + dtype = _check_dtype(dtype) + if _check_shape_contain_zero(shape): + return _fill(shape, 0.0, dtype) + output = F.fill(dtype, shape, 0) + return output + + +def full(shape, fill_value, dtype=None): + """ + Returns a new tensor of given shape and type, filled with fill_value. + + Args: + shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g., + (2, 3) or 2. + fill_value (Union[int, float, bool, list, tuple]): scalar or array_like + fill value. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`, if dtype is None, the data type + of the new tensor will be inferred from fill_value. Default is None. + + Returns: + Tensor, with the designated shape and dtype, filled with `fill_value`. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.full((2,2), True)) + [[True True] + [True True]] + """ + if dtype is None: + dtype = array(fill_value).dtype + + shape = _check_shape(shape) + _ = _check_input_for_asarray(fill_value) + dtype = _check_dtype(dtype) + + if isinstance(fill_value, (int, float, bool)) and not _check_shape_contain_zero(shape): + return F.fill(dtype, shape, fill_value) + + # if fill_value is array_like or shape contains zero. fall back to original + # numpy creation + return Tensor(onp.full(shape, fill_value, mstype.dtype_to_nptype(dtype))) + + +def arange(*args, **kwargs): + """ + Returns evenly spaced values within a given interval. + + Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. + The endpoint of the interval can optionally be excluded. + The current implementation is a direct wrapper on top of numpy.arange, except that + the default dtype is float32 and int32, compare to float64 and int64 for numpy + implementation. + + Args: + start(Union[int, float]): Start of interval. The interval includes this value. + When stop is provided as a position argument, start must be given, when stop + is a normal argument, start can be optional, and default is 0. + Please see additional examples below. + stop(Union[int, float], optional): End of interval. The interval does not + include this value, except in some cases where step is not an integer + and floating point round-off affects the length of out. + step(Union[int, float], optional): Spacing between values. For any output + out, this is the distance between two adjacent values, out[i+1] - out[i]. + The default step size is 1. If step is specified as a position argument, + start must also be given. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`. If dtype is None, the data type + of the new tensor will be inferred from start, stop and step. Default is None. + + Returns: + arangend tensor of evenly spaced values. + + Raises: + TypeError: If input arguments have types not specified above, or arguments are + not given in the correct orders specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.arange(0, 5, 1)) + [0 1 2 3 4] + >>> print(np.arange(3)) + [0 1 2] + >>> print(np.arange(start=0, stop=3)) + [0 1 2] + >>> print(np.arange(0, stop=3, step=0.5)) + [0. 0.5 1. 1.5 2. 2.5] + >>> print(np.arange(stop=3)) # This will lead to TypeError + """ + # infer the dtype, if either of start, end, step is float, default dtype is + # float32, else int32. + int_flag = True + final_dtype = None + + if args: + for item in args: + if isinstance(item, float): + int_flag = False + if kwargs: + if ('start' in kwargs and isinstance(kwargs['start'], float)) or \ + ('stop' in kwargs and isinstance(kwargs['stop'], float)) or \ + ('step' in kwargs and isinstance(kwargs['step'], float)): + int_flag = False + + if int_flag: + final_dtype = onp.int32 + else: + final_dtype = onp.float32 + + if 'dtype' in kwargs and kwargs['dtype'] is not None: + final_dtype = _check_dtype(kwargs['dtype']) + final_dtype = mstype.dtype_to_nptype(final_dtype) + kwargs['dtype'] = final_dtype + out = onp.arange(*args, **kwargs) + out = Tensor.from_numpy(out) + return out + + +def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): + """ + Returns evenly spaced values within a given interval. + + The current implementation is a direct wrapper on top of numpy.linspace, except + the default dtype is float32, compare to float64 for numpy, + + Args: + start (Union[int, list(int), tuple(int),tensor]):The starting value of the sequence. + stop (Union[int, list(int), tuple(int),tensor]):The end value of the sequence, + unless `endpoint` is set to False. In that case, the sequence consists + of all but the last of ``num + 1` evenly spaced samples, so that `stop` + is excluded. Note that the step size changes when `endpoint` is False. + num (int, optional): Number of samples to generate. Default is 50. + endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is + not included. Default is True. + retstep (bool, optional): If True, return (`samples`, `step`), where `step` is + the spacing between samples. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`.If `dtype` is None, infer the data + type from other input arguments. Default is None. + axis (int, optional): The axis in the result to store the samples. Relevant + only if start or stop are array-like. By default (0), the samples will + be along a new axis inserted at the beginning. Use -1 to get an axis at the end. + Default is 0. + + Returns: + samples (Tensor): There are `num` equally spaced samples in the closed interval + ``[start, stop]`` or the half-open interval ``[start, stop)`` + (depending on whether `endpoint` is True or False). + + step (float, optional): Only returned if `retstep` is True. + Size of spacing between samples. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.linspace(0, 5, 6)) + [0. 1. 2. 3. 4. 5.] + """ + + if isinstance(start, Tensor): + start = start.asnumpy() + + if isinstance(stop, Tensor): + stop = stop.asnumpy() + + if not isinstance(num, int): + raise TypeError(f"num should be an integer, but got {type(num)}") + + final_dtype = None + if dtype is not None: + final_dtype = _check_dtype(dtype) + final_dtype = mstype.dtype_to_nptype(final_dtype) + else: + final_dtype = onp.float32 + + dtype = final_dtype + out = onp.linspace(start, stop, num, endpoint, retstep, dtype, axis) + + if retstep: + array_out, step_out = out[0], out[1] + tensor_out = Tensor(array_out) + return tensor_out, step_out + + tensor_out = Tensor.from_numpy(out) + return tensor_out + + +def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): + """ + Returns numbers spaced evenly on a log scale. + + In linear space, the sequence starts at base ** start (base to the power of + start) and ends with base ** stop (see endpoint below). + The current implementation is a direct wrapper on top of numpy.logspace, except + the default dtype is float32, compare to float64 for numpy, + + Args: + start (Union[int, list(int), tuple(int), tensor]):The starting value of the sequence. + stop (Union[int, list(int), tuple(int), tensor]):The end value of the sequence, + unless `endpoint` is set to False. In that case, the sequence consists + of all but the last of ``num + 1` evenly spaced samples, so that `stop` + is excluded. Note that the step size changes when `endpoint` is False. + num (int, optional): Number of samples to generate. Default is 50. + endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is + not included. Default is True. + base (Union[int, float], optional): The base of the log space. The step size + between the elements in ln(samples) / ln(base) (or log_base(samples)) + is uniform. Default is 10.0. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`.If `dtype` is None, infer the data + type from other input arguments. Default is None. + axis (int, optional): The axis in the result to store the samples. Relevant + only if start or stop is array-like. By default (0), the samples will + be along a new axis inserted at the beginning. Use -1 to get an axis at the end. + Default is 0. + + Returns: + samples (Tensor): num samples, equally spaced on a log scale. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.logspace(0, 5, 6, base=2.0)) + [ 1. 2. 4. 8. 16. 32.] + """ + + if isinstance(start, Tensor): + start = start.asnumpy() + + if isinstance(stop, Tensor): + stop = stop.asnumpy() + + final_dtype = None + if dtype is not None: + final_dtype = _check_dtype(dtype) + final_dtype = mstype.dtype_to_nptype(final_dtype) + else: + final_dtype = onp.float32 + + dtype = final_dtype + out = onp.logspace(start, stop, num, endpoint, base, dtype, axis) + + tensor_out = Tensor.from_numpy(out) + return tensor_out + + +def eye(N, M=None, k=0, dtype=mstype.float32): + """ + Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere. + + Args: + N (int): Number of rows in the output, must be larger than 0. + M (int, optional): Number of columns in the output. If None, defaults to N, + if defined, must be larger than 0. Deault is None. + k (int, optional): Index of the diagonal: 0 (the default) refers to the main + diagonal, a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. Default is 0. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`. Default is mstype.float32. + + Returns: + result (Tensor): A tensor of shape (N,M). A tensor where all elements + are equal to zero, except for the k-th diagonal, whose values are equal to one. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.eye(2, 2)) + [[1. 0.] + [0. 1.]] + """ + dtype = _check_dtype(dtype) + if M is None: + M = N + if not (isinstance(M, int) and isinstance(N, int) and isinstance(k, int)): + raise TypeError("Input tensor dimensions should be integers.") + out = None + if k != 0 or N == 0 or M == 0: + # Fall back to original numpy creation method + out = onp.eye(N, M, k) + else: + out = F.eye(N, M, dtype) + return asarray(out, dtype=dtype) + + +def identity(n, dtype=mstype.float32): + """ + Returns the identity tensor. + + Args: + n (int): Number of rows and columns in the output, must be larger than 0. + dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`. Default is mstype.float32. + + Returns: + result (Tensor): A tensor of shape (n,n). A tensor where all elements + are equal to zero, except for the diagonal, whose values are equal to one. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If input arguments have types not specified above. + + Examples: + >>> import mindspore.numpy as np + >>> print(np.identity(2)) + [[1. 0.] + [0. 1.]] + """ + dtype = _check_dtype(dtype) + return eye(n, dtype=dtype) + + +def empty(shape, dtype=mstype.float32): + """ + Returns a new array of given shape and type, without initializing + entries. + + Note: + Numpy argument order is not supported. + Object arrays are not supported. + + Args: + shape (int or tuple of int): Shape of the empty array, e.g., + (2, 3) or 2. + dtype (data-type): optional. Desired output data-type for the + array, e.g, numpy.int8. Default is numpy.float32. + + Returns: + Tensor, array of uninitialized (arbitrary) data of the given + shape and dtype. + + Raises: + TypeError: if the input shape or dtype is invalid. + + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> output = np.empty((2, 3)) + >>> print(output) + Tensor(shape=[2, 3], dtype=Float32, value= + <uninitialized>) + """ + shape = _check_shape(shape) + dtype = _check_dtype(dtype) + return Tensor_(dtype, shape) + + +def _shape_matched(fn, arr): + """Returns the matched shape of elements in arr""" + shapes_all = groupby(map(fn, arr)) + shape = next(shapes_all)[0] + if next(shapes_all, False): + return _raise_value_error('Input array must have the same size across a dimension.') + return shape + + +def _get_shape(array_like): + """Returns the shape of the array like object by recursion.""" + if isinstance(array_like, Tensor): + return F.shape(array_like) + if isinstance(array_like, onp.ndarray): + return array_like.shape + if isinstance(array_like, (list, tuple)): + shape = _shape_matched(_get_shape, array_like) + return (len(array_like),) + shape + return () + + +def _get_dtype(array_like): + """Returns the data type of the array like object.""" + if isinstance(array_like, Tensor): + return F.dtype(array_like) + if isinstance(array_like, onp.ndarray): + return mstype.pytype_to_dtype(array_like.dtype) + if isinstance(array_like, (list, tuple)): + return asarray(array_like).dtype + return mstype.float32 + + +def _x_like(prototype, dtype, shape, constructor, fill_value=None): + """ + Returns a tensor with the same shape and type as prototype, + using constructor. + """ + _ = _check_input_for_asarray(prototype) + dtype_out = dtype + shape_out = shape + if not dtype_out: + dtype_out = _get_dtype(prototype) + if not shape_out and shape_out != 0: + shape_out = _get_shape(prototype) + if fill_value is not None: + return constructor(shape_out, fill_value, dtype_out) + return constructor(shape_out, dtype_out) + + +def empty_like(prototype, dtype=None, shape=None): + """ + Returns a new array with the same shape and type as a given array. + + Note: + Since list or tuple arrays are not supported, input array + must have the same size across a dimension. + If prototype is not a Tensor or a numpy array, dtype is + float32 by default if not provided. + + Args: + prototype (array_like): The shape and data-type of prototype + define these same attributes of the returned array. + dtype (data-type): optional. Overrides the data type of the + result. + shape (int or sequence of ints): optional. Overrides the shape + of the result. + + Returns: + Tensor, array of uninitialized (arbitrary) data with the same + shape and type as prototype. + + Raises: + ValueError: if prototype does not have the same shape across each + dimension. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = [[(1, 2)], onp.ones((1, 2)), [[2, 3]], onp.ones((1, 2))] + >>> output = np.empty_like(a) + >>> print(output) + Tensor(shape=[4, 1, 2], dtype=Float32, value= + <uninitialized>) + """ + return _x_like(prototype, dtype, shape, empty) + + +def ones_like(a, dtype=None, shape=None): + """ + Returns an array of ones with the same shape and type as a given array. + + Note: + Since list or tuple arrays are not supported, input array + must have the same size across a dimension. + If a is not a Tensor or a numpy array, dtype is float32 by default + if not provided. + + Args: + a (array_like): The shape and data-type of a define these same + attributes of the returned array. + dtype (data-type): optional. Overrides the data type of the + result. + shape (int or sequence of ints): optional. Overrides the shape + of the result. + + Returns: + Tensor, array of ones with the same shape and type as a. + + Raises: + ValueError: if prototype does not have the same shape across each + dimension. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] + >>> output = np.ones_like(a) + >>> print(output) + [[[1. 1.]] + + [[1. 1.]] + + [[1. 1.]] + + [[1. 1.]]] + """ + return _x_like(a, dtype, shape, ones) + + +def zeros_like(a, dtype=None, shape=None): + """ + Returns an array of zeros with the same shape and type as a given array. + + Note: + Since list or tuple arrays are not supported, input array + must have the same size across a dimension. + If a is not a Tensor or a numpy array, dtype is float32 by default + if not provided. + + Args: + a (array_like): The shape and data-type of a define these same + attributes of the returned array. + dtype (data-type): optional. Overrides the data type of the + result. + shape (int or sequence of ints): optional. Overrides the shape + of the result. + + Returns: + Tensor, array of zeros with the same shape and type as a. + + Raises: + ValueError: if prototype does not have the same shape across each + dimension. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] + >>> output = np.zeros_like(a) + >>> print(output) + [[[0. 0.]] + + [[0. 0.]] + + [[0. 0.]] + + [[0. 0.]]] + """ + return _x_like(a, dtype, shape, zeros) + + +def full_like(a, fill_value, dtype=None, shape=None): + """ + Returns a full array with the same shape and type as a given array. + + Note: + Since list or tuple arrays are not supported, input array + must have the same size across a dimension. + If a is not a Tensor or a numpy array, dtype is float32 by default + if not provided. + + Args: + a (array_like): The shape and data-type of a define these same + attributes of the returned array. + fill_value (scalar): Fill value. + dtype (data-type): optional. Overrides the data type of the + result. + shape (int or sequence of ints): optional. Overrides the shape + of the result. + + Returns: + Tensor, array of fill_value with the same shape and type as a. + + Raises: + ValueError: if prototype does not have the same shape across each + dimension. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = [[(1, 2)], onp.ones((1, 2)), [[2, 3]], onp.ones((1, 2))] + >>> output = np.full_like(a, 0.5) + >>> print(output) + [[[0.5 0.5]] + + [[0.5 0.5]] + + [[0.5 0.5]] + + [[0.5 0.5]]] + """ + return _x_like(a, dtype, shape, full, fill_value=fill_value) + + +def tri(N, M=None, k=0, dtype=mstype.float32): + """ + Returns an array with ones at and below the given diagonal and zeros elsewhere. + + Args: + N(int): Number of rows in the array. + M(int, optional): Number of columns in the array. By default, M is taken + equal to N. + k(int, optional): The sub-diagonal at and below which the array is filled. + k = 0 is the main diagonal, while k < 0 is below it, and k > 0 is above. + The default is 0. + dtype(mstype.dtype, optional): Data type of the returned array. The default + is mstype.float32. + + Returns: + tri(Tensor): Tensor with shape (N, M), with its lower triangle filled with + ones and zeros elsewhere; in other words T[i,j] == 1 for j <= i + k, + 0 otherwise. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> output = np.tri(3, 3, 1) + >>> print(output) + [[1. 1. 0.] + [1. 1. 1.] + [1. 1. 1.]] + """ + if M is None: + M = N + return nn_tril((N, M), dtype, k) + + +def tril(m, k=0): + """ + Returns a lower triangle of an array. + + Returns a copy of an array with elements above the k-th diagonal zeroed. + + Args: + m(array_like): The shape and data-type of a define these same + attributes of the returned array. + k(int, optional): Diagonal above which to zero elements. k = 0 (the default) + is the main diagonal, k < 0 is below it and k > 0 is above. + + Returns: + tril(Tensor): Lower triangle of m, of same shape and data-type as m. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input m's rank < 1. + + Examples: + >>> import mindspore.numpy as np + >>> output = np.tril(np.ones((3, 3))) + >>> print(output) + [[1. 0. 0.] + [1. 1. 0.] + [1. 1. 1.]] + """ + m = asarray(m) + shape = _get_shape(m) + dtype = _get_dtype(m) + m = m.astype(mstype.float32) + assist = nn_tril(shape, mstype.float32, k) + return F.tensor_mul(assist, m).astype(dtype) + + +def triu(m, k=0): + """ + Returns an upper triangle of an array. + + Returns a copy of an array with elements above the k-th diagonal zeroed. + + Args: + m(array_like): The shape and data-type of a define these same + attributes of the returned array. + k(int, optional): Diagonal above which to zero elements. k = 0 (the default) + is the main diagonal, k < 0 is below it and k > 0 is above. + + Returns: + triu(Tensor): Lower triangle of m, of same shape and data-type as m. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input m's rank < 1. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> output = np.triu(np.ones((3, 3))) + >>> print(output) + [[1. 1. 1.] + [0. 1. 1.] + [0. 0. 1.]] + """ + m = asarray(m) + shape = _get_shape(m) + dtype = _get_dtype(m) + m = m.astype(mstype.float32) + assist = nn_triu(shape, mstype.float32, k) + return F.tensor_mul(assist, m).astype(dtype) + + +def diagonal(a, offset=0, axis1=0, axis2=1): + """ + Returns specified diagonals. + + If `a` is 2-D, returns the diagonal of a with the given offset, i.e., the + collection of elements of the form a[i, i+offset]. If `a` has more than two + dimensions, then the axes specified by axis1 and axis2 are used to determine + the 2-D sub-array whose diagonal is returned. The shape of the resulting + array can be determined by removing axis1 and axis2 and appending an index + to the right equal to the size of the resulting diagonals. + + Args: + a (Tensor): Array from which the diagonals are taken. + offset (int): optional. Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal (0). + axis1 (int): optional. Axis to be used as the first axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + first axis (0). + axis2 (int): optional. Axis to be used as the second axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + second axis (1). + + Returns: + Tensor, if `a` is 2-D, then a 1-D array containing the diagonal. If + a.ndim > 2, then the dimensions specified by axis1 and axis2 are removed, + and a new axis inserted at the end corresponding to the diagonal. + + Raises: + ValueError: if the input tensor has less than two dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.arange(4).reshape(2,2) + >>> print(a) + [[0 1] + [2 3]] + >>> output = np.diagonal(a) + >>> print(output) + [0 3] + >>> output = np.diagonal(a, 1) + >>> print(output) + [1] + >>> a = np.arange(8).reshape(2, 2, 2) + >>> print(a) + [[[0 1] + [2 3]] + + [[4 5] + [6 7]]] + >>> output = np.diagonal(a, 0, 0, 1) + >>> print(output) + [[0 6] + [1 7]] + """ + ndim = F.rank(a) + if ndim < 2: + return _raise_value_error('diagonal requires an array of at least two dimensions') + dtype = F.dtype(a) + + if _is_empty(F.shape(a)): + return _empty(dtype, (0,)) + + cast_type = dtype + if not isinstance(dtype, Float): + # reduce_sum only supports float types + cast_type = mstype.float32 + a = F.cast(a, cast_type) + + axes = _check_axis_valid((axis1, axis2), ndim) + perm = () + for i in range(ndim): + if i not in axes: + perm += (i,) + perm += axes + a = transpose(a, perm) + + shape = F.shape(a) + n, m = shape[-2:] + e = _eye(n, m, offset, cast_type) + e = _expand(e, ndim) + e = _broadcast_to(e, F.shape(e), F.shape(a), ndim) + + prod = F.tensor_mul(a, e) + res = F.reduce_sum(prod, -1) + + begin = () + for i in range(ndim-2): + begin += (0,) + last_dim_begin = _max(0, -offset) + begin += (last_dim_begin,) + size = F.shape(res)[:-1] + last_dim_end = _min( + shape[-2], _max(0, shape[-1] - offset)) - last_dim_begin + if last_dim_end <= 0: + return _empty(dtype, size + (0,)) + size += (last_dim_end,) + res = F.tensor_slice(res, begin, size) + if not _check_same_type(cast_type, dtype): + res = F.cast(res, dtype) + return res + + +@constexpr +def _eye(N, M, k, dtype): + return eye(N=N, M=M, k=k, dtype=dtype) + + +def trace(a, offset=0, axis1=0, axis2=1): + """ + Returns the sum along diagonals of the array. + + If `a` is 2-D, the sum along its diagonal with the given offset is returned, + i.e., the sum of elements a[i,i+offset] for all i. + If `a` has more than two dimensions, then the axes specified by axis1 and + axis2 are used to determine the 2-D sub-arrays whose traces are returned. + The shape of the resulting array is the same as that of a with axis1 and + axis2 removed. + + Note: + Numpy arguments dtype and out are not supported. + + Args: + a (Tensor): Array from which the diagonals are taken. + offset (int): optional. Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal (0). + axis1 (int): optional. Axis to be used as the first axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + first axis (0). + axis2 (int): optional. Axis to be used as the second axis of the 2-D + sub-arrays from which the diagonals should be taken. Defaults to + second axis (1). + + Returns: + Tensor, sum_along_diagonals. If a is 2-D, the sum along the diagonal + is returned. If a has larger dimensions, then an array of sums along + diagonals is returned. + + Raises: + ValueError: if the input tensor has less than two dimensions. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.trace(np.eye(3)) + >>> print(output) + 3.0 + >>> a = np.arange(8).reshape((2,2,2)) + >>> output = np.trace(a) + >>> print(output) + [6 8] + >>> a = np.arange(24).reshape((2,2,2,3)) + >>> output = np.trace.shape + >>> print(output) + (2, 3) + """ + d = diagonal(a, offset, axis1=axis1, axis2=axis2) + shape = F.shape(d) + dtype = F.dtype(d) + if shape[-1] == 0: + return _empty(dtype, shape[:-1]) + + cast_type = dtype + if not isinstance(dtype, Float): + # reduce sum only supports float types + cast_type = mstype.float32 + d = F.cast(d, cast_type) + res = F.reduce_sum(d, -1) + if not _check_same_type(cast_type, dtype): + res = F.cast(res, dtype) + return res diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index 3bb9cc28484..b26b92a133c 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,583 +13,24 @@ # limitations under the License. # ============================================================================ """array operations, the function docs are adapted from Numpy API.""" -from copy import copy as py_copy -import numpy as onp - -from ..common import Tensor from ..common import dtype as mstype from ..ops import operations as P from ..ops import functional as F from ..ops.primitive import constexpr +from ..nn import Cell -from .utils import _check_shape, _check_shape_compile, _check_dtype, _check_is_int, \ - _check_axes_range, _check_start_normalize, _check_shape_contain_zero, _check_is_tensor, \ - _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, _check_is_list, \ - _covert_list_tensor_to_tuple_tensor +from .utils import _covert_list_tensor_to_tuple_tensor, _expand, _broadcast_to, \ + _is_empty +from .utils_const import _check_is_int, _check_axes_range, _check_start_normalize, \ + _check_is_tensor, _check_is_tuple, _check_is_list, _raise_type_error, _raise_value_error, \ + _infer_out_shape, _get_index_for_unique, _get_counts_for_unique, _empty, _promote, \ + _min, _check_same_type, _check_input_tensor -DEFAULT_FLOAT_DTYPE = mstype.float32 -DEFAULT_INT_DTYPE = mstype.int32 # According to official numpy reference, the dimension of a numpy array must be less # than 32 MAX_NUMPY_DIMS = 32 -def array(obj, dtype=None, copy=True, ndmin=0): - """ - Creates a tensor. - - This function creates tensors from an array-like object. - - Args: - obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a tensor. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.int32, or `int32`. If dtype is None, the data type - of the new tensor will be inferred from obj. Default is None. - copy (bool): If true, then the object is copied. Otherwise, a copy will - only be made if necessary. Default: True. - ndmin (int): Specifies the minimum number of dimensions that the resulting - tensor should have. Ones will be pre-pended to the shape as needed to - meet this requirement. Default: 0 - - Returns: - Tensor, generated tensor with the specified dtype. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.array([1,2,3])) - [1 2 3] - """ - if ndmin > 0: - # Fall back to original numpy creation. - if isinstance(obj, Tensor): - obj = obj.asnumpy() - return asarray(onp.array(obj, dtype, copy=copy, ndmin=ndmin)) - - if not copy: - return asarray(obj, dtype=dtype) - - obj = py_copy(obj) - return asarray(obj, dtype=dtype) - - -def asarray(a, dtype=None): - """ - Converts the input to tensor. - - This function converts tensors from an array-like object. - - Args: - a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a tensor. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and ndarrays. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.int32, or `int32`. If dtype is None, the data type - of the new tensor will be inferred from a. Default is None. - - Returns: - Tensor, generated tensor with the specified dtype. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.asarray([1,2,3])) - [1 2 3] - """ - - if dtype is not None: - dtype = _check_dtype(dtype) - - _ = _check_input_for_asarray(a) - - if isinstance(a, float) and (dtype is None): - dtype = DEFAULT_FLOAT_DTYPE - - if isinstance(a, int) and not isinstance(a, bool) and (dtype is None): - dtype = DEFAULT_INT_DTYPE - - if isinstance(a, bool) and (dtype is None): - dtype = mstype.bool_ - - if isinstance(a, (list, tuple)): - # Convert all tuple/nested tuples to lists - a = _deep_list(a) - # Convert all tensor sub-elements to numpy arrays - a = _deep_tensor_to_nparray(a) - a = onp.asarray(a) - # If dtype is not specified, we keep consistent with numpy decision - # only exceptions are: we use int/float32 - if dtype is None: - if a.dtype is onp.dtype('int64'): - dtype = DEFAULT_INT_DTYPE - elif a.dtype is onp.dtype('float64'): - dtype = DEFAULT_FLOAT_DTYPE - - if isinstance(a, onp.ndarray) and dtype is None: - if a.dtype is onp.dtype('bool'): - dtype = mstype.bool_ - elif a.dtype is onp.dtype('int'): - dtype = DEFAULT_INT_DTYPE - elif a.dtype is onp.dtype('float'): - dtype = DEFAULT_FLOAT_DTYPE - a = Tensor.from_numpy(a) - - # If a is already a tensor and we don't need to cast dtype, return a - if isinstance(a, Tensor): - if dtype is None: - return a - dtype = _check_dtype(dtype) - if dtype == a.dtype: - return a - - return Tensor(a, dtype=dtype) - - -def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE): - """ - Similar to asarray, converts the input to a float tensor. - - If non-float dtype is defined, this function will return a float32 tensor instead. - - Args: - a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a tensor. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. - - Returns: - Tensor, generated tensor with the specified float dtype. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.asfarray([1,2,3])) - [1. 2. 3.] - """ - dtype = _check_dtype(dtype) - _ = _check_input_for_asarray(a) - - if dtype not in (mstype.float16, mstype.float32, mstype.float64): - dtype = DEFAULT_FLOAT_DTYPE - - if isinstance(a, (list, tuple)): - # Convert all tuple/nested tuples to lists - a = _deep_list(a) - # Convert all tensor sub-elements to numpy arrays - a = _deep_tensor_to_nparray(a) - a = onp.asarray(a) - - if isinstance(a, onp.ndarray): - a = Tensor.from_numpy(a) - - return Tensor(a, dtype) - - -def copy_(a): - """ - Returns a tensor copy of the given object. - - Args: - a (Tensor): Input tensor. - - Returns: - Tensor, has the same data as `a`. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> x = np.ones((2,2)) - >>> print(np.copy(x)) - [[1. 1.] - [1. 1.]] - """ - return py_copy(a) - - -def ones(shape, dtype=DEFAULT_FLOAT_DTYPE): - """ - Returns a new tensor of given shape and type, filled with ones. - - Args: - shape (Union[int, tuple, list]): the shape of the new tensor. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. - - Returns: - Tensor, with the designated shape and dtype, filled with ones. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.ones((2,2))) - [[1. 1.] - [1. 1.]] - """ - if _check_shape_contain_zero(shape): - return asarray(onp.ones(shape), dtype=dtype) - shape = _check_shape(shape) - dtype = _check_dtype(dtype) - fill = P.Fill() - output = fill(dtype, shape, 1) - return output - - -def zeros(shape, dtype=DEFAULT_FLOAT_DTYPE): - """ - Returns a new tensor of given shape and type, filled with zeros. - - Args: - shape (Union[int, tuple, list]): the shape of the new tensor. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. - - Returns: - Tensor, with the designated shape and dtype, filled with zeros. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.zeros((2,2))) - [[0. 0.] - [0. 0.]] - """ - if _check_shape_contain_zero(shape): - return asarray(onp.zeros(shape), dtype=dtype) - shape = _check_shape(shape) - dtype = _check_dtype(dtype) - fill = P.Fill() - output = fill(dtype, shape, 0) - return output - - -def full(shape, fill_value, dtype=None): - """ - Returns a new tensor of given shape and type, filled with fill_value. - - Args: - shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g., - (2, 3) or 2. - fill_value (Union[int, float, bool, list, tuple]): scalar or array_like - fill value. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`, if dtype is None, the data type - of the new tensor will be inferred from fill_value. Default is None. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Returns: - Tensor, with the designated shape and dtype, filled with `fill_value`. - - Examples: - >>> import mindspore.numpy as np - >>> print(np.full((2,2), True)) - [[True True] - [True True]] - """ - if dtype is None: - dtype = array(fill_value).dtype - - shape = _check_shape(shape) - _ = _check_input_for_asarray(fill_value) - dtype = _check_dtype(dtype) - - if isinstance(fill_value, (int, float, bool)) and not _check_shape_contain_zero(shape): - return P.Fill()(dtype, shape, fill_value) - - # if fill_value is array_like or shape contains zero. fall back to original - # numpy creation - return Tensor(onp.full(shape, fill_value, mstype.dtype_to_nptype(dtype))) - - -def arange(*args, **kwargs): - """ - Returns evenly spaced values within a given interval. - - Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. - The endpoint of the interval can optionally be excluded. - The current implementation is a direct wrapper on top of numpy.arange, except that - the default dtype is float32 and int32, compare to float64 and int64 for numpy - implementation. - - Args: - start(Union[int, float]): Start of interval. The interval includes this value. - When stop is provided as a position argument, start must be given, when stop - is a normal argument, start can be optional, and default is 0. - Please see additional examples below. - stop(Union[int, float], optional): End of interval. The interval does not - include this value, except in some cases where step is not an integer - and floating point round-off affects the length of out. - step(Union[int, float], optional): Spacing between values. For any output - out, this is the distance between two adjacent values, out[i+1] - out[i]. - The default step size is 1. If step is specified as a position argument, - start must also be given. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. If dtype is None, the data type - of the new tensor will be inferred from start, stop and step. Default is None. - - Returns: - arangend tensor of evenly spaced values. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.arange(0, 5, 1)) - [0 1 2 3 4] - >>> print(np.arange(3)) - [0 1 2] - >>> print(np.arange(start=0, stop=3)) - [0 1 2] - >>> print(np.arange(0, stop=3, step=0.5)) - [0. 0.5 1. 1.5 2. 2.5] - >>> print(np.arange(stop=3)) # This will lead to TypeError - """ - # infer the dtype, if either of start, end, step is float, default dtype is - # float32, else int32. - int_flag = True - final_dtype = None - - if args: - for item in args: - if isinstance(item, float): - int_flag = False - if kwargs: - if ('start' in kwargs and isinstance(kwargs['start'], float)) or \ - ('stop' in kwargs and isinstance(kwargs['stop'], float)) or \ - ('step' in kwargs and isinstance(kwargs['step'], float)): - int_flag = False - - if int_flag: - final_dtype = onp.int32 - else: - final_dtype = onp.float32 - - if 'dtype' in kwargs and kwargs['dtype'] is not None: - final_dtype = _check_dtype(kwargs['dtype']) - final_dtype = mstype.dtype_to_nptype(final_dtype) - kwargs['dtype'] = final_dtype - out = onp.arange(*args, **kwargs) - out = Tensor.from_numpy(out) - return out - - -def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): - """ - Returns evenly spaced values within a given interval. - - The current implementation is a direct wrapper on top of numpy.linspace, except - the default dtype is float32, compare to float64 for numpy, - - Args: - start (Union[int, list(int), tuple(int),tensor]):The starting value of the sequence. - stop (Union[int, list(int), tuple(int),tensor]):The end value of the sequence, - unless `endpoint` is set to False. In that case, the sequence consists - of all but the last of ``num + 1` evenly spaced samples, so that `stop` - is excluded. Note that the step size changes when `endpoint` is False. - num (int, optional): Number of samples to generate. Default is 50. - endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is - not included. Default is True. - retstep (bool, optional): If True, return (`samples`, `step`), where `step` is - the spacing between samples. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`.If `dtype` is None, infer the data - type from other input arguments. Default is None. - axis (int, optional): The axis in the result to store the samples. Relevant - only if start or stop are array-like. By default (0), the samples will - be along a new axis inserted at the beginning. Use -1 to get an axis at the end. - Default is 0. - - Returns: - samples (Tensor): There are `num` equally spaced samples in the closed interval - ``[start, stop]`` or the half-open interval ``[start, stop)`` - (depending on whether `endpoint` is True or False). - - step (float, optional): Only returned if `retstep` is True. - Size of spacing between samples. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.linspace(0, 5, 6)) - [0. 1. 2. 3. 4. 5.] - """ - - if isinstance(start, Tensor): - start = start.asnumpy() - - if isinstance(stop, Tensor): - stop = stop.asnumpy() - - if not isinstance(num, int): - raise TypeError(f"num should be an integer, but got {type(num)}") - - final_dtype = None - if dtype is not None: - final_dtype = _check_dtype(dtype) - final_dtype = mstype.dtype_to_nptype(final_dtype) - else: - final_dtype = onp.float32 - - dtype = final_dtype - out = onp.linspace(start, stop, num, endpoint, retstep, dtype, axis) - - if retstep: - array_out, step_out = out[0], out[1] - tensor_out = Tensor(array_out) - return tensor_out, step_out - - tensor_out = Tensor(out) - return tensor_out - - -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): - """ - Returns numbers spaced evenly on a log scale. - - In linear space, the sequence starts at base ** start (base to the power of - start) and ends with base ** stop (see endpoint below). - The current implementation is a direct wrapper on top of numpy.logspace, except - the default dtype is float32, compare to float64 for numpy, - - Args: - start (Union[int, list(int), tuple(int), tensor]):The starting value of the sequence. - stop (Union[int, list(int), tuple(int), tensor]):The end value of the sequence, - unless `endpoint` is set to False. In that case, the sequence consists - of all but the last of ``num + 1` evenly spaced samples, so that `stop` - is excluded. Note that the step size changes when `endpoint` is False. - num (int, optional): Number of samples to generate. Default is 50. - endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is - not included. Default is True. - base (Union[int, float], optional): The base of the log space. The step size - between the elements in ln(samples) / ln(base) (or log_base(samples)) - is uniform. Default is 10.0. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`.If `dtype` is None, infer the data - type from other input arguments. Default is None. - axis (int, optional): The axis in the result to store the samples. Relevant - only if start or stop is array-like. By default (0), the samples will - be along a new axis inserted at the beginning. Use -1 to get an axis at the end. - Default is 0. - - Returns: - samples (Tensor): num samples, equally spaced on a log scale. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.logspace(0, 5, 6, base=2.0)) - [ 1. 2. 4. 8. 16. 32.] - """ - - if isinstance(start, Tensor): - start = start.asnumpy() - - if isinstance(stop, Tensor): - stop = stop.asnumpy() - - final_dtype = None - if dtype is not None: - final_dtype = _check_dtype(dtype) - final_dtype = mstype.dtype_to_nptype(final_dtype) - else: - final_dtype = onp.float32 - - dtype = final_dtype - out = onp.logspace(start, stop, num, endpoint, base, dtype, axis) - - tensor_out = Tensor.from_numpy(out) - return tensor_out - - -def eye(N, M=None, k=0, dtype=DEFAULT_FLOAT_DTYPE): - """ - Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere. - - Args: - N (int): Number of rows in the output, must be larger than 0. - M (int, optional): Number of columns in the output. If None, defaults to N, - if defined, must be larger than 0. Deault is None. - k (int, optional): Index of the diagonal: 0 (the default) refers to the main - diagonal, a positive value refers to an upper diagonal, and a negative value - to a lower diagonal. Default is 0. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. - - Returns: - result (Tensor): A tensor of shape (N,M). A tensor where all elements - are equal to zero, except for the k-th diagonal, whose values are equal to one. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.eye(2, 2)) - [[1. 0.] - [0. 1.]] - """ - dtype = _check_dtype(dtype) - make_eye = P.Eye() - if M is None: - M = N - M = int(M) - N = int(N) - k = int(k) - out = None - if k != 0 or N == 0 or M == 0: - # Fall back to original numpy creation method - out = onp.eye(N, M, k) - else: - out = make_eye(N, M, dtype) - return asarray(out, dtype=dtype) - - -def identity(n, dtype=DEFAULT_FLOAT_DTYPE): - """ - Returns the identity tensor. - - Args: - n (int): Number of rows and columns in the output, must be larger than 0. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. - - Returns: - result (Tensor): A tensor of shape (n,n). A tensor where all elements - are equal to zero, except for the diagonal, whose values are equal to one. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> import mindspore.numpy as np - >>> print(np.identity(2)) - [[1. 0.] - [0. 1.]] - """ - dtype = _check_dtype(dtype) - return eye(n, dtype=dtype) - @constexpr def _prepare_shape_for_expand_dims(shape, axes): @@ -612,21 +53,18 @@ def _prepare_shape_for_expand_dims(shape, axes): if isinstance(axes, int): new_shape_length += 1 if axes >= new_shape_length or axes < -new_shape_length: - raise ValueError( - f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}") + raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}") axes = {axes} elif isinstance(axes, (list, tuple)): new_shape_length += len(axes) for axis in axes: if axis >= new_shape_length or axis < -new_shape_length: - raise ValueError( - f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}") + raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}") axes = set(axes) else: - raise TypeError( - f"only int, tuple and list are allowed for axes, but got {type(axes)}") + raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}") for new_shape_idx in range(new_shape_length): if new_shape_idx in axes or new_shape_idx - new_shape_length in axes: @@ -651,6 +89,10 @@ def expand_dims(a, axis): Returns: Tensor, view of a tensor with the number of dimensions increased. + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If axis exceeds a.ndim. + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -661,65 +103,18 @@ def expand_dims(a, axis): >>> print(x.shape) (1, 2, 2) """ + if not _check_is_tensor(F.typeof(a)): + _raise_type_error("Input is not Tensor.") shape = F.shape(a) # yield expanded shape based on the axes new_shape = _prepare_shape_for_expand_dims(shape, axis) - return P.Reshape()(a, new_shape) - - -@constexpr -def _prepare_shape_for_squeeze(shape, axes): - """ - Creates the squeezed new shape based on the tensor and given axes. - - Args: - shape (tuple): the shape of the tensor - axes Union[None, int, tuple(int), list(int)]: the axes with dimensions squeezed. - - Returns: - new_shape(tuple): the shape with dimensions squeezed. - """ - new_shape = [] - ndim = len(shape) - - # Convert to set - if isinstance(axes, int): - if axes >= ndim or axes < -ndim: - raise ValueError( - f"axis {axes} is out of bounds for tensor of dimension {ndim}") - axes = {axes} - - elif isinstance(axes, (list, tuple)): - for axis in axes: - if axis >= ndim or axis < -ndim: - raise ValueError( - f"axis {axis} is out of bounds for tensor of dimension {ndim}") - axes = set(axes) - - elif axes is not None: - raise TypeError( - f"only int, tuple and list are allowed for axes, but got {type(axes)}") - - if axes is None: - new_shape = [s for s in shape if s != 1] - else: - for idx, s in enumerate(shape): - if s != 1 or (idx not in axes) and (idx - ndim not in axes): - new_shape.append(s) - # if an axis is selected with shape entry greater than one, an error is raised. - if s != 1 and ((idx in axes) or (idx - ndim in axes)): - raise ValueError( - f"axis {axes} has shape entry {s} > 1, cannot be squeezed.") - return tuple(new_shape) + return F.reshape(a, new_shape) def squeeze(a, axis=None): """ Removes single-dimensional entries from the shape of an tensor. - This is a temporary solution to support CPU backend. Will be changed - once CPU backend supports P.Squeeze(). - Args: a (Tensor): Input tensor array. axis: Union[None, int, list(int), tuple(list)]. Default is None. @@ -727,6 +122,10 @@ def squeeze(a, axis=None): Returns: Tensor, with all or a subset of the dimensions of length 1 removed. + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If specified axis has shape entry > 1. + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -737,10 +136,9 @@ def squeeze(a, axis=None): >>> print(x.shape) (2, 2) """ - shape = F.shape(a) - # yield squeezed shape based on the axes - new_shape = _prepare_shape_for_squeeze(shape, axis) - return P.Reshape()(a, new_shape) + if not _check_is_tensor(F.typeof(a)): + _raise_type_error("Input is not Tensor.") + return a.squeeze(axis) def transpose(a, axes=None): @@ -755,6 +153,10 @@ def transpose(a, axes=None): Returns: Tensor, the transposed tensor array. + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If the number of axes is not euqal to a.ndim. + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -765,15 +167,9 @@ def transpose(a, axes=None): >>> print(x.shape) (3, 2, 1) """ - if axes is None: - shape = F.shape(a) - length = F.tuple_len(shape) - perm = F.make_range(0, length) - new_order = F.tuple_reversed(perm) - return P.Transpose()(a, new_order) - - axes = _check_shape_compile(axes) - return P.Transpose()(a, axes) + if not _check_is_tensor(F.typeof(a)): + _raise_type_error("Input is not Tensor.") + return a.transpose(axes) def rollaxis(x, axis, start=0): @@ -808,7 +204,7 @@ def rollaxis(x, axis, start=0): ``Ascend`` ``GPU`` ``CPU`` Raises: - TypeError: If axis or start is not integer. + TypeError: If axis or start is not integer, or x is not tensor. ValueError: If axis is not in the range from -ndim to ndim-1 or start is not in the range from -ndim to ndim. @@ -819,8 +215,12 @@ def rollaxis(x, axis, start=0): >>> print(output.shape) (3, 2, 4) """ - _check_is_int(axis) - _check_is_int(start) + if not _check_is_tensor(F.typeof(x)): + _raise_type_error("Input is not Tensor.") + if not _check_is_int(axis): + _raise_type_error("integer argument expected, but got ", axis) + if not _check_is_int(start): + _raise_type_error("integer argument expected, but got ", start) shape = F.shape(x) ndim = F.tuple_len(shape) @@ -845,7 +245,7 @@ def rollaxis(x, axis, start=0): new_perm = perm[0:axis] + perm[axis+1:start] + \ perm[axis:axis+1] - return P.Transpose()(x, new_perm) + return F.transpose(x, new_perm) def swapaxes(x, axis1, axis2): @@ -861,7 +261,7 @@ def swapaxes(x, axis1, axis2): Transposed tensor, has the same data type as the original tensor x. Raises: - TypeError: If axis1 or axis2 is not integer. + TypeError: If axis1 or axis2 is not integer, or x is not tensor. ValueError: If axis1 or axis2 is not in the range from -ndim to ndim-1. Supported Platforms: @@ -874,30 +274,9 @@ def swapaxes(x, axis1, axis2): >>> print(output.shape) (4,3,2) """ - _check_is_int(axis1) - _check_is_int(axis2) - - shape = F.shape(x) - ndim = F.tuple_len(shape) - - axes = _check_axes_range((axis1, axis2), ndim) - axis1, axis2 = axes[0], axes[1] - - if axis1 == axis2: - return x - if axis1 > axis2: - axis1, axis2 = axis2, axis1 - - perm = F.make_range(0, ndim) - new_perm = None - if axis2 + 1 < ndim: - new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ - perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] - else: - new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ - perm[axis1+1:axis2] + perm[axis1:axis1+1] - - return P.Transpose()(x, new_perm) + if not _check_is_tensor(F.typeof(x)): + _raise_type_error("Input is not Tensor.") + return x.swapaxes(axis1, axis2) def reshape(x, new_shape): @@ -916,7 +295,7 @@ def reshape(x, new_shape): Reshaped Tensor. Has the same data type as the original tensor x. Raises: - TypeError: If new_shape is not integer, list or tuple. + TypeError: If new_shape is not integer, list or tuple, or x is not tensor. ValueError: If new_shape does not compatible with the original shape. Supported Platforms: @@ -939,8 +318,9 @@ def reshape(x, new_shape): >>> print(output) [-0.1 0.3 3.6 0.4 0.5 -3.2] """ - new_shape = _check_shape_compile(new_shape) - return P.Reshape()(x, new_shape) + if not _check_is_tensor(F.typeof(x)): + _raise_type_error("Input is not Tensor.") + return x.reshape(new_shape) def ravel(x): @@ -955,6 +335,9 @@ def ravel(x): Returns: Flattened tensor, has the same data type as the original tensor x. + Raises: + If x is not tensor. + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -965,13 +348,15 @@ def ravel(x): >>> print(output.shape) (24,) """ - return reshape(x, (-1,)) + if not _check_is_tensor(F.typeof(x)): + _raise_type_error("Input is not Tensor.") + return x.ravel() @constexpr def _move_axes_for_concatenate(arr_shape, axis): """ - Moves axis 0 to the disiganated position, while keeps other axes' relative + Moves axis 0 to the desiganated position, while keeps other axes' relative positions unchanged, only used if a single tensor is concatenated. """ @@ -982,6 +367,34 @@ def _move_axes_for_concatenate(arr_shape, axis): return new_axes, new_shape +def _promote_type_for_concatenate(tuple_of_tensors): + """ + Checks dtype for all tensors in the tuple. If dtypes are not the same, promote + them to the `highest` dtype in the tuple, so that they are ready for the concat + operator. + + Args: + tuple_of_tensors(tuple(tensor)): A tuple of tensors + + Returns: + tuple of tensors, with each tensor promoted to ths same dtype. + """ + need_cast = False + final_type = tuple_of_tensors[0].dtype + + for tensor in tuple_of_tensors: + if not _check_same_type(final_type, tensor.dtype): + need_cast = True + final_type = _promote(final_type, tensor.dtype) + + if not need_cast: + return tuple_of_tensors + tuple_of_casted_tensors = () + for tensor in tuple_of_tensors: + tuple_of_casted_tensors += (tensor.astype(final_type, copy=False),) + return tuple_of_casted_tensors + + def concatenate(arrays, axis=0): """ Joins a sequence of tensors along an existing axis. @@ -996,6 +409,10 @@ def concatenate(arrays, axis=0): Returns: Tensor, a tensor concatenated from a tensor or a list of tensors. + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If specified axis < 0, and exceeds tensor.ndim. + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -1029,11 +446,11 @@ def concatenate(arrays, axis=0): for arr in arrays: flattened_arrays += (ravel(arr),) axis = -1 + flattened_arrays = _promote_type_for_concatenate(flattened_arrays) return P.Concat(axis)(flattened_arrays) # convert a list of tensor to a tuple of tensor - if _check_is_list(array_type): - arrays = _covert_list_tensor_to_tuple_tensor(arrays) + arrays = _covert_list_tensor_to_tuple_tensor(arrays) arr_shape = F.shape(arrays[0]) _check_axes_range((axis,), len(arr_shape)) @@ -1042,4 +459,570 @@ def concatenate(arrays, axis=0): if len(arrays) == 1: return arrays[0] + arrays = _promote_type_for_concatenate(arrays) return P.Concat(axis)(arrays) + + +def column_stack(tup): + """ + Stacks 1-D tensors as columns into a 2-D tensor. 2-D tensors are stacked as-is, + like np.hstack. + + Args: + tup (Union[Tensor, tuple, list]): A sequence of 1-D or 2-D tensors. All + of them must have the same shape except the axis to be concatenated. + + Returns: + 2-D Tensor, formed by stacking the given tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If tup is not Tensor, list or tuple. + ValueError: If tup is empty. + + Examples: + >>> import mindspore.numpy as mnp + >>> import numpy as onp + >>> from mindspore import Tensor + >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32')) + >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32')) + >>> output = mnp.column_stack((x1, x2)) + >>> print(output) + [[1, 4], + [2, 5], + [3, 6]] + """ + if _check_is_tensor(F.typeof(tup)): + return tup + if not _check_is_list(tup) and not _check_is_tuple(tup): + _raise_type_error("Tensor or, list or tuple of tensors are required, but got ", tup) + if not tup: + _raise_value_error("Need at least one tensor to concatenate.") + + trans_tup = () + for tensor in tup: + shape = F.shape(tensor) + if F.tuple_len(shape) == 1: + reshape_tensor = F.reshape(tensor, shape+(1,)) + trans_tup += (reshape_tensor,) + else: + trans_tup += (tensor,) + return P.Concat(axis=1)(trans_tup) + + +def vstack(tup): + """ + Stacks tensors in sequence vertically. + This is equivalent to concatenation along the first axis. 1-D tensors should firstly be reshaped to (1, N), + and then be concatenated along the first axis. + + Args: + tup (Union[Tensor, tuple, list]): A sequence of 1-D or 2-D tensors. The tensors must have the same shape + along all but the first axis. 1-D tensors must have the same shape. + + Returns: + Stacked Tensor, formed by stacking the given tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If tup is not Tensor, list or tuple. + ValueError: If tup is empty. + + Examples: + >>> import mindspore.numpy as mnp + >>> import numpy as onp + >>> from mindspore import Tensor + >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32')) + >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32')) + >>> output = mnp.vstack((x1, x2)) + >>> print(output) + [[1, 2, 3], + [4, 5, 6]] + """ + if _check_is_tensor(F.typeof(tup)): + return tup + if not _check_is_list(tup) and not _check_is_tuple(tup): + _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) + if not tup: + _raise_value_error("Need at least one tensor to concatenate.") + + trans_tup = () + for tensor in tup: + shape = F.shape(tensor) + if F.tuple_len(shape) == 1: + reshape_tensor = F.reshape(tensor, (1,)+shape) + trans_tup += (reshape_tensor,) + else: + trans_tup += (tensor,) + return P.Concat(axis=0)(trans_tup) + + +def hstack(tup): + """ + Stacks tensors in sequence horizontally. + This is equivalent to concatenation along the second axis, except for 1-D tensors + where it concatenates along the first axis. + + Args: + tup (Union[Tensor, tuple, list]): A sequence of 1-D or 2-D tensors. The + tensors must have the same shape along all but the second axis, except + 1-D tensors which can be any length. + + Returns: + Stacked Tensor, formed by stacking the given tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If tup is not Tensor, list or tuple. + ValueError: If tup is empty. + + Examples: + >>> import mindspore.numpy as mnp + >>> import numpy as onp + >>> from mindspore import Tensor + >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32')) + >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32')) + >>> output = mnp.hstack((x1, x2)) + >>> print(output) + [1, 2, 3, 4, 5, 6] + """ + if _check_is_tensor(F.typeof(tup)): + return tup + if not _check_is_list(tup) and not _check_is_tuple(tup): + _raise_type_error(f"Tensor or, list or tuple of tensors are required, but got", tup) + if not tup: + _raise_value_error("Need at least one tensor to concatenate.") + + tuple_of_tensor = () + if _check_is_list(tup): + for tensor in tup: + tuple_of_tensor += (tensor,) + else: + tuple_of_tensor = tup + + if F.tuple_len(F.shape(tup[0])) == 1: + return P.Concat(axis=0)(tuple_of_tensor) + return P.Concat(axis=1)(tuple_of_tensor) + + +def dstack(tup): + """ + Stacks tensors in sequence depth wise (along the third axis). + This is equivalent to concatenation along the third axis. 1-D tensors (N,) should be reshaped to (1,N,1). + 2-D tensors (M,N) should be reshaped to (M,N,1) before concatenation. + + Args: + tup (Union[Tensor, tuple, list]): A sequence of tensors. The tensors must have the same shape along all but + the third axis. 1-D or 2-D tensors must have the same shape. + + Returns: + Stacked Tensor, formed by stacking the given tensors. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If tup is not Tensor, list or tuple. + ValueError: If tup is empty. + + Examples: + >>> import mindspore.numpy as mnp + >>> import numpy as onp + >>> from mindspore import Tensor + >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32')) + >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32')) + >>> output = mnp.dstack((x1, x2)) + >>> print(output) + [[[1, 4], + [2, 5], + [3, 6]]] + """ + if _check_is_tensor(F.typeof(tup)): + return tup + if not _check_is_list(tup) and not _check_is_tuple(tup): + _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) + if not tup: + _raise_value_error("Need at least one tensor to concatenate.") + + trans_tup = () + for tensor in tup: + shape = F.shape(tensor) + if F.tuple_len(shape) == 1: + reshape_tensor = F.reshape(tensor, (1,)+shape+(1,)) + trans_tup += (reshape_tensor,) + elif F.tuple_len(shape) == 2: + reshape_tensor = F.reshape(tensor, shape+(1,)) + trans_tup += (reshape_tensor,) + else: + trans_tup += (tensor,) + return P.Concat(axis=2)(trans_tup) + + +def where(condition, x=None, y=None): + """ + Returns elements chosen from x or y depending on condition. + + Note: + As nonzero is not supported, neither x or y can be None. + On CPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + On GPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + + Args: + condition (Tensor): where True, yield x, otherwise yield y. + x, y (Tensor): Values from which to choose. x, y and condition need + to be broadcastable to some shape. + + Returns: + Tensor or scalar, with elements from x where condition is True, and + elements from y elsewhere. + + Raises: + ValueError: if operands cannot be broadcast. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> condition = np.full((1, 1, 2), [False, True]) + >>> x = np.full((1, 3, 2), 5) + >>> y = np.full((2, 1, 1), 7) + >>> output = np.where(condition, x, y) + >>> print(output) + [[[7, 5], + [7, 5], + [7, 5]], + + [[7, 5], + [7, 5], + [7, 5]]] + """ + # type promotes input tensors + dtype1 = F.dtype(x) + dtype2 = F.dtype(y) + dtype = _promote(dtype1, dtype2) + if not _check_same_type(dtype1, dtype): + x = F.cast(x, dtype) + if not _check_same_type(dtype2, dtype): + y = F.cast(y, dtype) + is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type( + dtype2, mstype.bool_) + if is_bool: + # select does not support bool type for x or y + x = F.cast(x, mstype.float32) + y = F.cast(y, mstype.float32) + + # broadcasts input tensors + shape_out = _infer_out_shape(F.shape(condition), + F.shape(x), F.shape(y)) + ndim_out = len(shape_out) + condition = _expand(condition, ndim_out) + x = _expand(x, ndim_out) + y = _expand(y, ndim_out) + condition = _broadcast_to( + condition, F.shape(condition), shape_out, ndim_out) + x = _broadcast_to(x, F.shape(x), shape_out, ndim_out) + y = _broadcast_to(y, F.shape(y), shape_out, ndim_out) + if not _check_same_type(F.dtype(condition), mstype.bool_): + condition = F.cast(condition, mstype.bool_) + res = F.select(condition, x, y) + if is_bool: + res = F.cast(res, mstype.bool_) + return res + + +def _expand_atleast(arr, ndim): + """Expands arr to at least ndim.""" + arr = _expand(arr, _min(ndim, 2)) + if ndim > 2: + arr = _expand(arr, ndim, axis=-1) + return arr + + +def _atleast_xd(ndim, arys): + """Returns arys with at least ndim.""" + for arr in arys: + _check_input_tensor(F.typeof(arr)) + + if F.tuple_len(arys) == 1: + return _expand_atleast(*arys, ndim) + res = [] + for arr in res: + res.append(_expand_atleast(arr, ndim)) + return res + + +def atleast_1d(*arys): + """ + Converts inputs to arrays with at least one dimension. + + Scalar inputs are converted to 1-dimensional arrays, whilst + higher-dimensional inputs are preserved. + + Note: + In graph mode, returns a tuple of tensor instead of a list of + tensors. + On CPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + On GPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + Args: + arys1, arys2, … (Tensor): one or more input tensors. + + Returns: + Tensor, or list of tensors, each with a.ndim >= 1. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.ones((2, 3)) + >>> b = np.ones(0) + >>> c = np.ones(5) + >>> output = np.atleast_1d(a, b, c) + >>> print(output) + (Tensor(shape=[2, 3], dtype=Float32, value= + [[1.00000000e+000, 1.00000000e+000, 1.00000000e+000], + [1.00000000e+000, 1.00000000e+000, 1.00000000e+000]]), + Tensor(shape=[1], dtype=Float32, value= [1.00000000e+000]), + Tensor(shape=[5], dtype=Float32, + value= [1.00000000e+000, 1.00000000e+000, 1.00000000e+000, + 1.00000000e+000, 1.00000000e+000])) + """ + return _atleast_xd(1, arys) + + +def atleast_2d(*arys): + """ + Views inputs as arrays with at least two dimensions. + + Note: + In graph mode, returns a tuple of tensor instead of a list of + tensors. + On CPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + On GPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + Args: + arys1, arys2, … (Tensor): one or more input tensors. + + Returns: + Tensor, or list of tensors, each with a.ndim >= 2. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.ones((2, 3)) + >>> b = np.ones(0) + >>> c = np.ones(5) + >>> output = np.atleast_2d(a, b, c) + >>> print(output) + (Tensor(shape=[2, 3], dtype=Float32, value= + [[1.00000000e+000, 1.00000000e+000, 1.00000000e+000], + [1.00000000e+000, 1.00000000e+000, 1.00000000e+000]]), + Tensor(shape=[1, 1], dtype=Float32, value= [[1.00000000e+000]]), + Tensor(shape=[1, 5], dtype=Float32, + value= [[1.00000000e+000, 1.00000000e+000, 1.00000000e+000, + 1.00000000e+000, 1.00000000e+000]])) + """ + return _atleast_xd(2, arys) + + +def atleast_3d(*arys): + """ + Views inputs as arrays with at least three dimensions. + + Note: + In graph mode, returns a tuple of tensor instead of a list of + tensors. + On CPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + On GPU, the supported dtypes are np.float16, np.float32, np.int16, + and np.int32. + Args: + arys1, arys2, … (Tensor): one or more input tensors. + + Returns: + Tensor, or list of tensors, each with a.ndim >= 3. For example, + a 1-D array of shape (N,) becomes a view of shape (1, N, 1), and + a 2-D array of shape (M, N) becomes a view of shape (M, N, 1). + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.ones((2, 3)) + >>> b = np.ones(0) + >>> c = np.ones(5) + >>> output = np.atleast_3d(a, b, c) + >>> print(output) + (Tensor(shape=[2, 3, 1], dtype=Float32, value= + [[[1.00000000e+000], [1.00000000e+000], [1.00000000e+000]], + [[1.00000000e+000], [1.00000000e+000], [1.00000000e+000]]]), + Tensor(shape=[1, 1, 1], dtype=Float32, value= [[[1.00000000e+000]]]), + Tensor(shape=[1, 5, 1], dtype=Float32, + value= [[[1.00000000e+000], [1.00000000e+000], [1.00000000e+000], + [1.00000000e+000], [1.00000000e+000]]])) + """ + return _atleast_xd(3, arys) + + +def stack(arrays, axis=0): + """ + Joins a sequence of arrays along a new axis. + + The axis parameter specifies the index of the new axis in the + dimensions of the result. For example, if axis=0 it will be the + first dimension and if axis=-1 it will be the last dimension. + + + Note: + Numpy argument out is not supported. + + Args: + arrays (sequence of Tensor): Each array must have the same shape. + axis (int): optional. The axis in the result array along which the + input arrays are stacked. + + Returns: + Tensor, The stacked array has one more dimension than the input + arrays. + + Raises: + ValueError: if input is not Tensor, tuple, or list. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> arrays = [np.ones((3, 4)) for _ in range(10)] + >>> output = np.stack(arrays, axis=0) + >>> print(output.shape) + (10, 3, 4) + >>> output = np.stack(arrays, axis=1) + >>> print(output.shape) + (3, 10, 4) + >>> output = np.stack(arrays, axis=2) + >>> print(output.shape) + (3, 4, 10) + """ + arr_type = F.typeof(arrays) + + if _check_is_tensor(arr_type): + shape = F.shape(arrays) + ndim = F.rank(arrays) + axis = axis % ndim + axes = F.make_range(ndim) + perm = axes[1:axis+1] + (0,) + axes[axis+1:] + if _is_empty(shape): + return _empty(mstype.float32, shape[1:axis+1] + (shape[0],) + shape[axis+1:]) + return transpose(arrays, perm) + + if _check_is_tuple(arr_type) or _check_is_list(arr_type): + shape = (len(arrays),) + F.shape(arrays[0]) + ndim = len(shape) + axis = axis % ndim + if _is_empty(shape): + return _empty(mstype.float32, shape[1:axis+1] + (shape[0],) + shape[axis+1:]) + seq = () + for arr in arrays: + seq += (F.expand_dims(arr, axis),) + return concatenate(seq, axis) + return _raise_value_error('input arrays must be Tensor, tuple, or list') + + +class UniqueNet(Cell): + """The operation `mindspore.ops.Unique` must be wrapped inside a model and executed in graph mode. """ + + def __init__(self): + super(UniqueNet, self).__init__() + self.unique = P.Unique() + + def construct(self, x): + return self.unique(x) + + +def unique(x, return_index=False, return_inverse=False, return_counts=False): + """ + Finds the unique elements of a tensor. The input tensor will be flattened first + when it has more than one dimension. + + Note: + The operation is derived from mindspore.ops.Unique. + Numpy arguments `axis` is not supported. + + Args: + x (Tensor): The input tensor to be processed. + return_index (bool): If True, also return the indices of tensor x (along + the specified axis, if provided, or in the flattened tensor) that result + in the unique tensor. Default: False. + return_inverse (bool): If True, also return the indices of the unique tensor. + Default: False. + return_counts (bool): If True, also return the number of times each unique + item appears in input tensor `x`. Default: False. + + Returns: + Tensor or tuple of Tensors. + - If all of the three bool arguments (`return_index`, `return_inverse`, `return_counts`) + are False, just return the unique tensor. + - If parts of the three bool arguments are True, the corresponding results (Tensor) + will be added in the tuple. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + TypeError: If x is not tensor. + + Examples: + >>> import mindspore.numpy as mnp + >>> import numpy as onp + >>> input_x = mnp.asarray(onp.array([1, 2, 2, 2, 3, 4, 5]).astype('float32')) + >>> output_x = mnp.unique(input_x) + >>> print(output_x) + [1. 2. 3. 4. 5.] + >>> output_x = mnp.unique(input_x, return_index=True) + >>> print(output_x) + (Tensor(shape=[5], dtype=Float32, value= [ 1. 2. 3. 4. 5.]), Tensor(shape=[5], dtype=Float32, + value= [ 0. 1. 4. 5. 6.])) + >>> output_x = mnp.unique(input_x, return_inverse=True) + >>> print(output_x) + (Tensor(shape=[5], dtype=Float32, value= [ 1. 2. 3. 4. 5.]), Tensor(shape=[7], dtype=Int32, + value= [0, 1, 1, 1, 2, 3, 4])) + """ + if not _check_is_tensor(F.typeof(x)): + _raise_type_error("Tensor is expected, but got", x) + if F.tuple_len(F.shape(x)) > 1: + x = ravel(x) + uniq = UniqueNet() + unique_x, inverse_index = uniq(x) + if not return_index and not return_inverse and not return_counts: + return unique_x + res_tup = (unique_x,) + if return_index: + res_index = _get_index_for_unique(x, unique_x) + res_tup += (res_index,) + if return_inverse: + res_tup += (inverse_index,) + if return_counts: + res_counts = _get_counts_for_unique(x, unique_x) + res_tup += (res_counts,) + return res_tup diff --git a/mindspore/numpy/dtypes.py b/mindspore/numpy/dtypes.py index 201528b0186..211cf9cd1b5 100644 --- a/mindspore/numpy/dtypes.py +++ b/mindspore/numpy/dtypes.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +14,15 @@ # ============================================================================ """Dtypes and utilities""" -from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, uint64, \ - float16, float32, float64, bool_) +from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, uint64, + float16, float32, float64, bool_) # original numpy has int->int64, float->float64, uint->uint64 mapping. we map # them to 32 bit, since 64 bit calculation is not supported from mindspore # backend for now. inf = float('inf') +nan = float('nan') int_ = int32 uint = uint32 @@ -94,3 +95,69 @@ all_types = [ 'np.float32', 'np.float64', 'np.bool'] + +promotion_rule = { + (uint8, uint16): uint16, + (uint8, uint32): uint32, + (uint8, uint64): uint64, + (uint16, uint32): uint32, + (uint16, uint64): uint64, + (uint32, uint64): uint64, + (uint8, int8): int16, + (uint8, int16): int16, + (uint8, int32): int32, + (uint8, int64): int64, + (uint16, int8): int32, + (uint16, int16): int32, + (uint16, int32): int32, + (uint16, int64): int64, + (uint32, int8): int64, + (uint32, int16): int64, + (uint32, int32): int64, + (uint32, int64): int64, + (uint64, int8): float64, + (uint64, int16): float64, + (uint64, int32): float64, + (uint64, int64): float64, + (uint8, float16): float16, + (uint8, float32): float32, + (uint8, float64): float64, + (uint16, float16): float16, + (uint16, float32): float32, + (uint16, float64): float32, + (uint32, float16): float16, + (uint32, float32): float32, + (uint32, float64): float64, + (uint64, float16): float16, + (uint64, float32): float32, + (uint64, float64): float64, + (int8, int16): int16, + (int8, int32): int32, + (int8, int64): int64, + (int16, int32): int32, + (int16, int64): int64, + (int32, int64): int64, + (int8, float16): float16, + (int8, float32): float32, + (int8, float64): float64, + (int16, float16): float16, + (int16, float32): float32, + (int16, float64): float64, + (int32, float16): float16, + (int32, float32): float32, + (int32, float64): float64, + (float16, float32): float32, + (float16, float64): float64, + (float32, float64): float64, + (bool_, uint8): uint8, + (bool_, uint16): uint16, + (bool_, uint32): uint32, + (bool_, uint64): uint64, + (bool_, int8): int8, + (bool_, int16): int16, + (bool_, int32): int32, + (bool_, int64): int64, + (bool_, float16): float16, + (bool_, float32): float32, + (bool_, float64): float64, +} diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 650c0554d1b..9d617abcc97 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,354 @@ """math operations, the function docs are adapted from Numpy API.""" from ..ops import operations as P from ..ops import functional as F +from ..ops import composite as C from ..ops.primitive import constexpr -from .array_ops import squeeze, asarray -from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \ - _check_shape_aligned +from ..common import dtype as mstype +from .array_ops import ravel +from .array_ops import where as where_ +from .array_creations import asarray, full +from .utils import _is_scalar, _expand, _broadcast_to, _is_empty +from .utils_const import _infer_out_shape, _check_axis_valid, _get_device_compile, \ + _check_shape_aligned, _empty, _check_is_tensor, _raise_type_error, _check_same_type, \ + _check_is_float, _check_input_tensor +from .dtypes import nan + + +_mean_default = P.ReduceMean() +_mean_keepdims = P.ReduceMean(True) +_matmul = P.MatMul(False, False) +_matmul_T = P.MatMul(False, True) + + +def absolute(x, out=None, where=True, dtype=None): + """ + Calculates the absolute value element-wise. + + Note: + Numpy arguments casting, order, dtype, subok, signature, and extobj are + not supported. + When argument where is provided, argument out must have a tensor value. + Argument out is not supported for storing the result, however it can be + used in combination with argument where to set the value at indices for + which where is set to False. + Currently the backend kernel only supports float calculation, if the input + is not a float, then it will be casted to float32 and casted back. + + Args: + x (Tensor): Tensor to be used for calculation. + out (Tensor or None): optional, defaults to None. + where (Tensor or None): optional. For any non-default value of type other + than Tensor or None, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is True, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default out=None, + locations within it where the condition is False will remain + uninitialized. + dtype (data type): optional, defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.asarray([1, 2, 3, -4, -5], np.float64) + >>> output = np.absolute(x) + >>> print(output) + [1. 2. 3. 4. 5.] + """ + if not _check_is_tensor(F.typeof(x)): + _raise_type_error("Input is expected to be a tensor, but got ", x) + original_dtype = x.dtype + if not _check_is_float(original_dtype) and dtype is None: + x = x.astype(mstype.float32) + return _apply_tensor_op(F.absolute, x, out=out, where=where, dtype=dtype).astype(original_dtype) + return _apply_tensor_op(F.absolute, x, out=out, where=where, dtype=dtype) + + +def add(x1, x2, out=None, where=True, dtype=None): + """ + Adds arguments element-wise. + + Note: + Numpy arguments casting, order, dtype, subok, signature, and extobj are + not supported. + When argument where is provided, argument out must have a tensor value. + Argument out is not supported for storing the result, however it can be + used in combination with argument where to set the value at indices for + which where is set to False. + On GPU, the supported dtypes are np.float16, np.float32, np.int32, + and np.int64. + On CPU, the supported dtypes are np.float16, np.float32, np.float64, + np.int16, np.int32, and np.int64. + + Args: + x1 (Tensor): input to be added. + x2 (Tensor): input to be added. + out (Tensor or None): optional, defaults to None. + where (Tensor or None): optional. For any non-default value of type other + than Tensor or None, the output retains its original value. + This condition is broadcast over the input. At locations where the + condition is True, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default out=None, + locations within it where the condition is False will remain + uninitialized. + dtype (data type): optional, defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the sum of x1 and x2, element-wise. This is a scalar + if both x1 and x2 are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.full((3, 2), [1, 2]) + >>> x2 = np.full((3, 2), [3, 4]) + >>> output = np.add(x1, x2) + >>> print(output) + [[4, 6], + [4, 6], + [4, 6]] + """ + # broadcast is not fully supported in tensor_add on CPU, + # so we use tensor_sub as a substitute solution + if _get_device_compile() == 'CPU': + return subtract(x1, F.neg_tensor(x2), out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_add, x1, x2, out=out, where=where, dtype=dtype) + + +def subtract(x1, x2, out=None, where=True, dtype=None): + """ + Subtracts arguments, element-wise. + + Note: + Numpy arguments casting, order, dtype, subok, signature, and extobj are + not supported. + When argument where is provided, argument out must have a tensor value. + Argument out is not supported for storing the result, however it can be + used in combination with argument where to set the value at indices for + which where is set to False. + On GPU, the supported dtypes are np.float16, np.float32, np.int32, + and np.int64. + On CPU, the supported dtypes are np.float16, np.float32, np.float64, + np.int16, np.int32, and np.int64. + + Args: + x1 (Tensor): the input to be subtracted from. + x2 (Tensor): the input to be subtracted by. + out (Tensor or None): optional, defaults to None. + where (Tensor or None): optional. For any non-default value of type other + than Tensor or None, the output retains its original value. + This condition is broadcast over the input. At locations where the + condition is True, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default out=None, + locations within it where the condition is False will remain + uninitialized. + dtype (data type): optional, defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the difference of x1 and x2, element-wise. This is a + scalar if both x1 and x2 are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.full((3, 2), [1, 2]) + >>> x2 = np.full((3, 2), [3, 4]) + >>> output = np.subtract(x1, x2) + >>> print(output) + [[-2, -2], + [-2, -2], + [-2, -2]] + """ + return _apply_tensor_op(F.tensor_sub, x1, x2, out=out, where=where, dtype=dtype) + + +def multiply(x1, x2, out=None, where=True, dtype=None): + """ + Multiplies arguments element-wise. + + Note: + Numpy arguments casting, order, dtype, subok, signature, and extobj are + not supported. + When argument where is provided, argument out must have a tensor value. + Argument out is not supported for storing the result, however it can be + used in combination with argument where to set the value at indices for + which where is set to False. + On GPU, the supported dtypes are np.float16, np.float32, np.int32, + and np.int64. + On CPU, the supported dtypes are np.float16, np.float32, np.float64, + np.int16, np.int32, and np.int64. + + Args: + x1 (Tensor): input tensor to be multiplied. + x2 (Tensor): input tensor to be multiplied. + out (Tensor or None): optional, defaults to None. + where (Tensor or None): optional. For any non-default value of type other + than Tensor or None, the output retains its original value. + This condition is broadcast over the input. At locations where the + condition is True, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default out=None, + locations within it where the condition is False will remain + uninitialized. + dtype (data type): optional, defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the product of x1 and x2, element-wise. This is a scalar + if both x1 and x2 are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.full((3, 2), [1, 2]) + >>> x2 = np.full((3, 2), [3, 4]) + >>> output = np.multiply(x1, x2) + >>> print(output) + [[3, 8], + [3, 8], + [3, 8]] + """ + if _get_device_compile() == 'CPU': + # broadcast is not fully supported on CPU backend, + # and explicit broadcasting is performed + shape_out = _infer_out_shape(F.shape(x1), F.shape(x2)) + ndim_out = F.tuple_len(shape_out) + x1 = _expand(x1, ndim_out) + x2 = _expand(x2, ndim_out) + x1 = _broadcast_to(x1, F.shape(x1), shape_out, ndim_out) + x2 = _broadcast_to(x2, F.shape(x2), shape_out, ndim_out) + return _apply_tensor_op(F.tensor_mul, x1, x2, out=out, where=where, dtype=dtype) + + +def divide(x1, x2, out=None, where=True, dtype=None): + """ + Returns a true division of the inputs, element-wise. + + Instead of the Python traditional ‘floor division’, this returns a true + division. + + Note: + Numpy arguments casting, order, dtype, subok, signature, and extobj are + not supported. + When argument where is provided, argument out must have a tensor value. + Argument out is not supported for storing the result, however it can be + used in combination with argument where to set the value at indices for + which where is set to False. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, np.float64, + np.int16, np.int32, and np.int64. + + Args: + x1 (Tensor): the divident. + x2 (Tensor): the divisor. + out (Tensor or None): optional, defaults to None. + where (Tensor or None): optional. For any non-default value of type other + than Tensor or None, the output retains its original value. + This condition is broadcast over the input. At locations where the + condition is True, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default out=None, + locations within it where the condition is False will remain + uninitialized. + dtype (data type): optional, defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, this is a scalar if both x1 and x2 are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.full((3, 2), [1, 2]) + >>> x2 = np.full((3, 2), [3, 4]) + >>> output = np.divide(x1, x2) + >>> print(output) + [[0.33333333, 0.5], + [0.33333333, 0.5], + [0.33333333, 0.5]] + """ + if not _check_is_float(F.dtype(x1)) and not _check_is_float(F.dtype(x2)): + x1 = F.cast(x1, mstype.float32) + x2 = F.cast(x2, mstype.float32) + return _apply_tensor_op(F.tensor_div, x1, x2, out=out, where=where, dtype=dtype) + + +def power(x1, x2, out=None, where=True, dtype=None): + """ + First array elements raised to powers from second array, element-wise. + + Raises each base in x1 to the positionally-corresponding power in x2. + + Note: + Numpy arguments casting, order, dtype, subok, signature, and extobj are + not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, np.float64, + np.int16, np.int32, and np.int64. + + Args: + x1 (Tensor): the bases. + x2 (Tensor): the exponenets. + out (Tensor or None): optional, defaults to None. + where (Tensor or None): optional. For any non-default value of type other + than Tensor or None, the output retains its original value. + This condition is broadcast over the input. At locations where the + condition is True, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default out=None, + locations within it where the condition is False will remain + uninitialized. + dtype (data type): optional, defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the bases in x1 raised to the exponents in x2. This + is a scalarif both x1 and x2 are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.full((3, 2), [1, 2]) + >>> x2 = np.full((3, 2), [3, 4]) + >>> output = np.power(x1, x2) + >>> print(output) + [[ 1, 16], + [ 1, 16], + [ 1, 16]] + """ + return _apply_tensor_op(F.tensor_pow, x1, x2, out=out, where=where, dtype=dtype) def mean(a, axis=None, keepdims=False): @@ -31,8 +375,9 @@ def mean(a, axis=None, keepdims=False): Note: Numpy arguments dtype and out are not supported. - On GPU, the supported dtypes are mstype.float16, and mstype.float32. - On CPU, the supported dtypes are mstype.float16, and mstype.float32. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and + np.float64. Args: a (Tensor): input tensor containing numbers whose mean is desired. @@ -63,17 +408,29 @@ def mean(a, axis=None, keepdims=False): >>> print(output) 2.5 """ - axis = _check_axis_valid(axis, P.Rank()(a)) - if _is_empty(F.shape(a)): - return _nan() - if _is_scalar(a.shape): + + axis = _check_axis_valid(axis, F.rank(a)) + shape_a = F.shape(a) + + if _is_empty(shape_a): + if keepdims: + shape_out = _shape_reduced_keepdims(shape_a, axis) + else: + shape_out = _shape_reduced(shape_a, axis) + if _is_empty(shape_out): + return _empty(F.dtype(a), shape_out) + return _full_compile(shape_out, nan) + + if _is_scalar(shape_a): if keepdims: return a - return squeeze(a) + shape_out = _shape_reduced(shape_a, axis) + return F.reshape(a, shape_out) + if keepdims: - res = P.ReduceMean(True)(a, axis) + res = _mean_keepdims(a, axis) else: - res = P.ReduceMean(False)(a, axis) + res = _mean_default(a, axis) return res @@ -87,8 +444,9 @@ def inner(a, b): Note: Numpy argument out is not supported. - On GPU, the supported dtypes are mstype.float16, and mstype.float32. - On CPU, the supported dtype is mstype.float32. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and + np.float64. Args: a (Tensor): input tensor. If a and b are nonscalar, their last @@ -127,19 +485,21 @@ def inner(a, b): [[3. 3. 3. 3. 3. 3. 3.] [3. 3. 3. 3. 3. 3. 3.]]] """ - if P.Rank()(a) == 0 or P.Rank()(b) == 0: - if _is_scalar(a.shape): + if F.rank(a) == 0 or F.rank(b) == 0: + a = _expand(a, 1) + b = _expand(b, 1) + if F.rank(a) < F.rank(b): a, b = b, a - return _apply_bin_op(P.Mul(), a, b) + return F.tensor_mul(a, b) - _ = _check_shape_aligned(a.shape, b.shape) - aligned_shape_a = (F.shape_mul(a.shape[:-1]), a.shape[-1]) - aligned_shape_b = (F.shape_mul(b.shape[:-1]), a.shape[-1]) - a_aligned = P.Reshape()(a, aligned_shape_a) - b_aligned = P.Reshape()(b, aligned_shape_b) + _ = _check_shape_aligned(F.shape(a), F.shape(b)) + aligned_shape_a = (F.shape_mul(F.shape(a)[:-1]), F.shape(a)[-1]) + aligned_shape_b = (F.shape_mul(F.shape(b)[:-1]), F.shape(a)[-1]) + a_aligned = F.reshape(a, aligned_shape_a) + b_aligned = F.reshape(b, aligned_shape_b) - res = P.MatMul(False, True)(a_aligned, b_aligned) - res = P.Reshape()(res, a.shape[:-1] + b.shape[:-1]) + res = _matmul_T(a_aligned, b_aligned) + res = F.reshape(res, F.shape(a)[:-1] + F.shape(b)[:-1]) return res @@ -149,28 +509,245 @@ def _nan(): return asarray(float('nan')) -def _is_empty(shape): - """Checks if the shape is empty""" - return F.shape_mul(shape) == 0 +def dot(a, b): + """ + Dot product of two arrays. + + Specifically, + If both a and b are 1-D arrays, it is inner product of vectors + (without complex conjugation). + If both a and b are 2-D arrays, it is matrix multiplication. + If either a or b is 0-D (scalar), it is equivalent to multiply. + If a is an N-D array and b is a 1-D array, it is a sum product + over the last axis of a and b. + If a is an N-D array and b is an M-D array (where M>=2), it is a + sum product over the last axis of a and the second-to-last axis of b: + dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) + + Note: + Numpy argument out is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and + np.float64. + + Args: + a (Tensor): input tensor + b (Tensor): input tensor + + Returns: + Tensor or scalar, the dot product of a and b. If a and b are + both scalars or both 1-D arrays then a scalar is returned; + otherwise an array is returned + + Raises: + ValueError: If the last dimension of a is not the same size + as the second-to-last dimension of b. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.full((1, 3), 7) + >>> b = np.full((2, 3, 4), 5) + >>> output = np.dot(a, b) + >>> print(output) + [[[105, 105, 105, 105], + [105, 105, 105, 105]]] + """ + ndim_a, ndim_b = F.rank(a), F.rank(b) + if ndim_a > 0 and ndim_b >= 2: + perm = F.make_range(ndim_b) + perm = perm[:-2] + (perm[-1],) + (perm[-2],) + b = F.transpose(b, perm) + return inner(a, b) -def _expand(x, ndim): - """Expand x to ndim""" - while P.Rank()(x) < ndim: - x = P.ExpandDims()(x, 0) - return x +def outer(a, b): + """ + Computes the outer product of two vectors. + + Given two vectors, a = [a0, a1, ..., aM] and b = [b0, b1, ..., bN], + the outer product [1] is: + [[a0*b0 a0*b1 ... a0*bN ] + [a1*b0 . + [ ... . + [aM*b0 aM*bN ]] + + Note: + Numpy argument out is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and + np.float64. + + Args: + a (Tensor): first input vector. Input is flattened if not + already 1-dimensional. + b (Tensor): second input vector. Input is flattened if not + already 1-dimensional. + + Returns: + Tensor or scalar, out[i, j] = a[i] * b[j]. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.full(7, 2) + >>> b = np.full(4, 3) + >>> output = np.outer(a, b) + >>> print(output) + [[6, 6, 6, 6], + [6, 6, 6, 6], + [6, 6, 6, 6], + [6, 6, 6, 6], + [6, 6, 6, 6], + [6, 6, 6, 6], + [6, 6, 6, 6]] + """ + _check_input_tensor(F.typeof(a)) + _check_input_tensor(F.typeof(b)) + + if F.rank(a) != 1: + a = ravel(a) + if F.rank(b) != 1: + b = ravel(b) + a = F.reshape(a, (F.shape(a)[0], 1)) + b = _expand(b, 2) + return _matmul(a, b) -def _apply_bin_op(fn, x1, x2): - """apply binary operations based on fn.""" - device = _get_device_compile() - out_shape = _infer_out_shape(device, x1.shape, x2.shape) - if device == 'CPU': - # built-in operations on CPU does not support operands with - # dimensions of size 1 or with shape 0, therefore squeeze - # and scalar promotion is performed - x1, x2 = squeeze(x1), squeeze(x2) - x1, x2 = _expand(x1, 1), _expand(x2, 1) - res = fn(x1, x2) - res = P.Reshape()(res, out_shape) - return res +def tensordot(a, b, axes=2): + """ + Computes tensor dot product along specified axes. + + Given two tensors, a and b, and an array_like object containing two array_like + objects, (a_axes, b_axes), sum the products of a’s and b’s elements (components) + over the axes specified by a_axes and b_axes. The third argument can be a single + non-negative integer_like scalar, N; if it is such, then the last N dimensions of + a and the first N dimensions of b are summed over. + Three common use cases are: + axes = 0 : tensor product + axes = 1 : tensor dot product + axes = 2 : (default) tensor double contraction + When axes is integer_like, the sequence for evaluation will be: first the -Nth + axis in a and 0th axis in b, and the -1th axis in a and Nth axis in b last. + When there is more than one axis to sum over - and they are not the last (first) + axes of a (b) - the argument axes should consist of two sequences of the same + length, with the first axis to sum over given first in both sequences, the second + axis second, and so forth. + The shape of the result consists of the non-contracted axes of the first tensor, + followed by the non-contracted axes of the second. + + Note: + On CPU, the supported dypes are np.float16 and np.float32. + On GPU, the supported dypes are np.float16 and np.float32. + + Args: + a, b (Tensor): Tensors to “dot”. + axes (int or (2,) array_like): + integer_like: If an int N, sum over the last N axes of a and the first N + axes of b in order. The sizes of the corresponding axes must match. + (2,) array_like: Or, a list of axes to be summed over, first sequence + applying to a, second to b. Both elements array_like must be of the same + length. + + Returns: + Tensor, or list of tensors, the tensor dot product of the input. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.ones((3, 4, 5)) + >>> b = np.ones((4, 3, 2)) + >>> output = np.tensordot(a, b, axes=([1,0],[0,1])) + >>> print(output.shape) + (5, 2) + """ + _check_input_tensor(F.typeof(a)) + _check_input_tensor(F.typeof(b)) + + if F.rank(a)*F.rank(b) == 0 and axes == 0: + return F.tensor_mul(a, b) + return C.tensor_dot(a, b, axes) + + +@constexpr +def _full_compile(shape, value): + return full(shape, value) + + +@constexpr +def _shape_reduced_keepdims(shape, axes): + """ + Reduces dimensions corresponding to argument axes while + keeping the number of dimensions unchanged. + """ + ndim_out = F.tuple_len(shape) + shape_out = [1]*ndim_out + for i in range(ndim_out): + if not i in axes: + shape_out[i] = shape[i] + return tuple(shape_out) + + +@constexpr +def _shape_reduced(shape, axes): + """Removes dimensions corresponding to argument axes""" + ndim_orig = F.tuple_len(shape) + ndim_out = ndim_orig - F.tuple_len(axes) + shape_out = [0]*ndim_out + idx_out = 0 + for i in range(ndim_orig): + if not i in axes: + shape_out[idx_out] = shape[i] + idx_out += 1 + return tuple(shape_out) + + +def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): + """Infers the shape of the last two dimensions after performing matmul.""" + shape_rem = () + if ndim1 >= 2: + shape_rem += (shape1[-2],) + if transpose_b: + if ndim2 >= 2: + shape_rem += (shape2[-2],) + else: + if ndim1 >= 1: + shape_rem += (shape2[-1],) + return shape_rem + + +def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None): + """applies tensor operations based on fn""" + for arg in args: + _check_input_tensor(F.typeof(arg)) + res = fn(*args) + + # if out is set to a non-default value, return tensor will have the same + # dtype as out, which overrides the dtype passed into the keyword argument + if _check_is_tensor(F.typeof(out)): + dtype_out = F.dtype(out) + elif dtype is not None: + dtype_out = dtype + else: + dtype_out = F.dtype(res) + + if _check_is_tensor(F.typeof(out)) and _check_is_tensor(F.typeof(where)): + out = where_(where, res, out) + elif out is None or where is not None: + out = res + + if not _check_same_type(F.dtype(out), dtype_out): + out = F.cast(out, dtype_out) + + return out diff --git a/mindspore/numpy/utils.py b/mindspore/numpy/utils.py index 76bd383b12e..9508c90cc2e 100644 --- a/mindspore/numpy/utils.py +++ b/mindspore/numpy/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,126 +13,14 @@ # limitations under the License. # ============================================================================ """internal utility functions""" -from functools import partial import numpy as onp import mindspore.context as context from ..common import Tensor -from ..ops import operations as P from ..ops import functional as F -from ..ops.primitive import constexpr -from ..common import dtype as mstype -from .dtypes import dtype_tuple, all_types, dtype_map - -@constexpr -def _check_shape_compile(shape): - """check the shape param to match the numpy style inside the graph""" - if not isinstance(shape, (int, tuple, list)): - raise TypeError( - f"only int, tuple and list are allowed for shape, but got {type(shape)}") - if isinstance(shape, int): - shape = (shape,) - if isinstance(shape, list): - shape = tuple(shape) - return shape - - -@constexpr -def _check_is_int(x): - """Check the type of x is int.""" - if isinstance(x, int): - return True - raise TypeError(f"integer argument expected, but got {type(x)}.") - - -@constexpr -def _check_start_normalize(start, ndim): - """check and normalize start argument for rollaxis.""" - if start < -ndim or start > ndim: - raise ValueError( - f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.") - if start < 0: - start = start + ndim - return start - - -@constexpr -def _check_axes_range(axes, ndim): - """ - Check axes are within the number of dimensions of tensor x and normalize the negative axes. - Args: - axes (Union[int, tuple(int), list(int)]): Axes of the tensor. - ndim (int): The number of dimensions of the tensor. - Return: - Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple. - """ - if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list): - raise TypeError( - f"int, tuple(int) or list(int) expected, but got {type(axes)}.") - low = -ndim - up = ndim - 1 - if low > up: - raise ValueError( - f"Lower bound {low} and upper bound {up} of axes are not allowed.") - if isinstance(axes, int): - if axes < low or axes > up: - raise ValueError( - f"axis {axes} is out of bounds for tensor of dimension {ndim}.") - return axes if axes >= 0 else axes + ndim - new_axes = [] - for item in axes: - if not isinstance(item, int): - raise TypeError( - f"int in tuple or list expected, but got {type(item)}.") - if item < low or item > up: - raise ValueError( - f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.") - new_axes.append(item if item >= 0 else item + ndim) - return tuple(new_axes) - - -def _check_shape_contain_zero(shp): - """Check whether shape contains 0""" - if isinstance(shp, int): - return shp == 0 - if isinstance(shp, (list, tuple)): - for s in shp: - if s == 0: - return True - return False - - -def _check_shape(shape): - """check the shape param to match the numpy style outside the graph""" - if not isinstance(shape, (int, tuple, list)): - raise TypeError( - f"only int, tuple and list are allowed for shape, but got {type(shape)}") - if isinstance(shape, int): - shape = (shape,) - if isinstance(shape, list): - shape = tuple(shape) - return shape - - -def _check_dtype(dtype): - """check the input dtype and make conversions""" - # convert the string dtype to mstype.dtype - if isinstance(dtype, str): - dtype = dtype.lower() - dtype = dtype_map[dtype] - elif isinstance(dtype, type): - if dtype is int: - dtype = mstype.int32 - if dtype is float: - dtype = mstype.float32 - if dtype is bool: - dtype = mstype.bool_ - if dtype not in dtype_tuple: - raise TypeError( - f"only {all_types} are allowed for dtype, but got {type(dtype)}") - return dtype +from .utils_const import _tile_size def _deep_list(array_like): @@ -170,15 +58,8 @@ def _check_input_for_asarray(array_like): """check whether array_like argument is a valid type for np.asarray conversion""" if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): return True - raise TypeError( - "input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ - f"or numpy.ndarray, but got {type(array_like)}") - - -def _cast_to(array, dtype): - """cast the input to specified dtype""" - cast = P.Cast() - return cast(array, dtype) + raise TypeError("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ + f"or numpy.ndarray, but got {type(array_like)}") def _is_scalar(shape): @@ -186,10 +67,9 @@ def _is_scalar(shape): return F.shape_mul(shape) == 1 -@constexpr -def _get_device_compile(): - """Get the current device (`GPU`, `CPU`, `Ascend`)""" - return context.get_context('device_target') +def _is_empty(shape): + """Checks if the shape is empty""" + return F.shape_mul(shape) == 0 def _get_device(): @@ -199,10 +79,12 @@ def _get_device(): def _covert_list_tensor_to_tuple_tensor(list_of_tensor): """Convert a list of tensor to a tuple of tensor""" - tuple_of_tensor = () - for tensor in list_of_tensor: - tuple_of_tensor += (tensor,) - return tuple_of_tensor + if isinstance(list_of_tensor, list): + tuple_of_tensor = () + for tensor in list_of_tensor: + tuple_of_tensor += (tensor,) + return tuple_of_tensor + return list_of_tensor def _get_mode(): @@ -210,145 +92,14 @@ def _get_mode(): return context.get_context('mode') -@constexpr -def _reverse_index(idx, arr): - """ - Returns 1 if shape[idx:] is broadcastable to shape_out[idx:], - 2 situations if the function returns 1: - - 1. Tensor's shape has 1 at the designated dimension. - - 2. Tensor's dimension is less than the designated idx. (The Tensor shape - has been reversed) - For both cases, 2 tensors are broadcastable. - otherwise returns the element at position of shape - """ - if len(arr) <= idx: - return 1 - return arr[-1 - idx] +def _expand(x, ndim, axis=0): + """Expand x to ndim.""" + while F.rank(x) < ndim: + x = F.expand_dims(x, axis) + return x -@constexpr -def _infer_out_shape(device, *shapes): - """ - Returns shape of output after broadcasting - Raises ValueError if shape1 and shape2 cannot be broadcast - """ - shapes_unbroadcastable = False - cpu_shapes_different = False - contains_scalar = any(_is_scalar(shape) for shape in shapes) - ndim_max = max(map(len, shapes)) - shape_out = [0]*ndim_max - i = 0 - for i in range(ndim_max): - shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes)) - for shape in shapes: - if _reverse_index(i, shape) != shape_out[-1 - i]: - if _reverse_index(i, shape) != 1: - shapes_unbroadcastable = True - if device == 'CPU' and not contains_scalar: - cpu_shapes_different = True - if not shapes_unbroadcastable and not cpu_shapes_different: - return tuple(shape_out) - if shapes_unbroadcastable: - raise ValueError( - f'operands could not be broadcast together with shapes {*shapes,}') - raise ValueError('broadcasting is currently not supported on CPU. Non-scalar' + \ - f'operands must have the same shape, but got {*shapes,}') - - -@constexpr -def _check_axis_in_range(axis, ndim): - """Checks axes are with the bounds of ndim""" - if -ndim <= axis < ndim: - return True - raise ValueError( - f'axis {axis} is out of bounds for array of dimension {ndim}') - - -@constexpr -def _check_axis_valid(axes, ndim): - """ - Checks axes are valid given ndim, and returns axes that can be passed - to the built-in operator (non-negative, int or tuple) - """ - if isinstance(axes, int): - _ = _check_axis_in_range(axes, ndim) - return (axes % ndim,) - if isinstance(axes, tuple): - for axis in axes: - _ = _check_axis_in_range(axis, ndim) - axes = tuple(map(lambda x: x % ndim, axes)) - if all(axes.count(el) <= 1 for el in axes): - return axes - if axes is None: - axes = F.make_range(ndim) - return axes - raise ValueError('duplicate value in \'axis\'') - - -@constexpr -def _check_shape_aligned(shape1, shape2): - """Checks shape1 and shape2 are valid shapes to perform inner product""" - if shape1[-1] == shape2[-1]: - return True - raise ValueError( - f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') - - -@constexpr -def _check_dim_cpu(shape, bound): - """Checks input shape is upper-bounded by parameter bound""" - ndim = len(shape) - if _is_scalar(shape): - return True - if ndim <= bound: - return True - raise ValueError( - f'dimension {ndim} larger than {bound} is not supported on CPU') - - -@constexpr -def _tile_size(shape, out_shape, ndim): - """Returns tile_size such that shape*tile_size = out_shape""" - size = [1]*ndim - for idx, (i, j) in enumerate(zip(shape, out_shape)): - if i != j: - size[idx] = j - return tuple(size) - - -@constexpr -def _check_core_match(shape1, shape2): - """Checks shape1 and shape2 are valid shapes to perform matmul""" - ndim1, ndim2 = len(shape1), len(shape2) - if ndim1 < 1 or ndim2 < 2: - return True - if shape1[-1] == shape2[-2]: - return True - raise ValueError(f'mismatch in core dimension of input operands (size {shape1[-1]} ' + - f'is different from {shape2[-2]})') - - -@constexpr -def _cpu_not_support(name): - """Checks if a function not supported on cpu is executed on cpu device""" - if _get_device() != 'CPU': - return True - raise ValueError(f'{name} is not supported on CPU') - - -@constexpr -def _check_is_tuple(obj): - """Check whether obj is a tuple""" - return isinstance(obj, mstype.Tuple) - - -@constexpr -def _check_is_list(obj): - """Check whether obj is a list""" - return isinstance(obj, mstype.List) - - -@constexpr -def _check_is_tensor(obj): - """Check whether obj is a tensor""" - return isinstance(obj, mstype.tensor_type) +def _broadcast_to(x, shape_cur, shape_to, ndim_to): + """Broadcasts x from shape_cur to shape_to.""" + size = _tile_size(shape_cur, shape_to, ndim_to) + return F.tile(x, size) diff --git a/mindspore/numpy/utils_const.py b/mindspore/numpy/utils_const.py new file mode 100644 index 00000000000..faab3900737 --- /dev/null +++ b/mindspore/numpy/utils_const.py @@ -0,0 +1,364 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""internal graph-compatible utility functions""" +from functools import partial + +import numpy as onp + +import mindspore.context as context +from ..common import Tensor +from ..ops import functional as F +from ..ops.primitive import constexpr +from ..common import dtype as mstype +from .._c_expression import Tensor as Tensor_ +from .._c_expression.typing import Tuple, List + +from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map + + +@constexpr +def _check_shape(shape): + """check the shape param to match the numpy style""" + if not isinstance(shape, (int, tuple, list, Tuple, List)): + raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}") + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, (list, List)): + shape = tuple(shape) + return shape + + +@constexpr +def _check_dtype(dtype): + """check the input dtype and make conversions""" + # convert the string dtype to mstype.dtype + if isinstance(dtype, str): + dtype = dtype.lower() + dtype = dtype_map[dtype] + elif isinstance(dtype, type): + if dtype is int: + dtype = mstype.int32 + elif dtype is float: + dtype = mstype.float32 + else: + dtype = mstype.pytype_to_dtype(dtype) + if dtype not in dtype_tuple: + raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}") + return dtype + + +@constexpr +def _check_shape_contain_zero(shp): + """Check whether shape contains zero""" + if isinstance(shp, int): + return shp == 0 + return F.shape_mul(shp) == 0 + + +@constexpr +def _check_start_normalize(start, ndim): + """check and normalize start argument for rollaxis.""" + if start < -ndim or start > ndim: + raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.") + if start < 0: + start = start + ndim + return start + + +@constexpr +def _check_axes_range(axes, ndim): + """ + Check axes are within the number of dimensions of tensor x and normalize the negative axes. + Args: + axes (Union[int, tuple(int), list(int)]): Axes of the tensor. + ndim (int): The number of dimensions of the tensor. + Return: + Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple. + """ + if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list): + raise TypeError(f"int, tuple(int) or list(int) expected, but got {type(axes)}.") + low = -ndim + up = ndim - 1 + if low > up: + raise ValueError(f"Lower bound {low} and upper bound {up} of axes are not allowed.") + if isinstance(axes, int): + if axes < low or axes > up: + raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}.") + return axes if axes >= 0 else axes + ndim + new_axes = [] + for item in axes: + if not isinstance(item, int): + raise TypeError(f"int in tuple or list expected, but got {type(item)}.") + if item < low or item > up: + raise ValueError(f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.") + new_axes.append(item if item >= 0 else item + ndim) + return tuple(new_axes) + + +@constexpr +def _get_device_compile(): + """Get the current device (`GPU`, `CPU`, `Ascend`)""" + return context.get_context('device_target') + + +@constexpr +def _reverse_index(idx, arr): + """ + Returns 1 if shape[idx:] is broadcastable to shape_out[idx:], + 2 situations if the function returns 1: + - 1. Tensor's shape has 1 at the designated dimension. + - 2. Tensor's dimension is less than the designated idx. (The Tensor shape + has been reversed) + For both cases, 2 tensors are broadcastable. + otherwise returns the element at position of shape + """ + if len(arr) <= idx: + return 1 + return arr[-1 - idx] + + +@constexpr +def _infer_out_shape(*shapes): + """ + Returns shape of output after broadcasting + Raises ValueError if shape1 and shape2 cannot be broadcast + """ + shapes_unbroadcastable = False + ndim_max = max(map(len, shapes)) + shape_out = [0]*ndim_max + i = 0 + for i in range(ndim_max): + shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes)) + for shape in shapes: + if _reverse_index(i, shape) != shape_out[-1 - i]: + if _reverse_index(i, shape) != 1: + shapes_unbroadcastable = True + break + if shapes_unbroadcastable: + break + if not shapes_unbroadcastable: + return tuple(shape_out) + raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') + + +@constexpr +def _check_axis_in_range(axis, ndim): + """Checks axes are with the bounds of ndim""" + if -ndim <= axis < ndim: + return True + raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') + + +@constexpr +def _check_axis_valid(axes, ndim): + """ + Checks axes are valid given ndim, and returns axes that can be passed + to the built-in operator (non-negative, int or tuple) + """ + if isinstance(axes, int): + _ = _check_axis_in_range(axes, ndim) + return (axes % ndim,) + if isinstance(axes, tuple): + for axis in axes: + _ = _check_axis_in_range(axis, ndim) + axes = tuple(map(lambda x: x % ndim, axes)) + if all(axes.count(el) <= 1 for el in axes): + return axes + if axes is None: + axes = F.make_range(ndim) + return axes + raise ValueError('duplicate value in \'axis\'') + + +@constexpr +def _check_shape_aligned(shape1, shape2): + """Checks shape1 and shape2 are valid shapes to perform inner product""" + if shape1[-1] == shape2[-1]: + return True + raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') + + +@constexpr +def _tile_size(shape, out_shape, ndim): + """Returns tile_size such that shape*tile_size = out_shape""" + size = [1]*ndim + for idx, (i, j) in enumerate(zip(shape, out_shape)): + if i != j: + size[idx] = j + return tuple(size) + + +@constexpr +def _check_is_int(obj): + """Check whether obj is an integer.""" + return isinstance(obj, int) + + +@constexpr +def _check_is_tuple(obj): + """Check whether obj is a tuple""" + return isinstance(obj, (tuple, Tuple)) + + +@constexpr +def _check_is_list(obj): + """Check whether obj is a list""" + return isinstance(obj, (list, List)) + + +@constexpr +def _check_is_tensor(obj): + """Check whether obj is a tensor""" + return isinstance(obj, mstype.tensor_type) + + +@constexpr +def _raise_type_error(info, param=None): + """ + Raise TypeError in both graph/pynative mode + + Args: + info(str): info string to display + param(python obj): any object that can be recognized by graph mode. If is + not None, then param's type information will be extracted and displayed. + Default is None. + """ + if param is None: + raise TypeError(info) + raise TypeError(info + f"{type(param)}") + + +@constexpr +def _raise_value_error(info, param=None): + """ + Raise TypeError in both graph/pynative mode + + Args: + info(str): info string to display + param(python obj): any object that can be recognized by graph mode. If is + not None, then param's value information will be extracted and displayed. + Default is None. + """ + if param is None: + raise ValueError(info) + raise ValueError(info + f"{param}") + + +@constexpr +def _empty(dtype, shape): + """Returns an uninitialized array with dtype and shape.""" + return Tensor_(dtype, shape) + + +def _get_index_for_unique(input_x, unique_x): + """ + Return the indices of the first occurrences of the unique values in the original array. + + Args: + input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`. + unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order. + + Returns: + Tensor. The indices of the unique values in the original array. Has the same shape as `unique_x`. + """ + o_array = input_x.asnumpy() + dic = {} + for idx in range(o_array.size): + val = o_array[idx] + if val not in dic: + dic[val] = idx + + index_lst = [] + u_array = unique_x.asnumpy() + for idx in range(u_array.size): + index_lst.append(dic[u_array[idx]]) + + return Tensor(onp.array(index_lst), input_x.dtype) + + +@constexpr +def _get_counts_for_unique(input_x, unique_x): + """ + Return the number of times each of the unique values comes up in the original tensor. + + Args: + input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`. + unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order. + + Returns: + Tensor. The number of times each of the unique values comes up in the original tensor. + """ + dic = {} + o_array = input_x.asnumpy() + for idx in range(o_array.size): + val = o_array[idx] + if val not in dic: + dic[val] = 1 + else: + dic[val] += 1 + + u_array = unique_x.asnumpy() + counts_lst = [dic[val] for val in u_array] + + return Tensor(onp.array(counts_lst), input_x.dtype) + + +@constexpr +def _get_max_value(x): + """Returns the maximum value of the input tensor `x`. """ + return int(max(x.asnumpy())) + + +@constexpr +def _promote(dtype1, dtype2): + if dtype1 == dtype2: + return dtype1 + if (dtype1, dtype2) in promotion_rule: + return promotion_rule[dtype1, dtype2] + return promotion_rule[dtype2, dtype1] + + +@constexpr +def _max(*args): + """Returns the maximum value.""" + return max(*args) + + +@constexpr +def _min(*args): + """"Returns the minimum value.""" + return min(*args) + + +@constexpr +def _abs(arg): + """Returns the absolute value.""" + return abs(arg) + + +@constexpr +def _check_same_type(dtype1, dtype2): + return dtype1 == dtype2 + + +@constexpr +def _check_is_float(dtype): + return dtype in mstype.float_type + + +@constexpr +def _check_input_tensor(input_type): + if not _check_is_tensor(input_type): + raise TypeError(f'expect Tensor, but got {input_type}') diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index d02369b2a68..5e5523c6106 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -1,6 +1,6 @@ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ isconstant.set_const_prim(True) issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() +eye = P.Eye() fill = P.Fill() tile = P.Tile() select = P.Select() @@ -45,6 +46,7 @@ control_depend = P.ControlDepend() merge = P.Merge() geswitch = P.GeSwitch() addn = P.AddN() +absolute = P.Abs() tensor_add = P.TensorAdd() neg_tensor = P.Neg() tensor_lt = P.Less() @@ -67,6 +69,8 @@ assign_add = P.AssignAdd() assign = P.Assign() square = P.Square() sqrt = P.Sqrt() +reduce_sum = P.ReduceSum() +tensor_slice = P.Slice() scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() @@ -74,6 +78,8 @@ tuple_to_array = P.TupleToArray() scalar_cast = P.ScalarCast() print_ = P.Print() expand_dims = P.ExpandDims() +transpose = P.Transpose() +squeeze = P.Squeeze() scatter_nd = P.ScatterNd() gather = P.GatherV2() gather_nd = P.GatherNd() @@ -177,6 +183,7 @@ tensor_operator_registry.register('any', P.ReduceAny) tensor_operator_registry.register('abs', P.Abs) tensor_operator_registry.register('mean', P.ReduceMean) tensor_operator_registry.register('reshape', P.Reshape) +tensor_operator_registry.register('transpose', P.Transpose) tensor_operator_registry.register('broadcast_to', P.BroadcastTo) # ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) @@ -187,6 +194,7 @@ tensor_operator_registry.register('__le__', tensor_le) tensor_operator_registry.register('__gt__', tensor_gt) tensor_operator_registry.register('__ge__', tensor_ge) tensor_operator_registry.register('shape', shape) +tensor_operator_registry.register('squeeze', squeeze) # support GE backend for no compare operators tensor_operator_registry.register('cast', cast) diff --git a/tests/ut/python/numpy_native/__init__.py b/tests/st/numpy_native/__init__.py similarity index 94% rename from tests/ut/python/numpy_native/__init__.py rename to tests/st/numpy_native/__init__.py index 43327c33547..449438f4001 100644 --- a/tests/ut/python/numpy_native/__init__.py +++ b/tests/st/numpy_native/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/st/numpy_native/test_array_creations.py b/tests/st/numpy_native/test_array_creations.py new file mode 100644 index 00000000000..635b3d67f48 --- /dev/null +++ b/tests/st/numpy_native/test_array_creations.py @@ -0,0 +1,730 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""unit tests for numpy array operations""" + +import functools + +import pytest +import numpy as onp + +import mindspore.context as context +import mindspore.numpy as mnp + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class Cases(): + def __init__(self): + self.all_shapes = [ + 0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3] + ] + self.onp_dtypes = [onp.int32, 'int32', int, + onp.float32, 'float32', float, + onp.uint32, 'uint32', + onp.bool_, 'bool', bool] + + self.mnp_dtypes = [mnp.int32, 'int32', int, + mnp.float32, 'float32', float, + mnp.uint32, 'uint32', + mnp.bool_, 'bool', bool] + + self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,), + [(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member + (100, 100)).astype(onp.float32), + onp.random.random((100, 100)).astype(onp.bool)] + + self.arrs = [ + rand_int(2), + rand_int(2, 3), + rand_int(2, 3, 4), + rand_int(2, 3, 4, 5), + ] + + # scalars expanded across the 0th dimension + self.scalars = [ + rand_int(), + rand_int(1), + rand_int(1, 1), + rand_int(1, 1, 1), + ] + + # arrays of the same size expanded across the 0th dimension + self.expanded_arrs = [ + rand_int(2, 3), + rand_int(1, 2, 3), + rand_int(1, 1, 2, 3), + rand_int(1, 1, 1, 2, 3), + ] + + # arrays with dimensions of size 1 + self.nested_arrs = [ + rand_int(1), + rand_int(1, 2), + rand_int(3, 1, 8), + rand_int(1, 3, 9, 1), + ] + + # arrays which can be broadcast + self.broadcastables = [ + rand_int(5), + rand_int(6, 1), + rand_int(7, 1, 5), + rand_int(8, 1, 6, 1) + ] + + # boolean arrays which can be broadcast + self.bool_broadcastables = [ + rand_bool(), + rand_bool(1), + rand_bool(5), + rand_bool(6, 1), + rand_bool(7, 1, 5), + rand_bool(8, 1, 6, 1), + ] + + self.mnp_prototypes = [ + mnp.ones((2, 3, 4)), + mnp.ones((0, 3, 0, 2, 5)), + onp.ones((2, 7, 0)), + onp.ones(()), + [mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], + ([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])), + ] + + self.onp_prototypes = [ + onp.ones((2, 3, 4)), + onp.ones((0, 3, 0, 2, 5)), + onp.ones((2, 7, 0)), + onp.ones(()), + [onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], + ([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])), + ] + + +def match_array(actual, expected, error=0): + if error > 0: + onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), + decimal=error) + else: + onp.testing.assert_equal(actual.tolist(), expected.tolist()) + + +def check_all_results(onp_results, mnp_results, error=0): + """Check all results from numpy and mindspore.numpy""" + for i, _ in enumerate(onp_results): + match_array(onp_results[i], mnp_results[i].asnumpy()) + + +def check_all_unique_results(onp_results, mnp_results): + """ + Check all results from numpy and mindspore.numpy. + + Args: + onp_results (Union[tuple of numpy.arrays, numpy.array]) + mnp_results (Union[tuple of Tensors, Tensor]) + """ + for i, _ in enumerate(onp_results): + if isinstance(onp_results[i], tuple): + for j in range(len(onp_results[i])): + match_array(onp_results[i][j], + mnp_results[i][j].asnumpy(), error=7) + else: + match_array(onp_results[i], mnp_results[i].asnumpy(), error=7) + + +def run_non_kw_test(mnp_fn, onp_fn): + """Run tests on functions with non keyword arguments""" + test_case = Cases() + for i in range(len(test_case.arrs)): + arrs = test_case.arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.scalars)): + arrs = test_case.scalars[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.expanded_arrs)): + arrs = test_case.expanded_arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.nested_arrs)): + arrs = test_case.nested_arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + +def rand_int(*shape): + """return an random integer array with parameter shape""" + res = onp.random.randint(low=1, high=5, size=shape) + if isinstance(res, onp.ndarray): + return res.astype(onp.float32) + return float(res) + + +# return an random boolean array +def rand_bool(*shape): + return onp.random.rand(*shape) > 0.5 + + +def match_res(mnp_fn, onp_fn, *arrs, **kwargs): + """Checks results from applying mnp_fn and onp_fn on arrs respectively""" + mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) + mnp_res = mnp_fn(*mnp_arrs, **kwargs) + onp_res = onp_fn(*arrs, **kwargs) + match_all_arrays(mnp_res, onp_res) + + +def match_all_arrays(mnp_res, onp_res, error=0): + if isinstance(mnp_res, (tuple, list)): + for actual, expected in zip(mnp_res, onp_res): + match_array(actual.asnumpy(), expected, error) + else: + match_array(mnp_res.asnumpy(), onp_res, error) + + +def match_meta(actual, expected): + # float64 and int64 are not supported, and the defualt type for + # float and int are float32 and int32, respectively + if expected.dtype == onp.float64: + expected = expected.astype(onp.float32) + elif expected.dtype == onp.int64: + expected = expected.astype(onp.int32) + assert actual.shape == expected.shape + assert actual.dtype == expected.dtype + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_asarray(): + test_case = Cases() + for array in test_case.array_sets: + # Check for dtype matching + actual = onp.asarray(array) + expected = mnp.asarray(array).asnumpy() + # Since we set float32/int32 as the default dtype in mindspore, we need + # to make a conversion between numpy.asarray and mindspore.numpy.asarray + if actual.dtype is onp.dtype('float64'): + assert expected.dtype == onp.dtype('float32') + elif actual.dtype is onp.dtype('int64'): + assert expected.dtype == onp.dtype('int32') + else: + assert actual.dtype == expected.dtype + match_array(actual, expected, error=7) + + for i in range(len(test_case.onp_dtypes)): + actual = onp.asarray(array, test_case.onp_dtypes[i]) + expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected, error=7) + + # Additional tests for nested tensor/numpy_array mixture + mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + + actual = onp.asarray(onp_input) + expected = mnp.asarray(mnp_input).asnumpy() + match_array(actual, expected, error=7) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_array(): + # array's function is very similar to asarray, so we mainly test the + # `copy` argument. + test_case = Cases() + for array in test_case.array_sets: + arr1 = mnp.asarray(array) + arr2 = mnp.array(arr1, copy=False) + arr3 = mnp.array(arr1) + arr4 = mnp.asarray(array, dtype='int32') + arr5 = mnp.asarray(arr4, dtype=mnp.int32) + assert arr1 is arr2 + assert arr1 is not arr3 + assert arr4 is arr5 + + # Additional tests for nested tensor/numpy_array mixture + mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + + actual = onp.asarray(onp_input) + expected = mnp.asarray(mnp_input).asnumpy() + match_array(actual, expected, error=7) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_asfarray(): + test_case = Cases() + for array in test_case.array_sets: + # Check for dtype matching + actual = onp.asfarray(array) + expected = mnp.asfarray(array).asnumpy() + # Since we set float32/int32 as the default dtype in mindspore, we need + # to make a conversion between numpy.asarray and mindspore.numpy.asarray + if actual.dtype is onp.dtype('float64'): + assert expected.dtype == onp.dtype('float32') + else: + assert actual.dtype == expected.dtype + match_array(actual, expected, error=7) + + for i in range(len(test_case.onp_dtypes)): + actual = onp.asfarray(array, test_case.onp_dtypes[i]) + expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected, error=7) + + # Additional tests for nested tensor/numpy_array mixture + mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + + actual = onp.asarray(onp_input) + expected = mnp.asarray(mnp_input).asnumpy() + match_array(actual, expected, error=7) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_zeros(): + test_case = Cases() + for shape in test_case.all_shapes: + for i in range(len(test_case.onp_dtypes)): + actual = onp.zeros(shape, test_case.onp_dtypes[i]) + expected = mnp.zeros(shape, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + actual = onp.zeros(shape) + expected = mnp.zeros(shape).asnumpy() + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ones(): + test_case = Cases() + for shape in test_case.all_shapes: + for i in range(len(test_case.onp_dtypes)): + actual = onp.ones(shape, test_case.onp_dtypes[i]) + expected = mnp.ones(shape, test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + actual = onp.ones(shape) + expected = mnp.ones(shape).asnumpy() + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_full(): + actual = onp.full((2, 2), [1, 2]) + expected = mnp.full((2, 2), [1, 2]).asnumpy() + match_array(actual, expected) + + actual = onp.full((2, 0), onp.inf) + expected = mnp.full((2, 0), mnp.inf).asnumpy() + match_array(actual, expected) + + actual = onp.full((2, 3), True) + expected = mnp.full((2, 3), True).asnumpy() + match_array(actual, expected) + + actual = onp.full((3, 4, 5), 7.5) + expected = mnp.full((3, 4, 5), 7.5).asnumpy() + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_eye(): + test_case = Cases() + for i in range(len(test_case.onp_dtypes)): + for m in range(1, 5): + actual = onp.eye(m, dtype=test_case.onp_dtypes[i]) + expected = mnp.eye(m, dtype=test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + for n in range(1, 5): + for k in range(0, 5): + actual = onp.eye(m, n, k, dtype=test_case.onp_dtypes[i]) + expected = mnp.eye( + m, n, k, dtype=test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_identity(): + test_case = Cases() + for i in range(len(test_case.onp_dtypes)): + for m in range(1, 5): + actual = onp.identity(m, dtype=test_case.onp_dtypes[i]) + expected = mnp.identity(m, dtype=test_case.mnp_dtypes[i]).asnumpy() + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arange(): + actual = onp.arange(10) + expected = mnp.arange(10).asnumpy() + match_array(actual, expected) + + actual = onp.arange(0, 10) + expected = mnp.arange(0, 10).asnumpy() + match_array(actual, expected) + + actual = onp.arange(start=10) + expected = mnp.arange(start=10).asnumpy() + match_array(actual, expected) + + actual = onp.arange(start=10, step=0.1) + expected = mnp.arange(start=10, step=0.1).asnumpy() + match_array(actual, expected, error=6) + + actual = onp.arange(10, step=0.1) + expected = mnp.arange(10, step=0.1).asnumpy() + match_array(actual, expected, error=6) + + actual = onp.arange(0.1, 9.9) + expected = mnp.arange(0.1, 9.9).asnumpy() + match_array(actual, expected, error=6) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_linspace(): + actual = onp.linspace(2.0, 3.0, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0).asnumpy() + match_array(actual, expected, error=7) + + actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0, num=5).asnumpy() + match_array(actual, expected, error=7) + + actual = onp.linspace( + 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy() + match_array(actual, expected, error=7) + + actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32) + expected = mnp.linspace(2.0, 3.0, num=5, retstep=True) + match_array(actual[0], expected[0].asnumpy()) + assert actual[1] == expected[1] + + actual = onp.linspace(2.0, [3, 4, 5], num=5, + endpoint=False, dtype=onp.float32) + expected = mnp.linspace( + 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_logspace(): + actual = onp.logspace(2.0, 3.0, dtype=onp.float32) + expected = mnp.logspace(2.0, 3.0).asnumpy() + match_array(actual, expected) + + actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32) + expected = mnp.logspace(2.0, 3.0, num=5).asnumpy() + match_array(actual, expected) + + actual = onp.logspace( + 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) + expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy() + match_array(actual, expected) + + actual = onp.logspace(2.0, [3, 4, 5], num=5, + endpoint=False, dtype=onp.float32) + expected = mnp.logspace( + 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_empty(): + test_case = Cases() + for shape in test_case.all_shapes: + for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes, + test_case.onp_dtypes): + actual = mnp.empty(shape, mnp_dtype).asnumpy() + expected = onp.empty(shape, onp_dtype) + match_meta(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_empty_like(): + test_case = Cases() + for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): + actual = mnp.empty_like(mnp_proto).asnumpy() + expected = onp.empty_like(onp_proto) + assert actual.shape == expected.shape + + for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes, + test_case.onp_dtypes): + actual = mnp.empty_like(mnp_proto, dtype=mnp_dtype).asnumpy() + expected = onp.empty_like(onp_proto, dtype=onp_dtype) + match_meta(actual, expected) + + +def run_x_like(mnp_fn, onp_fn): + test_case = Cases() + for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): + actual = mnp_fn(mnp_proto).asnumpy() + expected = onp_fn(onp_proto) + match_array(actual, expected) + + for shape in test_case.all_shapes: + actual = mnp_fn(mnp_proto, shape=shape).asnumpy() + expected = onp_fn(onp_proto, shape=shape) + match_array(actual, expected) + + for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes, + test_case.onp_dtypes): + actual = mnp_fn(mnp_proto, dtype=mnp_dtype).asnumpy() + expected = onp_fn(onp_proto, dtype=onp_dtype) + match_array(actual, expected) + + actual = mnp_fn(mnp_proto, dtype=mnp_dtype, + shape=shape).asnumpy() + expected = onp_fn(onp_proto, dtype=onp_dtype, shape=shape) + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ones_like(): + run_x_like(mnp.ones_like, onp.ones_like) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_zeros_like(): + run_x_like(mnp.zeros_like, onp.zeros_like) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_full_like(): + test_case = Cases() + for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): + shape = onp.zeros_like(onp_proto).shape + fill_value = rand_int() + actual = mnp.full_like(mnp_proto, fill_value).asnumpy() + expected = onp.full_like(onp_proto, fill_value) + match_array(actual, expected) + + for i in range(len(shape) - 1, 0, -1): + fill_value = rand_int(*shape[i:]) + actual = mnp.full_like(mnp_proto, fill_value).asnumpy() + expected = onp.full_like(onp_proto, fill_value) + match_array(actual, expected) + + fill_value = rand_int(1, *shape[i + 1:]) + actual = mnp.full_like(mnp_proto, fill_value).asnumpy() + expected = onp.full_like(onp_proto, fill_value) + match_array(actual, expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tri_triu_tril(): + x = mnp.ones((16, 32), dtype="bool") + match_array(mnp.tril(x).asnumpy(), onp.tril(x.asnumpy())) + match_array(mnp.tril(x, -1).asnumpy(), onp.tril(x.asnumpy(), -1)) + match_array(mnp.triu(x).asnumpy(), onp.triu(x.asnumpy())) + match_array(mnp.triu(x, -1).asnumpy(), onp.triu(x.asnumpy(), -1)) + + x = mnp.ones((64, 64), dtype="uint8") + match_array(mnp.tril(x).asnumpy(), onp.tril(x.asnumpy())) + match_array(mnp.tril(x, 25).asnumpy(), onp.tril(x.asnumpy(), 25)) + match_array(mnp.triu(x).asnumpy(), onp.triu(x.asnumpy())) + match_array(mnp.triu(x, 25).asnumpy(), onp.triu(x.asnumpy(), 25)) + + match_array(mnp.tri(64, 64).asnumpy(), onp.tri(64, 64)) + match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) + + +def mnp_diagonal(arr): + return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) + + +def onp_diagonal(arr): + return onp.diagonal(arr, offset=2, axis1=-1, axis2=0) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_diagonal(): + arr = rand_int(0, 0) + match_res(mnp.diagonal, onp.diagonal, arr, offset=1) + + arr = rand_int(3, 5) + for i in [-10, -5, -1, 0, 2, 5, 6, 10]: + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=1) + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=0) + + arr = rand_int(7, 4, 9) + for i in [-10, -5, -1, 0, 2, 5, 6, 10]: + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-1) + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-2, axis2=2) + match_res(mnp.diagonal, onp.diagonal, arr, + offset=i, axis1=-1, axis2=-2) + + arr = rand_int(2, 5, 8, 1) + match_res(mnp_diagonal, onp_diagonal, arr) + for i in [-10, -5, -1, 0, 2, 5, 6, 10]: + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=-3, axis2=2) + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=1, axis2=3) + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=0, axis2=-2) + match_res(mnp.diagonal, onp.diagonal, arr, offset=i, axis1=2, axis2=-1) + + +def mnp_trace(arr): + return mnp.trace(arr, offset=4, axis1=1, axis2=2) + + +def onp_trace(arr): + return onp.trace(arr, offset=4, axis1=1, axis2=2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_trace(): + arr = rand_int(0, 0) + match_res(mnp.trace, onp.trace, arr, offset=1) + + arr = rand_int(3, 5) + for i in [-10, -5, -1, 0, 2, 5, 6, 10]: + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=1) + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=0) + + arr = rand_int(7, 4, 9) + for i in [-10, -5, -1, 0, 2, 5, 6, 10]: + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-1) + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-2, axis2=2) + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-1, axis2=-2) + + arr = rand_int(2, 5, 8, 1) + match_res(mnp_trace, onp_trace, arr) + for i in [-10, -5, -1, 0, 2, 5, 6, 10]: + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=-3, axis2=2) + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=1, axis2=3) + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=0, axis2=-2) + match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_asarray_exception(): + with pytest.raises(TypeError): + mnp.asarray({1, 2, 3}) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_linspace_exception(): + with pytest.raises(TypeError): + mnp.linspace(0, 1, num=2.5) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_empty_like_exception(): + with pytest.raises(ValueError): + mnp.empty_like([[1, 2, 3], [4, 5]]) diff --git a/tests/st/numpy_native/test_array_ops.py b/tests/st/numpy_native/test_array_ops.py new file mode 100644 index 00000000000..48b3634b19a --- /dev/null +++ b/tests/st/numpy_native/test_array_ops.py @@ -0,0 +1,1012 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""unit tests for numpy array operations""" + +import functools + +import pytest +import numpy as onp + +import mindspore.numpy as mnp +from mindspore.nn import Cell + + +class Cases(): + def __init__(self): + self.all_shapes = [ + 0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3] + ] + self.onp_dtypes = [onp.int32, 'int32', int, + onp.float32, 'float32', float, + onp.uint32, 'uint32', + onp.bool_, 'bool', bool] + + self.mnp_dtypes = [mnp.int32, 'int32', int, + mnp.float32, 'float32', float, + mnp.uint32, 'uint32', + mnp.bool_, 'bool', bool] + + self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,), + [(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member + (100, 100)).astype(onp.float32), + onp.random.random((100, 100)).astype(onp.bool)] + + self.arrs = [ + rand_int(2), + rand_int(2, 3), + rand_int(2, 3, 4), + rand_int(2, 3, 4, 5), + ] + + # scalars expanded across the 0th dimension + self.scalars = [ + rand_int(), + rand_int(1), + rand_int(1, 1), + rand_int(1, 1, 1), + ] + + # arrays of the same size expanded across the 0th dimension + self.expanded_arrs = [ + rand_int(2, 3), + rand_int(1, 2, 3), + rand_int(1, 1, 2, 3), + rand_int(1, 1, 1, 2, 3), + ] + + # arrays with dimensions of size 1 + self.nested_arrs = [ + rand_int(1), + rand_int(1, 2), + rand_int(3, 1, 8), + rand_int(1, 3, 9, 1), + ] + + # arrays which can be broadcast + self.broadcastables = [ + rand_int(5), + rand_int(6, 1), + rand_int(7, 1, 5), + rand_int(8, 1, 6, 1) + ] + + # boolean arrays which can be broadcast + self.bool_broadcastables = [ + rand_bool(), + rand_bool(1), + rand_bool(5), + rand_bool(6, 1), + rand_bool(7, 1, 5), + rand_bool(8, 1, 6, 1), + ] + + self.mnp_prototypes = [ + mnp.ones((2, 3, 4)), + mnp.ones((0, 3, 0, 2, 5)), + onp.ones((2, 7, 0)), + onp.ones(()), + [mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], + ([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])), + ] + + self.onp_prototypes = [ + onp.ones((2, 3, 4)), + onp.ones((0, 3, 0, 2, 5)), + onp.ones((2, 7, 0)), + onp.ones(()), + [onp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], + ([(1, 2), onp.ones(2)], (onp.ones(2), [3, 4])), + ] + + +def match_array(actual, expected, error=0): + if error > 0: + onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), + decimal=error) + else: + onp.testing.assert_equal(actual.tolist(), expected.tolist()) + + +def check_all_results(onp_results, mnp_results, error=0): + """Check all results from numpy and mindspore.numpy""" + for i, _ in enumerate(onp_results): + match_array(onp_results[i], mnp_results[i].asnumpy()) + + +def run_non_kw_test(mnp_fn, onp_fn): + """Run tests on functions with non keyword arguments""" + test_case = Cases() + for i in range(len(test_case.arrs)): + arrs = test_case.arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.scalars)): + arrs = test_case.scalars[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.expanded_arrs)): + arrs = test_case.expanded_arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.nested_arrs)): + arrs = test_case.nested_arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + +def rand_int(*shape): + """return an random integer array with parameter shape""" + res = onp.random.randint(low=1, high=5, size=shape) + if isinstance(res, onp.ndarray): + return res.astype(onp.float32) + return float(res) + + +# return an random boolean array +def rand_bool(*shape): + return onp.random.rand(*shape) > 0.5 + + +def match_res(mnp_fn, onp_fn, *arrs, **kwargs): + """Checks results from applying mnp_fn and onp_fn on arrs respectively""" + mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) + mnp_res = mnp_fn(*mnp_arrs, **kwargs) + onp_res = onp_fn(*arrs, **kwargs) + match_all_arrays(mnp_res, onp_res) + + +def match_all_arrays(mnp_res, onp_res, error=0): + if isinstance(mnp_res, (tuple, list)): + for actual, expected in zip(mnp_res, onp_res): + match_array(actual.asnumpy(), expected, error) + else: + match_array(mnp_res.asnumpy(), onp_res, error) + + +def match_meta(actual, expected): + # float64 and int64 are not supported, and the defualt type for + # float and int are float32 and int32, respectively + if expected.dtype == onp.float64: + expected = expected.astype(onp.float32) + elif expected.dtype == onp.int64: + expected = expected.astype(onp.int32) + assert actual.shape == expected.shape + assert actual.dtype == expected.dtype + + +# Test np.transpose and np.ndarray.transpose +def mnp_transpose(input_tensor): + a = mnp.transpose(input_tensor, (0, 2, 1)) + b = mnp.transpose(input_tensor, [2, 1, 0]) + c = mnp.transpose(input_tensor, (1, 0, 2)) + d = mnp.transpose(input_tensor) + return a, b, c, d + + +def onp_transpose(input_array): + a = onp.transpose(input_array, (0, 2, 1)) + b = onp.transpose(input_array, [2, 1, 0]) + c = onp.transpose(input_array, (1, 0, 2)) + d = onp.transpose(input_array) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_transpose(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_transposed = onp_transpose(onp_array) + m_transposed = mnp_transpose(mnp_array) + check_all_results(o_transposed, m_transposed) + + +# Test np.expand_dims +def mnp_expand_dims(input_tensor): + a = mnp.expand_dims(input_tensor, 0) + b = mnp.expand_dims(input_tensor, -1) + c = mnp.expand_dims(input_tensor, axis=2) + d = mnp.expand_dims(input_tensor, axis=-2) + return a, b, c, d + + +def onp_expand_dims(input_array): + a = onp.expand_dims(input_array, 0) + b = onp.expand_dims(input_array, -1) + c = onp.expand_dims(input_array, axis=2) + d = onp.expand_dims(input_array, axis=-2) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_expand_dims(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_expanded = onp_expand_dims(onp_array) + m_expanded = mnp_expand_dims(mnp_array) + check_all_results(o_expanded, m_expanded) + + +# Test np.squeeze +def mnp_squeeze(input_tensor): + a = mnp.squeeze(input_tensor) + b = mnp.squeeze(input_tensor, 0) + c = mnp.squeeze(input_tensor, axis=None) + d = mnp.squeeze(input_tensor, axis=-3) + e = mnp.squeeze(input_tensor, (2,)) + f = mnp.squeeze(input_tensor, (0, 2)) + return a, b, c, d, e, f + + +def onp_squeeze(input_array): + a = onp.squeeze(input_array) + b = onp.squeeze(input_array, 0) + c = onp.squeeze(input_array, axis=None) + d = onp.squeeze(input_array, axis=-3) + e = onp.squeeze(input_array, (2,)) + f = onp.squeeze(input_array, (0, 2)) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_squeeze(): + onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_squeezed = onp_squeeze(onp_array) + m_squeezed = mnp_squeeze(mnp_array) + check_all_results(o_squeezed, m_squeezed) + + onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_squeezed = onp_squeeze(onp_array) + m_squeezed = mnp_squeeze(mnp_array) + check_all_results(o_squeezed, m_squeezed) + + +# Test np.rollaxis +def mnp_rollaxis(input_tensor): + a = mnp.rollaxis(input_tensor, 0, 1) + b = mnp.rollaxis(input_tensor, 0, 2) + c = mnp.rollaxis(input_tensor, 2, 1) + d = mnp.rollaxis(input_tensor, 2, 2) + e = mnp.rollaxis(input_tensor, 0) + f = mnp.rollaxis(input_tensor, 1) + return a, b, c, d, e, f + + +def onp_rollaxis(input_array): + a = onp.rollaxis(input_array, 0, 1) + b = onp.rollaxis(input_array, 0, 2) + c = onp.rollaxis(input_array, 2, 1) + d = onp.rollaxis(input_array, 2, 2) + e = onp.rollaxis(input_array, 0) + f = onp.rollaxis(input_array, 1) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_rollaxis(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_rolled = onp_rollaxis(onp_array) + m_rolled = mnp_rollaxis(mnp_array) + check_all_results(o_rolled, m_rolled) + + +# Test np.swapaxes +def mnp_swapaxes(input_tensor): + a = mnp.swapaxes(input_tensor, 0, 1) + b = mnp.swapaxes(input_tensor, 1, 0) + c = mnp.swapaxes(input_tensor, 1, 1) + d = mnp.swapaxes(input_tensor, 2, 1) + e = mnp.swapaxes(input_tensor, 1, 2) + f = mnp.swapaxes(input_tensor, 2, 2) + return a, b, c, d, e, f + + +def onp_swapaxes(input_array): + a = onp.swapaxes(input_array, 0, 1) + b = onp.swapaxes(input_array, 1, 0) + c = onp.swapaxes(input_array, 1, 1) + d = onp.swapaxes(input_array, 2, 1) + e = onp.swapaxes(input_array, 1, 2) + f = onp.swapaxes(input_array, 2, 2) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_swapaxes(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_swaped = onp_swapaxes(onp_array) + m_swaped = mnp_swapaxes(mnp_array) + check_all_results(o_swaped, m_swaped) + + +# Test np.reshape +def mnp_reshape(input_tensor): + a = mnp.reshape(input_tensor, (3, 8)) + b = mnp.reshape(input_tensor, [3, -1]) + c = mnp.reshape(input_tensor, (-1, 12)) + d = mnp.reshape(input_tensor, (-1,)) + e = mnp.reshape(input_tensor, 24) + f = mnp.reshape(input_tensor, [2, 4, -1]) + g = input_tensor.reshape(3, 8) + h = input_tensor.reshape(3, -1) + i = input_tensor.reshape([-1, 3]) + j = input_tensor.reshape(-1) + return a, b, c, d, e, f, g, h, i, j + + +def onp_reshape(input_array): + a = onp.reshape(input_array, (3, 8)) + b = onp.reshape(input_array, [3, -1]) + c = onp.reshape(input_array, (-1, 12)) + d = onp.reshape(input_array, (-1,)) + e = onp.reshape(input_array, 24) + f = onp.reshape(input_array, [2, 4, -1]) + g = input_array.reshape(3, 8) + h = input_array.reshape(3, -1) + i = input_array.reshape([-1, 3]) + j = input_array.reshape(-1) + return a, b, c, d, e, f, g, h, i, j + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_reshape(): + onp_array = onp.random.random((2, 3, 4)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_reshaped = onp_reshape(onp_array) + m_reshaped = mnp_reshape(mnp_array) + check_all_results(o_reshaped, m_reshaped) + + +# Test np.ravel +def mnp_ravel(input_tensor): + a = mnp.ravel(input_tensor) + return a + + +def onp_ravel(input_array): + a = onp.ravel(input_array) + return a + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ravel(): + onp_array = onp.random.random((2, 3, 4)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_ravel = onp_ravel(onp_array) + m_ravel = mnp_ravel(mnp_array).asnumpy() + match_array(o_ravel, m_ravel) + + +# Test np.concatenate +def mnp_concatenate(input_tensor): + a = mnp.concatenate(input_tensor, None) + b = mnp.concatenate(input_tensor, 0) + c = mnp.concatenate(input_tensor, 1) + d = mnp.concatenate(input_tensor, 2) + return a, b, c, d + + +def onp_concatenate(input_array): + a = onp.concatenate(input_array, None) + b = onp.concatenate(input_array, 0) + c = onp.concatenate(input_array, 1) + d = onp.concatenate(input_array, 2) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_concatenate(): + onp_array = onp.random.random((5, 4, 3, 2)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_concatenate = onp_concatenate(onp_array) + m_concatenate = mnp_concatenate(mnp_array) + check_all_results(o_concatenate, m_concatenate) + + +def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5): + onp_array_lst = [] + mnp_array_lst = [] + shape = onp.random.randint(low=low, high=high, size=ndim) + new_shape = [sh for sh in shape] + while n > 0: + n -= 1 + onp_array1 = onp.random.randint( + low=low, high=high, size=shape).astype(onp.float32) + onp_array_lst.append(onp_array1) + mnp_array_lst.append(mnp.asarray(onp_array1)) + if axis is not None and axis < ndim: + new_shape[axis] += onp.random.randint(2) + onp_array2 = onp.random.randint( + low=low, high=high, size=new_shape).astype(onp.float32) + onp_array_lst.append(onp_array2) + mnp_array_lst.append(mnp.asarray(onp_array2)) + return onp_array_lst, mnp_array_lst + +# Test np.xstack + + +def prepare_array_sequences(n_lst, ndim_lst, axis=None, low=1, high=5): + onp_seq_lst = [] + mnp_seq_lst = [] + for n in n_lst: + for ndim in ndim_lst: + onp_array_lst, mnp_array_lst = construct_arrays( + n=n, ndim=ndim, axis=axis, low=low, high=high) + onp_seq_lst.append(onp_array_lst) + mnp_seq_lst.append(mnp_array_lst) + return onp_seq_lst, mnp_seq_lst + + +def mnp_column_stack(input_tensor): + return mnp.column_stack(input_tensor) + + +def onp_column_stack(input_array): + return onp.column_stack(input_array) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_column_stack(): + onp_seq_lst, mnp_seq_lst = prepare_array_sequences( + n_lst=[1, 5], ndim_lst=[1, 2, 3, 4], axis=1) + for i, onp_seq in enumerate(onp_seq_lst): + onp_seq = onp_seq_lst[i] + mnp_seq = mnp_seq_lst[i] + o_column_stack = onp_column_stack(onp_seq) + m_column_stack = mnp_column_stack(mnp_seq) + check_all_results(o_column_stack, m_column_stack) + + +def mnp_hstack(input_tensor): + return mnp.hstack(input_tensor) + + +def onp_hstack(input_array): + return onp.hstack(input_array) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_hstack(): + onp_seq_lst0, mnp_seq_lst0 = prepare_array_sequences( + n_lst=[1, 5], ndim_lst=[2, 3, 4], axis=1) + onp_seq_lst1, mnp_seq_lst1 = prepare_array_sequences( + n_lst=[1, 5], ndim_lst=[1], axis=0) + onp_seq_lst = onp_seq_lst0 + onp_seq_lst1 + mnp_seq_lst = mnp_seq_lst0 + mnp_seq_lst1 + for i, onp_seq in enumerate(onp_seq_lst): + mnp_seq = mnp_seq_lst[i] + o_hstack = onp_hstack(onp_seq) + m_hstack = mnp_hstack(mnp_seq) + check_all_results(o_hstack, m_hstack) + + +def mnp_dstack(input_tensor): + return mnp.dstack(input_tensor) + + +def onp_dstack(input_array): + return onp.dstack(input_array) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_dstack(): + onp_seq_lst, mnp_seq_lst = prepare_array_sequences( + n_lst=[1, 5], ndim_lst=[1, 2, 3, 4], axis=2) + for i, onp_seq in enumerate(onp_seq_lst): + mnp_seq = mnp_seq_lst[i] + o_dstack = onp_dstack(onp_seq) + m_dstack = mnp_dstack(mnp_seq) + check_all_results(o_dstack, m_dstack) + + +def mnp_vstack(input_tensor): + return mnp.vstack(input_tensor) + + +def onp_vstack(input_array): + return onp.vstack(input_array) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_vstack(): + onp_seq_lst0, mnp_seq_lst0 = prepare_array_sequences( + n_lst=[1, 5], ndim_lst=[2, 3, 4], axis=0) + onp_seq_lst1, mnp_seq_lst1 = prepare_array_sequences( + n_lst=[1, 5], ndim_lst=[1]) + onp_seq_lst = onp_seq_lst0 + onp_seq_lst1 + mnp_seq_lst = mnp_seq_lst0 + mnp_seq_lst1 + for i, onp_seq in enumerate(onp_seq_lst): + mnp_seq = mnp_seq_lst[i] + o_vstack = onp_vstack(onp_seq) + m_vstack = mnp_vstack(mnp_seq) + check_all_results(o_vstack, m_vstack) +# Test np.atleastxd + + +def mnp_atleast1d(*arys): + return mnp.atleast_1d(*arys) + + +def onp_atleast1d(*arys): + return onp.atleast_1d(*arys) + + +def mnp_atleast2d(*arys): + return mnp.atleast_2d(*arys) + + +def onp_atleast2d(*arys): + return onp.atleast_2d(*arys) + + +def mnp_atleast3d(*arys): + return mnp.atleast_3d(*arys) + + +def onp_atleast3d(*arys): + return onp.atleast_3d(*arys) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_atleast1d(): + run_non_kw_test(mnp_atleast1d, onp_atleast1d) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_atleast2d(): + run_non_kw_test(mnp_atleast2d, onp_atleast2d) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_atleast3d(): + run_non_kw_test(mnp_atleast3d, onp_atleast3d) + + +# Test np.where +def mnp_where(condition, x, y): + return mnp.where(condition, x, y) + + +def onp_where(condition, x, y): + return onp.where(condition, x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_where(): + test_case = Cases() + for condition1 in test_case.bool_broadcastables[:2]: + for x in test_case.broadcastables[:2]: + for y in test_case.broadcastables[:2]: + for condition2 in test_case.broadcastables[:2]: + match_res(mnp_where, onp_where, condition1, x, y) + match_res(mnp_where, onp_where, condition2, x, y) + + +# Test ndarray.flatten +def mnp_ndarray_flatten(input_tensor): + a = input_tensor.flatten() + b = input_tensor.flatten(order='F') + c = input_tensor.flatten(order='C') + return a, b, c + + +def onp_ndarray_flatten(input_array): + a = input_array.flatten() + b = input_array.flatten(order='F') + c = input_array.flatten(order='C') + return a, b, c + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ndarray_flatten(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_flatten = onp_ndarray_flatten(onp_array) + m_flatten = mnp_ndarray_flatten(mnp_array) + check_all_results(o_flatten, m_flatten) + + +# Test ndarray.transpose +def mnp_ndarray_transpose(input_tensor): + a = input_tensor.T + b = input_tensor.transpose() + c = input_tensor.transpose((0, 2, 1)) + d = input_tensor.transpose([0, 2, 1]) + return a, b, c, d + + +def onp_ndarray_transpose(input_array): + a = input_array.T + b = input_array.transpose() + c = input_array.transpose((0, 2, 1)) + d = input_array.transpose([0, 2, 1]) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ndarray_transpose(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_transposed = onp_ndarray_transpose(onp_array) + m_transposed = mnp_ndarray_transpose(mnp_array) + check_all_results(o_transposed, m_transposed) + + +# Test ndarray.astype +def mnp_ndarray_astype(input_tensor): + a = input_tensor.astype("float16") + b = input_tensor.astype(onp.float64) + c = input_tensor.astype(mnp.bool_) + return a, b, c + + +def onp_ndarray_astype(input_array): + a = input_array.astype("float16") + b = input_array.astype(onp.float64) + c = input_array.astype(onp.bool_) + return a, b, c + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ndarray_astype(): + onp_array = onp.random.random((3, 4, 5)).astype('float32') + mnp_array = mnp.asarray(onp_array) + o_astype = onp_ndarray_astype(onp_array) + m_astype = mnp_ndarray_astype(mnp_array) + for arr1, arr2 in zip(o_astype, m_astype): + assert arr1.dtype == arr2.asnumpy().dtype + + +def onp_concatenate_type_promotion(onp_array1, onp_array2, onp_array3, onp_array4): + o_concatenate = onp.concatenate((onp_array1, + onp_array2, + onp_array3, + onp_array4), -1) + return o_concatenate + + +def mnp_concatenate_type_promotion(mnp_array1, mnp_array2, mnp_array3, mnp_array4): + m_concatenate = mnp.concatenate([mnp_array1, + mnp_array2, + mnp_array3, + mnp_array4], -1) + return m_concatenate + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_concatenate_type_promotion(): + onp_array = onp.random.random((5, 1)).astype('float32') + mnp_array = mnp.asarray(onp_array) + onp_array1 = onp_array.astype(onp.float16) + onp_array2 = onp_array.astype(onp.bool_) + onp_array3 = onp_array.astype(onp.float32) + onp_array4 = onp_array.astype(onp.int32) + + mnp_array1 = mnp_array.astype(onp.float16) + mnp_array2 = mnp_array.astype(onp.bool_) + mnp_array3 = mnp_array.astype(onp.float32) + mnp_array4 = mnp_array.astype(onp.int32) + o_concatenate = onp_concatenate_type_promotion( + onp_array1, onp_array2, onp_array3, onp_array4).astype('float32') + m_concatenate = mnp_concatenate_type_promotion( + mnp_array1, mnp_array2, mnp_array3, mnp_array4) + check_all_results(o_concatenate, m_concatenate, error=1e-7) + + +def mnp_stack(*arrs): + a = mnp.stack(arrs, axis=-4) + b = mnp.stack(arrs, axis=-3) + c = mnp.stack(arrs, axis=0) + d = mnp.stack(arrs, axis=3) + e = mnp.stack(arrs, axis=2) + return a, b, c, d, e + + +def onp_stack(*arrs): + a = onp.stack(arrs, axis=-4) + b = onp.stack(arrs, axis=-3) + c = onp.stack(arrs, axis=0) + d = onp.stack(arrs, axis=3) + e = onp.stack(arrs, axis=2) + return a, b, c, d, e + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_stack(): + arr = rand_int(3, 4, 5, 6) + match_res(mnp.stack, onp.stack, arr) + for i in range(-4, 4): + match_res(mnp.stack, onp.stack, arr, axis=i) + + arr = rand_int(7, 4, 0, 3) + match_res(mnp.stack, onp.stack, arr) + for i in range(-4, 4): + match_res(mnp.stack, onp.stack, arr, axis=i) + + arrs = [rand_int(3, 4, 5) for i in range(10)] + match_res(mnp.stack, onp.stack, arrs) + match_res(mnp.stack, onp.stack, tuple(arrs)) + match_res(mnp_stack, onp_stack, *arrs) + for i in range(-4, 4): + match_res(mnp.stack, onp.stack, arrs, axis=i) + + arrs = [rand_int(3, 0, 5, 8, 0) for i in range(5)] + match_res(mnp.stack, onp.stack, arrs) + match_res(mnp.stack, onp.stack, tuple(arrs)) + match_res(mnp_stack, onp_stack, *arrs) + for i in range(-6, 6): + match_res(mnp.stack, onp.stack, arrs, axis=i) + + +class ReshapeExpandSqueeze(Cell): + def __init__(self): + super(ReshapeExpandSqueeze, self).__init__() + + def construct(self, x): + x = mnp.expand_dims(x, 2) + x = mnp.reshape(x, (1, 2, 3, 4, 1, 1)) + x = mnp.squeeze(x) + return x + + +class TransposeConcatRavel(Cell): + def __init__(self): + super(TransposeConcatRavel, self).__init__() + + def construct(self, x1, x2, x3): + x1 = mnp.transpose(x1, [0, 2, 1]) + x2 = x2.transpose(0, 2, 1) + x = mnp.concatenate((x1, x2, x3), -1) + x = mnp.ravel(x) + return x + + +class RollSwap(Cell): + def __init__(self): + super(RollSwap, self).__init__() + + def construct(self, x): + x = mnp.rollaxis(x, 2) + x = mnp.swapaxes(x, 0, 1) + return x + + +test_case_array_ops = [ + ('ReshapeExpandSqueeze', { + 'block': ReshapeExpandSqueeze(), + 'desc_inputs': [mnp.ones((2, 3, 4))]}), + + ('TransposeConcatRavel', { + 'block': TransposeConcatRavel(), + 'desc_inputs': [mnp.ones((2, 3, 4)), + mnp.ones((2, 3, 4)), + mnp.ones((2, 4, 1))]}), + + ('RollSwap', { + 'block': RollSwap(), + 'desc_inputs': [mnp.ones((2, 3, 4))]}) +] + +test_case_lists = [test_case_array_ops] +test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) +# use -k to select certain testcast +# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_expand_dims_exception(): + with pytest.raises(TypeError): + mnp.expand_dims(mnp.ones((3, 3)), 1.2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_swapaxes_exception(): + with pytest.raises(ValueError): + mnp.swapaxes(mnp.ones((3, 3)), 1, 10) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_flatten(): + lst = [[1.0, 2.0], [3.0, 4.0]] + tensor_list = mnp.asarray(lst) + assert tensor_list.flatten().asnumpy().tolist() == [1.0, 2.0, 3.0, 4.0] + assert tensor_list.flatten(order='F').asnumpy().tolist() == [ + 1.0, 3.0, 2.0, 4.0] + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_reshape(): + lst = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] + tensor_list = mnp.asarray(lst) + with pytest.raises(TypeError): + tensor_list = tensor_list.reshape({0, 1, 2}) + with pytest.raises(ValueError): + tensor_list = tensor_list.reshape(1, 2, 3) + assert tensor_list.reshape([-1, 4]).shape == (2, 4) + assert tensor_list.reshape(1, -1, 4).shape == (1, 2, 4) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_squeeze(): + lst = [[[1.0], [2.0], [3.0]]] + tensor_list = mnp.asarray(lst) + with pytest.raises(TypeError): + tensor_list = tensor_list.squeeze(1.2) + with pytest.raises(ValueError): + tensor_list = tensor_list.squeeze(4) + assert tensor_list.squeeze().shape == (3,) + assert tensor_list.squeeze(axis=2).shape == (1, 3) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_ravel(): + lst = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]] + tensor_list = mnp.asarray(lst) + assert tensor_list.ravel().shape == (8,) + assert tensor_list.ravel().asnumpy().tolist() == [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensor_swapaxes(): + lst = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + tensor_list = mnp.asarray(lst) + with pytest.raises(TypeError): + tensor_list = tensor_list.swapaxes(0, (1,)) + with pytest.raises(ValueError): + tensor_list = tensor_list.swapaxes(0, 3) + assert tensor_list.swapaxes(0, 1).shape == (3, 2) diff --git a/tests/st/numpy_native/test_math_ops.py b/tests/st/numpy_native/test_math_ops.py new file mode 100644 index 00000000000..9b430fcc924 --- /dev/null +++ b/tests/st/numpy_native/test_math_ops.py @@ -0,0 +1,567 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""unit tests for numpy math operations""" +from functools import partial + +import pytest +import numpy as onp + +import mindspore.numpy as mnp +from mindspore import context + + +def rand_int(*shape): + """return an random integer array with parameter shape""" + res = onp.random.randint(low=1, high=5, size=shape) + if isinstance(res, onp.ndarray): + return res.astype(onp.float32) + return float(res) + + +# return an random boolean array +def rand_bool(*shape): + return onp.random.rand(*shape) > 0.5 + + +class Cases(): + def __init__(self): + self.device_cpu = context.get_context('device_target') + + self.arrs = [ + rand_int(2), + rand_int(2, 3), + rand_int(2, 3, 4), + rand_int(2, 3, 4, 5), + ] + + # scalars expanded across the 0th dimension + self.scalars = [ + rand_int(), + rand_int(1), + rand_int(1, 1), + rand_int(1, 1, 1, 1), + ] + + # empty arrays + self.empty_arrs = [ + rand_int(0), + rand_int(4, 0), + rand_int(2, 0, 2), + rand_int(5, 0, 7, 0), + ] + + # arrays of the same size expanded across the 0th dimension + self.expanded_arrs = [ + rand_int(2, 3), + rand_int(1, 2, 3), + rand_int(1, 1, 2, 3), + rand_int(1, 1, 1, 2, 3), + ] + + # arrays of the same size expanded across the 0th dimension + self.expanded_arrs = [ + rand_int(2, 3), + rand_int(1, 2, 3), + rand_int(1, 1, 2, 3), + rand_int(1, 1, 1, 2, 3), + ] + + # arrays with last dimension aligned + self.aligned_arrs = [ + rand_int(2, 3), + rand_int(1, 4, 3), + rand_int(5, 1, 2, 3), + rand_int(4, 2, 1, 1, 3), + ] + + # arrays which can be broadcast + self.broadcastables = [ + rand_int(5), + rand_int(6, 1), + rand_int(7, 1, 5), + rand_int(8, 1, 6, 1) + ] + + # boolean arrays which can be broadcast + self.bool_broadcastables = [ + rand_bool(), + rand_bool(1), + rand_bool(5), + rand_bool(6, 1), + rand_bool(7, 1, 5), + rand_bool(8, 1, 6, 1), + ] + + # core dimension 0 is matched for each + # pair of array[i] and array[i + 1] + self.core_broadcastables = [ + rand_int(3), + rand_int(3), + rand_int(6), + rand_int(6, 4), + rand_int(5, 2), + rand_int(2), + rand_int(2, 9), + rand_int(9, 8), + rand_int(6), + rand_int(2, 6, 5), + rand_int(9, 2, 7), + rand_int(7), + rand_int(5, 2, 4), + rand_int(6, 1, 4, 9), + rand_int(7, 1, 5, 3, 2), + rand_int(8, 1, 6, 1, 2, 9), + ] + + # arrays with dimensions of size 1 + self.nested_arrs = [ + rand_int(1), + rand_int(1, 2), + rand_int(3, 1, 8), + rand_int(1, 3, 9, 1), + ] + + +test_case = Cases() +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +def mnp_add(x1, x2): + return mnp.add(x1, x2) + + +def onp_add(x1, x2): + return onp.add(x1, x2) + + +def mnp_subtract(x1, x2): + return mnp.subtract(x1, x2) + + +def onp_subtract(x1, x2): + return onp.subtract(x1, x2) + + +def mnp_mutiply(x1, x2): + return mnp.multiply(x1, x2) + + +def onp_multiply(x1, x2): + return onp.multiply(x1, x2) + + +def mnp_divide(x1, x2): + return mnp.divide(x1, x2) + + +def onp_divide(x1, x2): + return onp.divide(x1, x2) + + +def mnp_power(x1, x2): + return mnp.power(x1, x2) + + +def onp_power(x1, x2): + return onp.power(x1, x2) + + +def mnp_inner(a, b): + return mnp.inner(a, b) + + +def onp_inner(a, b): + return onp.inner(a, b) + + +def mnp_dot(a, b): + return mnp.dot(a, b) + + +def onp_dot(a, b): + return onp.dot(a, b) + + +def mnp_outer(a, b): + return mnp.outer(a, b) + + +def onp_outer(a, b): + return onp.outer(a, b) + + +def mnp_add_kwargs(x, y, where=None, out=None): + return mnp.add(x, y, where=where, out=out) + + +def onp_add_kwargs(x, y, where=None, out=None): + return onp.add(x, y, where=where, out=out) + + +def mnp_subtract_kwargs(x, y, where=None, out=None): + return mnp.subtract(x, y, where=where, out=out) + + +def onp_subtract_kwargs(x, y, where=None, out=None): + return onp.subtract(x, y, where=where, out=out) + + +def mnp_multiply_kwargs(x, y, where=None, out=None): + return mnp.multiply(x, y, where=where, out=out) + + +def onp_multiply_kwargs(x, y, where=None, out=None): + return onp.multiply(x, y, where=where, out=out) + + +def mnp_divide_kwargs(x, y, where=None, out=None): + return mnp.divide(x, y, where=where, out=out) + + +def onp_divide_kwargs(x, y, where=None, out=None): + return onp.divide(x, y, where=where, out=out) + + +def mnp_power_kwargs(x, y, where=None, out=None): + return mnp.power(x, y, where=where, out=out) + + +def onp_power_kwargs(x, y, where=None, out=None): + return onp.power(x, y, where=where, out=out) + + +def mnp_tensordot(x, y): + a = mnp.tensordot(x, y) + b = mnp.tensordot(x, y, axes=0) + c = mnp.tensordot(x, y, axes=1) + d = mnp.tensordot(x, y, axes=2) + e = mnp.tensordot(x, y, axes=(3, 0)) + f = mnp.tensordot(x, y, axes=[2, 1]) + g = mnp.tensordot(x, y, axes=((2, 3), (0, 1))) + h = mnp.tensordot(x, y, axes=[[3, 2], [1, 0]]) + return a, b, c, d, e, f, g, h + + +def onp_tensordot(x, y): + a = onp.tensordot(x, y) + b = onp.tensordot(x, y, axes=0) + c = onp.tensordot(x, y, axes=1) + d = onp.tensordot(x, y, axes=2) + e = onp.tensordot(x, y, axes=(3, 0)) + f = onp.tensordot(x, y, axes=[2, 1]) + g = onp.tensordot(x, y, axes=((2, 3), (0, 1))) + h = onp.tensordot(x, y, axes=[[3, 2], [1, 0]]) + return a, b, c, d, e, f, g, h + + +def run_binop_test(mnp_fn, onp_fn): + for arr in test_case.arrs: + match_res(mnp_fn, onp_fn, arr, arr) + + for scalar in test_case.scalars: + match_res(mnp_fn, onp_fn, arr, scalar) + match_res(mnp_fn, onp_fn, scalar, arr) + + for scalar1 in test_case.scalars: + for scalar2 in test_case.scalars: + match_res(mnp_fn, onp_fn, scalar1, scalar2) + + for expanded_arr1 in test_case.expanded_arrs: + for expanded_arr2 in test_case.expanded_arrs: + match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2) + + for broadcastable1 in test_case.broadcastables: + for broadcastable2 in test_case.broadcastables: + match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2) + + +def run_multi_test(mnp_fn, onp_fn, arrs): + mnp_arrs = map(mnp.asarray, arrs) + for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): + match_array(actual.asnumpy(), expected) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_add(): + run_binop_test(mnp_add, onp_add) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_subtract(): + run_binop_test(mnp_subtract, onp_subtract) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_multiply(): + run_binop_test(mnp_mutiply, onp_multiply) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_divide(): + run_binop_test(mnp_divide, onp_divide) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_power(): + run_binop_test(mnp_power, onp_power) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_inner(): + for arr1 in test_case.aligned_arrs: + for arr2 in test_case.aligned_arrs: + match_res(mnp_inner, onp_inner, arr1, arr2) + + for scalar1 in test_case.scalars: + for scalar2 in test_case.scalars: + match_res(mnp_inner, onp_inner, + scalar1, scalar2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_dot(): + # test case (1D, 1D) + match_res(mnp_dot, onp_dot, rand_int(3), rand_int(3)) + + # test case (2D, 2D) + match_res(mnp_dot, onp_dot, rand_int(4, 7), rand_int(7, 2)) + + # test case (0D, _) (_, 0D) + match_res(mnp_dot, onp_dot, rand_int(), rand_int(1, 9, 3)) + match_res(mnp_dot, onp_dot, rand_int(8, 5, 6, 3), rand_int()) + + # test case (ND, 1D) + match_res(mnp_dot, onp_dot, rand_int(2, 4, 5), rand_int(5)) + + # test case (ND, MD) + match_res(mnp_dot, onp_dot, rand_int(5, 4, 1, 8), rand_int(8, 3)) + + for i in range(8): + match_res(mnp_dot, onp_dot, + test_case.core_broadcastables[2*i], test_case.core_broadcastables[2*i + 1]) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_outer(): + run_binop_test(mnp_outer, onp_outer) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_add_kwargs(): + for where in test_case.bool_broadcastables[:2]: + for x in test_case.broadcastables[:2]: + for y in test_case.broadcastables[:2]: + shape_out = onp.broadcast(where, x, y).shape + out = rand_int(*shape_out) + match_res(mnp_add_kwargs, onp_add_kwargs, x, y, where, out) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tensordot(): + x = rand_int(4, 2, 7, 7) + y = rand_int(7, 7, 6) + run_multi_test(mnp_tensordot, onp_tensordot, (x, y)) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_type_promotion(): + arr = rand_int(2, 3) + onp_sum = onp_add(arr, arr) + + a = mnp.asarray(arr, dtype='float16') + b = mnp.asarray(arr, dtype='float32') + c = mnp.asarray(arr, dtype='int32') + + match_array(mnp_add(a, b).asnumpy(), onp_sum) + match_array(mnp_add(b, c).asnumpy(), onp_sum) + + +def mnp_absolute(x): + return mnp.absolute(x) + + +def onp_absolute(x): + return onp.absolute(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_absolute(): + arr = rand_int(2, 3) + + a = mnp.asarray(arr, dtype='float16') + b = mnp.asarray(arr, dtype='float32') + c = mnp.asarray(arr, dtype='uint8') + d = mnp.asarray(arr, dtype='bool') + + match_array(mnp_absolute(a).asnumpy(), onp_absolute(a.asnumpy())) + match_array(mnp_absolute(b).asnumpy(), onp_absolute(b.asnumpy())) + match_array(mnp_absolute(c).asnumpy(), onp_absolute(c.asnumpy())) + match_array(mnp_absolute(d).asnumpy(), onp_absolute(d.asnumpy())) + + where = rand_int(2, 3).astype('bool') + out = rand_int(2, 3) + match_array(mnp.absolute(a, out=mnp.asarray(out), where=mnp.asarray(where)).asnumpy(), + onp.absolute(a.asnumpy(), out=out, where=where)) + + +def mnp_add_dtype(x1, x2, out, where): + a = mnp.add(x1, x2, dtype=mnp.float16) + b = mnp.add(x1, x2, out=out, dtype=mnp.float16) + c = mnp.add(x1, x2, where=where, dtype=mnp.float16) + d = mnp.add(x1, x2, out=out, where=where, dtype=mnp.float16) + return a, b, c, d + + +def onp_add_dtype(x1, x2, out, where): + a = onp.add(x1, x2, dtype=onp.float16) + b = onp.add(x1, x2, out=out, dtype=onp.float16) + c = onp.add(x1, x2, where=where, dtype=onp.float16) + d = onp.add(x1, x2, out=out, where=where, dtype=onp.float16) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_add_dtype(): + x1 = rand_int(2, 3).astype('int32') + x2 = rand_int(2, 3).astype('int32') + out = rand_int(2, 3).astype('float32') + where = rand_bool(2, 3) + arrs = (x1, x2, out, where) + mnp_arrs = map(mnp.array, arrs) + mnp_res = mnp_add_dtype(*mnp_arrs) + onp_res = onp_add_dtype(*arrs) + for actual, expected in zip(mnp_res, onp_res): + assert actual.asnumpy().dtype == expected.dtype + + +# check if the output from mnp function and onp function applied on the arrays are matched + + +def match_res(mnp_fn, onp_fn, *arrs): + mnp_arrs = map(partial(mnp.asarray, dtype='float32'), arrs) + mnp_res = mnp_fn(*mnp_arrs) + onp_res = onp_fn(*arrs) + if isinstance(mnp_res, (tuple, list)): + for actual, expected in zip(mnp_res, onp_res): + match_array(actual.asnumpy(), expected) + else: + match_array(mnp_res.asnumpy(), onp_res) + + +def match_array(actual, expected, error=5): + if error > 0: + onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), + decimal=error) + else: + onp.testing.assert_equal(actual.tolist(), expected.tolist()) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_exception_innner(): + with pytest.raises(ValueError): + mnp.inner(mnp.asarray(test_case.arrs[0]), + mnp.asarray(test_case.arrs[1])) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_exception_add(): + with pytest.raises(ValueError): + mnp.add(mnp.asarray(test_case.arrs[1]), mnp.asarray(test_case.arrs[2])) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_exception_mean(): + with pytest.raises(ValueError): + mnp.mean(mnp.asarray(test_case.arrs[0]), (-1, 0)) diff --git a/tests/ut/python/ir/test_tensor_py.py b/tests/ut/python/ir/test_tensor_py.py index 138c52ce872..5db563cffec 100644 --- a/tests/ut/python/ir/test_tensor_py.py +++ b/tests/ut/python/ir/test_tensor_py.py @@ -69,6 +69,24 @@ def test_tensor_size(): assert arr.size == b.size +def test_tensor_itemsize(): + arr = np.ones((1, 2, 3)) + b = ms.Tensor(arr) + assert arr.itemsize == b.itemsize + + +def test_tensor_strides(): + arr = np.ones((3, 4, 5, 6)) + b = ms.Tensor(arr) + assert arr.strides == b.strides + + +def test_tensor_nbytes(): + arr = np.ones((3, 4, 5, 6)) + b = ms.Tensor(arr) + assert arr.nbytes == b.nbytes + + def test_dtype(): a = ms.Tensor(np.ones((2, 3), dtype=np.int32)) assert a.dtype == ms.int32 diff --git a/tests/ut/python/numpy_native/test_array_ops.py b/tests/ut/python/numpy_native/test_array_ops.py deleted file mode 100644 index 51613b9ee69..00000000000 --- a/tests/ut/python/numpy_native/test_array_ops.py +++ /dev/null @@ -1,591 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""unit tests for numpy array operations""" - -import functools - -import pytest -import numpy as onp - -import mindspore.context as context -import mindspore.numpy as mnp -from mindspore.nn import Cell - -from ..ut_filter import non_graph_engine -from ....mindspore_test_framework.mindspore_test import mindspore_test -from ....mindspore_test_framework.pipeline.forward.compile_forward \ - import pipeline_for_compile_forward_ge_graph_for_case_by_case_config - - -class Cases(): - def __init__(self): - self.all_shapes = [ - 0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3] - ] - self.onp_dtypes = [onp.int32, 'int32', int, - onp.float32, 'float32', float, - onp.uint32, 'uint32', - onp.bool_, 'bool', bool] - - self.mnp_dtypes = [mnp.int32, 'int32', int, - mnp.float32, 'float32', float, - mnp.uint32, 'uint32', - mnp.bool_, 'bool', bool] - - self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,), - [(1, 2, 3), (4, 5, 6)], onp.random.random( - (100, 100)).astype(onp.float32), - onp.random.random((100, 100)).astype(onp.bool)] - - -def match_array(actual, expected, error=0): - if error > 0: - onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), - decimal=error) - else: - onp.testing.assert_equal(actual.tolist(), expected.tolist()) - - -def check_all_results(onp_results, mnp_results): - """Check all results from numpy and mindspore.numpy""" - for i, _ in enumerate(onp_results): - match_array(onp_results[i], mnp_results[i].asnumpy()) - - -def test_asarray(): - test_case = Cases() - for array in test_case.array_sets: - # Check for dtype matching - actual = onp.asarray(array) - expected = mnp.asarray(array).asnumpy() - # Since we set float32/int32 as the default dtype in mindspore, we need - # to make a conversion between numpy.asarray and mindspore.numpy.asarray - if actual.dtype is onp.dtype('float64'): - assert expected.dtype == onp.dtype('float32') - elif actual.dtype is onp.dtype('int64'): - assert expected.dtype == onp.dtype('int32') - else: - assert actual.dtype == expected.dtype - match_array(actual, expected, error=7) - - for i in range(len(test_case.onp_dtypes)): - actual = onp.asarray(array, test_case.onp_dtypes[i]) - expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy() - match_array(actual, expected, error=7) - - # Additional tests for nested tensor/numpy_array mixture - mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - - actual = onp.asarray(onp_input) - expected = mnp.asarray(mnp_input).asnumpy() - match_array(actual, expected, error=7) - - -def test_array(): - # array's function is very similar to asarray, so we mainly test the - # `copy` argument. - test_case = Cases() - for array in test_case.array_sets: - arr1 = mnp.asarray(array) - arr2 = mnp.array(arr1, copy=False) - arr3 = mnp.array(arr1) - arr4 = mnp.asarray(array, dtype='int32') - arr5 = mnp.asarray(arr4, dtype=mnp.int32) - assert arr1 is arr2 - assert arr1 is not arr3 - assert arr4 is arr5 - - # Additional tests for nested tensor/numpy_array mixture - mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - - actual = onp.asarray(onp_input) - expected = mnp.asarray(mnp_input).asnumpy() - match_array(actual, expected, error=7) - - -def test_asfarray(): - test_case = Cases() - for array in test_case.array_sets: - # Check for dtype matching - actual = onp.asfarray(array) - expected = mnp.asfarray(array).asnumpy() - # Since we set float32/int32 as the default dtype in mindspore, we need - # to make a conversion between numpy.asarray and mindspore.numpy.asarray - if actual.dtype is onp.dtype('float64'): - assert expected.dtype == onp.dtype('float32') - else: - assert actual.dtype == expected.dtype - match_array(actual, expected, error=7) - - for i in range(len(test_case.onp_dtypes)): - actual = onp.asfarray(array, test_case.onp_dtypes[i]) - expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy() - match_array(actual, expected, error=7) - - # Additional tests for nested tensor/numpy_array mixture - mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - - actual = onp.asarray(onp_input) - expected = mnp.asarray(mnp_input).asnumpy() - match_array(actual, expected, error=7) - - -def test_zeros(): - test_case = Cases() - for shape in test_case.all_shapes: - for i in range(len(test_case.onp_dtypes)): - actual = onp.zeros(shape, test_case.onp_dtypes[i]) - expected = mnp.zeros(shape, test_case.mnp_dtypes[i]).asnumpy() - match_array(actual, expected) - actual = onp.zeros(shape) - expected = mnp.zeros(shape).asnumpy() - match_array(actual, expected) - - -def test_ones(): - test_case = Cases() - for shape in test_case.all_shapes: - for i in range(len(test_case.onp_dtypes)): - actual = onp.ones(shape, test_case.onp_dtypes[i]) - expected = mnp.ones(shape, test_case.mnp_dtypes[i]).asnumpy() - match_array(actual, expected) - actual = onp.ones(shape) - expected = mnp.ones(shape).asnumpy() - match_array(actual, expected) - - -def test_full(): - actual = onp.full((2, 2), [1, 2]) - expected = mnp.full((2, 2), [1, 2]).asnumpy() - match_array(actual, expected) - - actual = onp.full((2, 0), onp.inf) - expected = mnp.full((2, 0), mnp.inf).asnumpy() - match_array(actual, expected) - - actual = onp.full((2, 3), True) - expected = mnp.full((2, 3), True).asnumpy() - match_array(actual, expected) - - actual = onp.full((3, 4, 5), 7.5) - expected = mnp.full((3, 4, 5), 7.5).asnumpy() - match_array(actual, expected) - - -def test_eye(): - test_case = Cases() - for i in range(len(test_case.onp_dtypes)): - for m in range(1, 5): - actual = onp.eye(m, dtype=test_case.onp_dtypes[i]) - expected = mnp.eye(m, dtype=test_case.mnp_dtypes[i]).asnumpy() - match_array(actual, expected) - for n in range(1, 5): - for k in range(0, 5): - actual = onp.eye(m, n, k, dtype=test_case.onp_dtypes[i]) - expected = mnp.eye( - m, n, k, dtype=test_case.mnp_dtypes[i]).asnumpy() - match_array(actual, expected) - - -def test_identity(): - test_case = Cases() - for i in range(len(test_case.onp_dtypes)): - for m in range(1, 5): - actual = onp.identity(m, dtype=test_case.onp_dtypes[i]) - expected = mnp.identity(m, dtype=test_case.mnp_dtypes[i]).asnumpy() - match_array(actual, expected) - - -def test_arange(): - actual = onp.arange(10) - expected = mnp.arange(10).asnumpy() - match_array(actual, expected) - - actual = onp.arange(0, 10) - expected = mnp.arange(0, 10).asnumpy() - match_array(actual, expected) - - actual = onp.arange(start=10) - expected = mnp.arange(start=10).asnumpy() - match_array(actual, expected) - - actual = onp.arange(start=10, step=0.1) - expected = mnp.arange(start=10, step=0.1).asnumpy() - match_array(actual, expected, error=6) - - actual = onp.arange(10, step=0.1) - expected = mnp.arange(10, step=0.1).asnumpy() - match_array(actual, expected, error=6) - - actual = onp.arange(0.1, 9.9) - expected = mnp.arange(0.1, 9.9).asnumpy() - match_array(actual, expected, error=6) - - -def test_linspace(): - actual = onp.linspace(2.0, 3.0, dtype=onp.float32) - expected = mnp.linspace(2.0, 3.0).asnumpy() - match_array(actual, expected, error=7) - - actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32) - expected = mnp.linspace(2.0, 3.0, num=5).asnumpy() - match_array(actual, expected, error=7) - - actual = onp.linspace( - 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) - expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy() - match_array(actual, expected, error=7) - - actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32) - expected = mnp.linspace(2.0, 3.0, num=5, retstep=True) - match_array(actual[0], expected[0].asnumpy()) - assert actual[1] == expected[1] - - actual = onp.linspace(2.0, [3, 4, 5], num=5, - endpoint=False, dtype=onp.float32) - expected = mnp.linspace( - 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() - match_array(actual, expected) - - -def test_logspace(): - actual = onp.logspace(2.0, 3.0, dtype=onp.float32) - expected = mnp.logspace(2.0, 3.0).asnumpy() - match_array(actual, expected) - - actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32) - expected = mnp.logspace(2.0, 3.0, num=5).asnumpy() - match_array(actual, expected) - - actual = onp.logspace( - 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) - expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy() - match_array(actual, expected) - - actual = onp.logspace(2.0, [3, 4, 5], num=5, - endpoint=False, dtype=onp.float32) - expected = mnp.logspace( - 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() - match_array(actual, expected) - - -# Test np.transpose and np.ndarray.transpose - - -def mnp_transpose(input_tensor): - a = mnp.transpose(input_tensor, (0, 2, 1)) - b = mnp.transpose(input_tensor, [2, 1, 0]) - c = mnp.transpose(input_tensor, (1, 0, 2)) - d = mnp.transpose(input_tensor) - return a, b, c, d - - -def onp_transpose(input_array): - a = onp.transpose(input_array, (0, 2, 1)) - b = onp.transpose(input_array, [2, 1, 0]) - c = onp.transpose(input_array, (1, 0, 2)) - d = onp.transpose(input_array) - return a, b, c, d - -# Test np.expand_dims - - -def mnp_expand_dims(input_tensor): - a = mnp.expand_dims(input_tensor, 0) - b = mnp.expand_dims(input_tensor, -1) - c = mnp.expand_dims(input_tensor, axis=2) - d = mnp.expand_dims(input_tensor, axis=-2) - return a, b, c, d - - -def onp_expand_dims(input_array): - a = onp.expand_dims(input_array, 0) - b = onp.expand_dims(input_array, -1) - c = onp.expand_dims(input_array, axis=2) - d = onp.expand_dims(input_array, axis=-2) - return a, b, c, d - -# Test np.squeeze - - -def mnp_squeeze(input_tensor): - a = mnp.squeeze(input_tensor) - b = mnp.squeeze(input_tensor, 0) - c = mnp.squeeze(input_tensor, axis=None) - d = mnp.squeeze(input_tensor, axis=-3) - e = mnp.squeeze(input_tensor, (2,)) - f = mnp.squeeze(input_tensor, (0, 2)) - return a, b, c, d, e, f - - -def onp_squeeze(input_array): - a = onp.squeeze(input_array) - b = onp.squeeze(input_array, 0) - c = onp.squeeze(input_array, axis=None) - d = onp.squeeze(input_array, axis=-3) - e = onp.squeeze(input_array, (2,)) - f = onp.squeeze(input_array, (0, 2)) - return a, b, c, d, e, f - -# Test np.rollaxis - - -def mnp_rollaxis(input_tensor): - a = mnp.rollaxis(input_tensor, 0, 1) - b = mnp.rollaxis(input_tensor, 0, 2) - c = mnp.rollaxis(input_tensor, 2, 1) - d = mnp.rollaxis(input_tensor, 2, 2) - e = mnp.rollaxis(input_tensor, 0) - f = mnp.rollaxis(input_tensor, 1) - return a, b, c, d, e, f - - -def onp_rollaxis(input_array): - a = onp.rollaxis(input_array, 0, 1) - b = onp.rollaxis(input_array, 0, 2) - c = onp.rollaxis(input_array, 2, 1) - d = onp.rollaxis(input_array, 2, 2) - e = onp.rollaxis(input_array, 0) - f = onp.rollaxis(input_array, 1) - return a, b, c, d, e, f - -# Test np.swapaxes - - -def mnp_swapaxes(input_tensor): - a = mnp.swapaxes(input_tensor, 0, 1) - b = mnp.swapaxes(input_tensor, 1, 0) - c = mnp.swapaxes(input_tensor, 1, 1) - d = mnp.swapaxes(input_tensor, 2, 1) - e = mnp.swapaxes(input_tensor, 1, 2) - f = mnp.swapaxes(input_tensor, 2, 2) - return a, b, c, d, e, f - - -def onp_swapaxes(input_array): - a = onp.swapaxes(input_array, 0, 1) - b = onp.swapaxes(input_array, 1, 0) - c = onp.swapaxes(input_array, 1, 1) - d = onp.swapaxes(input_array, 2, 1) - e = onp.swapaxes(input_array, 1, 2) - f = onp.swapaxes(input_array, 2, 2) - return a, b, c, d, e, f - -# Test np.reshape - - -def mnp_reshape(input_tensor): - a = mnp.reshape(input_tensor, (3, 8)) - b = mnp.reshape(input_tensor, [3, -1]) - c = mnp.reshape(input_tensor, (-1, 12)) - d = mnp.reshape(input_tensor, (-1,)) - e = mnp.reshape(input_tensor, 24) - f = mnp.reshape(input_tensor, [2, 4, -1]) - return a, b, c, d, e, f - - -def onp_reshape(input_array): - a = onp.reshape(input_array, (3, 8)) - b = onp.reshape(input_array, [3, -1]) - c = onp.reshape(input_array, (-1, 12)) - d = onp.reshape(input_array, (-1,)) - e = onp.reshape(input_array, 24) - f = onp.reshape(input_array, [2, 4, -1]) - return a, b, c, d, e, f - -# Test np.ravel - - -def mnp_ravel(input_tensor): - a = mnp.ravel(input_tensor) - return a - - -def onp_ravel(input_array): - a = onp.ravel(input_array) - return a - -# Test np.concatenate - - -def mnp_concatenate(input_tensor): - a = mnp.concatenate(input_tensor, None) - b = mnp.concatenate(input_tensor, 0) - c = mnp.concatenate(input_tensor, 1) - d = mnp.concatenate(input_tensor, 2) - return a, b, c, d - - -def onp_concatenate(input_array): - a = onp.concatenate(input_array, None) - b = onp.concatenate(input_array, 0) - c = onp.concatenate(input_array, 1) - d = onp.concatenate(input_array, 2) - return a, b, c, d - - -def test_transpose(): - onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_transposed = onp_transpose(onp_array) - m_transposed = mnp_transpose(mnp_array) - check_all_results(o_transposed, m_transposed) - - -def test_expand_dims(): - onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_expanded = onp_expand_dims(onp_array) - m_expanded = mnp_expand_dims(mnp_array) - check_all_results(o_expanded, m_expanded) - - -def test_squeeze(): - onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_squeezed = onp_squeeze(onp_array) - m_squeezed = mnp_squeeze(mnp_array) - check_all_results(o_squeezed, m_squeezed) - - onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_squeezed = onp_squeeze(onp_array) - m_squeezed = mnp_squeeze(mnp_array) - check_all_results(o_squeezed, m_squeezed) - - -def test_rollaxis(): - onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_rolled = onp_rollaxis(onp_array) - m_rolled = mnp_rollaxis(mnp_array) - check_all_results(o_rolled, m_rolled) - - -def test_swapaxes(): - onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_swaped = onp_swapaxes(onp_array) - m_swaped = mnp_swapaxes(mnp_array) - check_all_results(o_swaped, m_swaped) - - -def test_reshape(): - onp_array = onp.random.random((2, 3, 4)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_reshaped = onp_reshape(onp_array) - m_reshaped = mnp_reshape(mnp_array) - check_all_results(o_reshaped, m_reshaped) - - -def test_ravel(): - onp_array = onp.random.random((2, 3, 4)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_ravel = onp_ravel(onp_array) - m_ravel = mnp_ravel(mnp_array).asnumpy() - match_array(o_ravel, m_ravel) - - -def test_concatenate(): - onp_array = onp.random.random((5, 4, 3, 2)).astype('float32') - mnp_array = mnp.asarray(onp_array) - o_concatenate = onp_concatenate(onp_array) - m_concatenate = mnp_concatenate(mnp_array) - check_all_results(o_concatenate, m_concatenate) - - -class ReshapeExpandSqueeze(Cell): - def __init__(self): - super(ReshapeExpandSqueeze, self).__init__() - - def construct(self, x): - x = mnp.expand_dims(x, 2) - x = mnp.reshape(x, (1, 2, 3, 4, 1, 1)) - x = mnp.squeeze(x) - return x - - -class TransposeConcatRavel(Cell): - def __init__(self): - super(TransposeConcatRavel, self).__init__() - - def construct(self, x1, x2, x3): - x1 = mnp.transpose(x1, [0, 2, 1]) - x2 = x2.transpose(0, 2, 1) - x = mnp.concatenate((x1, x2, x3), -1) - x = mnp.ravel(x) - return x - - -class RollSwap(Cell): - def __init__(self): - super(RollSwap, self).__init__() - - def construct(self, x): - x = mnp.rollaxis(x, 2) - x = mnp.swapaxes(x, 0, 1) - return x - - -test_case_array_ops = [ - ('ReshapeExpandSqueeze', { - 'block': ReshapeExpandSqueeze(), - 'desc_inputs': [mnp.ones((2, 3, 4))]}), - - ('TransposeConcatRavel', { - 'block': TransposeConcatRavel(), - 'desc_inputs': [mnp.ones((2, 3, 4)), - mnp.ones((2, 3, 4)), - mnp.ones((2, 4, 1))]}), - - ('RollSwap', { - 'block': RollSwap(), - 'desc_inputs': [mnp.ones((2, 3, 4))]}) -] - -test_case_lists = [test_case_array_ops] -test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) -# use -k to select certain testcast -# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm - - -@non_graph_engine -@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) -def test_exec(): - context.set_context(mode=context.GRAPH_MODE) - return test_exec_case - - -def test_expand_dims_exception(): - with pytest.raises(TypeError): - mnp.expand_dims(mnp.ones((3, 3)), 1.2) - - -def test_asarray_exception(): - with pytest.raises(TypeError): - mnp.asarray({1, 2, 3}) - - -def test_swapaxes_exception(): - with pytest.raises(ValueError): - mnp.swapaxes(mnp.ones((3, 3)), 1, 10) - - -def test_linspace_exception(): - with pytest.raises(TypeError): - mnp.linspace(0, 1, num=2.5) diff --git a/tests/ut/python/numpy_native/test_math_ops.py b/tests/ut/python/numpy_native/test_math_ops.py deleted file mode 100644 index 0758dc5200e..00000000000 --- a/tests/ut/python/numpy_native/test_math_ops.py +++ /dev/null @@ -1,102 +0,0 @@ - -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""unit tests for numpy math operations""" - -import pytest -import numpy as onp - -import mindspore.numpy as mnp - - -def rand_int(*shape): - """return an random integer array with parameter shape""" - res = onp.random.randint(low=1, high=5, size=shape) - if isinstance(res, onp.ndarray): - res = res.astype(onp.float32) - return res - - -class Cases(): - def __init__(self): - - self.arrs = [ - rand_int(2), - rand_int(2, 3), - rand_int(2, 3, 4), - rand_int(2, 3, 4, 5), - ] - - # scalars expanded across the 0th dimension - self.scalars = [ - rand_int(), - rand_int(1), - rand_int(1, 1), - rand_int(1, 1, 1), - ] - - # arrays with last dimension aligned - self.aligned_arrs = [ - rand_int(2, 3), - rand_int(1, 4, 3), - rand_int(5, 1, 2, 3), - rand_int(4, 2, 1, 1, 3), - ] - - -test_case = Cases() - - -def mnp_inner(a, b): - return mnp.inner(a, b) - - -def onp_inner(a, b): - return onp.inner(a, b) - - -def test_inner(): - for arr1 in test_case.aligned_arrs: - for arr2 in test_case.aligned_arrs: - match_res(mnp_inner, onp_inner, arr1, arr2) - - for scalar1 in test_case.scalars: - for scalar2 in test_case.scalars: - match_res(mnp_inner, onp_inner, - scalar1, scalar2) - - -# check if the output from mnp function and onp function applied on the arrays are matched - - -def match_res(mnp_fn, onp_fn, arr1, arr2): - actual = mnp_fn(mnp.asarray(arr1, dtype='float32'), - mnp.asarray(arr2, dtype='float32')).asnumpy() - expected = onp_fn(arr1, arr2) - match_array(actual, expected) - - -def match_array(actual, expected, error=5): - if error > 0: - onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), - decimal=error) - else: - onp.testing.assert_equal(actual.tolist(), expected.tolist()) - - -def test_exception_innner(): - with pytest.raises(ValueError): - mnp.inner(mnp.asarray(test_case.arrs[0]), - mnp.asarray(test_case.arrs[1])) diff --git a/tests/ut/python/pipeline/parse/test_astype.py b/tests/ut/python/pipeline/parse/test_astype.py new file mode 100644 index 00000000000..ab0586204a1 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_astype.py @@ -0,0 +1,79 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test astype""" +import pytest + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_astype(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.astype("float16") + + net = Net() + res = net() + assert res.dtype == mstype.float16 + + +def test_astype_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.int64) + + def construct(self): + return self.value.astype(mstype.bool_) + + net = Net() + res = net() + assert res.dtype == mstype.bool_ + + +def test_astype_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float64) + + def construct(self): + return self.value.astype(mstype.uint64) + + net = Net() + res = net() + assert res.dtype == mstype.uint64 + + +def test_astype_error_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.astype("float88") + + net = Net() + with pytest.raises(TypeError): + net() diff --git a/tests/ut/python/pipeline/parse/test_flatten.py b/tests/ut/python/pipeline/parse/test_flatten.py new file mode 100644 index 00000000000..202e5a6b40d --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_flatten.py @@ -0,0 +1,77 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test flatten""" +import pytest + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_flatten(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.flatten() + + net = Net() + net() + + +def test_flatten_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.flatten(order='F') + + net = Net() + net() + + +def test_flatten_error(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.flatten(order='X') + + net = Net() + with pytest.raises(ValueError): + net() + + +def test_flatten_error_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.flatten(order=123) + + net = Net() + with pytest.raises(TypeError): + net() diff --git a/tests/ut/python/pipeline/parse/test_properties.py b/tests/ut/python/pipeline/parse/test_properties.py new file mode 100644 index 00000000000..8af7ace32ab --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_properties.py @@ -0,0 +1,103 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test tensor properties in graph mode""" + +import numpy as np + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_ndim(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor(np.random.random( + (2, 3, 4, 5)), dtype=mstype.float32) + + def construct(self): + return self.value.ndim + + net = Net() + res = net() + assert res == 4 + + +def test_nbytes(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor(np.random.random( + (2, 3, 4, 5)), dtype=mstype.float32) + + def construct(self): + return self.value.nbytes + + net = Net() + res = net() + assert res == 480 + + +def test_size(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor(np.random.random( + (2, 3, 4, 5)), dtype=mstype.float32) + + def construct(self): + return self.value.size + + net = Net() + res = net() + assert res == 120 + + +def test_strides(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor(np.random.random( + (2, 3, 4, 5)), dtype=mstype.float32) + + def construct(self): + return self.value.strides + + net = Net() + res = net() + assert res == (240, 80, 20, 4) + + +def test_itemsize(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value1 = Tensor(np.random.random( + (2, 3, 4, 5)), dtype=mstype.float64) + self.value2 = Tensor(np.random.random( + (2, 3, 4, 5)), dtype=mstype.int32) + self.value3 = Tensor(np.random.random( + (2, 3, 4, 5)), dtype=mstype.bool_) + + def construct(self): + return (self.value1.itemsize, self.value2.itemsize, self.value3.itemsize) + + net = Net() + res = net() + assert res == (8, 4, 1) diff --git a/tests/ut/python/pipeline/parse/test_reshape.py b/tests/ut/python/pipeline/parse/test_reshape.py new file mode 100644 index 00000000000..91f2fd901d5 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_reshape.py @@ -0,0 +1,90 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test reshape""" +import pytest + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_reshape(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.reshape(-1) + + net = Net() + net() + + +def test_reshape_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.reshape([3, 2, 1]) + + net = Net() + net() + + +def test_reshape_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.reshape((-1, 2)) + + net = Net() + net() + + +def test_reshape_error(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.reshape(1, 2, 4) + + net = Net() + with pytest.raises(ValueError): + net() + + +def test_reshape_error_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.reshape((1, 2, 3.5)) + + net = Net() + with pytest.raises(TypeError): + net() diff --git a/tests/ut/python/pipeline/parse/test_transpose.py b/tests/ut/python/pipeline/parse/test_transpose.py new file mode 100644 index 00000000000..16db9c4cc84 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_transpose.py @@ -0,0 +1,103 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test transpose""" +import pytest + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_transpose(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.transpose() + + net = Net() + net() + + +def test_transpose_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.transpose(1, 0) + + net = Net() + net() + + +def test_transpose_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.transpose([1, 0]) + + net = Net() + net() + + +def test_transpose_3(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.transpose((1, 0)) + + net = Net() + net() + + +def test_transpose_error(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.transpose(0, 2, 1) + + net = Net() + with pytest.raises(ValueError): + net() + + +def test_transpose_error_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32) + + def construct(self): + return self.value.transpose(1.0, 0) + + net = Net() + with pytest.raises(TypeError): + net()