fix mindpsore.numpy packaging issue and format API comments

This commit is contained in:
yanglf1121 2020-12-15 10:39:15 +08:00
parent 280db3d651
commit 9072283395
8 changed files with 164 additions and 151 deletions

View File

@ -281,6 +281,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/_extends ${CMAKE_SOURCE_DIR}/mindspore/_extends
${CMAKE_SOURCE_DIR}/mindspore/parallel ${CMAKE_SOURCE_DIR}/mindspore/parallel
${CMAKE_SOURCE_DIR}/mindspore/mindrecord ${CMAKE_SOURCE_DIR}/mindspore/mindrecord
${CMAKE_SOURCE_DIR}/mindspore/numpy
${CMAKE_SOURCE_DIR}/mindspore/train ${CMAKE_SOURCE_DIR}/mindspore/train
${CMAKE_SOURCE_DIR}/mindspore/common ${CMAKE_SOURCE_DIR}/mindspore/common
${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/ops

View File

@ -42,3 +42,5 @@ array_ops_module = ['array', 'asarray', 'asfarray', 'copy', 'ones', 'zeros', 'ar
math_module = ['mean', 'inner'] math_module = ['mean', 'inner']
__all__ = array_ops_module + math_module + numeric_types __all__ = array_ops_module + math_module + numeric_types
__all__.sort()

View File

@ -17,37 +17,37 @@ from copy import copy as py_copy
import numpy as onp import numpy as onp
import mindspore from ..common import Tensor
from mindspore import Tensor from ..common import dtype as mstype
from mindspore.ops import operations as P from ..ops import operations as P
from mindspore.ops import functional as F from ..ops import functional as F
from mindspore.ops.primitive import constexpr from ..ops.primitive import constexpr
from .utils import _check_shape, _check_shape_compile, _check_dtype, _check_is_int, \ from .utils import _check_shape, _check_shape_compile, _check_dtype, _check_is_int, \
_check_axes_range, _check_start_normalize, _check_shape_contain_zero, _check_is_tensor, \ _check_axes_range, _check_start_normalize, _check_shape_contain_zero, _check_is_tensor, \
_check_input_for_asarray _check_input_for_asarray
DEFAULT_FLOAT_DTYPE = mindspore.float32 DEFAULT_FLOAT_DTYPE = mstype.float32
DEFAULT_INT_DTYPE = mindspore.int32 DEFAULT_INT_DTYPE = mstype.int32
def array(obj, dtype=None, copy=True, ndmin=0): def array(obj, dtype=None, copy=True, ndmin=0):
""" """
Create a tensor. Creates a tensor.
This function creat tensors from an array-like object. This function creates tensors from an array-like object.
Args: Args:
obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
any form that can be converted to an array. This includes lists, lists of any form that can be converted to a tensor. This includes lists, lists of
tuples, tuples, tuples of tuples, tuples of lists and ndarrays. tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.int32, or `int32`. If dtype is None, the data type be in format of np.int32, or `int32`. If dtype is None, the data type
of the new tensor will be inferred from obj. Default is None. of the new tensor will be inferred from obj. Default is None.
copy (bool): If true, then the object is copied. Otherwise, a copy will copy (bool): If true, then the object is copied. Otherwise, a copy will
only be made if necessary. Default: True. only be made if necessary. Default: True.
ndmin (int): Specifies the minimum number of dimensions that the resulting ndmin (int): Specifies the minimum number of dimensions that the resulting
array should have. Ones will be pre-pended to the shape as needed to tensor should have. Ones will be pre-pended to the shape as needed to
meet this requirement. Default: 0 meet this requirement. Default: 0
Returns: Returns:
@ -76,15 +76,15 @@ def array(obj, dtype=None, copy=True, ndmin=0):
def asarray(a, dtype=None): def asarray(a, dtype=None):
""" """
Convert the input to tensor. Converts the input to tensor.
This function convert tensors from an array-like object. This function converts tensors from an array-like object.
Args: Args:
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
any form that can be converted to an array. This includes lists, lists of any form that can be converted to a tensor. This includes lists, lists of
tuples, tuples, tuples of tuples, tuples of lists and ndarrays. tuples, tuples, tuples of tuples, tuples of lists and ndarrays.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.int32, or `int32`. If dtype is None, the data type be in format of np.int32, or `int32`. If dtype is None, the data type
of the new tensor will be inferred from a. Default is None. of the new tensor will be inferred from a. Default is None.
@ -112,7 +112,7 @@ def asarray(a, dtype=None):
dtype = DEFAULT_INT_DTYPE dtype = DEFAULT_INT_DTYPE
if isinstance(a, bool) and (dtype is None): if isinstance(a, bool) and (dtype is None):
dtype = mindspore.bool_ dtype = mstype.bool_
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
a = onp.asarray(a) a = onp.asarray(a)
@ -126,14 +126,14 @@ def asarray(a, dtype=None):
if isinstance(a, onp.ndarray) and dtype is None: if isinstance(a, onp.ndarray) and dtype is None:
if a.dtype is onp.dtype('bool'): if a.dtype is onp.dtype('bool'):
dtype = mindspore.bool_ dtype = mstype.bool_
elif a.dtype is onp.dtype('int'): elif a.dtype is onp.dtype('int'):
dtype = DEFAULT_INT_DTYPE dtype = DEFAULT_INT_DTYPE
elif a.dtype is onp.dtype('float'): elif a.dtype is onp.dtype('float'):
dtype = DEFAULT_FLOAT_DTYPE dtype = DEFAULT_FLOAT_DTYPE
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)
# If a is already an tensor and we don't need to cast dtype, return a # If a is already a tensor and we don't need to cast dtype, return a
if isinstance(a, Tensor): if isinstance(a, Tensor):
if dtype is None: if dtype is None:
return a return a
@ -146,16 +146,16 @@ def asarray(a, dtype=None):
def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE): def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE):
""" """
Similar to asarray, convert the input to an float array. Similar to asarray, converts the input to a float tensor.
If non-float dtype is defined, this function will return a float32 Tensor instead. If non-float dtype is defined, this function will return a float32 tensor instead.
Args: Args:
a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in
any form that can be converted to an array. This includes lists, lists of any form that can be converted to a tensor. This includes lists, lists of
tuples, tuples, tuples of tuples, tuples of lists and ndarrays. tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`. Default is mindspore.float32. be in format of np.float32, or `float32`. Default is mstype.float32.
Returns: Returns:
Tensor, generated tensor with the specified float dtype. Tensor, generated tensor with the specified float dtype.
@ -171,7 +171,7 @@ def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE):
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
_ = _check_input_for_asarray(a) _ = _check_input_for_asarray(a)
if dtype not in (mindspore.float16, mindspore.float32, mindspore.float64): if dtype not in (mstype.float16, mstype.float32, mstype.float64):
dtype = DEFAULT_FLOAT_DTYPE dtype = DEFAULT_FLOAT_DTYPE
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
@ -185,7 +185,7 @@ def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE):
def copy_(a): def copy_(a):
""" """
Return an tensor copy of the given object. Returns a tensor copy of the given object.
Args: Args:
a (Tensor): Input tensor. a (Tensor): Input tensor.
@ -198,20 +198,22 @@ def copy_(a):
Examples: Examples:
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> print(np.copy([1,2,3])) >>> x = np.ones((2,2))
[1. 2. 3.] >>> print(np.copy(x))
[[1. 1.]
[1. 1.]]
""" """
return py_copy(a) return py_copy(a)
def ones(shape, dtype=DEFAULT_FLOAT_DTYPE): def ones(shape, dtype=DEFAULT_FLOAT_DTYPE):
""" """
Return a new array of given shape and type, filled with ones. Returns a new tensor of given shape and type, filled with ones.
Args: Args:
shape (Union[int, tuple, list]): the shape of the new array. shape (Union[int, tuple, list]): the shape of the new tensor.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`. Default is mindspore.float32. be in format of np.float32, or `float32`. Default is mstype.float32.
Returns: Returns:
Tensor, with the designated shape and dtype, filled with ones. Tensor, with the designated shape and dtype, filled with ones.
@ -231,17 +233,17 @@ def ones(shape, dtype=DEFAULT_FLOAT_DTYPE):
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
fill = P.Fill() fill = P.Fill()
output = fill(dtype, shape, 1) output = fill(dtype, shape, 1)
return Tensor(output, dtype=dtype) return output
def zeros(shape, dtype=DEFAULT_FLOAT_DTYPE): def zeros(shape, dtype=DEFAULT_FLOAT_DTYPE):
""" """
Return a new array of given shape and type, filled with zeros. Returns a new tensor of given shape and type, filled with zeros.
Args: Args:
shape (Union[int, tuple, list]): the shape of the new array. shape (Union[int, tuple, list]): the shape of the new tensor.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`. Default is mindspore.float32. be in format of np.float32, or `float32`. Default is mstype.float32.
Returns: Returns:
Tensor, with the designated shape and dtype, filled with zeros. Tensor, with the designated shape and dtype, filled with zeros.
@ -261,19 +263,19 @@ def zeros(shape, dtype=DEFAULT_FLOAT_DTYPE):
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
fill = P.Fill() fill = P.Fill()
output = fill(dtype, shape, 0) output = fill(dtype, shape, 0)
return Tensor(output, dtype=dtype) return output
def full(shape, fill_value, dtype=None): def full(shape, fill_value, dtype=None):
""" """
Return a new array of given shape and type, filled with fill_value. Returns a new tensor of given shape and type, filled with fill_value.
Args: Args:
shape (Union[int, tuple(int), list(int)]): Shape of the new array, e.g., shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g.,
(2, 3) or 2. (2, 3) or 2.
fill_value (Union[int, float, bool, list, tuple]): scalar or array_like fill_value (Union[int, float, bool, list, tuple]): scalar or array_like
fill value. fill value.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`, if dtype is None, the data type be in format of np.float32, or `float32`, if dtype is None, the data type
of the new tensor will be inferred from fill_value. Default is None. of the new tensor will be inferred from fill_value. Default is None.
@ -301,16 +303,16 @@ def full(shape, fill_value, dtype=None):
# if fill_value is array_like or shape contains zero. fall back to original # if fill_value is array_like or shape contains zero. fall back to original
# numpy creation # numpy creation
return Tensor(onp.full(shape, fill_value, mindspore.dtype_to_nptype(dtype))) return Tensor(onp.full(shape, fill_value, mstype.dtype_to_nptype(dtype)))
def arange(*args, **kwargs): def arange(*args, **kwargs):
""" """
Return evenly spaced values within a given interval. Returns evenly spaced values within a given interval.
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
The endpoint of the interval can optionally be excluded. The endpoint of the interval can optionally be excluded.
The current implementation is a direct wrapper on top of numpy.arange, except The current implementation is a direct wrapper on top of numpy.arange, except that
the default dtype is float32 and int32, compare to float64 and int64 for numpy the default dtype is float32 and int32, compare to float64 and int64 for numpy
implementation. implementation.
@ -324,12 +326,12 @@ def arange(*args, **kwargs):
out, this is the distance between two adjacent values, out[i+1] - out[i]. out, this is the distance between two adjacent values, out[i+1] - out[i].
The default step size is 1. If step is specified as a position argument, The default step size is 1. If step is specified as a position argument,
start must also be given. start must also be given.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`. If dtype is None, the data type be in format of np.float32, or `float32`. If dtype is None, the data type
of the new tensor will be inferred from start, stop and step. Default is None. of the new tensor will be inferred from start, stop and step. Default is None.
Returns: Returns:
arangend Tensor, array of evenly spaced values. arangend tensor of evenly spaced values.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -361,16 +363,16 @@ def arange(*args, **kwargs):
if 'dtype' in kwargs and kwargs['dtype'] is not None: if 'dtype' in kwargs and kwargs['dtype'] is not None:
final_dtype = _check_dtype(kwargs['dtype']) final_dtype = _check_dtype(kwargs['dtype'])
final_dtype = mindspore.dtype_to_nptype(final_dtype) final_dtype = mstype.dtype_to_nptype(final_dtype)
kwargs['dtype'] = final_dtype kwargs['dtype'] = final_dtype
out = onp.arange(*args, **kwargs) out = onp.arange(*args, **kwargs)
out = Tensor.from_numpy(out) out = Tensor.from_numpy(out)
return Tensor(out) return out
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
""" """
Return evenly spaced values within a given interval. Returns evenly spaced values within a given interval.
The current implementation is a direct wrapper on top of numpy.linspace, except The current implementation is a direct wrapper on top of numpy.linspace, except
the default dtype is float32, compare to float64 for numpy, the default dtype is float32, compare to float64 for numpy,
@ -386,7 +388,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
not included. Default is True. not included. Default is True.
retstep (bool, optional): If True, return (`samples`, `step`), where `step` is retstep (bool, optional): If True, return (`samples`, `step`), where `step` is
the spacing between samples. the spacing between samples.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`.If `dtype` is None, infer the data be in format of np.float32, or `float32`.If `dtype` is None, infer the data
type from other input arguments. Default is None. type from other input arguments. Default is None.
axis (int, optional): The axis in the result to store the samples. Relevant axis (int, optional): The axis in the result to store the samples. Relevant
@ -420,7 +422,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
final_dtype = None final_dtype = None
if dtype is not None: if dtype is not None:
final_dtype = _check_dtype(dtype) final_dtype = _check_dtype(dtype)
final_dtype = mindspore.dtype_to_nptype(final_dtype) final_dtype = mstype.dtype_to_nptype(final_dtype)
else: else:
final_dtype = onp.float32 final_dtype = onp.float32
@ -438,7 +440,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
""" """
Return numbers spaced evenly on a log scale. Returns numbers spaced evenly on a log scale.
In linear space, the sequence starts at base ** start (base to the power of In linear space, the sequence starts at base ** start (base to the power of
start) and ends with base ** stop (see endpoint below). start) and ends with base ** stop (see endpoint below).
@ -457,11 +459,11 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
base (Union[int, float], optional): The base of the log space. The step size base (Union[int, float], optional): The base of the log space. The step size
between the elements in ln(samples) / ln(base) (or log_base(samples)) between the elements in ln(samples) / ln(base) (or log_base(samples))
is uniform. Default is 10.0. is uniform. Default is 10.0.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`.If `dtype` is None, infer the data be in format of np.float32, or `float32`.If `dtype` is None, infer the data
type from other input arguments. Default is None. type from other input arguments. Default is None.
axis (int, optional): The axis in the result to store the samples. Relevant axis (int, optional): The axis in the result to store the samples. Relevant
only if start or stop are array-like. By default (0), the samples will only if start or stop is array-like. By default (0), the samples will
be along a new axis inserted at the beginning. Use -1 to get an axis at the end. be along a new axis inserted at the beginning. Use -1 to get an axis at the end.
Default is 0. Default is 0.
@ -486,7 +488,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
final_dtype = None final_dtype = None
if dtype is not None: if dtype is not None:
final_dtype = _check_dtype(dtype) final_dtype = _check_dtype(dtype)
final_dtype = mindspore.dtype_to_nptype(final_dtype) final_dtype = mstype.dtype_to_nptype(final_dtype)
else: else:
final_dtype = onp.float32 final_dtype = onp.float32
@ -499,7 +501,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
def eye(N, M=None, k=0, dtype=DEFAULT_FLOAT_DTYPE): def eye(N, M=None, k=0, dtype=DEFAULT_FLOAT_DTYPE):
""" """
Return a 2-D array with ones on the diagnoal and zeros elsewhere. Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere.
Args: Args:
N (int): Number of rows in the output, must be larger than 0. N (int): Number of rows in the output, must be larger than 0.
@ -508,11 +510,11 @@ def eye(N, M=None, k=0, dtype=DEFAULT_FLOAT_DTYPE):
k (int, optional): Index of the diagonal: 0 (the default) refers to the main k (int, optional): Index of the diagonal: 0 (the default) refers to the main
diagonal, a positive value refers to an upper diagonal, and a negative value diagonal, a positive value refers to an upper diagonal, and a negative value
to a lower diagonal. Default is 0. to a lower diagonal. Default is 0.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`. Default is mindspore.float32. be in format of np.float32, or `float32`. Default is mstype.float32.
Returns: Returns:
result (Tensor): A tensor array of shape (N,M). An array where all elements result (Tensor): A tensor of shape (N,M). A tensor where all elements
are equal to zero, except for the k-th diagonal, whose values are equal to one. are equal to zero, except for the k-th diagonal, whose values are equal to one.
Supported Platforms: Supported Platforms:
@ -542,15 +544,15 @@ def eye(N, M=None, k=0, dtype=DEFAULT_FLOAT_DTYPE):
def identity(n, dtype=DEFAULT_FLOAT_DTYPE): def identity(n, dtype=DEFAULT_FLOAT_DTYPE):
""" """
Return the identity array. Returns the identity tensor.
Args: Args:
n (int): Number of rows and columns in the output, must be larger than 0. n (int): Number of rows and columns in the output, must be larger than 0.
dtype (Union[mindspore.dtype, str], optional): Designated array dtype, can dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can
be in format of np.float32, or `float32`. Default is mindspore.float32. be in format of np.float32, or `float32`. Default is mstype.float32.
Returns: Returns:
result (Tensor): A tensor array of shape (n,n). An array where all elements result (Tensor): A tensor of shape (n,n). A tensor where all elements
are equal to zero, except for the diagonal, whose values are equal to one. are equal to zero, except for the diagonal, whose values are equal to one.
Supported Platforms: Supported Platforms:
@ -569,7 +571,7 @@ def identity(n, dtype=DEFAULT_FLOAT_DTYPE):
@constexpr @constexpr
def _prepare_shape_for_expand_dims(shape, axes): def _prepare_shape_for_expand_dims(shape, axes):
""" """
Creat the expanded new shape based on the shape and given axes Creates the expanded new shape based on the shape and given axes
Args: Args:
shape (tuple): the shape of the tensor shape (tuple): the shape of the tensor
@ -588,7 +590,7 @@ def _prepare_shape_for_expand_dims(shape, axes):
new_shape_length += 1 new_shape_length += 1
if axes >= new_shape_length or axes < -new_shape_length: if axes >= new_shape_length or axes < -new_shape_length:
raise ValueError( raise ValueError(
f"axis {axes} is out of bounds for array of dimension {new_shape_length}") f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}")
axes = {axes} axes = {axes}
elif isinstance(axes, (list, tuple)): elif isinstance(axes, (list, tuple)):
@ -596,7 +598,7 @@ def _prepare_shape_for_expand_dims(shape, axes):
for axis in axes: for axis in axes:
if axis >= new_shape_length or axis < -new_shape_length: if axis >= new_shape_length or axis < -new_shape_length:
raise ValueError( raise ValueError(
f"axis {axis} is out of bounds for array of dimension {new_shape_length}") f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}")
axes = set(axes) axes = set(axes)
else: else:
@ -614,9 +616,9 @@ def _prepare_shape_for_expand_dims(shape, axes):
def expand_dims(a, axis): def expand_dims(a, axis):
""" """
Expand the shape of an array. Expands the shape of a tensor.
Insert a new axis that will appear at the axis position in the expanded array shape. Inserts a new axis that will appear at the axis position in the expanded tensor shape.
Args: Args:
a (Tensor): Input tensor array. a (Tensor): Input tensor array.
@ -633,8 +635,8 @@ def expand_dims(a, axis):
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> x = np.ones((2,2)) >>> x = np.ones((2,2))
>>> x = np.expand_dims(x,0) >>> x = np.expand_dims(x,0)
>>> print(x,shape) >>> print(x.shape)
(2,2,1) (1, 2, 2)
""" """
shape = F.shape(a) shape = F.shape(a)
# yield expanded shape based on the axes # yield expanded shape based on the axes
@ -645,11 +647,11 @@ def expand_dims(a, axis):
@constexpr @constexpr
def _prepare_shape_for_squeeze(shape, axes): def _prepare_shape_for_squeeze(shape, axes):
""" """
Creat the squeezed new shape based on the tensor and given axes. Creates the squeezed new shape based on the tensor and given axes.
Args: Args:
shape (tuple): the shape of the tensor shape (tuple): the shape of the tensor
axes Union(None, int, tuple(int), list(int)): the axes with dimensions squeezed. axes Union[None, int, tuple(int), list(int)]: the axes with dimensions squeezed.
Returns: Returns:
new_shape(tuple): the shape with dimensions squeezed. new_shape(tuple): the shape with dimensions squeezed.
@ -661,14 +663,14 @@ def _prepare_shape_for_squeeze(shape, axes):
if isinstance(axes, int): if isinstance(axes, int):
if axes >= ndim or axes < -ndim: if axes >= ndim or axes < -ndim:
raise ValueError( raise ValueError(
f"axis {axes} is out of bounds for array of dimension {ndim}") f"axis {axes} is out of bounds for tensor of dimension {ndim}")
axes = {axes} axes = {axes}
elif isinstance(axes, (list, tuple)): elif isinstance(axes, (list, tuple)):
for axis in axes: for axis in axes:
if axis >= ndim or axis < -ndim: if axis >= ndim or axis < -ndim:
raise ValueError( raise ValueError(
f"axis {axis} is out of bounds for array of dimension {ndim}") f"axis {axis} is out of bounds for tensor of dimension {ndim}")
axes = set(axes) axes = set(axes)
elif axes is not None: elif axes is not None:
@ -690,7 +692,7 @@ def _prepare_shape_for_squeeze(shape, axes):
def squeeze(a, axis=None): def squeeze(a, axis=None):
""" """
Remove single-dimensional entries from the shape of an array. Removes single-dimensional entries from the shape of an tensor.
This is a temporary solution to support CPU backend. Will be changed This is a temporary solution to support CPU backend. Will be changed
once CPU backend supports P.Squeeze(). once CPU backend supports P.Squeeze().
@ -709,8 +711,8 @@ def squeeze(a, axis=None):
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> x = np.ones((1,2,2,1)) >>> x = np.ones((1,2,2,1))
>>> x = np.squeeze(x) >>> x = np.squeeze(x)
>>> print(x,shape) >>> print(x.shape)
(2,2) (2, 2)
""" """
shape = F.shape(a) shape = F.shape(a)
# yield squeezed shape based on the axes # yield squeezed shape based on the axes
@ -720,11 +722,11 @@ def squeeze(a, axis=None):
def transpose(a, axes=None): def transpose(a, axes=None):
""" """
Reverse or permute the axes of an array; returns the modified array. Reverses or permutes the axes of a tensor; returns the modified tensor.
Args: Args:
a (Tensor): a tensor to be transposed a (Tensor): a tensor to be transposed
axes Union[None, tuple, list]: the axes order, if axes is None, transpose axes (Union[None, tuple, list]): the axes order, if axes is None, transpose
the entire tensor. Default is None. the entire tensor. Default is None.
Returns: Returns:
@ -737,8 +739,8 @@ def transpose(a, axes=None):
>>> import mindspore.numpy as np >>> import mindspore.numpy as np
>>> x = np.ones((1,2,3)) >>> x = np.ones((1,2,3))
>>> x = np.transpose(x) >>> x = np.transpose(x)
>>> print(x,shape) >>> print(x.shape)
(3,2,1) (3, 2, 1)
""" """
if axes is None: if axes is None:
shape = F.shape(a) shape = F.shape(a)
@ -753,7 +755,7 @@ def transpose(a, axes=None):
def rollaxis(x, axis, start=0): def rollaxis(x, axis, start=0):
""" """
Roll the specified axis backwards, until it lies in the given position. Rolls the specified axis backwards, until it lies in the given position.
The positions of the other axes do not change relative to one another. The positions of the other axes do not change relative to one another.
Args: Args:
@ -761,8 +763,10 @@ def rollaxis(x, axis, start=0):
axis (int): The axis to be rolled. axis (int): The axis to be rolled.
start (int): start (int):
- When start >= 0: - When start >= 0:
- When start <= axis: the axis is rolled back until it lies in this position (start). - When start <= axis: the axis is rolled back until it lies in
- When start > axis: the axis is rolled until it lies before this position (start). this position (start).
- When start > axis: the axis is rolled until it lies before this
position (start).
- When start < 0: the start will be normalized as follows: - When start < 0: the start will be normalized as follows:
start ........... Normalized start start ........... Normalized start
-(x.ndim+1) raise ValueError -(x.ndim+1) raise ValueError
@ -786,14 +790,11 @@ def rollaxis(x, axis, start=0):
start is not in the range from -ndim to ndim. start is not in the range from -ndim to ndim.
Examples: Examples:
>>> import mindspore >>> import mindspore.numpy as np
>>> import mindspore.numpy as mnp >>> x = np.ones((2,3,4))
>>> from mindspore import Tensor >>> output = np.rollaxis(x, 0, 2)
>>> import numpy as onp
>>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32)
>>> output = mnp.rollaxis(x, 0, 2)
>>> print(output.shape) >>> print(output.shape)
(3,2,4) (3, 2, 4)
""" """
_check_is_int(axis) _check_is_int(axis)
_check_is_int(start) _check_is_int(start)
@ -826,15 +827,15 @@ def rollaxis(x, axis, start=0):
def swapaxes(x, axis1, axis2): def swapaxes(x, axis1, axis2):
""" """
Interchange two axes of a tensor. Interchanges two axes of a tensor.
Args: Args:
x (Tensor): A Tensor to be transposed. x (Tensor): A tensor to be transposed.
axis1 (int): First axis. axis1 (int): First axis.
axis2 (int): Second axis. axis2 (int): Second axis.
Returns: Returns:
Transposed Tensor. Has the same data type as the original tensor x. Transposed tensor, has the same data type as the original tensor x.
Raises: Raises:
TypeError: If axis1 or axis2 is not integer. TypeError: If axis1 or axis2 is not integer.
@ -844,12 +845,9 @@ def swapaxes(x, axis1, axis2):
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore.numpy as np
>>> import mindspore.numpy as mnp >>> x = np.ones((2,3,4))
>>> from mindspore import Tensor >>> output = np.swapaxes(x, 0, 2)
>>> import numpy as onp
>>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32)
>>> output = mnp.swapaxes(x, 0, 2)
>>> print(output.shape) >>> print(output.shape)
(4,3,2) (4,3,2)
""" """
@ -881,10 +879,10 @@ def swapaxes(x, axis1, axis2):
def reshape(x, new_shape): def reshape(x, new_shape):
""" """
Reshape a tensor without changing its data. Reshapes a tensor without changing its data.
Args: Args:
x (Tensor): A Tensor to be reshaped. x (Tensor): A tensor to be reshaped.
new_shape (Union[int, list(int), tuple(int)]): The new shape should be new_shape (Union[int, list(int), tuple(int)]): The new shape should be
compatible with the original shape. If the tuple has only one element, compatible with the original shape. If the tuple has only one element,
the result will be a 1-D tensor of that length. One shape dimension the result will be a 1-D tensor of that length. One shape dimension
@ -902,19 +900,19 @@ def reshape(x, new_shape):
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) >>> import mindspore.numpy as np
>>> reshape = mindspore.numpy.reshape() >>> x = np.asarray([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> output = reshape(x, (3, 2)) >>> output = np.reshape(x, (3, 2))
>>> print(output) >>> print(output)
[[-0.1 0.3] [[-0.1 0.3]
[ 3.6 0.4] [ 3.6 0.4]
[ 0.5 -3.2]] [ 0.5 -3.2]]
>>> output = reshape(x, (3, -1)) >>> output = np.reshape(x, (3, -1))
>>> print(output) >>> print(output)
[[-0.1 0.3] [[-0.1 0.3]
[ 3.6 0.4] [ 3.6 0.4]
[ 0.5 -3.2]] [ 0.5 -3.2]]
>>> output = reshape(x, (6, )) >>> output = np.reshape(x, (6, ))
>>> print(output) >>> print(output)
[-0.1 0.3 3.6 0.4 0.5 -3.2] [-0.1 0.3 3.6 0.4 0.5 -3.2]
""" """
@ -924,7 +922,7 @@ def reshape(x, new_shape):
def ravel(x): def ravel(x):
""" """
Return a contiguous flattened tensor. Returns a contiguous flattened tensor.
A 1-D tensor, containing the elements of the input, is returned. A 1-D tensor, containing the elements of the input, is returned.
@ -932,18 +930,15 @@ def ravel(x):
x (Tensor): A tensor to be flattened. x (Tensor): A tensor to be flattened.
Returns: Returns:
Flattened Tensor. Has the same data type as the original tensor x. Flattened tensor, has the same data type as the original tensor x.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore.numpy as np
>>> import mindspore.numpy as mnp >>> x = np.ones((2,3,4))
>>> from mindspore import Tensor >>> output = np.ravel(x)
>>> import numpy as onp
>>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32)
>>> output = mnp.ravel(x)
>>> print(output.shape) >>> print(output.shape)
(24,) (24,)
""" """
@ -953,8 +948,8 @@ def ravel(x):
@constexpr @constexpr
def _move_axes_for_concatenate(arr_shape, axis): def _move_axes_for_concatenate(arr_shape, axis):
""" """
move axis 0 to the disiganated position, while keep other axes' relative Moves axis 0 to the disiganated position, while keeps other axes' relative
positions unchanged, only used if a single array is concatenated. positions unchanged, only used if a single tensor is concatenated.
""" """
original_axes = tuple(range(len(arr_shape))) original_axes = tuple(range(len(arr_shape)))
@ -966,17 +961,17 @@ def _move_axes_for_concatenate(arr_shape, axis):
def concatenate(arrays, axis=0): def concatenate(arrays, axis=0):
""" """
Join a sequence of arrays along an existing axis. Joins a sequence of tensors along an existing axis.
Args: Args:
arrays: Union[Tensor, tuple(Tensor), list(Tensor)], a Tensor or a list arrays: Union[Tensor, tuple(Tensor), list(Tensor)], a tensor or a list
of Tensor to be concatenated. of tensors to be concatenated.
axis (int, optional): The axis along which the arrays will be joined, axis (int, optional): The axis along which the tensors will be joined,
if axis is None, arrays are flattened before use. Default is 0. if axis is None, tensors are flattened before use. Default is 0.
Returns: Returns:
Tensor, a Tensor concatenated from a Tensor or a list of Tensors. Tensor, a tensor concatenated from a tensor or a list of tensors.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -986,8 +981,8 @@ def concatenate(arrays, axis=0):
>>> x1 = np.ones((1,2,3)) >>> x1 = np.ones((1,2,3))
>>> x2 = np.ones((1,2,1)) >>> x2 = np.ones((1,2,1))
>>> x = np.concatenate((x1, x2), axis=-1) >>> x = np.concatenate((x1, x2), axis=-1)
>>> print(x,shape) >>> print(x.shape)
(1,2,4) (1, 2, 4)
""" """
array_type = F.typeof(arrays) array_type = F.typeof(arrays)
if _check_is_tensor(array_type): if _check_is_tensor(array_type):

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""Dtypes and utilities""" """Dtypes and utilities"""
from mindspore import (int8, int16, int32, int64, uint8, uint16, uint32, uint64, \ from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, uint64, \
float16, float32, float64, bool_) float16, float32, float64, bool_)
# original numpy has int->int64, float->float64, uint->uint64 mapping. we map # original numpy has int->int64, float->float64, uint->uint64 mapping. we map

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""math operations, the function docs are adapted from Numpy API.""" """math operations, the function docs are adapted from Numpy API."""
from mindspore.ops import operations as P from ..ops import operations as P
from mindspore.ops import functional as F from ..ops import functional as F
from .array_ops import squeeze from .array_ops import squeeze
from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \ from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \
_check_shape_aligned _check_shape_aligned
@ -22,7 +22,7 @@ from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_
def mean(a, axis=None, keepdims=False): def mean(a, axis=None, keepdims=False):
""" """
Compute the arithmetic mean along the specified axis. Computes the arithmetic mean along the specified axis.
Returns the average of the array elements. The average is taken Returns the average of the array elements. The average is taken
over the flattened array by default, otherwise over the specified over the flattened array by default, otherwise over the specified
@ -30,8 +30,8 @@ def mean(a, axis=None, keepdims=False):
Note: Note:
Numpy arguments dtype and out are not supported. Numpy arguments dtype and out are not supported.
On GPU, the supported dtypes are np.float16, and np.float32. On GPU, the supported dtypes are mstype.float16, and mstype.float32.
On CPU, the supported dtypes are np.float16, and np.float32. On CPU, the supported dtypes are mstype.float16, and mstype.float32.
Args: Args:
a (Tensor): input tensor containing numbers whose mean is desired. a (Tensor): input tensor containing numbers whose mean is desired.
@ -56,6 +56,7 @@ def mean(a, axis=None, keepdims=False):
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore.numpy as np
>>> a = np.arange(6, dtype='float32') >>> a = np.arange(6, dtype='float32')
>>> output = np.mean(a, 0) >>> output = np.mean(a, 0)
>>> print(output) >>> print(output)
@ -83,8 +84,8 @@ def inner(a, b):
Note: Note:
Numpy argument out is not supported. Numpy argument out is not supported.
On GPU, the supported dtypes are np.float16, and np.float32. On GPU, the supported dtypes are mstype.float16, and mstype.float32.
On CPU, the supported dtype is np.float32. On CPU, the supported dtype is mstype.float32.
Args: Args:
a (Tensor): input tensor. If a and b are nonscalar, their last a (Tensor): input tensor. If a and b are nonscalar, their last
@ -103,6 +104,7 @@ def inner(a, b):
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore.numpy as np
>>> a = np.ones((5, 3)) >>> a = np.ones((5, 3))
>>> b = np.ones((2, 7, 3)) >>> b = np.ones((2, 7, 3))
>>> output = np.inner(a, b) >>> output = np.inner(a, b)

View File

@ -17,13 +17,12 @@ from functools import partial
import numpy as onp import numpy as onp
import mindspore
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from ..common import Tensor
from mindspore.ops import operations as P from ..ops import operations as P
from mindspore.ops import functional as F from ..ops import functional as F
from mindspore.ops.primitive import constexpr from ..ops.primitive import constexpr
from mindspore.common import dtype as mstype from ..common import dtype as mstype
from .dtypes import dtype_tuple, all_types, dtype_map from .dtypes import dtype_tuple, all_types, dtype_map
@ -119,17 +118,17 @@ def _check_shape(shape):
def _check_dtype(dtype): def _check_dtype(dtype):
"""check the input dtype and make conversions""" """check the input dtype and make conversions"""
# convert the string dtype to mindspore.dtype # convert the string dtype to mstype.dtype
if isinstance(dtype, str): if isinstance(dtype, str):
dtype = dtype.lower() dtype = dtype.lower()
dtype = dtype_map[dtype] dtype = dtype_map[dtype]
elif isinstance(dtype, type): elif isinstance(dtype, type):
if dtype is int: if dtype is int:
dtype = mindspore.int32 dtype = mstype.int32
if dtype is float: if dtype is float:
dtype = mindspore.float32 dtype = mstype.float32
if dtype is bool: if dtype is bool:
dtype = mindspore.bool_ dtype = mstype.bool_
if dtype not in dtype_tuple: if dtype not in dtype_tuple:
raise TypeError( raise TypeError(
f"only {all_types} are allowed for dtype, but got {type(dtype)}") f"only {all_types} are allowed for dtype, but got {type(dtype)}")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""unit tests for array operations""" """unit tests for numpy array operations"""
import functools import functools

View File

@ -1,8 +1,23 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""unit tests for numpy math operations"""
import pytest import pytest
import numpy as onp import numpy as onp
import mindspore.context as context
import mindspore.numpy as mnp import mindspore.numpy as mnp
@ -16,7 +31,6 @@ def rand_int(*shape):
class Cases(): class Cases():
def __init__(self): def __init__(self):
self.device_cpu = context.get_context('device_target') == 'CPU'
self.arrs = [ self.arrs = [
rand_int(2), rand_int(2),