forked from mindspore-Ecosystem/mindspore
!35347 flatten support functional and tensor interface, dynamic shape and vmap
Merge pull request !35347 from polyhedral/flatten
This commit is contained in:
commit
621b48760b
docs/api
api_python
api_python_en
mindspore/python/mindspore/ops
tests/st/ops
|
@ -314,6 +314,7 @@ Array操作
|
|||
mindspore.ops.range
|
||||
mindspore.ops.rank
|
||||
mindspore.ops.reshape
|
||||
mindspore.ops.flatten
|
||||
mindspore.ops.scatter_nd
|
||||
mindspore.ops.select
|
||||
mindspore.ops.shape
|
||||
|
|
|
@ -1,19 +1,8 @@
|
|||
mindspore.ops.Flatten
|
||||
======================
|
||||
|
||||
.. py:class:: mindspore.ops.Flatten
|
||||
.. py:class:: mindspore.ops.Flatten()
|
||||
|
||||
扁平化(Flatten)输入Tensor,不改变0轴的size。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **input_x** (Tensor) - 待扁平化的Tensor,其shape为 :math:`(N, \ldots)`, :math:`N` 表示batch size。
|
||||
|
||||
**输出:**
|
||||
|
||||
Tensor,输出shape为 :math:`(N, X)` 的Tensor,其中 :math:`X` 是余下维度的乘积。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `input_x` 不是Tensor。
|
||||
- **ValueError** - `input_x` 的shape长度小于1。
|
||||
更多参考详见 :func:`mindspore.ops.flatten`。
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.flatten
|
||||
======================
|
||||
|
||||
.. py:function:: mindspore.ops.flatten(input_x)
|
||||
|
||||
扁平化(Flatten)输入Tensor,不改变0轴的size。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_x** (Tensor) - 待扁平化的Tensor,其shape为 :math:`(N, \ldots)`, :math:`N` 表示batch size。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,其shape为 :math:`(N, X)` 的Tensor,其中 :math:`X` 是余下维度的乘积。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `input_x` 不是Tensor。
|
||||
- **ValueError** - `input_x` 的shape长度小于1。
|
|
@ -314,6 +314,7 @@ Array Operation
|
|||
mindspore.ops.range
|
||||
mindspore.ops.rank
|
||||
mindspore.ops.reshape
|
||||
mindspore.ops.flatten
|
||||
mindspore.ops.scatter_nd
|
||||
mindspore.ops.select
|
||||
mindspore.ops.shape
|
||||
|
|
|
@ -375,7 +375,6 @@ from .log1p import _log1p_tbe
|
|||
from .resize_bilinear import _resize_bilinear_tbe
|
||||
from .resize_bilinear_grad import _resize_bilinear_grad_tbe
|
||||
from .flatten import _flatten_tbe
|
||||
from .flatten_ds import _flatten_ds_tbe
|
||||
from .roi_align import _roi_align_tbe
|
||||
from .roi_align_grad import _roi_align_grad_tbe
|
||||
from .bounding_box_decode import _bounding_box_decode_tbe
|
||||
|
|
|
@ -23,6 +23,8 @@ flatten_op_info = TBERegOp("Flatten") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("flatten") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_compile_static(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("axis", "optional", "int", "all", "1") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Flatten op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
flatten_ds_op_info = TBERegOp("Flatten") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("flatten.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("flatten") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("axis", "optional", "int", "all", "1") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(flatten_ds_op_info)
|
||||
def _flatten_ds_tbe():
|
||||
"""Flatten TBE register"""
|
||||
return
|
|
@ -197,6 +197,23 @@ def get_reshape_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Flatten)
|
||||
def get_flatten_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `Flatten` operation."""
|
||||
|
||||
def vmap_rule(x_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
output = prim(x)
|
||||
return (output, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Select)
|
||||
def get_select_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for 'Select' operation."""
|
||||
|
|
|
@ -45,6 +45,7 @@ from .array_func import (
|
|||
rank,
|
||||
reshape,
|
||||
reshape_,
|
||||
flatten,
|
||||
tensor_slice,
|
||||
slice,
|
||||
scalar_to_array,
|
||||
|
|
|
@ -36,6 +36,7 @@ shape_ = P.Shape()
|
|||
rank_ = P.Rank()
|
||||
tensor_shape_ = P.TensorShape()
|
||||
reshape_ = P.Reshape()
|
||||
flatten_ = P.Flatten()
|
||||
tensor_slice = P.Slice()
|
||||
expand_dims_ = P.ExpandDims()
|
||||
transpose_ = P.Transpose()
|
||||
|
@ -785,6 +786,33 @@ def reshape(input_x, input_shape):
|
|||
return reshape_(input_x, input_shape)
|
||||
|
||||
|
||||
def flatten(input_x):
|
||||
r"""
|
||||
Flattens a tensor without changing its batch size on the 0-th axis.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): Tensor of shape :math:`(N, \ldots)` to be flattened, where :math:`N` is batch size.
|
||||
|
||||
Returns:
|
||||
Tensor, the shape of the output tensor is :math:`(N, X)`, where :math:`X` is
|
||||
the product of the remaining dimension.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
ValueError: If length of shape of `input_x` is less than 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
|
||||
>>> output = ops.flatten(input_x)
|
||||
>>> print(output.shape)
|
||||
(1, 24)
|
||||
"""
|
||||
return flatten_(input_x)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
|
||||
if isinstance(scalar, int) and tensor_type != mstype.int32:
|
||||
|
@ -2702,6 +2730,7 @@ __all__ = [
|
|||
'range',
|
||||
'reshape',
|
||||
'reshape_',
|
||||
'flatten',
|
||||
'tensor_slice',
|
||||
'slice',
|
||||
'scalar_cast',
|
||||
|
|
|
@ -911,6 +911,7 @@ tensor_operator_registry.register('pow', P.Pow)
|
|||
tensor_operator_registry.register('mean', P.ReduceMean)
|
||||
tensor_operator_registry.register('round', P.Round)
|
||||
tensor_operator_registry.register('reshape', P.Reshape)
|
||||
tensor_operator_registry.register('flatten', P.Flatten)
|
||||
tensor_operator_registry.register('transpose', P.Transpose)
|
||||
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
|
||||
tensor_operator_registry.register('matmul', P.MatMul)
|
||||
|
|
|
@ -14,98 +14,140 @@
|
|||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
class FlattenNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
super(FlattenNet, self).__init__()
|
||||
self.flatten = P.Flatten()
|
||||
|
||||
def construct(self, tensor):
|
||||
return self.flatten(tensor)
|
||||
|
||||
|
||||
def test_net_int8():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
|
||||
net = Net()
|
||||
def flatten_net(nptype):
|
||||
x = np.random.randn(1, 16, 1, 1).astype(nptype)
|
||||
net = FlattenNet()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
|
||||
|
||||
def test_net_uint8():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
|
||||
net = Net()
|
||||
def flatten_net_int8():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.int8)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.int8)
|
||||
|
||||
|
||||
def flatten_net_uint8():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint8)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint8)
|
||||
|
||||
|
||||
def flatten_net_int16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.int16)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.int16)
|
||||
|
||||
|
||||
def flatten_net_uint16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint16)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint16)
|
||||
|
||||
|
||||
def flatten_net_int32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.int32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.int32)
|
||||
|
||||
|
||||
def flatten_net_uint32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint32)
|
||||
|
||||
|
||||
def flatten_net_int64():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.int64)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.int64)
|
||||
|
||||
|
||||
def flatten_net_uint64():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint64)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.uint64)
|
||||
|
||||
|
||||
def flatten_net_float16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.float16)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.float16)
|
||||
|
||||
|
||||
def flatten_net_float32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net(np.float32)
|
||||
|
||||
|
||||
def flatten_net_dynamic(nptype, mstype):
|
||||
x = np.random.randn(1, 16, 3, 1).astype(nptype)
|
||||
x_dy = Tensor(shape=(1, None, 3, 1), dtype=mstype)
|
||||
net = FlattenNet()
|
||||
net.set_inputs(x_dy)
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
|
||||
|
||||
def test_net_int16():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
def flatten_net_dynamic_float16():
|
||||
# graph mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net_dynamic(np.float16, mindspore.float16)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
flatten_net_dynamic(np.float16, mindspore.float16)
|
||||
|
||||
|
||||
def test_net_uint16():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
def flatten_net_dynamic_float32():
|
||||
# graph mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
flatten_net_dynamic(np.float32, mindspore.float32)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_flatten_net_dynamic(np.float32, mindspore.float32)
|
||||
|
||||
|
||||
def test_net_int32():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
|
||||
|
||||
def test_net_uint32():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
|
||||
|
||||
def test_net_int64():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
|
||||
|
||||
def test_net_uint64():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
|
||||
|
||||
def test_net_float16():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
|
||||
|
||||
def test_net_float32():
|
||||
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
print(output.asnumpy())
|
||||
assert np.all(output.asnumpy() == x.flatten())
|
||||
if __name__ == "__main__":
|
||||
flatten_net_dynamic_float16()
|
||||
flatten_net_dynamic_float32()
|
||||
|
|
|
@ -18,8 +18,10 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class NetFlatten(nn.Cell):
|
||||
|
@ -140,4 +142,74 @@ def test_last_flatten():
|
|||
flatten = NetLastFlatten()
|
||||
output = flatten(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_flatten_tensor_interface():
|
||||
"""
|
||||
Feature: test_flatten_tensor_interface.
|
||||
Description: test cases for tensor interface
|
||||
Expectation: raise TypeError.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
in_np = np.random.randn(1, 16, 3, 1).astype(np.float32)
|
||||
in_tensor = Tensor(in_np)
|
||||
|
||||
output_ms = in_tensor.flatten()
|
||||
output_np = in_np.flatten()
|
||||
|
||||
np.testing.assert_allclose(output_ms.asnumpy(), output_np, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_flatten_functional_interface():
|
||||
"""
|
||||
Feature: test_flatten_functional_interface.
|
||||
Description: test cases for functional interface.
|
||||
Expectation: raise TypeError.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
in_np = np.random.randn(1, 16, 3, 1).astype(np.float32)
|
||||
in_tensor = Tensor(in_np)
|
||||
|
||||
output_ms = F.flatten(in_tensor)
|
||||
output_np = np.reshape(in_np, (1, 48))
|
||||
|
||||
np.testing.assert_allclose(output_ms.asnumpy(), output_np, rtol=1e-3)
|
||||
|
||||
|
||||
def flatten_graph(x):
|
||||
return P.Flatten()(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_flatten_vmap():
|
||||
"""
|
||||
Feature: test flatten vmap.
|
||||
Description: test cases for vmap.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
np.random.seed(0)
|
||||
in_np = np.random.rand(3, 4, 5).astype(np.float32)
|
||||
output_np = np.reshape(in_np, (3, 20))
|
||||
|
||||
in_tensor = Tensor(in_np)
|
||||
vmap_round_net = ops.vmap(flatten_graph)
|
||||
output = vmap_round_net(in_tensor)
|
||||
np.testing.assert_allclose(output.asnumpy(), output_np, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_flatten_tensor_interface()
|
||||
test_flatten_functional_interface()
|
||||
test_flatten_vmap()
|
||||
|
|
Loading…
Reference in New Issue