api: full, chunk

This commit is contained in:
yuzhenhua 2022-12-21 14:21:58 +08:00
parent 5db6dbb8d4
commit 1700634a6d
16 changed files with 344 additions and 2 deletions

View File

@ -412,6 +412,7 @@ Array操作
mindspore.ops.bincount
mindspore.ops.broadcast_to
mindspore.ops.cat
mindspore.ops.chunk
mindspore.ops.col2im
mindspore.ops.concat
mindspore.ops.count_nonzero
@ -426,6 +427,7 @@ Array操作
mindspore.ops.fliplr
mindspore.ops.flipud
mindspore.ops.fold
mindspore.ops.full
mindspore.ops.gather
mindspore.ops.gather_d
mindspore.ops.gather_elements

View File

@ -0,0 +1,6 @@
mindspore.Tensor.chunk
======================
.. py:method:: mindspore.Tensor.chunk(chunks, axis=0)
详情请参考 :func:`mindspore.ops.chunk`

View File

@ -72,6 +72,7 @@ mindspore.Tensor
mindspore.Tensor.cholesky
mindspore.Tensor.cholesky_inverse
mindspore.Tensor.choose
mindspore.Tensor.chunk
mindspore.Tensor.clamp
mindspore.Tensor.clip
mindspore.Tensor.col2im

View File

@ -0,0 +1,21 @@
mindspore.ops.chunk
====================
.. py:function:: mindspore.ops.chunk(x, chunks, axis=0)
根据指定的轴将输入Tensor切分成块。
参数:
- **x** (Tensor) - Tensor的shape为 :math:`(x_1, x_2, ..., x_R)`
- **chunks** (int]) - 要返回的块数。
- **axis** (int) - 指定分割轴。默认值0。
返回:
tuple[Tensor]。
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `axis` 不是int类型。
- **ValueError** - 参数 `axis` 超出 :math:`(-x.dim, x.dim)` 范围。
- **TypeError** - `chunks` 不是int。
- **ValueError** - 参数 `chunks` 不是正数。

View File

@ -0,0 +1,17 @@
mindspore.ops.full
==================
.. py:function:: mindspore.ops.full(size, fill_value, *, dtype=None)
创建一个指定shape的Tensor并用指定值填充。
参数:
- **size** (Union(tuple[int], list[int])) - 指定输出Tensor的shape。
- **fill_value** (number.Number) - 用来填充输出Tensor的值。
- **dtype** (mindspore.dtype) - 指定输出Tensor的数据类型。数据类型只支持 `bool_ <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_`number <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_
返回:
Tensor。
异常:
- **TypeError** - `size` 不是元组。
- **TypeError** - `size` 中包含小于0的成员。

View File

@ -78,6 +78,7 @@
mindspore.Tensor.cholesky
mindspore.Tensor.cholesky_inverse
mindspore.Tensor.choose
mindspore.Tensor.chunk
mindspore.Tensor.clamp
mindspore.Tensor.clip
mindspore.Tensor.col2im

View File

@ -412,6 +412,7 @@ Array Operation
mindspore.ops.bincount
mindspore.ops.broadcast_to
mindspore.ops.cat
mindspore.ops.chunk
mindspore.ops.col2im
mindspore.ops.concat
mindspore.ops.count_nonzero
@ -426,6 +427,7 @@ Array Operation
mindspore.ops.fliplr
mindspore.ops.flipud
mindspore.ops.fold
mindspore.ops.full
mindspore.ops.gather
mindspore.ops.gather_d
mindspore.ops.gather_elements

View File

@ -173,6 +173,7 @@ BuiltInTypeMap &GetMethodMap() {
{"angle", std::string("angle")}, // C.reduce_any
{"any", std::string("any_")}, // C.reduce_any
{"bincount", std::string("bincount")}, // bincount
{"chunk", std::string("chunk")}, // chunk
{"slogdet", std::string("slogdet")}, // slogdet
{"tril", std::string("tril")}, // tril
{"__add__", std::string("add")}, // C.add

View File

@ -299,6 +299,13 @@ def slogdet(x):
return F.slogdet(x)
def chunk(x, chunks, axis=0):
r"""
For details, please refer to :func:`mindspore.ops.chunk`.
"""
return F.chunk(x, chunks, axis)
def tril(x, diagonal=0):
r"""
For details, please refer to :func:`mindspore.ops.tril`.

View File

@ -748,6 +748,13 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('bincount')(self, weights, minlength)
def chunk(self, chunks, axis=0):
r"""
For details, please refer to :func:`mindspore.ops.chunk`.
"""
self._init_check()
return tensor_operator_registry.get('chunk')(self, chunks, axis)
def item(self, index=None):
"""
Get the item at the specified index of the tensor.

View File

@ -35,6 +35,8 @@ from .array_func import (
padding,
fill,
fill_,
full,
chunk,
tile,
size,
ones,

View File

@ -17,7 +17,6 @@
from __future__ import absolute_import
import builtins
import numpy as np
import mindspore.common.dtype as mstype
@ -665,6 +664,113 @@ def fills(x, value):
return fills_(x, value_)
def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-name
"""
Create a Tensor of the specified shape and fill it with the specified value.
Args:
size (Union(tuple[int], list[int])): The specified shape of output tensor.
fill_value (number.Number): Value to fill the returned tensor.
dtype (mindspore.dtype): The specified type of output tensor. The data type only supports
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ and
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
Returns:
Tensor.
Raises:
TypeError: If `size` is not a tuple or list.
TypeError: The element in `size` is less than 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = ops.full((2, 2), 1)
>>> print(output)
[[1. 1.]
[1. 1.]]
>>> output = ops.full((3, 3), 0)
>>> print(output)
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
"""
if not isinstance(size, (list, tuple)):
raise TypeError(f"For 'ops.full', 'size' must be a tuple or list of ints, but got {type(size)}.")
if dtype is None:
dtype = mstype.int64
if dtype not in mstype.all_types:
raise TypeError(f"For 'ops.full', 'dtype' must be mindspore.type, but got {dtype}.")
if isinstance(size, list):
size = tuple(size)
return fill_(dtype, size, fill_value)
def chunk(x, chunks, axis=0):
"""
Splits the Tensor into chunks along the given axis.
Note:
This function may return less then the specified number of chunks!
Args:
x (Tensor): A Tensor to be divided.
chunks (int): Number of chunks to return.
axis (int): The axis along which to split. Default: 0.
Returns:
A tuple of sub-tensors.
Raises:
TypeError: If argument `x` is not Tensor.
TypeError: The sum of `chunks` is not int.
TypeError: If argument `axis` is not int.
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)` .
ValueError: If argument `chunks` is not positive number.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = np.arange(9).astype("float32")
>>> output = ops.chunk(Tensor(input_x), 3)
>>> print(output)
(Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
"""
if not isinstance(x, Tensor):
raise TypeError(f'For ops.chunk parameter `x` must be Tensor, but got {type(x)}')
_ = validator.check_axis_type(axis, True, False, False)
axis = _canonicalize_axis(axis, x.ndim)
if not isinstance(chunks, int):
raise TypeError(f"For ops.chunk type of argument `chunks` should be integer, but got {type(chunks)}")
if chunks <= 0:
raise ValueError(f"For ops.chunk parameter 'chunks' must be greater than 0, but got {chunks}")
arr_shape = x.shape
length_along_dim = arr_shape[axis]
if chunks > length_along_dim:
res = P.Split(axis, length_along_dim)(x)
elif length_along_dim % chunks == 0:
res = P.Split(axis, chunks)(x)
else:
block_size = int(np.ceil(length_along_dim / chunks))
true_chunks = int(length_along_dim // block_size)
length1 = true_chunks * block_size
length2 = length_along_dim - length1
start1 = _list_comprehensions(rank(x), 0, True)
size1 = _tuple_setitem(arr_shape, axis, length1)
start2 = _tuple_setitem(start1, axis, length1)
size2 = _tuple_setitem(arr_shape, axis, length2)
res = P.Split(axis, true_chunks)(tensor_slice(x, start1, size1)) + \
P.Split(axis, 1)(tensor_slice(x, start2, size2))
return res
def ones(shape, dtype=None): # pylint: disable=redefined-outer-name
r"""
Creates a tensor filled with value ones.
@ -4688,7 +4794,7 @@ def split(x, split_size_or_sections, axis=0):
return res
def tril(input_x, diagonal):
def tril(input_x, diagonal=0): # pylint: disable=redefined-outer-name
"""
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,
the other elements of the result tensor out are set to 0.
@ -5796,6 +5902,8 @@ __all__ = [
'reverse',
'reverse_sequence',
'hamming_window',
'chunk',
'full',
'dyn_shape',
'rank',
'range',

View File

@ -135,6 +135,7 @@ tensor_operator_registry.register('rsqrt', rsqrt)
tensor_operator_registry.register('bincount', bincount)
tensor_operator_registry.register('slogdet', slogdet)
tensor_operator_registry.register('tril', tril)
tensor_operator_registry.register('chunk', chunk)
tensor_operator_registry.register('sqrt', sqrt)
tensor_operator_registry.register('square', square)
tensor_operator_registry.register('sub', sub)

View File

@ -0,0 +1,58 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, chunks, axis):
output = ops.chunk(x, chunks, axis)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_chunk_normal(mode):
"""
Feature: ops.chunk
Description: Verify the result of chunk
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = Tensor([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]])
chunks = 6
axis = 1
out = net(x, chunks, axis)
expect_out_1 = np.array([[[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]]])
expect_out_2 = np.array([[[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]]])
assert np.allclose(out[0].asnumpy(), expect_out_1)
assert np.allclose(out[1].asnumpy(), expect_out_2)

View File

@ -0,0 +1,51 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, size, fill_value):
output = ops.full(size, fill_value)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_full_normal(mode):
"""
Feature: ops.full
Description: Verify the result of ops.full
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
size = (1, 2, 3)
fill_value = 11
out = net(size, fill_value)
expect_out = np.array([[[11, 11, 11],
[11, 11, 11]]])
assert np.allclose(out.asnumpy(), expect_out)

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x, chunks, axis):
output = x.chunk(chunks, axis)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_chunk_normal(mode):
"""
Feature: ops.chunk
Description: Verify the result of chunk
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
x = Tensor([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]])
chunks = 6
axis = 1
out = net(x, chunks, axis)
expect_out_1 = np.array([[[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]]])
expect_out_2 = np.array([[[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]]])
assert np.allclose(out[0].asnumpy(), expect_out_1)
assert np.allclose(out[1].asnumpy(), expect_out_2)