Add new np interfaces and add graph support

This commit is contained in:
yanglf1121 2020-12-24 15:15:26 +08:00
parent 90777cd3bf
commit c5ea8223f5
14 changed files with 6516 additions and 1176 deletions

View File

@ -22,36 +22,59 @@ Note:
- array_ops.py defines all the array operation interfaces.
- array_creations.py defines all the array generation interfaces.
- math_ops.py defines all the math operations on tensors.
- logic_ops.py defines all the logical operations on tensors.
- dtypes.py defines all the mindspore.numpy dtypes (mainly redirected from mindspore)
"""
from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape,
ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d,
column_stack, hstack, dstack, vstack, stack, unique)
column_stack, hstack, dstack, vstack, stack, unique, moveaxis,
tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit,
flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat)
from .array_creations import copy_ as copy
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
linspace, logspace, eye, identity, empty, empty_like,
ones_like, zeros_like, full_like, diagonal, tril, triu,
tri, trace)
tri, trace, cumsum, meshgrid, mgrid, ogrid, diagflat,
diag, diag_indices, ix_)
from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16,
uint32, uint64, float_, float16, float32, float64, bool_, inf,
numeric_types)
from .math_ops import (mean, inner, add, subtract, multiply, divide, power,
dot, outer, tensordot, absolute)
uint32, uint64, float_, float16, float32, float64, bool_, inf, nan,
numeric_types, PINF, NINF)
from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide, power,
dot, outer, tensordot, absolute, std, var, average, minimum,
matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin,
hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero,
positive, negative, clip, floor_divide, remainder, fix, fmod, trunc,
exp, expm1)
from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite,
isnan, isinf, isposinf, isneginf, isscalar)
mod = remainder
fabs = absolute
array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape',
'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d',
'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique']
'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis',
'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit',
'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take',
'repeat']
array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange',
'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like',
'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu',
'tri', 'trace']
'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag',
'diag_indices', 'ix_', 'cumsum']
math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'power',
'dot', 'outer', 'tensordot', 'absolute']
math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power',
'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal',
'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum',
'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad',
'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide',
'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs']
__all__ = array_ops_module + array_creations_module + math_module + numeric_types
logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite',
'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar']
__all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types
__all__.sort()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,12 @@ from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, ui
# backend for now.
inf = float('inf')
PINF = float('inf')
NINF = float('-inf')
nan = float('nan')
# all three of inf, PINF, and NINF are defined in the original numpy, and as we aim for
# consistency same thing is done here
pi = 3.141592653589793
int_ = int32
uint = uint32

View File

@ -0,0 +1,576 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""logical operations, the function docs are adapted from Numpy API."""
from .math_ops import _apply_tensor_op
from ..ops import functional as F
from ..common import dtype as mstype
from .._c_expression import typing
from .array_creations import zeros, ones
from .utils import _check_input_tensor
def not_equal(x1, x2, out=None, where=True, dtype=None):
"""
Returns (x1 != x2) element-wise.
Args:
x1 (Tensor): First input tensor to be compared.
x2 (Tensor): Second input tensor to be compared.
out (Tensor or None, optional), default is None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are
scalars.
Raises:
TypeError: If the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.numpy as np
>>> a = np.asarray([1, 2])
>>> b = np.asarray([[1, 3],[1, 4]])
>>> print(np.not_equal(a, b))
>>> [[False True]
[False True]]
"""
_check_input_tensor(x1, x2)
return _apply_tensor_op(F.not_equal, x1, x2, out=out, where=where, dtype=dtype)
def less_equal(x1, x2, out=None, where=True, dtype=None):
"""
Returns the truth value of ``(x1 <= x2)`` element-wise.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Args:
x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
scalars.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.less_equal(np.array([4, 2, 1]), np.array([2, 2, 2]))
>>> print(output)
[False True True]
"""
_check_input_tensor(x1, x2)
return _apply_tensor_op(F.tensor_le, x1, x2, out=out, where=where, dtype=dtype)
def less(x1, x2, out=None, where=True, dtype=None):
"""
Returns the truth value of ``(x1 < x2)`` element-wise.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Args:
x1 (Tensor): input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
scalars.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.less(np.array([1, 2]), np.array([2, 2]))
>>> print(output)
[ True False]
"""
return _apply_tensor_op(F.tensor_lt, x1, x2, out=out, where=where, dtype=dtype)
def greater_equal(x1, x2, out=None, where=True, dtype=None):
"""
Returns the truth value of ``(x1 >= x2)`` element-wise.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Args:
x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
scalars.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.greater_equal(np.array([4, 2, 1]), np.array([2, 2, 2]))
>>> print(output)
[ True True False]
"""
return _apply_tensor_op(F.tensor_ge, x1, x2, out=out, where=where, dtype=dtype)
def greater(x1, x2, out=None, where=True, dtype=None):
"""
Returns the truth value of ``(x1 > x2)`` element-wise.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Args:
x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
scalars.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.greater(np.array([4, 2]), np.array([2, 2]))
>>> print(output)
[ True False]
"""
return _apply_tensor_op(F.tensor_gt, x1, x2, out=out, where=where, dtype=dtype)
def equal(x1, x2, out=None, where=True, dtype=None):
"""
Returns the truth value of ``(x1 == x2)`` element-wise.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
Args:
x1 (Tensor): Input array.
x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be
broadcastable to a common shape (which becomes the shape of the output).
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type
bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are
scalars.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.equal(np.array([0, 1, 3]), np.arange(3))
>>> print(output)
[ True True False]
"""
return _apply_tensor_op(F.equal, x1, x2, out=out, where=where, dtype=dtype)
def isfinite(x, out=None, where=True, dtype=None):
"""
Tests element-wise for finiteness (not infinity or not Not a Number).
The result is returned as a boolean array.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
On GPU, the supported dtypes are np.float16, and np.float32.
Args:
x (Tensor): Input values.
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, true where `x` is not positive infinity, negative infinity,
or NaN; false otherwise. This is a scalar if `x` is a scalar.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.isfinite(np.array([np.inf, 1., np.nan]).astype('float32'))
>>> print(output)
[False True False]
>>> output = np.isfinite(np.log(np.array(-1.).astype('float32')))
>>> print(output)
False
"""
return _apply_tensor_op(F.isfinite, x, out=out, where=where, dtype=dtype)
def _isnan(x):
"""Compures isnan without applying keyword arguments."""
shape = F.shape(x)
zeros_tensor = zeros(shape, mstype.float32)
ones_tensor = ones(shape, mstype.float32)
non_neg = F.tensor_ge(x, zeros_tensor)
non_pos = F.tensor_le(x, zeros_tensor)
res = F.select(non_neg, zeros_tensor, ones_tensor)
res = F.select(non_pos, zeros_tensor, res)
return F.cast(res, mstype.bool_)
def isnan(x, out=None, where=True, dtype=None):
"""
Tests element-wise for NaN and return result as a boolean array.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
On GPU, the supported dtypes are np.float16, and np.float32.
Args:
x (Tensor): Input values.
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, true where `x` is NaN, false otherwise. This is a scalar if
`x` is a scalar.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.isnan(np.array(np.nan, np.float32))
>>> print(output)
True
>>> output = np.isnan(np.array(np.inf, np.float32))
>>> print(output)
False
"""
return _apply_tensor_op(_isnan, x, out=out, where=where, dtype=dtype)
def _isinf(x):
"""Computes isinf without applying keyword arguments."""
shape = F.shape(x)
zeros_tensor = zeros(shape, mstype.float32)
ones_tensor = ones(shape, mstype.float32)
not_inf = F.isfinite(x)
is_nan = _isnan(x)
res = F.select(not_inf, zeros_tensor, ones_tensor)
res = F.select(is_nan, zeros_tensor, res)
return F.cast(res, mstype.bool_)
def isinf(x, out=None, where=True, dtype=None):
"""
Tests element-wise for positive or negative infinity.
Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False.
Note:
Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are
not supported.
When `where` is provided, `out` must have a tensor value. `out` is not supported
for storing the result, however it can be used in combination with `where` to set
the value at indices for which `where` is set to False.
On GPU, the supported dtypes are np.float16, and np.float32.
Args:
x (Tensor): Input values.
out (Tensor or None, optional): defaults to None.
where (Tensor or None, optional): For any non-default value of type other
than :class:`Tensor` or :class:`None`, the output retains its original value.
This condition is broadcasted over the input. At locations where the
condition is `True`, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value. Note that
if an uninitialized out array is created via the default ``out=None``,
locations within it where the condition is `False` will remain
uninitialized.
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
output Tensor.
Returns:
Tensor or scalar, true where `x` is positive or negative infinity, false
otherwise. This is a scalar if `x` is a scalar.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.isinf(np.array(np.inf, np.float32))
>>> print(output)
True
>>> output = np.isinf(np.array([np.inf, -np.inf, 1.0, np.nan], np.float32))
>>> print(output)
[ True True False False]
"""
return _apply_tensor_op(_isinf, x, out=out, where=where, dtype=dtype)
def _is_sign_inf(x, fn):
"""Tests element-wise for inifinity with sign."""
shape = F.shape(x)
zeros_tensor = zeros(shape, mstype.float32)
ones_tensor = ones(shape, mstype.float32)
not_inf = F.isfinite(x)
is_sign = fn(x, zeros_tensor)
res = F.select(not_inf, zeros_tensor, ones_tensor)
res = F.select(is_sign, res, zeros_tensor)
return F.cast(res, mstype.bool_)
def isposinf(x):
"""
Tests element-wise for positive infinity, returns result as bool array.
Note:
Numpy argument `out` is not supported.
On GPU, the supported dtypes are np.float16, and np.float32.
Args:
x (Tensor): Input values.
Returns:
Tensor or scalar, true where `x` is positive infinity, false otherwise.
This is a scalar if `x` is a scalar.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.isposinf(np.array([-np.inf, 0., np.inf], np.float32))
>>> print(output)
[False False True]
"""
_check_input_tensor(x)
return _is_sign_inf(x, F.tensor_gt)
def isneginf(x):
"""
Tests element-wise for negative infinity, returns result as bool array.
Note:
Numpy argument `out` is not supported.
On GPU, the supported dtypes are np.float16, and np.float32.
Args:
x (Tensor): Input values.
Returns:
Tensor or scalar, true where `x` is negative infinity, false otherwise.
This is a scalar if `x` is a scalar.
Raises:
TypeError: if the input is not a tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.isneginf(np.array([-np.inf, 0., np.inf], np.float32))
>>> print(output)
[ True False False]
"""
return _is_sign_inf(x, F.tensor_lt)
def isscalar(element):
"""
Returns True if the type of element is a scalar type.
Note:
Only object types recognized by the mindspore parser are supported,
which includes objects, types, methods and functions defined within
the scope of mindspore. Other built-in types are not supported.
Args:
element (any): Input argument, can be of any type and shape.
Returns:
Boolean, True if `element` is a scalar type, False if it is not.
Raises:
TypeError: if the type of `element` is not supported by mindspore parser.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = np.isscalar(3.1)
>>> print(output)
True
>>> output = np.isscalar(np.array(3.1))
>>> print(output)
False
>>> output = np.isscalar(False)
>>> print(output)
True
>>> output = np.isscalar('numpy')
>>> print(output)
True
"""
return isinstance(F.typeof(element), (typing.Number, typing.Int, typing.UInt,
typing.Float, typing.Bool, typing.String))

File diff suppressed because it is too large Load Diff

View File

@ -16,11 +16,11 @@
import numpy as onp
import mindspore.context as context
from ..common import Tensor
from ..ops import functional as F
from ..common import dtype as mstype
from .utils_const import _tile_size
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error
def _deep_list(array_like):
@ -56,10 +56,9 @@ def _deep_tensor_to_nparray(array_like):
def _check_input_for_asarray(array_like):
"""check whether array_like argument is a valid type for np.asarray conversion"""
if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
return True
raise TypeError("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \
f"or numpy.ndarray, but got {type(array_like)}")
if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
_raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \
"or numpy.ndarray, but got ", array_like)
def _is_scalar(shape):
@ -67,16 +66,6 @@ def _is_scalar(shape):
return F.shape_mul(shape) == 1
def _is_empty(shape):
"""Checks if the shape is empty"""
return F.shape_mul(shape) == 0
def _get_device():
"""Get the current device (`GPU`, `CPU`, `Ascend`)"""
return context.get_context('device_target')
def _convert_list_tensor_to_tuple_tensor(list_of_tensor):
"""Convert a list of tensor to a tuple of tensor"""
if isinstance(list_of_tensor, list):
@ -87,19 +76,66 @@ def _convert_list_tensor_to_tuple_tensor(list_of_tensor):
return list_of_tensor
def _get_mode():
"""Get the current mode (0 is Graph mode, 1 is PyNative mode)"""
return context.get_context('mode')
def _expand(x, ndim, axis=0):
"""Expand x to ndim."""
while F.rank(x) < ndim:
x = F.expand_dims(x, axis)
return x
"""Expand x to ndim from axis, which can be 0 or -1."""
shape = _add_unit_axes(F.shape(x), ndim, axis == -1)
return F.reshape(x, shape)
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
"""Broadcasts x from shape_cur to shape_to."""
size = _tile_size(shape_cur, shape_to, ndim_to)
return F.tile(x, size)
def _broadcast_to_shape(x, shape):
"""Broadcasts x from current shape to shape"""
ndim_to = len(shape)
x = _expand(x, ndim_to)
return _broadcast_to(x, F.shape(x), shape, ndim_to)
def _get_size(x, axis=None):
"""Get the number of elements along the given axis of tensor x."""
if axis is None or F.tuple_len(axis) == 0:
axis = F.make_range(x.ndim)
nums = 1
for ax in axis:
nums *= x.shape[ax]
return nums
def _check_input_tensor(*tensors):
for tensor in tensors:
if not isinstance(tensor, Tensor):
_raise_type_error('expect Tensor, but got ', F.typeof(tensor))
return True
def _convert_64_to_32(tensor):
"""Convert tensor with float64/int64 types to float32/int32."""
if tensor.dtype == mstype.float64:
return tensor.astype("float32")
if tensor.dtype == mstype.int64:
return tensor.astype("int32")
return tensor
def _get_dtype_from_scalar(*input_numbers):
"""
Get the final dtype from series of input numbers, compared with F.typeof, we
return int32/float32 for python int/float instead.
"""
bool_flag = True
int_flag = True
for number in input_numbers:
if number is not None:
if not isinstance(number, bool):
bool_flag = False
if not isinstance(number, int):
int_flag = False
if bool_flag:
return mstype.bool_
if int_flag:
return mstype.int32
return mstype.float32

View File

@ -13,14 +13,16 @@
# limitations under the License.
# ============================================================================
"""internal graph-compatible utility functions"""
import math
from functools import partial
import mindspore.context as context
from ..ops import functional as F
from ..ops.primitive import constexpr
from ..common import dtype as mstype
from ..common import Tensor
from .._c_expression import Tensor as Tensor_
from .._c_expression.typing import Tuple, List
from .._c_expression import typing
from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map
@ -28,12 +30,17 @@ from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map
@constexpr
def _check_shape(shape):
"""check the shape param to match the numpy style"""
if not isinstance(shape, (int, tuple, list, Tuple, List)):
if not isinstance(shape, (int, tuple, list, typing.Tuple, typing.List)):
raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}")
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, (list, List)):
if isinstance(shape, (list, typing.List)):
shape = tuple(shape)
for s in shape:
if not isinstance(s, int):
raise TypeError("each entry in shape should be int.")
if s < 0:
raise ValueError("each entry in shape should no less than 0.")
return shape
@ -57,7 +64,7 @@ def _check_dtype(dtype):
@constexpr
def _check_shape_contain_zero(shp):
def _is_shape_empty(shp):
"""Check whether shape contains zero"""
if isinstance(shp, int):
return shp == 0
@ -77,35 +84,28 @@ def _check_start_normalize(start, ndim):
@constexpr
def _check_axes_range(axes, ndim):
"""
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
Check axes type and normalize the negative axes.
Args:
axes (Union[int, tuple(int), list(int)]): Axes of the tensor.
axes: Axes of the tensor.
ndim (int): The number of dimensions of the tensor.
Return:
Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
Raises:
TypeError: If the axes are not integer, tuple(int) or list(int).
ValueError: If duplicate axes exists or some axis is out of bounds.
"""
if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list):
raise TypeError(f"int, tuple(int) or list(int) expected, but got {type(axes)}.")
low = -ndim
up = ndim - 1
if low > up:
raise ValueError(f"Lower bound {low} and upper bound {up} of axes are not allowed.")
if isinstance(axes, int):
if axes < low or axes > up:
raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}.")
return axes if axes >= 0 else axes + ndim
new_axes = []
for item in axes:
if not isinstance(item, int):
raise TypeError(f"int in tuple or list expected, but got {type(item)}.")
if item < low or item > up:
raise ValueError(f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.")
new_axes.append(item if item >= 0 else item + ndim)
return tuple(new_axes)
_check_axis_type(axes, True, True, True)
if isinstance(axes, (list, tuple)):
_check_element_int(axes)
axes = _canonicalize_axis(axes, ndim)
return axes
@constexpr
def _get_device_compile():
def _get_device():
"""Get the current device (`GPU`, `CPU`, `Ascend`)"""
return context.get_context('device_target')
@ -153,9 +153,10 @@ def _infer_out_shape(*shapes):
@constexpr
def _check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
if -ndim <= axis < ndim:
return True
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
if not isinstance(axis, int):
raise TypeError(f'axes should be integers, not {type(axis)}')
if not -ndim <= axis < ndim:
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
@constexpr
@ -165,26 +166,25 @@ def _check_axis_valid(axes, ndim):
to the built-in operator (non-negative, int or tuple)
"""
if isinstance(axes, int):
_ = _check_axis_in_range(axes, ndim)
_check_axis_in_range(axes, ndim)
return (axes % ndim,)
if isinstance(axes, tuple):
if isinstance(axes, (tuple, list)):
for axis in axes:
_ = _check_axis_in_range(axis, ndim)
_check_axis_in_range(axis, ndim)
axes = tuple(map(lambda x: x % ndim, axes))
if all(axes.count(el) <= 1 for el in axes):
return axes
if axes is None:
axes = F.make_range(ndim)
return axes
raise ValueError('duplicate value in \'axis\'')
raise ValueError('duplicate value in "axis"')
@constexpr
def _check_shape_aligned(shape1, shape2):
"""Checks shape1 and shape2 are valid shapes to perform inner product"""
if shape1[-1] == shape2[-1]:
return True
raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
if shape1[-1] != shape2[-1]:
raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
@constexpr
@ -197,30 +197,6 @@ def _tile_size(shape, out_shape, ndim):
return tuple(size)
@constexpr
def _check_is_int(obj):
"""Check whether obj is an integer."""
return isinstance(obj, int)
@constexpr
def _check_is_tuple(obj):
"""Check whether obj is a tuple"""
return isinstance(obj, (tuple, Tuple))
@constexpr
def _check_is_list(obj):
"""Check whether obj is a list"""
return isinstance(obj, (list, List))
@constexpr
def _check_is_tensor(obj):
"""Check whether obj is a tensor"""
return isinstance(obj, mstype.tensor_type)
@constexpr
def _raise_type_error(info, param=None):
"""
@ -298,6 +274,177 @@ def _check_is_float(dtype):
@constexpr
def _check_input_tensor(input_type):
if not _check_is_tensor(input_type):
raise TypeError(f'expect Tensor, but got {input_type}')
def _check_is_int(dtype):
return isinstance(dtype, typing.Int)
@constexpr
def _check_matmul_shapes(shape1, shape2):
"""Checks shape1 and shape2 are valid shapes to perform matmul"""
ndim1, ndim2 = len(shape1), len(shape2)
if ndim1 < 1 or ndim2 < 1:
raise ValueError('input operands must have at least 1 dimension')
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
raise ValueError(f'mismatch in core dimension of input operands (size '
f'{shape1[-1]} is different from {shape2[-2]})')
@constexpr
def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
"""Check axis argument type."""
if type_int and isinstance(axis, int):
return True
if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
for ax in axis:
if not isinstance(ax, int):
raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.")
return True
type_str = ""
if type_int: type_str += "int, "
if type_tuple: type_str += "tuple, "
if type_list: type_str += "list, "
raise TypeError(f"Axis should be {type_str}but got {type(axis)}.")
@constexpr
def _canonicalize_axis(axis, ndim):
"""
Check axes are within the number of dimensions of tensor x and normalize the negative axes.
Args:
axis (Union[int, tuple(int), list(int)]): Axes of the tensor.
ndim (int): The number of dimensions of the tensor.
Return:
Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
"""
if isinstance(axis, int):
axis = [axis]
for ax in axis:
_check_axis_in_range(ax, ndim)
def canonicalizer(ax):
return ax + ndim if ax < 0 else ax
axis = tuple([canonicalizer(axis) for axis in axis])
if all(axis.count(el) <= 1 for el in axis):
return axis if len(axis) > 1 else axis[0]
raise ValueError(f"duplicate axes in {axis}.")
@constexpr
def _broadcast_tuples(tup1, tup2):
"""
Broadcast two 1D tuples to the same length, if inputs are ints, convert to
tuples first.
"""
tup1 = (tup1,) if isinstance(tup1, int) else tup1
tup2 = (tup2,) if isinstance(tup2, int) else tup2
if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
raise TypeError("input shift and axis must be tuple or list or int.")
if len(tup1) == len(tup2):
return tup1, tup2
if len(tup1) == 1:
tup1 *= len(tup2)
elif len(tup2) == 1:
tup2 *= len(tup1)
else:
raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
return tup1, tup2
@constexpr
def _expanded_shape(ndim, axis_size, axis):
"""
Returns a shape with size = 1 for all dimensions
except at axis.
"""
return tuple([axis_size if i == axis else 1 for i in range(ndim)])
@constexpr
def _add_unit_axes(shape, ndim, append=False):
"""
Prepends shape with 1s so that it has the number of dimensions ndim.
If append is set to True, returns shape appended with 1s instead.
"""
if isinstance(shape, int):
shape = (shape,)
ndim_diff = ndim - len(shape)
if ndim_diff > 0:
if append:
shape = [i for i in shape] + [1]*ndim_diff
else:
shape = [1]*ndim_diff + [i for i in shape]
return tuple(shape)
@constexpr
def _check_element_int(lst):
"""
Check whether each element in `lst` is an integer.
"""
for item in lst:
if not isinstance(item, int):
raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.")
return True
@constexpr
def _type_convert(force, obj):
"""
Convert type of `obj` to `force`.
"""
return force(obj)
@constexpr
def _list_comprehensions(obj, item=None, return_tuple=False):
"""
Generates a new list/tuple by list comprehension.
Args:
obj (Union[int, list, tuple]):
If integer, it will be the length of the returned tuple/list.
item: The value to be filled. Default: None.
If None, the values in the new list/tuple are the same as obj
or range(obj) when obj is integer.
return_tuple(bool): If true, returns tuple, else returns list.
Returns:
List or tuple.
"""
res = []
lst = obj
if isinstance(obj, int):
lst = range(obj)
if item is None:
res = [i for i in lst]
else:
res = [item for i in lst]
if return_tuple:
return tuple(res)
return res
@constexpr
def _tuple_getitem(tup, idx, startswith=True):
"""
Returns a slice from tup starting with idx. If startswith is False,
returns a lice from tup ending with idx instead.
"""
if startswith:
return tup[idx:]
return tup[:idx]
@constexpr
def _iota(dtype, num):
"""Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
# TODO: Change to P.Linspace when the kernel is implemented on CPU.
return Tensor(list(range(int(num))), dtype)
@constexpr
def _ceil(number):
"""Ceils the number in graph mode."""
return math.ceil(number)

View File

@ -59,18 +59,25 @@ tensor_div = P.RealDiv()
tensor_floordiv = P.FloorDiv()
tensor_pow = P.Pow()
tensor_mod = P.FloorMod()
tensor_exp = P.Exp()
tensor_expm1 = P.Expm1()
strided_slice = P.StridedSlice()
same_type_shape = P.SameTypeShape()
check_bprop = P.CheckBprop()
equal = P.Equal()
not_equal = P.NotEqual()
isfinite = P.IsFinite()
assign_sub = P.AssignSub()
assign_add = P.AssignAdd()
assign = P.Assign()
square = P.Square()
sqrt = P.Sqrt()
log = P.Log()
reduce_sum = P.ReduceSum()
tensor_slice = P.Slice()
maximum = P.Maximum()
minimum = P.Minimum()
floor = P.Floor()
scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()
@ -82,6 +89,7 @@ transpose = P.Transpose()
squeeze = P.Squeeze()
scatter_nd = P.ScatterNd()
gather = P.Gather()
gather_d = P.GatherD()
gather_nd = P.GatherNd()
scatter_update = P.ScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()

View File

@ -14,15 +14,13 @@
# ============================================================================
"""unit tests for numpy array operations"""
import functools
import pytest
import numpy as onp
import mindspore.context as context
import mindspore.numpy as mnp
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \
match_all_arrays
class Cases():
@ -97,10 +95,10 @@ class Cases():
self.mnp_prototypes = [
mnp.ones((2, 3, 4)),
mnp.ones((0, 3, 0, 2, 5)),
onp.ones((2, 7, 0)),
onp.ones(()),
[mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]],
([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])),
mnp.ones((2, 7, 0)),
mnp.ones(()),
[mnp.ones(3), (1, 2, 3), mnp.ones(3), [4, 5, 6]],
([(1, 2), mnp.ones(2)], (mnp.ones(2), [3, 4])),
]
self.onp_prototypes = [
@ -113,97 +111,6 @@ class Cases():
]
def match_array(actual, expected, error=0):
if error > 0:
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
decimal=error)
else:
onp.testing.assert_equal(actual.tolist(), expected.tolist())
def check_all_results(onp_results, mnp_results, error=0):
"""Check all results from numpy and mindspore.numpy"""
for i, _ in enumerate(onp_results):
match_array(onp_results[i], mnp_results[i].asnumpy())
def check_all_unique_results(onp_results, mnp_results):
"""
Check all results from numpy and mindspore.numpy.
Args:
onp_results (Union[tuple of numpy.arrays, numpy.array])
mnp_results (Union[tuple of Tensors, Tensor])
"""
for i, _ in enumerate(onp_results):
if isinstance(onp_results[i], tuple):
for j in range(len(onp_results[i])):
match_array(onp_results[i][j],
mnp_results[i][j].asnumpy(), error=7)
else:
match_array(onp_results[i], mnp_results[i].asnumpy(), error=7)
def run_non_kw_test(mnp_fn, onp_fn):
"""Run tests on functions with non keyword arguments"""
test_case = Cases()
for i in range(len(test_case.arrs)):
arrs = test_case.arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.scalars)):
arrs = test_case.scalars[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.expanded_arrs)):
arrs = test_case.expanded_arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.nested_arrs)):
arrs = test_case.nested_arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
def rand_int(*shape):
"""return an random integer array with parameter shape"""
res = onp.random.randint(low=1, high=5, size=shape)
if isinstance(res, onp.ndarray):
return res.astype(onp.float32)
return float(res)
# return an random boolean array
def rand_bool(*shape):
return onp.random.rand(*shape) > 0.5
def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
"""Checks results from applying mnp_fn and onp_fn on arrs respectively"""
mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs)
mnp_res = mnp_fn(*mnp_arrs, **kwargs)
onp_res = onp_fn(*arrs, **kwargs)
match_all_arrays(mnp_res, onp_res)
def match_all_arrays(mnp_res, onp_res, error=0):
if isinstance(mnp_res, (tuple, list)):
for actual, expected in zip(mnp_res, onp_res):
match_array(actual.asnumpy(), expected, error)
else:
match_array(mnp_res.asnumpy(), onp_res, error)
def match_meta(actual, expected):
# float64 and int64 are not supported, and the default type for
# float and int are float32 and int32, respectively
if expected.dtype == onp.float64:
expected = expected.astype(onp.float32)
elif expected.dtype == onp.int64:
expected = expected.astype(onp.int32)
assert actual.shape == expected.shape
assert actual.dtype == expected.dtype
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -440,27 +347,50 @@ def test_arange():
def test_linspace():
actual = onp.linspace(2.0, 3.0, dtype=onp.float32)
expected = mnp.linspace(2.0, 3.0).asnumpy()
match_array(actual, expected, error=7)
match_array(actual, expected, error=6)
actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32)
expected = mnp.linspace(2.0, 3.0, num=5).asnumpy()
match_array(actual, expected, error=7)
match_array(actual, expected, error=6)
actual = onp.linspace(
2.0, 3.0, num=5, endpoint=False, dtype=onp.float32)
expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy()
match_array(actual, expected, error=7)
match_array(actual, expected, error=6)
actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32)
expected = mnp.linspace(2.0, 3.0, num=5, retstep=True)
match_array(actual[0], expected[0].asnumpy())
assert actual[1] == expected[1]
assert actual[1] == expected[1].asnumpy()
actual = onp.linspace(2.0, [3, 4, 5], num=5,
endpoint=False, dtype=onp.float32)
expected = mnp.linspace(
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
match_array(actual, expected)
match_array(actual, expected, error=6)
start = onp.random.random([2, 1, 4])
stop = onp.random.random([1, 5, 1])
actual = onp.linspace(start, stop, num=20, retstep=True,
endpoint=False, dtype=onp.float32)
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
retstep=True, endpoint=False)
match_array(actual[0], expected[0].asnumpy(), error=6)
match_array(actual[1], expected[1].asnumpy(), error=6)
actual = onp.linspace(start, stop, num=20, retstep=True,
endpoint=False, dtype=onp.int16)
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
retstep=True, endpoint=False, dtype=mnp.int16)
match_array(actual[0], expected[0].asnumpy(), error=6)
match_array(actual[1], expected[1].asnumpy(), error=6)
for axis in range(2):
actual = onp.linspace(start, stop, num=20, retstep=False,
endpoint=False, dtype=onp.float32, axis=axis)
expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20,
retstep=False, endpoint=False, dtype=mnp.float32, axis=axis)
match_array(actual, expected.asnumpy(), error=6)
@pytest.mark.level1
@ -472,22 +402,22 @@ def test_linspace():
def test_logspace():
actual = onp.logspace(2.0, 3.0, dtype=onp.float32)
expected = mnp.logspace(2.0, 3.0).asnumpy()
match_array(actual, expected)
match_array(actual, expected, error=3)
actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32)
expected = mnp.logspace(2.0, 3.0, num=5).asnumpy()
match_array(actual, expected)
match_array(actual, expected, error=3)
actual = onp.logspace(
2.0, 3.0, num=5, endpoint=False, dtype=onp.float32)
expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy()
match_array(actual, expected)
match_array(actual, expected, error=3)
actual = onp.logspace(2.0, [3, 4, 5], num=5,
actual = onp.logspace(2.0, [3, 4, 5], num=5, base=2,
endpoint=False, dtype=onp.float32)
expected = mnp.logspace(
2.0, [3, 4, 5], num=5, endpoint=False).asnumpy()
match_array(actual, expected)
2.0, [3, 4, 5], num=5, base=2, endpoint=False).asnumpy()
match_array(actual, expected, error=3)
@pytest.mark.level1
@ -537,7 +467,6 @@ def run_x_like(mnp_fn, onp_fn):
actual = mnp_fn(mnp_proto, shape=shape).asnumpy()
expected = onp_fn(onp_proto, shape=shape)
match_array(actual, expected)
for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes,
test_case.onp_dtypes):
actual = mnp_fn(mnp_proto, dtype=mnp_dtype).asnumpy()
@ -581,18 +510,18 @@ def test_full_like():
for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes):
shape = onp.zeros_like(onp_proto).shape
fill_value = rand_int()
actual = mnp.full_like(mnp_proto, fill_value).asnumpy()
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
expected = onp.full_like(onp_proto, fill_value)
match_array(actual, expected)
for i in range(len(shape) - 1, 0, -1):
fill_value = rand_int(*shape[i:])
actual = mnp.full_like(mnp_proto, fill_value).asnumpy()
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
expected = onp.full_like(onp_proto, fill_value)
match_array(actual, expected)
fill_value = rand_int(1, *shape[i + 1:])
actual = mnp.full_like(mnp_proto, fill_value).asnumpy()
actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy()
expected = onp.full_like(onp_proto, fill_value)
match_array(actual, expected)
@ -620,6 +549,26 @@ def test_tri_triu_tril():
match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10))
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cumsum():
x = mnp.ones((16, 16), dtype="bool")
match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy()))
match_array(mnp.cumsum(x, axis=0).asnumpy(),
onp.cumsum(x.asnumpy(), axis=0))
match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy()))
x = rand_int(3, 4, 5)
match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(),
onp.cumsum(x, dtype="bool"))
match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(),
onp.cumsum(x, axis=-1))
def mnp_diagonal(arr):
return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0)
@ -697,6 +646,138 @@ def test_trace():
match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1)
def mnp_meshgrid(*xi):
a = mnp.meshgrid(*xi)
b = mnp.meshgrid(*xi, sparse=True)
c = mnp.meshgrid(*xi, indexing='ij')
d = mnp.meshgrid(*xi, sparse=False, indexing='ij')
return a, b, c, d
def onp_meshgrid(*xi):
a = onp.meshgrid(*xi)
b = onp.meshgrid(*xi, sparse=True)
c = onp.meshgrid(*xi, indexing='ij')
d = onp.meshgrid(*xi, sparse=False, indexing='ij')
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_meshgrid():
xi = (onp.full(3, 2), onp.full(1, 5), onp.full(
(2, 3), 9), onp.full((4, 5, 6), 7))
for i in range(len(xi)):
arrs = xi[i:]
mnp_arrs = map(mnp.asarray, arrs)
for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)):
match_all_arrays(mnp_res, onp_res)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mgrid():
mnp_res = mnp.mgrid[0:5]
onp_res = onp.mgrid[0:5]
match_all_arrays(mnp_res, onp_res, error=5)
mnp_res = mnp.mgrid[2:30:4j, -10:20:7, 2:5:0.5]
onp_res = onp.mgrid[2:30:4j, -10:20:7, 2:5:0.5]
match_all_arrays(mnp_res, onp_res, error=5)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ogrid():
mnp_res = mnp.ogrid[0:5]
onp_res = onp.ogrid[0:5]
match_all_arrays(mnp_res, onp_res, error=5)
mnp_res = mnp.ogrid[2:30:4j, -10:20:7, 2:5:0.5]
onp_res = onp.ogrid[2:30:4j, -10:20:7, 2:5:0.5]
match_all_arrays(mnp_res, onp_res, error=5)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_diagflat():
arrs = [rand_int(0), rand_int(2, 3), rand_int(3, 5, 0)]
for arr in arrs:
for i in [-2, 0, 7]:
match_res(mnp.diagflat, onp.diagflat, arr, k=i)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_diag():
arrs = [rand_int(0), rand_int(0, 0), rand_int(7), rand_int(5, 5),
rand_int(3, 8), rand_int(9, 6)]
for arr in arrs:
for i in [-10, -5, -1, 0, 2, 5, 6, 10]:
match_res(mnp.diag, onp.diag, arr, k=i)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_diag_indices():
mnp_res = mnp.diag_indices(0)
onp_res = onp.diag_indices(0)
match_all_arrays(mnp_res, onp_res)
mnp_res = mnp.diag_indices(3, 0)
onp_res = onp.diag_indices(3, 0)
match_all_arrays(mnp_res, onp_res)
mnp_res = mnp.diag_indices(5, 7)
onp_res = onp.diag_indices(5, 7)
match_all_arrays(mnp_res, onp_res)
def mnp_ix_(*args):
return mnp.ix_(*args)
def onp_ix_(*args):
return onp.ix_(*args)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_ix_():
arrs = [rand_int(i + 1) for i in range(10)]
for i in range(10):
test_arrs = arrs[:i + 1]
match_res(mnp_ix_, onp_ix_, *test_arrs)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training

View File

@ -22,6 +22,9 @@ import numpy as onp
import mindspore.numpy as mnp
from mindspore.nn import Cell
from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \
rand_bool, match_res, run_multi_test
class Cases():
def __init__(self):
@ -111,81 +114,6 @@ class Cases():
]
def match_array(actual, expected, error=0):
if error > 0:
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
decimal=error)
else:
onp.testing.assert_equal(actual.tolist(), expected.tolist())
def check_all_results(onp_results, mnp_results, error=0):
"""Check all results from numpy and mindspore.numpy"""
for i, _ in enumerate(onp_results):
match_array(onp_results[i], mnp_results[i].asnumpy())
def run_non_kw_test(mnp_fn, onp_fn):
"""Run tests on functions with non keyword arguments"""
test_case = Cases()
for i in range(len(test_case.arrs)):
arrs = test_case.arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.scalars)):
arrs = test_case.scalars[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.expanded_arrs)):
arrs = test_case.expanded_arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.nested_arrs)):
arrs = test_case.nested_arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
def rand_int(*shape):
"""return an random integer array with parameter shape"""
res = onp.random.randint(low=1, high=5, size=shape)
if isinstance(res, onp.ndarray):
return res.astype(onp.float32)
return float(res)
# return an random boolean array
def rand_bool(*shape):
return onp.random.rand(*shape) > 0.5
def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
"""Checks results from applying mnp_fn and onp_fn on arrs respectively"""
mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs)
mnp_res = mnp_fn(*mnp_arrs, **kwargs)
onp_res = onp_fn(*arrs, **kwargs)
match_all_arrays(mnp_res, onp_res)
def match_all_arrays(mnp_res, onp_res, error=0):
if isinstance(mnp_res, (tuple, list)):
assert len(mnp_res) == len(onp_res)
for actual, expected in zip(mnp_res, onp_res):
match_array(actual.asnumpy(), expected, error)
else:
match_array(mnp_res.asnumpy(), onp_res, error)
def match_meta(actual, expected):
# float64 and int64 are not supported, and the default type for
# float and int are float32 and int32, respectively
if expected.dtype == onp.float64:
expected = expected.astype(onp.float32)
elif expected.dtype == onp.int64:
expected = expected.astype(onp.int32)
assert actual.shape == expected.shape
assert actual.dtype == expected.dtype
# Test np.transpose and np.ndarray.transpose
def mnp_transpose(input_tensor):
a = mnp.transpose(input_tensor, (0, 2, 1))
@ -458,6 +386,34 @@ def test_concatenate():
check_all_results(o_concatenate, m_concatenate)
def mnp_append(arr1, arr2):
a = mnp.append(arr1, arr2)
b = mnp.append(arr1, arr2, axis=0)
c = mnp.append(arr1, arr2, axis=-1)
return a, b, c
def onp_append(arr1, arr2):
a = onp.append(arr1, arr2)
b = onp.append(arr1, arr2, axis=0)
c = onp.append(arr1, arr2, axis=-1)
return a, b, c
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_append():
onp_array = onp.random.random((4, 3, 2)).astype('float32')
onp_value = onp.random.random((4, 3, 2)).astype('float32')
mnp_array = mnp.asarray(onp_array)
mnp_value = mnp.asarray(onp_value)
onp_res = onp_append(onp_array, onp_value)
mnp_res = mnp_append(mnp_array, mnp_value)
check_all_results(onp_res, mnp_res)
def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5):
onp_array_lst = []
mnp_array_lst = []
@ -629,7 +585,7 @@ def onp_atleast3d(*arys):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atleast1d():
run_non_kw_test(mnp_atleast1d, onp_atleast1d)
run_non_kw_test(mnp_atleast1d, onp_atleast1d, Cases())
@pytest.mark.level1
@ -639,7 +595,7 @@ def test_atleast1d():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atleast2d():
run_non_kw_test(mnp_atleast2d, onp_atleast2d)
run_non_kw_test(mnp_atleast2d, onp_atleast2d, Cases())
@pytest.mark.level1
@ -649,7 +605,7 @@ def test_atleast2d():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atleast3d():
run_non_kw_test(mnp_atleast3d, onp_atleast3d)
run_non_kw_test(mnp_atleast3d, onp_atleast3d, Cases())
# Test np.where
@ -858,6 +814,444 @@ def test_stack():
match_res(mnp.stack, onp.stack, arrs, axis=i)
def mnp_roll(input_tensor):
a = mnp.roll(input_tensor, -3)
b = mnp.roll(input_tensor, [-2, -3], 1)
c = mnp.roll(input_tensor, (3, 0, -5), (-1, -2, 0))
d = mnp.roll(input_tensor, (4,), [0, 0, 1])
return a, b, c, d
def onp_roll(input_array):
a = onp.roll(input_array, -3)
b = onp.roll(input_array, [-2, -3], 1)
c = onp.roll(input_array, (3, 0, -5), (-1, -2, 0))
d = onp.roll(input_array, (4,), [0, 0, 1])
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_roll():
arr = rand_int(3, 4, 5)
match_res(mnp_roll, onp_roll, arr)
arr = rand_int(1, 4, 6).astype("int64")
match_res(mnp_roll, onp_roll, arr)
def mnp_moveaxis(a):
a = mnp.moveaxis(a, 3, 3)
b = mnp.moveaxis(a, -1, 4)
c = mnp.moveaxis(a, (2, 1, 4), (0, 3, 2))
d = mnp.moveaxis(a, [-2, -5], [2, -4])
return a, b, c, d
def onp_moveaxis(a):
a = onp.moveaxis(a, 3, 3)
b = onp.moveaxis(a, -1, 4)
c = onp.moveaxis(a, (2, 1, 4), (0, 3, 2))
d = onp.moveaxis(a, [-2, -5], [2, -4])
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_moveaxis():
a = rand_int(2, 4, 5, 9, 6)
match_res(mnp_moveaxis, onp_moveaxis, a)
a = rand_int(2, 4, 5, 0, 6, 7, 1, 3, 8)
match_res(mnp_moveaxis, onp_moveaxis, a)
def mnp_tile(x):
a = mnp.tile(x, 0)
b = mnp.tile(x, 1)
c = mnp.tile(x, 3)
d = mnp.tile(x, [5, 1])
e = mnp.tile(x, (3, 1, 0))
f = mnp.tile(x, [5, 1, 2, 3, 7])
return a, b, c, d, e, f
def onp_tile(x):
a = onp.tile(x, 0)
b = onp.tile(x, 1)
c = onp.tile(x, 3)
d = onp.tile(x, [5, 1])
e = onp.tile(x, (3, 1, 0))
f = onp.tile(x, [5, 1, 2, 3, 7])
return a, b, c, d, e, f
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tile():
a = rand_int(2, 3, 4)
match_res(mnp_tile, onp_tile, a)
b = rand_int(5, 0, 8)
match_res(mnp_tile, onp_tile, b)
def mnp_broadcast_to(x):
a = mnp.broadcast_to(x, (2, 3))
b = mnp.broadcast_to(x, (8, 1, 3))
return a, b
def onp_broadcast_to(x):
a = onp.broadcast_to(x, (2, 3))
b = onp.broadcast_to(x, (8, 1, 3))
return a, b
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_broadcast_to():
x = rand_int()
match_res(mnp_broadcast_to, onp_broadcast_to, x)
x = rand_int(3)
match_res(mnp_broadcast_to, onp_broadcast_to, x)
x = rand_int(1, 3)
match_res(mnp_broadcast_to, onp_broadcast_to, x)
def mnp_broadcast_arrays(*args):
return mnp.broadcast_arrays(*args)
def onp_broadcast_arrays(*args):
return onp.broadcast_arrays(*args)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_broadcast_arrays():
test_case = Cases()
broadcastables = test_case.broadcastables
for i in range(len(broadcastables)):
arrs = broadcastables[i:]
match_res(mnp_broadcast_arrays, onp_broadcast_arrays, *arrs)
def mnp_flip(x):
a = mnp.flip(x)
b = mnp.flip(x, 0)
c = mnp.flip(x, 1)
d = mnp.flip(x, (-3, -1))
return a, b, c, d
def onp_flip(x):
a = onp.flip(x)
b = onp.flip(x, 0)
c = onp.flip(x, 1)
d = onp.flip(x, (-3, -1))
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_flip():
x = rand_int(2, 3, 4)
run_multi_test(mnp_flip, onp_flip, (x,))
def mnp_flipud(x):
return mnp.flipud(x)
def onp_flipud(x):
return onp.flipud(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_flipud():
x = rand_int(2, 3, 4)
run_multi_test(mnp_flipud, onp_flipud, (x,))
def mnp_fliplr(x):
return mnp.fliplr(x)
def onp_fliplr(x):
return onp.fliplr(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_fliplr():
x = rand_int(2, 3, 4)
run_multi_test(mnp_fliplr, onp_fliplr, (x,))
def mnp_split(input_tensor):
a = mnp.split(input_tensor, indices_or_sections=1)
b = mnp.split(input_tensor, indices_or_sections=3)
c = mnp.split(input_tensor, indices_or_sections=(-9, -8, 6))
d = mnp.split(input_tensor, indices_or_sections=(3, 2, 1))
e = mnp.split(input_tensor, indices_or_sections=(-10, -4, 5, 10))
f = mnp.split(input_tensor, indices_or_sections=[0, 2], axis=1)
return a, b, c, d, e, f
def onp_split(input_array):
a = onp.split(input_array, indices_or_sections=1)
b = onp.split(input_array, indices_or_sections=3)
c = onp.split(input_array, indices_or_sections=(-9, -8, 6))
d = onp.split(input_array, indices_or_sections=(3, 2, 1))
e = onp.split(input_array, indices_or_sections=(-10, -4, 5, 10))
f = onp.split(input_array, indices_or_sections=[0, 2], axis=1)
return a, b, c, d, e, f
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_split():
onp_arrs = [
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'),
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_split = onp_split(onp_arr)
m_split = mnp_split(mnp_arr)
for expect_lst, actual_lst in zip(o_split, m_split):
for expect, actual in zip(expect_lst, actual_lst):
match_array(expect, actual.asnumpy())
def mnp_vsplit(input_tensor):
a = mnp.vsplit(input_tensor, indices_or_sections=3)
b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
c = mnp.vsplit(input_tensor, indices_or_sections=[0, 2])
return a, b, c
def onp_vsplit(input_array):
a = onp.vsplit(input_array, indices_or_sections=3)
b = onp.vsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
c = onp.vsplit(input_array, indices_or_sections=[0, 2])
return a, b, c
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vsplit():
onp_arrs = [
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'),
onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_vsplit = onp_vsplit(onp_arr)
m_vsplit = mnp_vsplit(mnp_arr)
for expect_lst, actual_lst in zip(o_vsplit, m_vsplit):
for expect, actual in zip(expect_lst, actual_lst):
match_array(expect, actual.asnumpy())
def mnp_hsplit(input_tensor):
a = mnp.hsplit(input_tensor, indices_or_sections=3)
b = mnp.hsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
c = mnp.hsplit(input_tensor, indices_or_sections=[0, 2])
return a, b, c
def onp_hsplit(input_array):
a = onp.hsplit(input_array, indices_or_sections=3)
b = onp.hsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
c = onp.hsplit(input_array, indices_or_sections=[0, 2])
return a, b, c
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_hsplit():
onp_arrs = [
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32'),
onp.random.randint(1, 5, size=(4, 9, 5)).astype('float64')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_hsplit = onp_hsplit(onp_arr)
m_hsplit = mnp_hsplit(mnp_arr)
for expect_lst, actual_lst in zip(o_hsplit, m_hsplit):
for expect, actual in zip(expect_lst, actual_lst):
match_array(expect, actual.asnumpy())
def mnp_dsplit(input_tensor):
a = mnp.dsplit(input_tensor, indices_or_sections=3)
b = mnp.dsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10))
c = mnp.dsplit(input_tensor, indices_or_sections=[0, 2])
return a, b, c
def onp_dsplit(input_array):
a = onp.dsplit(input_array, indices_or_sections=3)
b = onp.dsplit(input_array, indices_or_sections=(-10, -4, 5, 10))
c = onp.dsplit(input_array, indices_or_sections=[0, 2])
return a, b, c
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dsplit():
onp_arrs = [
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32'),
onp.random.randint(1, 5, size=(5, 4, 9)).astype('float64')
]
mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs]
for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs):
o_dsplit = onp_dsplit(onp_arr)
m_dsplit = mnp_dsplit(mnp_arr)
for expect_lst, actual_lst in zip(o_dsplit, m_dsplit):
for expect, actual in zip(expect_lst, actual_lst):
match_array(expect, actual.asnumpy())
def mnp_take_along_axis(*arrs):
x = arrs[0]
a = mnp.take_along_axis(x, arrs[1], axis=None)
b = mnp.take_along_axis(x, arrs[2], axis=1)
c = mnp.take_along_axis(x, arrs[3], axis=-1)
d = mnp.take_along_axis(x, arrs[4], axis=0)
return a, b, c, d
def onp_take_along_axis(*arrs):
x = arrs[0]
a = onp.take_along_axis(x, arrs[1], axis=None)
b = onp.take_along_axis(x, arrs[2], axis=1)
c = onp.take_along_axis(x, arrs[3], axis=-1)
d = onp.take_along_axis(x, arrs[4], axis=0)
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_take_along_axis():
x = rand_int(6, 7, 8, 9)
indices1 = rand_int(2).astype(onp.int32)
indices2 = rand_int(6, 3, 8, 1).astype(onp.int32)
indices3 = rand_int(6, 1, 8, 5).astype(onp.int32)
indices4 = rand_int(4, 1, 1, 1).astype(onp.int32)
run_multi_test(mnp_take_along_axis, onp_take_along_axis,
(x, indices1, indices2, indices3, indices4))
def mnp_take(x, indices):
a = mnp.take(x, indices)
b = mnp.take(x, indices, axis=-1)
c = mnp.take(x, indices, axis=0, mode='wrap')
d = mnp.take(x, indices, axis=1, mode='clip')
return a, b, c, d
def onp_take(x, indices):
a = onp.take(x, indices)
b = onp.take(x, indices, axis=-1)
c = onp.take(x, indices, axis=0, mode='wrap')
d = onp.take(x, indices, axis=1, mode='clip')
return a, b, c, d
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_take():
x = rand_int(2, 3, 4, 5)
indices = rand_int(2, 3).astype(onp.int32)
run_multi_test(mnp_take, onp_take, (x, indices))
def mnp_repeat(x):
a = mnp.repeat(x, 2)
b = mnp.repeat(x, 3, axis=0)
c = mnp.repeat(x, (4, 1, 5), axis=1)
d = mnp.repeat(x, (3, 2, 1, 0, 4), axis=-1)
e = mnp.repeat(x, 0)
return a, b, c, d, e
def onp_repeat(x):
a = onp.repeat(x, 2)
b = onp.repeat(x, 3, axis=0)
c = onp.repeat(x, (4, 1, 5), axis=1)
d = onp.repeat(x, (3, 2, 1, 0, 4), axis=-1)
e = onp.repeat(x, 0)
return a, b, c, d, e
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_repeat():
x = rand_int(2, 3, 4, 5)
run_multi_test(mnp_repeat, onp_repeat, (x,))
class ReshapeExpandSqueeze(Cell):
def __init__(self):
super(ReshapeExpandSqueeze, self).__init__()

View File

@ -0,0 +1,263 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""unit tests for numpy logical operations"""
import pytest
import numpy as onp
import mindspore.numpy as mnp
from .utils import rand_int, run_binop_test, match_res
class Cases():
def __init__(self):
self.arrs = [
rand_int(2),
rand_int(2, 3),
rand_int(2, 3, 4),
rand_int(2, 3, 4, 5),
]
# scalars expanded across the 0th dimension
self.scalars = [
rand_int(),
rand_int(1),
rand_int(1, 1),
rand_int(1, 1, 1, 1),
]
# arrays of the same size expanded across the 0th dimension
self.expanded_arrs = [
rand_int(2, 3),
rand_int(1, 2, 3),
rand_int(1, 1, 2, 3),
rand_int(1, 1, 1, 2, 3),
]
# arrays which can be broadcast
self.broadcastables = [
rand_int(5),
rand_int(6, 1),
rand_int(7, 1, 5),
rand_int(8, 1, 6, 1)
]
# array which contains infs and nans
self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]])
test_case = Cases()
def mnp_not_equal(a, b):
return mnp.not_equal(a, b)
def onp_not_equal(a, b):
return onp.not_equal(a, b)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_not_equal():
run_binop_test(mnp_not_equal, onp_not_equal, test_case)
def mnp_less_equal(a, b):
return mnp.less_equal(a, b)
def onp_less_equal(a, b):
return onp.less_equal(a, b)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_less_equal():
run_binop_test(mnp_less_equal, onp_less_equal, test_case)
def mnp_less(a, b):
return mnp.less(a, b)
def onp_less(a, b):
return onp.less(a, b)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_less():
run_binop_test(mnp_less, onp_less, test_case)
def mnp_greater_equal(a, b):
return mnp.greater_equal(a, b)
def onp_greater_equal(a, b):
return onp.greater_equal(a, b)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_greater_equal():
run_binop_test(mnp_greater_equal, onp_greater_equal, test_case)
def mnp_greater(a, b):
return mnp.greater(a, b)
def onp_greater(a, b):
return onp.greater(a, b)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_greater():
run_binop_test(mnp_greater, onp_greater, test_case)
def mnp_equal(a, b):
return mnp.equal(a, b)
def onp_equal(a, b):
return onp.equal(a, b)
def test_equal():
run_binop_test(mnp_equal, onp_equal, test_case)
def mnp_isfinite(x):
return mnp.isfinite(x)
def onp_isfinite(x):
return onp.isfinite(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isfinite():
match_res(mnp_isfinite, onp_isfinite, test_case.infs)
def mnp_isnan(x):
return mnp.isnan(x)
def onp_isnan(x):
return onp.isnan(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isnan():
match_res(mnp_isnan, onp_isnan, test_case.infs)
def mnp_isinf(x):
return mnp.isinf(x)
def onp_isinf(x):
return onp.isinf(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isinf():
match_res(mnp_isinf, onp_isinf, test_case.infs)
def mnp_isposinf(x):
return mnp.isposinf(x)
def onp_isposinf(x):
return onp.isposinf(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isposinf():
match_res(mnp_isposinf, onp_isposinf, test_case.infs)
def mnp_isneginf(x):
return mnp.isneginf(x)
def onp_isneginf(x):
return onp.isneginf(x)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isneginf():
match_res(mnp_isneginf, onp_isneginf, test_case.infs)
def test_isscalar():
assert mnp.isscalar(1) == onp.isscalar(1)
assert mnp.isscalar(2.3) == onp.isscalar(2.3)
assert mnp.isscalar([4.5]) == onp.isscalar([4.5])
assert mnp.isscalar(False) == onp.isscalar(False)
assert mnp.isscalar(mnp.array(True)) == onp.isscalar(onp.array(True))
assert mnp.isscalar('numpy') == onp.isscalar('numpy')

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,165 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""utility functions for mindspore.numpy st tests"""
import functools
import numpy as onp
import mindspore.numpy as mnp
def match_array(actual, expected, error=0):
if isinstance(actual, int):
actual = onp.asarray(actual)
if isinstance(expected, int):
expected = onp.asarray(expected)
if error > 0:
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
decimal=error)
else:
onp.testing.assert_equal(actual.tolist(), expected.tolist())
def check_all_results(onp_results, mnp_results, error=0):
"""Check all results from numpy and mindspore.numpy"""
for i, _ in enumerate(onp_results):
match_array(onp_results[i], mnp_results[i].asnumpy())
def check_all_unique_results(onp_results, mnp_results):
"""
Check all results from numpy and mindspore.numpy.
Args:
onp_results (Union[tuple of numpy.arrays, numpy.array])
mnp_results (Union[tuple of Tensors, Tensor])
"""
for i, _ in enumerate(onp_results):
if isinstance(onp_results[i], tuple):
for j in range(len(onp_results[i])):
match_array(onp_results[i][j],
mnp_results[i][j].asnumpy(), error=7)
else:
match_array(onp_results[i], mnp_results[i].asnumpy(), error=7)
def run_non_kw_test(mnp_fn, onp_fn, test_case):
"""Run tests on functions with non keyword arguments"""
for i in range(len(test_case.arrs)):
arrs = test_case.arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.scalars)):
arrs = test_case.scalars[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.expanded_arrs)):
arrs = test_case.expanded_arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
for i in range(len(test_case.nested_arrs)):
arrs = test_case.nested_arrs[:i]
match_res(mnp_fn, onp_fn, *arrs)
def rand_int(*shape):
"""return an random integer array with parameter shape"""
res = onp.random.randint(low=1, high=5, size=shape)
if isinstance(res, onp.ndarray):
return res.astype(onp.float32)
return float(res)
# return an random boolean array
def rand_bool(*shape):
return onp.random.rand(*shape) > 0.5
def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
"""Checks results from applying mnp_fn and onp_fn on arrs respectively"""
mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs)
error = kwargs.get('error', 0)
kwargs.pop('error', None)
mnp_res = mnp_fn(*mnp_arrs, **kwargs)
onp_res = onp_fn(*arrs, **kwargs)
match_all_arrays(mnp_res, onp_res, error=error)
def match_all_arrays(mnp_res, onp_res, error=0):
if isinstance(mnp_res, (tuple, list)):
assert len(mnp_res) == len(onp_res)
for actual, expected in zip(mnp_res, onp_res):
match_array(actual.asnumpy(), expected, error)
else:
match_array(mnp_res.asnumpy(), onp_res, error)
def match_meta(actual, expected):
# float64 and int64 are not supported, and the default type for
# float and int are float32 and int32, respectively
if expected.dtype == onp.float64:
expected = expected.astype(onp.float32)
elif expected.dtype == onp.int64:
expected = expected.astype(onp.int32)
assert actual.shape == expected.shape
assert actual.dtype == expected.dtype
def run_binop_test(mnp_fn, onp_fn, test_case):
for arr in test_case.arrs:
match_res(mnp_fn, onp_fn, arr, arr)
for scalar in test_case.scalars:
match_res(mnp_fn, onp_fn, arr, scalar)
match_res(mnp_fn, onp_fn, scalar, arr)
for scalar1 in test_case.scalars:
for scalar2 in test_case.scalars:
match_res(mnp_fn, onp_fn, scalar1, scalar2)
for expanded_arr1 in test_case.expanded_arrs:
for expanded_arr2 in test_case.expanded_arrs:
match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2)
for broadcastable1 in test_case.broadcastables:
for broadcastable2 in test_case.broadcastables:
match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2)
def run_unary_test(mnp_fn, onp_fn, test_case, error=0):
for arr in test_case.arrs:
match_res(mnp_fn, onp_fn, arr, error=error)
for arr in test_case.scalars:
match_res(mnp_fn, onp_fn, arr, error=error)
for arr in test_case.expanded_arrs:
match_res(mnp_fn, onp_fn, arr, error=error)
def run_multi_test(mnp_fn, onp_fn, arrs, error=0):
mnp_arrs = map(mnp.asarray, arrs)
for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)):
match_array(actual.asnumpy(), expected, error)
def run_single_test(mnp_fn, onp_fn, arr, error=0):
mnp_arr = mnp.asarray(arr)
for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)):
if isinstance(expected, tuple):
for actual_arr, expected_arr in zip(actual, expected):
match_array(actual_arr.asnumpy(), expected_arr, error)
match_array(actual.asnumpy(), expected, error)