From cb8e1fa3f14516ae7078af56c0e714c5c35a1aae Mon Sep 17 00:00:00 2001 From: kanghui <405194527@qq.com> Date: Fri, 14 May 2021 05:18:07 +0000 Subject: [PATCH] add ops :depthtospace,spacetodepth Signed-off-by: kanghui <405194527@qq.com> --- .../cpu/depthtospace_cpu_kernel.cc | 89 ++++++++++++++ .../cpu/depthtospace_cpu_kernel.h | 85 +++++++++++++ .../cpu/spacetodepth_cpu_kernel.cc | 91 ++++++++++++++ .../cpu/spacetodepth_cpu_kernel.h | 84 +++++++++++++ tests/st/ops/cpu/test_depthtospace_op.py | 115 ++++++++++++++++++ tests/st/ops/cpu/test_spacetodepth_op.py | 109 +++++++++++++++++ 6 files changed, 573 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_depthtospace_op.py create mode 100644 tests/st/ops/cpu/test_spacetodepth_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc new file mode 100644 index 00000000000..b0aa95a5a00 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h" + +#include + +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void DepthToSpaceCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + block_size_ = AnfAlgo::GetNodeAttr(kernel_node, "block_size"); +} + +template +bool DepthToSpaceCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t size = IntToSize(inputs[0]->size / sizeof(T)); + std::vector input_shape = input_shape_; + std::vector output_shape = output_shape_; + size_t block_size = block_size_; + size_t input_dimension = input_shape.size(); + size_t output_strides[3] = {1, 1, 1}; + + for (size_t i = input_dimension - 1; i >= 1; --i) { + for (size_t j = 0; j < i; ++j) { + output_strides[j] *= output_shape[i]; + } + } + + auto task = [&, input_addr, output_addr](size_t start, size_t end) { + std::vector output_pos_array(input_dimension, 0); + for (size_t i = start; i < end; ++i) { + size_t tmp_pos = i; + for (size_t j = 0; j < input_dimension - 1; ++j) { + output_pos_array[j] = tmp_pos / output_strides[j]; + tmp_pos %= output_strides[j]; + } + output_pos_array.back() = tmp_pos; + size_t input_pos = output_pos_array[0]; + input_pos = + (input_pos * input_shape[1]) + + (output_pos_array[1] + + (block_size * (output_pos_array[2] % block_size) + output_pos_array[3] % block_size) * output_shape[1]); + input_pos = (input_pos * input_shape[2]) + (output_pos_array[2] / block_size); + input_pos = (input_pos * input_shape[3]) + (output_pos_array[3] / block_size); + output_addr[i] = input_addr[input_pos]; + } + }; + + CPUKernelUtils::ParallelFor(task, size); + return true; +} + +template +void DepthToSpaceCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h new file mode 100644 index 00000000000..57e4b8339fd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_ +#include +#include +#include + +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +namespace mindspore { +namespace kernel { +template +class DepthToSpaceCPUKernel : public CPUKernel { + public: + DepthToSpaceCPUKernel() = default; + ~DepthToSpaceCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector output_shape_; + size_t block_size_; +}; + +MS_REG_CPU_KERNEL_T( + DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DepthToSpaceCPUKernel, float); + +MS_REG_CPU_KERNEL_T( + DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DepthToSpaceCPUKernel, float16); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + DepthToSpaceCPUKernel, int8_t); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + DepthToSpaceCPUKernel, int16_t); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + DepthToSpaceCPUKernel, int); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + DepthToSpaceCPUKernel, int64_t); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + DepthToSpaceCPUKernel, uint8_t); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + DepthToSpaceCPUKernel, uint16_t); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + DepthToSpaceCPUKernel, uint32_t); + +MS_REG_CPU_KERNEL_T(DepthToSpace, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + DepthToSpaceCPUKernel, uint64_t); + +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc new file mode 100644 index 00000000000..e28142b6599 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h" + +#include + +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void SpaceToDepthCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + CheckParam(kernel_node); + + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + block_size_ = AnfAlgo::GetNodeAttr(kernel_node, "block_size"); +} + +template +bool SpaceToDepthCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t size = IntToSize(inputs[0]->size / sizeof(T)); + + std::vector input_shape = input_shape_; + std::vector output_shape = output_shape_; + size_t block_size = block_size_; + size_t input_dimension = input_shape.size(); + size_t input_strides[3] = {1, 1, 1}; + + for (size_t i = input_dimension - 1; i >= 1; --i) { + for (size_t j = 0; j < i; ++j) { + input_strides[j] *= input_shape[i]; + } + } + + auto task = [&, input_addr, output_addr](size_t start, size_t end) { + std::vector input_pos_array(input_dimension, 0); + for (size_t i = start; i < end; ++i) { + size_t tmp_pos = i; + for (size_t j = 0; j < input_dimension - 1; ++j) { + input_pos_array[j] = tmp_pos / input_strides[j]; + tmp_pos %= input_strides[j]; + } + input_pos_array.back() = tmp_pos; + size_t output_pos = input_pos_array[0]; + output_pos = + (output_pos * output_shape[1]) + + (input_pos_array[1] + + (block_size * (input_pos_array[2] % block_size) + input_pos_array[3] % block_size) * input_shape[1]); + output_pos = (output_pos * output_shape[2]) + (input_pos_array[2] / block_size); + output_pos = (output_pos * output_shape[3]) + (input_pos_array[3] / block_size); + output_addr[output_pos] = input_addr[i]; + } + }; + + CPUKernelUtils::ParallelFor(task, size); + return true; +} + +template +void SpaceToDepthCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h new file mode 100644 index 00000000000..6e12ff85371 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_ +#include +#include + +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +namespace mindspore { +namespace kernel { +template +class SpaceToDepthCPUKernel : public CPUKernel { + public: + SpaceToDepthCPUKernel() = default; + ~SpaceToDepthCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector output_shape_; + size_t block_size_; +}; + +MS_REG_CPU_KERNEL_T( + SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SpaceToDepthCPUKernel, float); + +MS_REG_CPU_KERNEL_T( + SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SpaceToDepthCPUKernel, float16); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + SpaceToDepthCPUKernel, int8_t); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + SpaceToDepthCPUKernel, int16_t); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SpaceToDepthCPUKernel, int); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + SpaceToDepthCPUKernel, int64_t); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + SpaceToDepthCPUKernel, uint8_t); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + SpaceToDepthCPUKernel, uint16_t); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + SpaceToDepthCPUKernel, uint32_t); + +MS_REG_CPU_KERNEL_T(SpaceToDepth, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + SpaceToDepthCPUKernel, uint64_t); + +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_depthtospace_op.py b/tests/st/ops/cpu/test_depthtospace_op.py new file mode 100644 index 00000000000..a790152f1ce --- /dev/null +++ b/tests/st/ops/cpu/test_depthtospace_op.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') +class DepthToSpaceNet(nn.Cell): + def __init__(self, nptype, block_size=2, input_shape=(1, 12, 1, 1)): + super(DepthToSpaceNet, self).__init__() + self.DepthToSpace = P.DepthToSpace(block_size) + + input_size = 1 + for i in input_shape: + input_size = input_size*i + data_np = np.arange(input_size).reshape(input_shape).astype(nptype) + self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1') + + @ms_function + def construct(self): + y1 = self.DepthToSpace(self.x1) + return y1 + + +def DepthToSpace(nptype, block_size=2, input_shape=(1, 12, 1, 1)): + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.array([[[[0, 3], + [6, 9]], + [[1, 4], + [7, 10]], + [[2, 5], + [8, 11]]]]).astype(nptype) + dts = DepthToSpaceNet(nptype, block_size, input_shape) + output = dts() + assert (output.asnumpy() == expect).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_float32(): + DepthToSpace(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_float16(): + DepthToSpace(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int32(): + DepthToSpace(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int64(): + DepthToSpace(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int8(): + DepthToSpace(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int16(): + DepthToSpace(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint8(): + DepthToSpace(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint16(): + DepthToSpace(np.uint16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint32(): + DepthToSpace(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint64(): + DepthToSpace(np.uint64) diff --git a/tests/st/ops/cpu/test_spacetodepth_op.py b/tests/st/ops/cpu/test_spacetodepth_op.py new file mode 100644 index 00000000000..865cc4fd575 --- /dev/null +++ b/tests/st/ops/cpu/test_spacetodepth_op.py @@ -0,0 +1,109 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') +class SpaceToDepthNet(nn.Cell): + def __init__(self, nptype): + super(SpaceToDepthNet, self).__init__() + self.SpaceToDepth = P.SpaceToDepth(2) + + data_np = np.array([[[[0, 3], + [6, 9]], + [[1, 4], + [7, 10]], + [[2, 5], + [8, 11]]]]).astype(nptype) + self.data_np = data_np + self.x = Parameter(initializer(Tensor(self.data_np), (1, 3, 2, 2)), name='x') + + @ms_function + def construct(self): + return self.SpaceToDepth(self.x) + + +def SpaceToDepth(nptype): + expect = np.arange(12).reshape((1, 12, 1, 1)).astype(nptype) + std = SpaceToDepthNet(nptype) + output = std() + assert (output.asnumpy() == expect).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_float32(): + SpaceToDepth(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_float16(): + SpaceToDepth(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int32(): + SpaceToDepth(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int64(): + SpaceToDepth(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int8(): + SpaceToDepth(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int16(): + SpaceToDepth(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint8(): + SpaceToDepth(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint16(): + SpaceToDepth(np.uint16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint32(): + SpaceToDepth(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint64(): + SpaceToDepth(np.uint64)