From 3bc94b97d725ab49a293dd3ce5fe4004a13b5586 Mon Sep 17 00:00:00 2001 From: lishixing3 Date: Mon, 7 Dec 2020 15:10:01 +0800 Subject: [PATCH] fix some liness --- .../kernel_compiler/cpu/dropout_cpu_kernel.cc | 1 + .../kernel_compiler/cpu/unpack_cpu_kernel.cc | 110 +++++++++ .../kernel_compiler/cpu/unpack_cpu_kernel.h | 88 +++++++ mindspore/ops/operations/array_ops.py | 2 +- tests/st/ops/cpu/test_unpack_op.py | 215 ++++++++++++++++++ 5 files changed, 415 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_unpack_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc index cab48c08a3d..61542935453 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/dropout_cpu_kernel.cc @@ -66,6 +66,7 @@ void DropoutCPUKernel::LaunchKernel(const std::vector &inputs, const } void DropoutCPUKernel::CheckParam(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DropoutCPUKernel needs 1 input."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.cc new file mode 100644 index 00000000000..53c6882828f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/unpack_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void UnpackCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + int64_t axis_tmp = AnfAlgo::GetNodeAttr(kernel_node, "axis"); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (axis_tmp < 0) { + axis_tmp += SizeToLong(input_shape.size()); + } + size_t axis_ = LongToSize(axis_tmp); + output_num_ = LongToSize(AnfAlgo::GetNodeAttr(kernel_node, "num")); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + if (i > IntToSize(axis_)) { + dims_after_axis_ *= input_shape[i]; + } + } + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); +} + +template +void UnpackCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + workspace_size_list_.emplace_back(sizeof(T *) * output_num_); +} + +template +bool UnpackCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + LaunchKernel(inputs, workspace, outputs); + return true; +} + +template +void UnpackCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + input_ = reinterpret_cast(inputs[0]->addr); + MS_EXCEPTION_IF_NULL(input_); + outputs_host_ = reinterpret_cast(workspace[0]->addr); + MS_EXCEPTION_IF_NULL(outputs_host_); + for (size_t i = 0; i < outputs.size(); i++) { + outputs_host_[i] = reinterpret_cast(outputs[i]->addr); + MS_EXCEPTION_IF_NULL(outputs_host_[i]); + } + auto max_thread_num = std::thread::hardware_concurrency(); + size_t thread_num = input_size_ < 128 * max_thread_num ? std::ceil(input_size_ / 128.0) : max_thread_num; + if (thread_num < 1) { + MS_LOG(ERROR) << "Invalid value: thread_num" << thread_num; + return; + } + std::vector threads; + threads.reserve(thread_num); + size_t start = 0; + size_t one_gap = (input_size_ + thread_num - 1) / thread_num; + if (one_gap < 1) { + MS_LOG(ERROR) << "Invalid value: one_gap " << one_gap; + return; + } + while (start < input_size_) { + size_t end = (start + one_gap) > input_size_ ? input_size_ : (start + one_gap); + threads.emplace_back(std::thread(&UnpackCPUKernel::UnpackResult, this, start, end)); + start += one_gap; + } + for (size_t i = 0; i < threads.size(); ++i) { + threads[i].join(); + } +} + +template +void UnpackCPUKernel::UnpackResult(const size_t start, const size_t end) { + for (size_t i = start; i < end; ++i) { + size_t output_index = (i / dims_after_axis_) % output_num_; + size_t number_of_reset = output_num_ * dims_after_axis_; + size_t tensor_index = i / number_of_reset * dims_after_axis_ + i % dims_after_axis_; + outputs_host_[output_index][tensor_index] = input_[i]; + } +} + +template +void UnpackCPUKernel::CheckParam(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but UnpackCPUKernel needs 1 input."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.h new file mode 100644 index 00000000000..3d080b00638 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unpack_cpu_kernel.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNPACK_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNPACK_CPU_KERNEL_H_ +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class UnpackCPUKernel : public CPUKernel { + public: + UnpackCPUKernel() = default; + ~UnpackCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + void LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + void InitInputOutputSize(const CNodePtr &kernel_node) override; + + protected: + virtual void CheckParam(const CNodePtr &kernel_node); + virtual void UnpackResult(const size_t start, const size_t end); + size_t input_size_{1}; + size_t output_num_{0}; + size_t dims_after_axis_{1}; + T *input_{nullptr}; + T **outputs_host_{nullptr}; + TypeId dtype_{kTypeUnknown}; +}; +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + UnpackCPUKernel, int8_t); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + UnpackCPUKernel, int16_t); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnpackCPUKernel, int); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + UnpackCPUKernel, int64_t); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + UnpackCPUKernel, bool); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + UnpackCPUKernel, uint8_t); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + UnpackCPUKernel, uint16_t); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + UnpackCPUKernel, uint32_t); +MS_REG_CPU_KERNEL_T(Unpack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + UnpackCPUKernel, uint64_t); +MS_REG_CPU_KERNEL_T( + Unpack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnpackCPUKernel, float16); +MS_REG_CPU_KERNEL_T( + Unpack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnpackCPUKernel, float); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNPACK_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b9c4cb7cc09..fb59b137a06 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2308,7 +2308,7 @@ class Unpack(PrimitiveWithInfer): ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)). Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> unpack = ops.Unpack() diff --git a/tests/st/ops/cpu/test_unpack_op.py b/tests/st/ops/cpu/test_unpack_op.py new file mode 100644 index 00000000000..addda203d8f --- /dev/null +++ b/tests/st/ops/cpu/test_unpack_op.py @@ -0,0 +1,215 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + + +class Net(nn.Cell): + def __init__(self, nptype): + super(Net, self).__init__() + + self.unpack = P.Unpack(axis=3) + self.data_np = np.array([[[[[0, 0], + [-2, -1]], + [[0, 0], + [0, 1]]], + [[[0, 0], + [2, 3]], + [[0, 0], + [4, 5]]], + [[[0, 0], + [6, 7]], + [[0, 0], + [8, 9]]]], + [[[[0, 0], + [10, 11]], + [[0, 0], + [12, 13]]], + [[[0, 0], + [14, 15]], + [[0, 0], + [16, 17]]], + [[[0, 0], + [18, 19]], + [[0, 0], + [20, 21]]]], + [[[[0, 0], + [22, 23]], + [[0, 0], + [24, 25]]], + [[[0, 0], + [26, 27]], + [[0, 0], + [28, 29]]], + [[[0, 0], + [30, 31]], + [[0, 0], + [32, 33]]]]]).astype(nptype) + self.x1 = Parameter(initializer(Tensor(self.data_np), [3, 3, 2, 2, 2]), name='x1') + + @ms_function + def construct(self): + return self.unpack(self.x1) + + +def unpack(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + unpack_ = Net(nptype) + output = unpack_() + expect = (np.reshape(np.array([0] * 36).astype(nptype), (3, 3, 2, 2)), + np.arange(-2, 34, 1).reshape(3, 3, 2, 2).astype(nptype)) + + for i, exp in enumerate(expect): + assert (output[i].asnumpy() == exp).all() + + +def unpack_pynative(nptype): + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + x1 = np.array([[[[[0, 0], + [-2, -1]], + [[0, 0], + [0, 1]]], + [[[0, 0], + [2, 3]], + [[0, 0], + [4, 5]]], + [[[0, 0], + [6, 7]], + [[0, 0], + [8, 9]]]], + [[[[0, 0], + [10, 11]], + [[0, 0], + [12, 13]]], + [[[0, 0], + [14, 15]], + [[0, 0], + [16, 17]]], + [[[0, 0], + [18, 19]], + [[0, 0], + [20, 21]]]], + [[[[0, 0], + [22, 23]], + [[0, 0], + [24, 25]]], + [[[0, 0], + [26, 27]], + [[0, 0], + [28, 29]]], + [[[0, 0], + [30, 31]], + [[0, 0], + [32, 33]]]]]).astype(nptype) + x1 = Tensor(x1) + expect = (np.reshape(np.array([0] * 36).astype(nptype), (3, 3, 2, 2)), + np.arange(-2, 34, 1).reshape(3, 3, 2, 2).astype(nptype)) + output = P.Unpack(axis=3)(x1) + + for i, exp in enumerate(expect): + assert (output[i].asnumpy() == exp).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_graph_float32(): + unpack(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_graph_float16(): + unpack(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_graph_int32(): + unpack(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_graph_int16(): + unpack(np.int16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_graph_uint8(): + unpack(np.uint8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_graph_bool(): + unpack(np.bool) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_pynative_float32(): + unpack_pynative(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_pynative_float16(): + unpack_pynative(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_pynative_int32(): + unpack_pynative(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_pynative_int16(): + unpack_pynative(np.int16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_pynative_uint8(): + unpack_pynative(np.uint8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_unpack_pynative_bool(): + unpack_pynative(np.bool)