diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index bce2f96eee8..6b26263e72b 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -168,3 +168,4 @@ from .floor import _floor_tbe from .log1p import _log1p_tbe from .resize_bilinear import _resize_bilinear_tbe from .resize_bilinear_grad import _resize_bilinear_grad_tbe +from .flatten import _flatten_tbe diff --git a/mindspore/ops/_op_impl/tbe/flatten.py b/mindspore/ops/_op_impl/tbe/flatten.py new file mode 100644 index 00000000000..0413422fb04 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/flatten.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Flatten op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +flatten_op_info = TBERegOp("Flatten") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("flatten.so") \ + .compute_cost(10) \ + .kernel_name("flatten") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(flatten_op_info) +def _flatten_tbe(): + """Flatten TBE register""" + return diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 992d7957ae6..a245f7b5374 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -447,6 +447,17 @@ class UnfoldNetSame(nn.Cell): return self.unfold(x) +class FlattenNet(nn.Cell): + """ FlattenNet definition """ + + def __init__(self): + super(FlattenNet, self).__init__() + self.flatten = P.Flatten() + + def construct(self, x): + return self.flatten(x) + + test_cases = [ ('SoftMaxGrad', { 'block': SoftMaxGrad(VirtualNetWithLoss(P.Softmax())), @@ -532,6 +543,10 @@ test_cases = [ 'desc_inputs': [Tensor(np.array([3, 4, 5, 6]).astype(np.float32))], 'desc_bprop': [Tensor(np.array([1, 2, 3, 4]).astype(np.float32))], 'skip': ['backward']}), + ('FlattenNet', { + 'block': FlattenNet(), + 'desc_inputs': [Tensor(np.ones([1, 2, 3, 4], np.float32))], + }), ] test_cases_for_verify_exception = [