flatten support functional and tensor interface, dynamic shape and vmap

Merge pull request  from polyhedral/flatten
This commit is contained in:
i-robot 2022-06-06 09:50:51 +00:00 committed by Gitee
commit 621b48760b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 256 additions and 129 deletions

View File

@ -314,6 +314,7 @@ Array操作
mindspore.ops.range
mindspore.ops.rank
mindspore.ops.reshape
mindspore.ops.flatten
mindspore.ops.scatter_nd
mindspore.ops.select
mindspore.ops.shape

View File

@ -1,19 +1,8 @@
mindspore.ops.Flatten
======================
.. py:class:: mindspore.ops.Flatten
.. py:class:: mindspore.ops.Flatten()
扁平化Flatten输入Tensor不改变0轴的size。
**输入:**
- **input_x** (Tensor) - 待扁平化的Tensor其shape为 :math:`(N, \ldots)` :math:`N` 表示batch size。
**输出:**
Tensor输出shape为 :math:`(N, X)` 的Tensor其中 :math:`X` 是余下维度的乘积。
**异常:**
- **TypeError** - `input_x` 不是Tensor。
- **ValueError** - `input_x` 的shape长度小于1。
更多参考详见 :func:`mindspore.ops.flatten`

View File

@ -0,0 +1,19 @@
mindspore.ops.flatten
======================
.. py:function:: mindspore.ops.flatten(input_x)
扁平化Flatten输入Tensor不改变0轴的size。
**参数:**
- **input_x** (Tensor) - 待扁平化的Tensor其shape为 :math:`(N, \ldots)` :math:`N` 表示batch size。
**返回:**
Tensor其shape为 :math:`(N, X)` 的Tensor其中 :math:`X` 是余下维度的乘积。
**异常:**
- **TypeError** - `input_x` 不是Tensor。
- **ValueError** - `input_x` 的shape长度小于1。

View File

@ -314,6 +314,7 @@ Array Operation
mindspore.ops.range
mindspore.ops.rank
mindspore.ops.reshape
mindspore.ops.flatten
mindspore.ops.scatter_nd
mindspore.ops.select
mindspore.ops.shape

View File

@ -375,7 +375,6 @@ from .log1p import _log1p_tbe
from .resize_bilinear import _resize_bilinear_tbe
from .resize_bilinear_grad import _resize_bilinear_grad_tbe
from .flatten import _flatten_tbe
from .flatten_ds import _flatten_ds_tbe
from .roi_align import _roi_align_tbe
from .roi_align_grad import _roi_align_grad_tbe
from .bounding_box_decode import _bounding_box_decode_tbe

View File

@ -23,6 +23,8 @@ flatten_op_info = TBERegOp("Flatten") \
.compute_cost(10) \
.kernel_name("flatten") \
.partial_flag(True) \
.dynamic_compile_static(True) \
.dynamic_shape(True) \
.attr("axis", "optional", "int", "all", "1") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \

View File

@ -1,46 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Flatten op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
flatten_ds_op_info = TBERegOp("Flatten") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("flatten.so") \
.compute_cost(10) \
.kernel_name("flatten") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("axis", "optional", "int", "all", "1") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(flatten_ds_op_info)
def _flatten_ds_tbe():
"""Flatten TBE register"""
return

View File

@ -197,6 +197,23 @@ def get_reshape_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.Flatten)
def get_flatten_vmap_rule(prim, axis_size):
"""VmapRule for `Flatten` operation."""
def vmap_rule(x_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
x = _bdim_at_front(x, x_dim, axis_size)
output = prim(x)
return (output, 0)
return vmap_rule
@vmap_rules_getters.register(P.Select)
def get_select_vmap_rule(prim, axis_size):
"""VmapRule for 'Select' operation."""

View File

@ -45,6 +45,7 @@ from .array_func import (
rank,
reshape,
reshape_,
flatten,
tensor_slice,
slice,
scalar_to_array,

View File

@ -36,6 +36,7 @@ shape_ = P.Shape()
rank_ = P.Rank()
tensor_shape_ = P.TensorShape()
reshape_ = P.Reshape()
flatten_ = P.Flatten()
tensor_slice = P.Slice()
expand_dims_ = P.ExpandDims()
transpose_ = P.Transpose()
@ -785,6 +786,33 @@ def reshape(input_x, input_shape):
return reshape_(input_x, input_shape)
def flatten(input_x):
r"""
Flattens a tensor without changing its batch size on the 0-th axis.
Args:
input_x (Tensor): Tensor of shape :math:`(N, \ldots)` to be flattened, where :math:`N` is batch size.
Returns:
Tensor, the shape of the output tensor is :math:`(N, X)`, where :math:`X` is
the product of the remaining dimension.
Raises:
TypeError: If `input_x` is not a Tensor.
ValueError: If length of shape of `input_x` is less than 1.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
>>> output = ops.flatten(input_x)
>>> print(output.shape)
(1, 24)
"""
return flatten_(input_x)
@constexpr
def _check_select_type_match(scalar, tensor_type, scalar_name, tensor_name):
if isinstance(scalar, int) and tensor_type != mstype.int32:
@ -2702,6 +2730,7 @@ __all__ = [
'range',
'reshape',
'reshape_',
'flatten',
'tensor_slice',
'slice',
'scalar_cast',

View File

@ -911,6 +911,7 @@ tensor_operator_registry.register('pow', P.Pow)
tensor_operator_registry.register('mean', P.ReduceMean)
tensor_operator_registry.register('round', P.Round)
tensor_operator_registry.register('reshape', P.Reshape)
tensor_operator_registry.register('flatten', P.Flatten)
tensor_operator_registry.register('transpose', P.Transpose)
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
tensor_operator_registry.register('matmul', P.MatMul)

View File

@ -14,98 +14,140 @@
# ============================================================================
import numpy as np
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
class FlattenNet(nn.Cell):
def __init__(self):
super(Net, self).__init__()
super(FlattenNet, self).__init__()
self.flatten = P.Flatten()
def construct(self, tensor):
return self.flatten(tensor)
def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
def flatten_net(nptype):
x = np.random.randn(1, 16, 1, 1).astype(nptype)
net = FlattenNet()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
def flatten_net_int8():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.int8)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.int8)
def flatten_net_uint8():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.uint8)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.uint8)
def flatten_net_int16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.int16)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.int16)
def flatten_net_uint16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.uint16)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.uint16)
def flatten_net_int32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.int32)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.int32)
def flatten_net_uint32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.uint32)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.uint32)
def flatten_net_int64():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.int64)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.int64)
def flatten_net_uint64():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.uint64)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.uint64)
def flatten_net_float16():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.float16)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.float16)
def flatten_net_float32():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net(np.float32)
def flatten_net_dynamic(nptype, mstype):
x = np.random.randn(1, 16, 3, 1).astype(nptype)
x_dy = Tensor(shape=(1, None, 3, 1), dtype=mstype)
net = FlattenNet()
net.set_inputs(x_dy)
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def flatten_net_dynamic_float16():
# graph mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net_dynamic(np.float16, mindspore.float16)
# pynative mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
flatten_net_dynamic(np.float16, mindspore.float16)
def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def flatten_net_dynamic_float32():
# graph mode
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
flatten_net_dynamic(np.float32, mindspore.float32)
# pynative mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_flatten_net_dynamic(np.float32, mindspore.float32)
def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.flatten())
if __name__ == "__main__":
flatten_net_dynamic_float16()
flatten_net_dynamic_float32()

View File

@ -18,8 +18,10 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class NetFlatten(nn.Cell):
@ -140,4 +142,74 @@ def test_last_flatten():
flatten = NetLastFlatten()
output = flatten(x)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_flatten_tensor_interface():
"""
Feature: test_flatten_tensor_interface.
Description: test cases for tensor interface
Expectation: raise TypeError.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
in_np = np.random.randn(1, 16, 3, 1).astype(np.float32)
in_tensor = Tensor(in_np)
output_ms = in_tensor.flatten()
output_np = in_np.flatten()
np.testing.assert_allclose(output_ms.asnumpy(), output_np, rtol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_flatten_functional_interface():
"""
Feature: test_flatten_functional_interface.
Description: test cases for functional interface.
Expectation: raise TypeError.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
in_np = np.random.randn(1, 16, 3, 1).astype(np.float32)
in_tensor = Tensor(in_np)
output_ms = F.flatten(in_tensor)
output_np = np.reshape(in_np, (1, 48))
np.testing.assert_allclose(output_ms.asnumpy(), output_np, rtol=1e-3)
def flatten_graph(x):
return P.Flatten()(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_flatten_vmap():
"""
Feature: test flatten vmap.
Description: test cases for vmap.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
np.random.seed(0)
in_np = np.random.rand(3, 4, 5).astype(np.float32)
output_np = np.reshape(in_np, (3, 20))
in_tensor = Tensor(in_np)
vmap_round_net = ops.vmap(flatten_graph)
output = vmap_round_net(in_tensor)
np.testing.assert_allclose(output.asnumpy(), output_np, rtol=1e-3)
if __name__ == "__main__":
test_flatten_tensor_interface()
test_flatten_functional_interface()
test_flatten_vmap()