forked from mindspore-Ecosystem/mindspore
api: full, chunk
This commit is contained in:
parent
5db6dbb8d4
commit
1700634a6d
|
@ -412,6 +412,7 @@ Array操作
|
|||
mindspore.ops.bincount
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.cat
|
||||
mindspore.ops.chunk
|
||||
mindspore.ops.col2im
|
||||
mindspore.ops.concat
|
||||
mindspore.ops.count_nonzero
|
||||
|
@ -426,6 +427,7 @@ Array操作
|
|||
mindspore.ops.fliplr
|
||||
mindspore.ops.flipud
|
||||
mindspore.ops.fold
|
||||
mindspore.ops.full
|
||||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.chunk
|
||||
======================
|
||||
|
||||
.. py:method:: mindspore.Tensor.chunk(chunks, axis=0)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.chunk`。
|
|
@ -72,6 +72,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.cholesky
|
||||
mindspore.Tensor.cholesky_inverse
|
||||
mindspore.Tensor.choose
|
||||
mindspore.Tensor.chunk
|
||||
mindspore.Tensor.clamp
|
||||
mindspore.Tensor.clip
|
||||
mindspore.Tensor.col2im
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.ops.chunk
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.chunk(x, chunks, axis=0)
|
||||
|
||||
根据指定的轴将输入Tensor切分成块。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - Tensor的shape为 :math:`(x_1, x_2, ..., x_R)` 。
|
||||
- **chunks** (int]) - 要返回的块数。
|
||||
- **axis** (int) - 指定分割轴。默认值:0。
|
||||
|
||||
返回:
|
||||
tuple[Tensor]。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `axis` 不是int类型。
|
||||
- **ValueError** - 参数 `axis` 超出 :math:`(-x.dim, x.dim)` 范围。
|
||||
- **TypeError** - `chunks` 不是int。
|
||||
- **ValueError** - 参数 `chunks` 不是正数。
|
|
@ -0,0 +1,17 @@
|
|||
mindspore.ops.full
|
||||
==================
|
||||
|
||||
.. py:function:: mindspore.ops.full(size, fill_value, *, dtype=None)
|
||||
|
||||
创建一个指定shape的Tensor,并用指定值填充。
|
||||
|
||||
参数:
|
||||
- **size** (Union(tuple[int], list[int])) - 指定输出Tensor的shape。
|
||||
- **fill_value** (number.Number) - 用来填充输出Tensor的值。
|
||||
- **dtype** (mindspore.dtype) - 指定输出Tensor的数据类型。数据类型只支持 `bool_ <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_ 和 `number <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_ 。
|
||||
返回:
|
||||
Tensor。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `size` 不是元组。
|
||||
- **TypeError** - `size` 中包含小于0的成员。
|
|
@ -78,6 +78,7 @@
|
|||
mindspore.Tensor.cholesky
|
||||
mindspore.Tensor.cholesky_inverse
|
||||
mindspore.Tensor.choose
|
||||
mindspore.Tensor.chunk
|
||||
mindspore.Tensor.clamp
|
||||
mindspore.Tensor.clip
|
||||
mindspore.Tensor.col2im
|
||||
|
|
|
@ -412,6 +412,7 @@ Array Operation
|
|||
mindspore.ops.bincount
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.cat
|
||||
mindspore.ops.chunk
|
||||
mindspore.ops.col2im
|
||||
mindspore.ops.concat
|
||||
mindspore.ops.count_nonzero
|
||||
|
@ -426,6 +427,7 @@ Array Operation
|
|||
mindspore.ops.fliplr
|
||||
mindspore.ops.flipud
|
||||
mindspore.ops.fold
|
||||
mindspore.ops.full
|
||||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
|
|
|
@ -173,6 +173,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"angle", std::string("angle")}, // C.reduce_any
|
||||
{"any", std::string("any_")}, // C.reduce_any
|
||||
{"bincount", std::string("bincount")}, // bincount
|
||||
{"chunk", std::string("chunk")}, // chunk
|
||||
{"slogdet", std::string("slogdet")}, // slogdet
|
||||
{"tril", std::string("tril")}, // tril
|
||||
{"__add__", std::string("add")}, // C.add
|
||||
|
|
|
@ -299,6 +299,13 @@ def slogdet(x):
|
|||
return F.slogdet(x)
|
||||
|
||||
|
||||
def chunk(x, chunks, axis=0):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.chunk`.
|
||||
"""
|
||||
return F.chunk(x, chunks, axis)
|
||||
|
||||
|
||||
def tril(x, diagonal=0):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.tril`.
|
||||
|
|
|
@ -748,6 +748,13 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('bincount')(self, weights, minlength)
|
||||
|
||||
def chunk(self, chunks, axis=0):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.chunk`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('chunk')(self, chunks, axis)
|
||||
|
||||
def item(self, index=None):
|
||||
"""
|
||||
Get the item at the specified index of the tensor.
|
||||
|
|
|
@ -35,6 +35,8 @@ from .array_func import (
|
|||
padding,
|
||||
fill,
|
||||
fill_,
|
||||
full,
|
||||
chunk,
|
||||
tile,
|
||||
size,
|
||||
ones,
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import builtins
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -665,6 +664,113 @@ def fills(x, value):
|
|||
return fills_(x, value_)
|
||||
|
||||
|
||||
def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-name
|
||||
"""
|
||||
Create a Tensor of the specified shape and fill it with the specified value.
|
||||
|
||||
Args:
|
||||
size (Union(tuple[int], list[int])): The specified shape of output tensor.
|
||||
fill_value (number.Number): Value to fill the returned tensor.
|
||||
dtype (mindspore.dtype): The specified type of output tensor. The data type only supports
|
||||
`bool_ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ and
|
||||
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `size` is not a tuple or list.
|
||||
TypeError: The element in `size` is less than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> output = ops.full((2, 2), 1)
|
||||
>>> print(output)
|
||||
[[1. 1.]
|
||||
[1. 1.]]
|
||||
>>> output = ops.full((3, 3), 0)
|
||||
>>> print(output)
|
||||
[[0. 0. 0.]
|
||||
[0. 0. 0.]
|
||||
[0. 0. 0.]]
|
||||
"""
|
||||
if not isinstance(size, (list, tuple)):
|
||||
raise TypeError(f"For 'ops.full', 'size' must be a tuple or list of ints, but got {type(size)}.")
|
||||
if dtype is None:
|
||||
dtype = mstype.int64
|
||||
if dtype not in mstype.all_types:
|
||||
raise TypeError(f"For 'ops.full', 'dtype' must be mindspore.type, but got {dtype}.")
|
||||
if isinstance(size, list):
|
||||
size = tuple(size)
|
||||
return fill_(dtype, size, fill_value)
|
||||
|
||||
|
||||
def chunk(x, chunks, axis=0):
|
||||
"""
|
||||
Splits the Tensor into chunks along the given axis.
|
||||
|
||||
Note:
|
||||
This function may return less then the specified number of chunks!
|
||||
|
||||
Args:
|
||||
x (Tensor): A Tensor to be divided.
|
||||
chunks (int): Number of chunks to return.
|
||||
axis (int): The axis along which to split. Default: 0.
|
||||
|
||||
Returns:
|
||||
A tuple of sub-tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument `x` is not Tensor.
|
||||
TypeError: The sum of `chunks` is not int.
|
||||
TypeError: If argument `axis` is not int.
|
||||
ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)` .
|
||||
ValueError: If argument `chunks` is not positive number.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = np.arange(9).astype("float32")
|
||||
>>> output = ops.chunk(Tensor(input_x), 3)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3], dtype=Float32, value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]),
|
||||
Tensor(shape=[3], dtype=Float32, value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]),
|
||||
Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]))
|
||||
"""
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f'For ops.chunk parameter `x` must be Tensor, but got {type(x)}')
|
||||
_ = validator.check_axis_type(axis, True, False, False)
|
||||
axis = _canonicalize_axis(axis, x.ndim)
|
||||
|
||||
if not isinstance(chunks, int):
|
||||
raise TypeError(f"For ops.chunk type of argument `chunks` should be integer, but got {type(chunks)}")
|
||||
if chunks <= 0:
|
||||
raise ValueError(f"For ops.chunk parameter 'chunks' must be greater than 0, but got {chunks}")
|
||||
|
||||
arr_shape = x.shape
|
||||
length_along_dim = arr_shape[axis]
|
||||
|
||||
if chunks > length_along_dim:
|
||||
res = P.Split(axis, length_along_dim)(x)
|
||||
elif length_along_dim % chunks == 0:
|
||||
res = P.Split(axis, chunks)(x)
|
||||
else:
|
||||
block_size = int(np.ceil(length_along_dim / chunks))
|
||||
true_chunks = int(length_along_dim // block_size)
|
||||
length1 = true_chunks * block_size
|
||||
length2 = length_along_dim - length1
|
||||
start1 = _list_comprehensions(rank(x), 0, True)
|
||||
size1 = _tuple_setitem(arr_shape, axis, length1)
|
||||
start2 = _tuple_setitem(start1, axis, length1)
|
||||
size2 = _tuple_setitem(arr_shape, axis, length2)
|
||||
res = P.Split(axis, true_chunks)(tensor_slice(x, start1, size1)) + \
|
||||
P.Split(axis, 1)(tensor_slice(x, start2, size2))
|
||||
return res
|
||||
|
||||
|
||||
def ones(shape, dtype=None): # pylint: disable=redefined-outer-name
|
||||
r"""
|
||||
Creates a tensor filled with value ones.
|
||||
|
@ -4688,7 +4794,7 @@ def split(x, split_size_or_sections, axis=0):
|
|||
return res
|
||||
|
||||
|
||||
def tril(input_x, diagonal):
|
||||
def tril(input_x, diagonal=0): # pylint: disable=redefined-outer-name
|
||||
"""
|
||||
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,
|
||||
the other elements of the result tensor out are set to 0.
|
||||
|
@ -5796,6 +5902,8 @@ __all__ = [
|
|||
'reverse',
|
||||
'reverse_sequence',
|
||||
'hamming_window',
|
||||
'chunk',
|
||||
'full',
|
||||
'dyn_shape',
|
||||
'rank',
|
||||
'range',
|
||||
|
|
|
@ -135,6 +135,7 @@ tensor_operator_registry.register('rsqrt', rsqrt)
|
|||
tensor_operator_registry.register('bincount', bincount)
|
||||
tensor_operator_registry.register('slogdet', slogdet)
|
||||
tensor_operator_registry.register('tril', tril)
|
||||
tensor_operator_registry.register('chunk', chunk)
|
||||
tensor_operator_registry.register('sqrt', sqrt)
|
||||
tensor_operator_registry.register('square', square)
|
||||
tensor_operator_registry.register('sub', sub)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, chunks, axis):
|
||||
output = ops.chunk(x, chunks, axis)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_chunk_normal(mode):
|
||||
"""
|
||||
Feature: ops.chunk
|
||||
Description: Verify the result of chunk
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]])
|
||||
chunks = 6
|
||||
axis = 1
|
||||
out = net(x, chunks, axis)
|
||||
expect_out_1 = np.array([[[[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]]]])
|
||||
expect_out_2 = np.array([[[[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]]]])
|
||||
assert np.allclose(out[0].asnumpy(), expect_out_1)
|
||||
assert np.allclose(out[1].asnumpy(), expect_out_2)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, size, fill_value):
|
||||
output = ops.full(size, fill_value)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_full_normal(mode):
|
||||
"""
|
||||
Feature: ops.full
|
||||
Description: Verify the result of ops.full
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
size = (1, 2, 3)
|
||||
fill_value = 11
|
||||
out = net(size, fill_value)
|
||||
expect_out = np.array([[[11, 11, 11],
|
||||
[11, 11, 11]]])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, chunks, axis):
|
||||
output = x.chunk(chunks, axis)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_chunk_normal(mode):
|
||||
"""
|
||||
Feature: ops.chunk
|
||||
Description: Verify the result of chunk
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor([[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]])
|
||||
chunks = 6
|
||||
axis = 1
|
||||
out = net(x, chunks, axis)
|
||||
expect_out_1 = np.array([[[[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]]]])
|
||||
expect_out_2 = np.array([[[[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]]]])
|
||||
assert np.allclose(out[0].asnumpy(), expect_out_1)
|
||||
assert np.allclose(out[1].asnumpy(), expect_out_2)
|
Loading…
Reference in New Issue