!13977 fix numpy docstring errors
From: @yanglf1121 Reviewed-by: @liangchenghui,@c_34 Signed-off-by: @liangchenghui,@c_34
This commit is contained in:
commit
c008eb0aa1
113
RELEASE.md
113
RELEASE.md
|
@ -106,6 +106,119 @@ def construct(self, *inputs):
|
|||
</tr>
|
||||
</table>
|
||||
|
||||
###### `mindspore.numpy.array()`, `mindspore.numpy.asarray()`, `mindspore.numpy.asfarray()`, `mindspore.numpy.copy()` now support GRAPH mode, but cannot accept `numpy.ndarray` as input arguments anymore([!12726](https://gitee.com/mindspore/mindspore/pulls/12726))
|
||||
|
||||
Previously, these interfaces can accept numpy.ndarray as arguments and convert numpy.ndarray to Tensor, but cannot be used in GRAPH mode.
|
||||
However, currently MindSpore Parser cannot parse numpy.ndarray in JIT-graph. To support these interfaces in graph mode, we have to remove `numpy.ndarray` support. With that being said, users can still use `Tensor` to convert `numpy.ndarray` to tensors.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td style="text-align:center"> 1.1.1 </td> <td style="text-align:center"> 1.2.0-rc1 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
```python
|
||||
>>> import mindspore.numpy as mnp
|
||||
>>> import numpy
|
||||
>>>
|
||||
>>> nd_array = numpy.array([1,2,3])
|
||||
>>> tensor = mnp.asarray(nd_array) # this line cannot be parsed in GRAPH mode
|
||||
```
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
```python
|
||||
>>> import mindspore.numpy as mnp
|
||||
>>> import numpy
|
||||
>>>
|
||||
>>> tensor = mnp.asarray([1,2,3]) # this line can be parsed in GRAPH mode
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
###### mindspore.numpy interfaces remove support for keyword arguments `out` and `where`([!12726](https://gitee.com/mindspore/mindspore/pulls/12726))
|
||||
|
||||
Previously, we have incomplete support for keyword arguments `out` and `where` in mindspore.numpy interfaces, however, the `out` argument is only functional when `where` argument is also provided, and `out` cannot be used to pass reference to numpy functions. Therefore, we have removed these two arguments to avoid any confusion users may have. Their original functionality can be found in [np.where](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/numpy/mindspore.numpy.where.html#mindspore.numpy.where)
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td style="text-align:center"> 1.1.1 </td> <td style="text-align:center"> 1.2.0-rc1 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
```python
|
||||
>>> import mindspore.numpy as np
|
||||
>>>
|
||||
>>> a = np.ones((3,3))
|
||||
>>> b = np.ones((3,3))
|
||||
>>> out = np.zeros((3,3))
|
||||
>>> where = np.asarray([[True, False, True],[False, False, True],[True, True, True]])
|
||||
>>> res = np.add(a, b, out=out, where=where) # `out` cannot be used as a reference, therefore it is misleading
|
||||
```
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
```python
|
||||
>>> import mindspore.numpy as np
|
||||
>>>
|
||||
>>> a = np.ones((3,3))
|
||||
>>> b = np.ones((3,3))
|
||||
>>> out = np.zeros((3,3))
|
||||
>>> where = np.asarray([[True, False, True],[False, False, True],[True, True, True]])
|
||||
>>> res = np.add(a, b)
|
||||
>>> out = np.where(where, x=res, y=out) # instead of np.add(a, b, out=out, where=where)
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
#### Deprecations
|
||||
|
||||
##### Python API
|
||||
|
||||
###### `nn.MatMul` is now deprecated in favor of `ops.matmul` ([!12817](https://gitee.com/mindspore/mindspore/pulls/12817))
|
||||
|
||||
[ops.matmul](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/ops/mindspore.ops.matmul.html#mindspore.ops.matmul) follows the API of [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) as closely as possible. As a function interface, [ops.matmul](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/ops/mindspore.ops.matmul.html#mindspore.ops.matmul) is applied without instantiation, as opposed to `nn.MatMul`, which should only be used as a class instance.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td style="text-align:center"> 1.1.1 </td> <td style="text-align:center"> 1.2.0-rc1 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
```python
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor, nn
|
||||
>>>
|
||||
>>> x = Tensor(np.ones((2, 3)).astype(onp.float32)
|
||||
>>> y = Tensor(np.ones((3, 4)).astype(onp.float32)
|
||||
>>> nn.MatMul()(x, y)
|
||||
```
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
```python
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>>
|
||||
>>> x = Tensor(np.ones((2, 3)).astype(onp.float32)
|
||||
>>> y = Tensor(np.ones((3, 4)).astype(onp.float32)
|
||||
>>> ops.matmul(x, y)
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
# MindSpore 1.1.1 Release Notes
|
||||
|
||||
## MindSpore
|
||||
|
|
|
@ -1293,14 +1293,14 @@ def meshgrid(*xi, sparse=False, indexing='xy'):
|
|||
>>> y = np.linspace(0, 1, 2)
|
||||
>>> xv, yv = np.meshgrid(x, y)
|
||||
>>> print(xv)
|
||||
[[0. , 0.5, 1. ],
|
||||
[0. , 0.5, 1. ]]
|
||||
[[0. 0.5 1. ]
|
||||
[0. 0.5 1. ]]
|
||||
>>> print(yv)
|
||||
[[0., 0., 0.],
|
||||
[1., 1., 1.]]
|
||||
[[0. 0. 0.],
|
||||
[1. 1. 1.]]
|
||||
>>> xv, yv = np.meshgrid(x, y, sparse=True)
|
||||
>>> print(xv)
|
||||
[[0. , 0.5, 1. ]]
|
||||
[[0. 0.5 1. ]]
|
||||
>>> print(yv)
|
||||
[[0.],
|
||||
[1.]
|
||||
|
@ -1426,19 +1426,19 @@ class mGridClass(nd_grid):
|
|||
>>> from mindspore.numpy import mgrid
|
||||
>>> output = mgrid[0:5, 0:5]
|
||||
>>> print(output)
|
||||
[[[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
[4, 4, 4, 4, 4]],
|
||||
[[0, 1, 2, 3, 4],
|
||||
[0, 1, 2, 3, 4],
|
||||
[0, 1, 2, 3, 4],
|
||||
[0, 1, 2, 3, 4],
|
||||
[0, 1, 2, 3, 4]]]
|
||||
[[[0 0 0 0 0]
|
||||
[1 1 1 1 1]
|
||||
[2 2 2 2 2]
|
||||
[3 3 3 3 3]
|
||||
[4 4 4 4 4]]
|
||||
[[0 1 2 3 4]
|
||||
[0 1 2 3 4]
|
||||
[0 1 2 3 4]
|
||||
[0 1 2 3 4]
|
||||
[0 1 2 3 4]]]
|
||||
>>> output = mgrid[-1:1:5j]
|
||||
>>> print(output)
|
||||
[-1. , -0.5, 0. , 0.5, 1. ]
|
||||
[-1. -0.5 0. 0.5 1. ]
|
||||
"""
|
||||
def __init__(self):
|
||||
super(mGridClass, self).__init__(sparse=False)
|
||||
|
@ -1473,13 +1473,13 @@ class oGridClass(nd_grid):
|
|||
[Tensor(shape=[5, 1], dtype=Int32, value=
|
||||
[[0],
|
||||
[1],
|
||||
[2],
|
||||
[2]
|
||||
[3],
|
||||
[4]]), Tensor(shape=[1, 5], dtype=Int32, value=
|
||||
[[0, 1, 2, 3, 4]])]
|
||||
>>> output = ogrid[-1:1:5j]
|
||||
>>> print(output)
|
||||
[-1. , -0.5, 0. , 0.5, 1. ]
|
||||
[-1. -0.5 0. 0.5 1. ]
|
||||
"""
|
||||
def __init__(self):
|
||||
super(oGridClass, self).__init__(sparse=True)
|
||||
|
@ -1684,10 +1684,10 @@ def ix_(*args):
|
|||
>>> import mindspore.numpy as np
|
||||
>>> ixgrid = np.ix_(np.array([0, 1]), np.array([2, 4]))
|
||||
>>> print(ixgrid)
|
||||
[Tensor(shape=[2, 1], dtype=Int32, value=
|
||||
(Tensor(shape=[2, 1], dtype=Int32, value=
|
||||
[[0],
|
||||
[1]]), Tensor(shape=[1, 2], dtype=Int32, value=
|
||||
[[2, 4]])]
|
||||
[[2, 4]]))
|
||||
"""
|
||||
# TODO boolean mask
|
||||
_check_input_tensor(*args)
|
||||
|
@ -1784,8 +1784,9 @@ def indices(dimensions, dtype=mstype.int32, sparse=False):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as np
|
||||
>>> grid = np.indices((2, 3))
|
||||
>>> print(indices)
|
||||
>>> print(grid)
|
||||
[Tensor(shape=[2, 3], dtype=Int32, value=
|
||||
[[0, 0, 0],
|
||||
[1, 1, 1]]), Tensor(shape=[2, 3], dtype=Int32, value=
|
||||
|
|
|
@ -492,16 +492,14 @@ def column_stack(tup):
|
|||
ValueError: If `tup` is empty.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as mnp
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore import Tensor
|
||||
>>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32'))
|
||||
>>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32'))
|
||||
>>> output = mnp.column_stack((x1, x2))
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x1 = np.array([1, 2, 3]).astype('int32')
|
||||
>>> x2 = np.array([4, 5, 6]).astype('int32')
|
||||
>>> output = np.column_stack((x1, x2))
|
||||
>>> print(output)
|
||||
[[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]]
|
||||
[[1 4]
|
||||
[2 5]
|
||||
[3 6]]
|
||||
"""
|
||||
if isinstance(tup, Tensor):
|
||||
return tup
|
||||
|
@ -541,15 +539,13 @@ def vstack(tup):
|
|||
ValueError: If `tup` is empty.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.numpy as mnp
|
||||
>>> import numpy as onp
|
||||
>>> from mindspore import Tensor
|
||||
>>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32'))
|
||||
>>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32'))
|
||||
>>> output = mnp.vstack((x1, x2))
|
||||
>>> import mindspore.numpy as np
|
||||
>>> x1 = np.array([1, 2, 3]).astype('int32')
|
||||
>>> x2 = np.array([4, 5, 6]).astype('int32')
|
||||
>>> output = np.vstack((x1, x2))
|
||||
>>> print(output)
|
||||
[[1, 2, 3],
|
||||
[4, 5, 6]]
|
||||
[[1 2 3]
|
||||
[4 5 6]]
|
||||
"""
|
||||
if isinstance(tup, Tensor):
|
||||
return tup
|
||||
|
@ -690,12 +686,13 @@ def where(condition, x=None, y=None):
|
|||
>>> y = np.full((2, 1, 1), 7)
|
||||
>>> output = np.where(condition, x, y)
|
||||
>>> print(output)
|
||||
[[[7, 5],
|
||||
[7, 5],
|
||||
[7, 5]],
|
||||
[[7, 5],
|
||||
[7, 5],
|
||||
[7, 5]]]
|
||||
[[[7 5]
|
||||
[7 5]
|
||||
[7 5]]
|
||||
|
||||
[[7 5]
|
||||
[7 5]
|
||||
[7 5]]]
|
||||
"""
|
||||
# type promotes input tensors
|
||||
dtype1 = F.dtype(x)
|
||||
|
@ -978,7 +975,7 @@ def unique(x, return_inverse=False):
|
|||
>>> input_x = np.asarray([1, 2, 2, 2, 3, 4, 5]).astype('int32')
|
||||
>>> output_x = np.unique(input_x)
|
||||
>>> print(output_x)
|
||||
[1, 2, 3, 4, 5]
|
||||
[1 2 3 4 5]
|
||||
>>> output_x = np.unique(input_x, return_inverse=True)
|
||||
>>> print(output_x)
|
||||
(Tensor(shape=[5], dtype=Int32, value= [ 1, 2, 3, 4, 5]), Tensor(shape=[7], dtype=Int32,
|
||||
|
@ -1055,7 +1052,7 @@ def roll(a, shift, axis=None):
|
|||
Tensor, with the same shape as a.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
|
|
|
@ -57,8 +57,8 @@ def not_equal(x1, x2, dtype=None):
|
|||
>>> a = np.asarray([1, 2])
|
||||
>>> b = np.asarray([[1, 3],[1, 4]])
|
||||
>>> print(np.not_equal(a, b))
|
||||
>>> [[False True]
|
||||
[False True]]
|
||||
[[False True]
|
||||
[False True]]
|
||||
"""
|
||||
_check_input_tensor(x1, x2)
|
||||
return _apply_tensor_op(F.not_equal, x1, x2, dtype=dtype)
|
||||
|
@ -253,9 +253,6 @@ def isfinite(x, dtype=None):
|
|||
>>> output = np.isfinite(np.array([np.inf, 1., np.nan]).astype('float32'))
|
||||
>>> print(output)
|
||||
[False True False]
|
||||
>>> output = np.isfinite(np.log(np.array(-1.).astype('float32')))
|
||||
>>> print(output)
|
||||
False
|
||||
"""
|
||||
return _apply_tensor_op(F.isfinite, x, dtype=dtype)
|
||||
|
||||
|
|
|
@ -258,9 +258,9 @@ def add(x1, x2, dtype=None):
|
|||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.add(x1, x2)
|
||||
>>> print(output)
|
||||
[[4, 6],
|
||||
[4, 6],
|
||||
[4, 6]]
|
||||
[[4 6]
|
||||
[4 6]
|
||||
[4 6]]
|
||||
"""
|
||||
# broadcast is not fully supported in tensor_add on CPU,
|
||||
# so we use tensor_sub as a substitute solution
|
||||
|
@ -297,9 +297,9 @@ def subtract(x1, x2, dtype=None):
|
|||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.subtract(x1, x2)
|
||||
>>> print(output)
|
||||
[[-2, -2],
|
||||
[-2, -2],
|
||||
[-2, -2]]
|
||||
[[-2 -2]
|
||||
[-2 -2]
|
||||
[-2 -2]]
|
||||
"""
|
||||
return _apply_tensor_op(F.tensor_sub, x1, x2, dtype=dtype)
|
||||
|
||||
|
@ -331,9 +331,9 @@ def multiply(x1, x2, dtype=None):
|
|||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.multiply(x1, x2)
|
||||
>>> print(output)
|
||||
[[3, 8],
|
||||
[3, 8],
|
||||
[3, 8]]
|
||||
[[3 8]
|
||||
[3 8]
|
||||
[3 8]]
|
||||
"""
|
||||
if _get_device() == 'CPU':
|
||||
_check_input_tensor(x1, x2)
|
||||
|
@ -374,9 +374,9 @@ def divide(x1, x2, dtype=None):
|
|||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.divide(x1, x2)
|
||||
>>> print(output)
|
||||
[[0.33333333, 0.5],
|
||||
[0.33333333, 0.5],
|
||||
[0.33333333, 0.5]]
|
||||
[[0.33333334 0.5 ]
|
||||
[0.33333334 0.5 ]
|
||||
[0.33333334 0.5 ]]
|
||||
"""
|
||||
if not _check_is_float(F.dtype(x1)) and not _check_is_float(F.dtype(x2)):
|
||||
x1 = F.cast(x1, mstype.float32)
|
||||
|
@ -413,9 +413,9 @@ def true_divide(x1, x2, dtype=None):
|
|||
>>> x2 = np.full((3, 2), [3, 4])
|
||||
>>> output = np.true_divide(x1, x2)
|
||||
>>> print(output)
|
||||
[[0.33333333, 0.5],
|
||||
[0.33333333, 0.5],
|
||||
[0.33333333, 0.5]]
|
||||
[[0.33333334 0.5 ]
|
||||
[0.33333334 0.5 ]
|
||||
[0.33333334 0.5 ]]
|
||||
"""
|
||||
return divide(x1, x2, dtype=dtype)
|
||||
|
||||
|
@ -450,9 +450,9 @@ def power(x1, x2, dtype=None):
|
|||
>>> x2 = np.full((3, 2), [3, 4]).astype('float32')
|
||||
>>> output = np.power(x1, x2)
|
||||
>>> print(output)
|
||||
[[ 1, 16],
|
||||
[ 1, 16],
|
||||
[ 1, 16]]
|
||||
[[ 1 16]
|
||||
[ 1 16]
|
||||
[ 1 16]]
|
||||
"""
|
||||
return _apply_tensor_op(F.tensor_pow, x1, x2, dtype=dtype)
|
||||
|
||||
|
@ -708,8 +708,8 @@ def dot(a, b):
|
|||
>>> b = np.full((2, 3, 4), 5).astype('float32')
|
||||
>>> output = np.dot(a, b)
|
||||
>>> print(output)
|
||||
[[[105, 105, 105, 105],
|
||||
[105, 105, 105, 105]]]
|
||||
[[[105. 105. 105. 105.]
|
||||
[105. 105. 105. 105.]]]
|
||||
"""
|
||||
ndim_a, ndim_b = F.rank(a), F.rank(b)
|
||||
if ndim_a > 0 and ndim_b >= 2:
|
||||
|
@ -760,13 +760,13 @@ def outer(a, b):
|
|||
>>> b = np.full(4, 3).astype('float32')
|
||||
>>> output = np.outer(a, b)
|
||||
>>> print(output)
|
||||
[[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6],
|
||||
[6, 6, 6, 6]]
|
||||
[[6. 6. 6. 6.]
|
||||
[6. 6. 6. 6.]
|
||||
[6. 6. 6. 6.]
|
||||
[6. 6. 6. 6.]
|
||||
[6. 6. 6. 6.]
|
||||
[6. 6. 6. 6.]
|
||||
[6. 6. 6. 6.]]
|
||||
"""
|
||||
_check_input_tensor(a, b)
|
||||
if F.rank(a) != 1:
|
||||
|
@ -1478,7 +1478,7 @@ def amin(a, axis=None, keepdims=False, initial=None, where=True):
|
|||
[0. 1.]
|
||||
>>> output = np.amin(a, axis=1)
|
||||
>>> print(output)
|
||||
[0, 2]
|
||||
[0. 2.]
|
||||
>>> output = np.amin(a, where=np.array([False, True]), initial=10, axis=0)
|
||||
>>> print(output)
|
||||
[10. 1.]
|
||||
|
@ -3733,7 +3733,7 @@ def promote_types(type1, type2):
|
|||
>>> import mindspore.numpy as np
|
||||
>>> output = np.promote_types(np.float32, np.float64)
|
||||
>>> print(output)
|
||||
np.float64
|
||||
Float64
|
||||
"""
|
||||
type1 = _check_dtype(type1)
|
||||
type2 = _check_dtype(type2)
|
||||
|
|
Loading…
Reference in New Issue