forked from mindspore-Ecosystem/mindspore
!42291 [ST][MS][OPS] fold & unfold Functional & Tensor APIs with STs.
Merge pull request !42291 from alashkari/functional-tensor-apis-sept-19
This commit is contained in:
commit
ac8a42d05d
|
@ -314,6 +314,7 @@ Array操作
|
|||
mindspore.ops.dyn_shape
|
||||
mindspore.ops.expand
|
||||
mindspore.ops.expand_dims
|
||||
mindspore.ops.fold
|
||||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
|
@ -360,6 +361,7 @@ Array操作
|
|||
mindspore.ops.tile
|
||||
mindspore.ops.top_k
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unfold
|
||||
mindspore.ops.unique
|
||||
mindspore.ops.unique_consecutive
|
||||
mindspore.ops.unique_with_pad
|
||||
|
|
|
@ -193,6 +193,7 @@ Array操作
|
|||
mindspore.Tensor.dtype
|
||||
mindspore.Tensor.expand_as
|
||||
mindspore.Tensor.expand_dims
|
||||
mindspore.Tensor.fold
|
||||
mindspore.Tensor.gather
|
||||
mindspore.Tensor.gather_elements
|
||||
mindspore.Tensor.gather_nd
|
||||
|
@ -235,6 +236,7 @@ Array操作
|
|||
mindspore.Tensor.to_tensor
|
||||
mindspore.Tensor.trace
|
||||
mindspore.Tensor.transpose
|
||||
mindspore.Tensor.unfold
|
||||
mindspore.Tensor.unique_consecutive
|
||||
mindspore.Tensor.unique_with_pad
|
||||
mindspore.Tensor.unsorted_segment_max
|
||||
|
|
|
@ -198,6 +198,7 @@ Array Methods
|
|||
mindspore.Tensor.dtype
|
||||
mindspore.Tensor.expand_as
|
||||
mindspore.Tensor.expand_dims
|
||||
mindspore.Tensor.fold
|
||||
mindspore.Tensor.gather
|
||||
mindspore.Tensor.gather_elements
|
||||
mindspore.Tensor.gather_nd
|
||||
|
@ -240,6 +241,7 @@ Array Methods
|
|||
mindspore.Tensor.to_tensor
|
||||
mindspore.Tensor.trace
|
||||
mindspore.Tensor.transpose
|
||||
mindspore.Tensor.unfold
|
||||
mindspore.Tensor.unique_consecutive
|
||||
mindspore.Tensor.unique_with_pad
|
||||
mindspore.Tensor.unsorted_segment_max
|
||||
|
|
|
@ -314,6 +314,7 @@ Array Operation
|
|||
mindspore.ops.dyn_shape
|
||||
mindspore.ops.expand
|
||||
mindspore.ops.expand_dims
|
||||
mindspore.ops.fold
|
||||
mindspore.ops.gather
|
||||
mindspore.ops.gather_d
|
||||
mindspore.ops.gather_elements
|
||||
|
@ -360,6 +361,7 @@ Array Operation
|
|||
mindspore.ops.tile
|
||||
mindspore.ops.top_k
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unfold
|
||||
mindspore.ops.unique
|
||||
mindspore.ops.unique_consecutive
|
||||
mindspore.ops.unique_with_pad
|
||||
|
|
|
@ -330,6 +330,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"cross", std::string("cross")}, // cross()
|
||||
{"erfinv", std::string("erfinv")}, // erfinv()
|
||||
{"less_equal", std::string("less_equal")}, // less_equal()
|
||||
{"fold", std::string("fold")}, // fold()
|
||||
{"unfold", std::string("unfold")}, // unfold()
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -3297,4 +3297,17 @@ def less_equal(input, other):
|
|||
Computes the boolean value of :math:`input\_x <= other` element-wise.
|
||||
"""
|
||||
return F.less_equal(input, other)
|
||||
|
||||
|
||||
|
||||
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
||||
r"""
|
||||
Combines an array of sliding local blocks into a large containing tensor.
|
||||
"""
|
||||
return F.fold(input, output_size, kernel_size, dilation, padding, stride)
|
||||
|
||||
|
||||
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
||||
r"""
|
||||
Extracts sliding local blocks from a batched input tensor.
|
||||
"""
|
||||
return F.unfold(input, kernel_size, dilation, padding, stride)
|
||||
|
|
|
@ -6180,6 +6180,96 @@ class Tensor(Tensor_):
|
|||
return tensor_operator_registry.get('less_equal')(self, other)
|
||||
|
||||
|
||||
def fold(self, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
||||
r"""
|
||||
Combines an array of sliding local blocks into a large containing tensor.
|
||||
|
||||
.. warning::
|
||||
- Currently, only 4-D input tensors (batched image-like tensors) are supported.
|
||||
|
||||
Args:
|
||||
output_size (Tensor): 1D tensor with `2` elements of data type int.
|
||||
kernel_size (Union[int, tuple[int], list[int]]): The size of the kernel, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Must be specified.
|
||||
dilation (Union[int, tuple[int], list[int]]): The size of the dilation, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
padding (Union[int, tuple[int], list[int]]): The size of the padding, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 0.
|
||||
stride (Union[int, tuple[int], list[int]]): The size of the stride, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
|
||||
Returns:
|
||||
A Tensor, with same type as input tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If :attr:`kernel_size`, `dilation`, `padding`, `stride` data type is not in
|
||||
Union[int, tuple[int], list[int]].
|
||||
ValueError: If :attr:`kernel_size`, `dilation`, `padding`, `stride` value is not
|
||||
greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`padding` value is less than zero or elements number more than `2`.
|
||||
ValueError: If `(input tensor).shape[2] != kernel_size[0] * kernel_size[1]`.
|
||||
ValueError: If `(input tensor).shape[3]` does not match the calculated number of sliding blocks.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(input_data=np.random.rand(16, 16, 4, 25), dtype=mstype.float32)
|
||||
>>> output_size = Tensor(input_data=[8, 8], dtype=mstype.int32)
|
||||
>>> output = ops.fold(x, output_size, [2, 2], [2, 2], [2, 2], [2, 2])
|
||||
>>> print(output.shape)
|
||||
(16, 16, 8, 8)
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('fold')(self, output_size, kernel_size, dilation, padding, stride)
|
||||
|
||||
|
||||
def unfold(self, kernel_size, dilation=1, padding=0, stride=1):
|
||||
r"""
|
||||
Extracts sliding local blocks from a batched input tensor.
|
||||
|
||||
.. warning::
|
||||
- Currently, only 4-D input tensors (batched image-like tensors) are supported.
|
||||
|
||||
Args:
|
||||
kernel_size (Union[int, tuple[int], list[int]]): The size of the kernel, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Must be specified.
|
||||
dilation (Union[int, tuple[int], list[int]]): The dilation of the window, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
padding (Union[int, tuple[int], list[int]]): The pad of the window, that must be
|
||||
a tuple of one or two or four `int` for height and width.
|
||||
If one int, pad_height = pad_width.
|
||||
If two int, pad_height = padding[0], pad_width = padding[1].
|
||||
If four int, padding = [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]
|
||||
Default: 0.
|
||||
stride (Union[int, tuple[int], list[int]]): The stride of the window, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
|
||||
Returns:
|
||||
A Tensor, with same type as input tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If :attr:`kernel_size` data type is not in Union[int, tuple[int], list[int]].
|
||||
TypeError: If :attr:`stride` data type is not in Union[int, tuple[int], list[int]].
|
||||
TypeError: If :attr:`dilation` data type is not in Union[int, tuple[int], list[int]].
|
||||
ValueError: If :attr:`kernel_size` value is not greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`stride` value is not greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`dilation` value is not greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`padding` value is not greater than zero.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(4, 4, 32, 32), mindspore.float64)
|
||||
>>> output = x.unfold(kernel_size=3, dilation=1, stride=1)
|
||||
>>> print(output.shape)
|
||||
(4, 36, 30, 30)
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('unfold')(self, kernel_size, dilation, padding, stride)
|
||||
|
||||
|
||||
class RowTensor(RowTensor_):
|
||||
"""
|
||||
A sparse representation of a set of tensor slices at given indices.
|
||||
|
|
|
@ -114,7 +114,9 @@ from .array_func import (
|
|||
min,
|
||||
population_count,
|
||||
top_k,
|
||||
expand
|
||||
expand,
|
||||
fold,
|
||||
unfold,
|
||||
)
|
||||
from .parameter_func import (
|
||||
assign,
|
||||
|
|
|
@ -34,11 +34,14 @@ from ..operations.array_ops import (
|
|||
ScatterNdMul,
|
||||
IndexFill,
|
||||
AffineGrid,
|
||||
Im2Col,
|
||||
)
|
||||
from ..operations.nn_ops import AdaptiveMaxPool2D
|
||||
from ..operations.array_ops import TensorScatterElements
|
||||
from ...common import Tensor
|
||||
from .._primitive_cache import _get_cache_prim
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
|
||||
eye_ = P.Eye()
|
||||
fill_ = P.Fill()
|
||||
|
@ -4508,6 +4511,141 @@ def expand(input_x, size):
|
|||
return expand_op(input_x, size)
|
||||
|
||||
|
||||
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
||||
"""
|
||||
Combines an array of sliding local blocks into a large containing tensor.
|
||||
|
||||
.. warning::
|
||||
- Currently, only 4-D input tensors (batched image-like tensors) are supported.
|
||||
|
||||
Args:
|
||||
input (Tensor): a tensor with data type float16 or float.
|
||||
output_size (Tensor): 1D tensor with `2` elements of data type int.
|
||||
kernel_size (Union[int, tuple[int], list[int]]): The size of the kernel, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Must be specified.
|
||||
dilation (Union[int, tuple[int], list[int]]): The size of the dilation, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
padding (Union[int, tuple[int], list[int]]): The size of the padding, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 0.
|
||||
stride (Union[int, tuple[int], list[int]]): The size of the stride, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
|
||||
Returns:
|
||||
A Tensor, with same type as 'input'.
|
||||
|
||||
Raises:
|
||||
TypeError: If :attr:`kernel_size`, `dilation`, `padding`, `stride` data type is not in
|
||||
Union[int, tuple[int], list[int]].
|
||||
ValueError: If :attr:`kernel_size`, `dilation`, `padding`, `stride` value is not
|
||||
greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`padding` value is less than zero or elements number more than `2`.
|
||||
ValueError: If `input.shape[2] != kernel_size[0] * kernel_size[1]`.
|
||||
ValueError: If `input.shape[3]` does not match the calculated number of sliding blocks.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(input_data=np.random.rand(16, 16, 4, 25), dtype=mstype.float32)
|
||||
>>> output_size = Tensor(input_data=[8, 8], dtype=mstype.int32)
|
||||
>>> output = ops.fold(x, output_size, [2, 2], [2, 2], [2, 2], [2, 2])
|
||||
>>> print(output.shape)
|
||||
(16, 16, 8, 8)
|
||||
"""
|
||||
validator.check_value_type('kernel_size', kernel_size, [int, list, tuple], 'fold')
|
||||
validator.check_value_type('dilation', dilation, [int, list, tuple], 'fold')
|
||||
validator.check_value_type('padding', padding, [int, list, tuple], 'fold')
|
||||
validator.check_value_type('stride', stride, [int, list, tuple], 'fold')
|
||||
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
padding = (padding, padding) if isinstance(padding, int) else padding
|
||||
stride = (stride, stride) if isinstance(stride, int) else stride
|
||||
|
||||
validator.check("kernel_size size", len(kernel_size), "", 2, Rel.EQ, 'fold')
|
||||
validator.check_positive_int_sequence(kernel_size, "kernel_size", 'fold')
|
||||
validator.check("dilation size", len(dilation), "", 2, Rel.EQ, 'fold')
|
||||
validator.check_positive_int_sequence(dilation, "dilation", 'fold')
|
||||
validator.check("padding size", len(padding), "", 2, Rel.EQ, 'fold')
|
||||
validator.check_non_negative_int_sequence(padding, "padding", 'fold')
|
||||
validator.check("stride size", len(stride), "", 2, Rel.EQ, 'fold')
|
||||
validator.check_positive_int_sequence(stride, "stride", 'fold')
|
||||
|
||||
fold_op = _get_cache_prim(Col2Im)(kernel_size, dilation, padding, stride)
|
||||
return fold_op(input, output_size)
|
||||
|
||||
|
||||
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
||||
"""
|
||||
Extracts sliding local blocks from a batched input tensor.
|
||||
|
||||
.. warning::
|
||||
- Currently, only 4-D input tensors (batched image-like tensors) are supported.
|
||||
|
||||
Args:
|
||||
input (Tensor): input tensor. Support all real number data type.
|
||||
kernel_size (Union[int, tuple[int], list[int]]): The size of the kernel, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Must be specified.
|
||||
dilation (Union[int, tuple[int], list[int]]): The dilation of the window, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
padding (Union[int, tuple[int], list[int]]): The pad of the window, that must be
|
||||
a tuple of one or two or four `int` for height and width.
|
||||
If one int, pad_height = pad_width.
|
||||
If two int, pad_height = padding[0], pad_width = padding[1].
|
||||
If four int, padding = [pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]
|
||||
Default: 0.
|
||||
stride (Union[int, tuple[int], list[int]]): The stride of the window, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
|
||||
Returns:
|
||||
A Tensor, with same type as 'input'.
|
||||
|
||||
Raises:
|
||||
TypeError: If :attr:`kernel_size` data type is not in Union[int, tuple[int], list[int]].
|
||||
TypeError: If :attr:`stride` data type is not in Union[int, tuple[int], list[int]].
|
||||
TypeError: If :attr:`dilation` data type is not in Union[int, tuple[int], list[int]].
|
||||
ValueError: If :attr:`kernel_size` value is not greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`stride` value is not greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`dilation` value is not greater than zero or elements number more than `2`.
|
||||
ValueError: If :attr:`padding` value is not greater than zero.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(4, 4, 32, 32), mindspore.float64)
|
||||
>>> output = ops.unfold(x, kernel_size=3, dilation=1, stride=1)
|
||||
>>> print(output.shape)
|
||||
(4, 36, 30, 30)
|
||||
"""
|
||||
validator.check_value_type('ksizes', kernel_size, [int, tuple, list], 'unfold')
|
||||
validator.check_value_type('stride', stride, [int, tuple, list], 'unfold')
|
||||
validator.check_value_type('dilation', dilation, [int, tuple, list], 'unfold')
|
||||
validator.check_value_type('padding', padding, [int, tuple, list], 'unfold')
|
||||
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
stride = (stride, stride) if isinstance(stride, int) else stride
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
padding = (padding, padding, padding, padding) if isinstance(padding, int) else padding
|
||||
|
||||
validator.check("ksizes size", len(kernel_size), "", [1, 2], Rel.IN, 'unfold')
|
||||
validator.check_positive_int_sequence(kernel_size, "ksizes", 'unfold')
|
||||
validator.check("stride size", len(stride), "", [1, 2], Rel.IN, 'unfold')
|
||||
validator.check_positive_int_sequence(stride, "stride", 'unfold')
|
||||
validator.check("dilation size", len(dilation), "", [1, 2], Rel.IN, 'unfold')
|
||||
validator.check_positive_int_sequence(dilation, "dilation", 'unfold')
|
||||
|
||||
validator.check("padding size", len(padding), "", [1, 2, 4], Rel.IN, 'unfold')
|
||||
validator.check_non_negative_int_sequence(padding, "padding", 'unfold')
|
||||
|
||||
unfold_op = _get_cache_prim(Im2Col)(ksizes=kernel_size,
|
||||
strides=stride,
|
||||
dilations=dilation,
|
||||
padding_mode="CALCULATED",
|
||||
pads=padding)
|
||||
return unfold_op(input)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'unique_with_pad',
|
||||
|
@ -4597,6 +4735,8 @@ __all__ = [
|
|||
'unsorted_segment_sum',
|
||||
'population_count',
|
||||
'top_k',
|
||||
'expand'
|
||||
'expand',
|
||||
'fold',
|
||||
'unfold',
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -421,6 +421,8 @@ tensor_operator_registry.register('conj', conj)
|
|||
tensor_operator_registry.register('cross', cross)
|
||||
tensor_operator_registry.register('erfinv', erfinv)
|
||||
tensor_operator_registry.register('less_equal', less_equal)
|
||||
tensor_operator_registry.register('fold', fold)
|
||||
tensor_operator_registry.register('unfold', unfold)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
def test_fold_functional_api():
|
||||
"""
|
||||
Feature: test fold functional API.
|
||||
Description: test case for fold functional API.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones([16, 16, 4, 25]), mstype.float32)
|
||||
output_size = Tensor([8, 8], mstype.int32)
|
||||
output = F.fold(x, output_size, kernel_size=[2, 2], dilation=[2, 2], padding=[2, 2], stride=[2, 2])
|
||||
expected_shape = (16, 16, 8, 8)
|
||||
assert output.dtype == x.dtype
|
||||
assert output.shape == expected_shape
|
||||
|
||||
|
||||
def test_fold_tensor_api():
|
||||
"""
|
||||
Feature: test fold tensor API.
|
||||
Description: test case for fold tensor API.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones([16, 16, 4, 25]), mstype.float32)
|
||||
output_size = Tensor([8, 8], mstype.int32)
|
||||
output = x.fold(output_size, kernel_size=[2, 2], dilation=[2, 2], padding=[2, 2], stride=[2, 2])
|
||||
expected_shape = (16, 16, 8, 8)
|
||||
assert output.dtype == x.dtype
|
||||
assert output.shape == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_fold_tensor_functional_api_modes():
|
||||
"""
|
||||
Feature: test fold tensor and functional APIs for different modes.
|
||||
Description: test case for fold tensor API.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_fold_functional_api()
|
||||
test_fold_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_fold_functional_api()
|
||||
test_fold_tensor_api()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fold_tensor_functional_api_modes()
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
def test_fold_functional_api():
|
||||
"""
|
||||
Feature: test fold functional API.
|
||||
Description: test case for fold functional API.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones([16, 16, 4, 25]), mstype.float32)
|
||||
output_size = Tensor([8, 8], mstype.int32)
|
||||
output = F.fold(x, output_size, kernel_size=[2, 2], dilation=[2, 2], padding=[2, 2], stride=[2, 2])
|
||||
expected_shape = (16, 16, 8, 8)
|
||||
assert output.dtype == x.dtype
|
||||
assert output.shape == expected_shape
|
||||
|
||||
|
||||
def test_fold_tensor_api():
|
||||
"""
|
||||
Feature: test fold tensor API.
|
||||
Description: test case for fold tensor API.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
x = Tensor(np.ones([16, 16, 4, 25]), mstype.float32)
|
||||
output_size = Tensor([8, 8], mstype.int32)
|
||||
output = x.fold(output_size, kernel_size=[2, 2], dilation=[2, 2], padding=[2, 2], stride=[2, 2])
|
||||
expected_shape = (16, 16, 8, 8)
|
||||
assert output.dtype == x.dtype
|
||||
assert output.shape == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fold_tensor_functional_api_modes():
|
||||
"""
|
||||
Feature: test fold tensor and functional APIs for different modes.
|
||||
Description: test case for fold tensor API.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_fold_functional_api()
|
||||
test_fold_tensor_api()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_fold_functional_api()
|
||||
test_fold_tensor_api()
|
Loading…
Reference in New Issue