forked from mindspore-Ecosystem/mindspore
commit
2ff013c3cc
|
@ -270,6 +270,7 @@ mindspore.ops
|
|||
mindspore.ops.positive
|
||||
mindspore.ops.pow
|
||||
mindspore.ops.rad2deg
|
||||
mindspore.ops.ravel
|
||||
mindspore.ops.real
|
||||
mindspore.ops.reciprocal
|
||||
mindspore.ops.remainder
|
||||
|
@ -454,6 +455,7 @@ Array操作
|
|||
mindspore.ops.flipud
|
||||
mindspore.ops.fold
|
||||
mindspore.ops.full
|
||||
mindspore.ops.full_like
|
||||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.full_like
|
||||
=======================
|
||||
|
||||
.. py:function:: mindspore.ops.full_like(x, fill_value, *, dtype=None)
|
||||
|
||||
返回一个与输入相同大小的Tensor,填充 `fill_value`。'ops.full_like(x, fill_value)'相当于'ops.full(x.shape, fill_value, dtype=x.dtype)'。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - `x` 的shape决定输出Tensor的shape。
|
||||
- **fill_value** (number.Number) - 用来填充输出Tensor的值。
|
||||
|
||||
关键字参数:
|
||||
- **dtype** (mindspore.dtype) - 指定输出Tensor的数据类型。数据类型只支持 `bool_` 和 `number ` ,更多细节详见 :class:`mindspore.dtype` 。默认值:None。
|
||||
|
||||
返回:
|
||||
Tensor。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor。
|
|
@ -0,0 +1,15 @@
|
|||
mindspore.ops.ravel
|
||||
======================
|
||||
|
||||
.. py:method:: mindspore.ops.ravel(x)
|
||||
|
||||
返回一个展开的一维Tensor。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 要展开的Tensor。
|
||||
|
||||
返回:
|
||||
一维Tensor,含有与输入相同的元素。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 如果 `x` 不是Tensor。
|
|
@ -270,6 +270,7 @@ Element-by-Element Operations
|
|||
mindspore.ops.positive
|
||||
mindspore.ops.pow
|
||||
mindspore.ops.rad2deg
|
||||
mindspore.ops.ravel
|
||||
mindspore.ops.real
|
||||
mindspore.ops.reciprocal
|
||||
mindspore.ops.remainder
|
||||
|
@ -454,6 +455,7 @@ Array Operation
|
|||
mindspore.ops.flipud
|
||||
mindspore.ops.fold
|
||||
mindspore.ops.full
|
||||
mindspore.ops.full_like
|
||||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
|
|
|
@ -36,6 +36,7 @@ from .array_func import (
|
|||
fill,
|
||||
fill_,
|
||||
full,
|
||||
full_like,
|
||||
chunk,
|
||||
tile,
|
||||
size,
|
||||
|
@ -49,6 +50,7 @@ from .array_func import (
|
|||
dyn_shape,
|
||||
rank,
|
||||
hamming_window,
|
||||
ravel,
|
||||
reshape,
|
||||
reshape_,
|
||||
reverse,
|
||||
|
|
|
@ -464,6 +464,33 @@ def reverse(x, axis):
|
|||
return P.ReverseV2(axis)(x)
|
||||
|
||||
|
||||
def ravel(x):
|
||||
"""
|
||||
Return a contiguous flattened tensor.
|
||||
|
||||
Args:
|
||||
x (Tensor): A tensor to be flattened.
|
||||
|
||||
Outputs:
|
||||
Tensor, a 1-D tensor, containing the same elements of the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If argument `x` is not Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
|
||||
>>> output = ops.ravel(x)
|
||||
>>> print(output)
|
||||
[0, 1, 2, 1]
|
||||
>>> print(output.shape)
|
||||
(4,)
|
||||
"""
|
||||
return ops.reshape(x, (-1,))
|
||||
|
||||
|
||||
def matrix_band_part(x, lower, upper):
|
||||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
@ -713,6 +740,46 @@ def full(size, fill_value, *, dtype=None): # pylint: disable=redefined-outer-nam
|
|||
return fill_(dtype, size, fill_value)
|
||||
|
||||
|
||||
def full_like(x, fill_value, *, dtype=None):
|
||||
"""
|
||||
Returns a Tensor with the same size as `x` filled with `fill_value`. 'ops.full_like(x, fill_value)' is
|
||||
equivalent to 'ops.full(x.shape, fill_value, dtype=x.dtype)'.
|
||||
|
||||
Args:
|
||||
x (Tensor): The shape of `x` will determine shape of the output Tensor.
|
||||
fill_value (number.Number): Value to fill the returned Tensor.
|
||||
|
||||
Keyword Args:
|
||||
dtype (mindspore.dtype, optional): The specified type of output tensor. The data type only supports
|
||||
`bool_` and `number` , for details, please refer to :class:`mindspore.dtype` . Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> output = ops.full((2, 2), 1)
|
||||
>>> print(output)
|
||||
[[1. 1.]
|
||||
[1. 1.]]
|
||||
>>> output = ops.full((3, 3), 0)
|
||||
>>> print(output)
|
||||
[[0. 0. 0.]
|
||||
[0. 0. 0.]
|
||||
[0. 0. 0.]]
|
||||
"""
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f"For ops.full_like, the argument 'x' must be tensor, but got {type(x)}")
|
||||
if dtype is None:
|
||||
dtype = x.dtype
|
||||
return full(x.shape, fill_value, dtype=dtype)
|
||||
|
||||
|
||||
def chunk(x, chunks, axis=0):
|
||||
"""
|
||||
Splits the Tensor into chunks along the given axis.
|
||||
|
@ -6195,6 +6262,7 @@ __all__ = [
|
|||
'hamming_window',
|
||||
'chunk',
|
||||
'full',
|
||||
'full_like',
|
||||
'dyn_shape',
|
||||
'rank',
|
||||
'range',
|
||||
|
@ -6245,6 +6313,7 @@ __all__ = [
|
|||
'masked_select',
|
||||
'where',
|
||||
'narrow',
|
||||
'ravel',
|
||||
'scatter_add',
|
||||
'scatter_mul',
|
||||
'scatter_max',
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, fill_value):
|
||||
output = ops.full_like(x, fill_value)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_full_normal(mode):
|
||||
"""
|
||||
Feature: ops.full
|
||||
Description: Verify the result of ops.full
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]])
|
||||
fill_value = 11
|
||||
out = net(x, fill_value)
|
||||
expect_out = np.array([[[11, 11, 11, 11], [11, 11, 11, 11], [11, 11, 11, 11]]])
|
||||
assert np.allclose(out.asnumpy(), expect_out)
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return ops.ravel(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_ravel(mode):
|
||||
"""
|
||||
Feature: ops.ravel
|
||||
Description: Verify the result of ravel
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
x = Tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], ms.int32)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
expect_output = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
Loading…
Reference in New Issue