From a82577c00e145d8af1dd3616aff2e454858aa038 Mon Sep 17 00:00:00 2001 From: huanghui Date: Thu, 10 Sep 2020 14:11:16 +0800 Subject: [PATCH] Add CPU kernels: TensorAdd, Sub, Mul, Div --- .../cpu/arithmetic_cpu_kernel.cc | 137 ++++++++++++++++++ ...b_cpu_kernel.h => arithmetic_cpu_kernel.h} | 27 +++- .../cpu/arithmetic_self_cpu_kernel.cc | 91 ++++++++++++ .../cpu/arithmetic_self_cpu_kernel.h | 50 +++++++ .../backend/kernel_compiler/cpu/cpu_kernel.h | 1 + .../kernel_compiler/cpu/sub_cpu_kernel.cc | 90 ------------ tests/st/ops/cpu/test_arithmetic_op.py | 46 ++++++ tests/st/ops/cpu/test_arithmetic_self_op.py | 44 ++++++ 8 files changed, 388 insertions(+), 98 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc rename mindspore/ccsrc/backend/kernel_compiler/cpu/{sub_cpu_kernel.h => arithmetic_cpu_kernel.h} (57%) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc create mode 100644 tests/st/ops/cpu/test_arithmetic_op.py create mode 100644 tests/st/ops/cpu/test_arithmetic_self_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc new file mode 100644 index 00000000000..fa5704057b7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h" +#include +#include +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +template +void Add(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { + for (size_t i = start; i < end; i++) { + out[i] = input1[i] + (is_number ? *input2 : input2[i]); + } +} + +template +void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { + for (size_t i = start; i < end; i++) { + out[i] = input1[i] - (is_number ? *input2 : input2[i]); + } +} + +template +void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { + for (size_t i = start; i < end; i++) { + out[i] = input1[i] * (is_number ? *input2 : input2[i]); + } +} + +template +void Div(const T *input1, const T *input2, T *out, size_t start, size_t end, bool is_number) { + for (size_t i = start; i < end; i++) { + auto div_number = is_number ? *input2 : input2[i]; + if (div_number == 0) { + MS_LOG(EXCEPTION) << "Cannot divided by 0!"; + } + out[i] = input1[i] / div_number; + } +} +} // namespace + +void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == prim::kPrimTensorAdd->name()) { + operate_type_ = ADD; + } else if (kernel_name == prim::kPrimSub->name()) { + operate_type_ = SUB; + } else if (kernel_name == prim::kPrimMul->name()) { + operate_type_ = MUL; + } else if (kernel_name == "Div") { + operate_type_ = DIV; + } + + auto shape0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + if (shape1.size() == 0) { + is_number_ = true; + } else { + is_number_ = false; + if (shape0.size() != shape1.size()) { + MS_LOG(EXCEPTION) << "Input0 and input1 must has the same shape"; + } + for (size_t i = 0; i < shape0.size(); ++i) { + if (shape0[i] != shape1[i]) { + MS_LOG(EXCEPTION) << "Input0 and input1 must has the same shape"; + } + } + } + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { + MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; + } +} + +bool ArithmeticCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt64) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "Only support int32, float32, but actual data type is " << TypeIdLabel(dtype_); + } + return true; +} + +template +void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + T *input1 = reinterpret_cast(inputs[0]->addr); + T *input2 = reinterpret_cast(inputs[1]->addr); + T *output = reinterpret_cast(outputs[0]->addr); + auto lens = inputs[0]->size / sizeof(T); + MS_LOG(INFO) << "lens=" << lens; + + const size_t thread_num = 24; + std::vector threads; + threads.reserve(thread_num); + size_t start = 0; + size_t once_compute_size = (lens + thread_num - 1) / thread_num; + while (start < lens) { + size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); + if (operate_type_ == ADD) { + threads.emplace_back(std::thread(Add, input1, input2, output, start, end, is_number_)); + } else if (operate_type_ == SUB) { + threads.emplace_back(std::thread(Sub, input1, input2, output, start, end, is_number_)); + } else if (operate_type_ == MUL) { + threads.emplace_back(std::thread(Mul, input1, input2, output, start, end, is_number_)); + } else if (operate_type_ == DIV) { + threads.emplace_back(std::thread(Div, input1, input2, output, start, end, is_number_)); + } + start += once_compute_size; + } + for (size_t i = 0; i < threads.size(); ++i) { + threads[i].join(); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h similarity index 57% rename from mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h rename to mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 53bef5f7dc6..20ea77e3502 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SUB_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SUB_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_CPU_KERNEL_H_ #include #include #include "backend/kernel_compiler/cpu/cpu_kernel.h" @@ -22,24 +22,35 @@ namespace mindspore { namespace kernel { -class SubCPUKernel : public CPUKernel { +class ArithmeticCPUKernel : public CPUKernel { public: - SubCPUKernel() : offset_(0) {} - ~SubCPUKernel() override = default; + ArithmeticCPUKernel() = default; + ~ArithmeticCPUKernel() override = default; void InitKernel(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + private: - int offset_; + bool is_number_{false}; + OperateType operate_type_{ADD}; + TypeId dtype_{kTypeUnknown}; }; MS_REG_CPU_KERNEL( Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SubCPUKernel); + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ArithmeticCPUKernel); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SUB_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc new file mode 100644 index 00000000000..1a33cdf8d47 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h" +#include +#include +#include +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +template +void Square(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = in[i] * in[i]; + } +} + +template +void Sqrt(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = sqrtf(in[i]); + } +} +} // namespace + +void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == prim::kPrimSquare->name()) { + operate_type_ = SQUARE; + } else if (kernel_name == prim::kPrimSqrt->name()) { + operate_type_ = SQRT; + } + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); +} + +bool ArithmeticSelfCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "Only support float32, int32, but actual data type is " << TypeIdLabel(dtype_); + } + return true; +} + +template +void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + T *input = reinterpret_cast(inputs[0]->addr); + T *output = reinterpret_cast(outputs[0]->addr); + auto lens = inputs[0]->size / sizeof(T); + MS_LOG(INFO) << "lens=" << lens; + + const size_t thread_num = 24; + std::vector threads; + threads.reserve(thread_num); + size_t start = 0; + size_t once_compute_size = (lens + thread_num - 1) / thread_num; + while (start < lens) { + size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); + if (operate_type_ == SQUARE) { + threads.emplace_back(std::thread(Square, input, output, start, end)); + } else if (operate_type_ == SQRT) { + threads.emplace_back(std::thread(Sqrt, input, output, start, end)); + } + start += once_compute_size; + } + for (size_t i = 0; i < threads.size(); ++i) { + threads[i].join(); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h new file mode 100644 index 00000000000..3d3981f1792 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_SELF_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_SELF_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ArithmeticSelfCPUKernel : public CPUKernel { + public: + ArithmeticSelfCPUKernel() = default; + ~ArithmeticSelfCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + OperateType operate_type_{SQUARE}; + TypeId dtype_{kTypeUnknown}; +}; + +MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ArithmeticSelfCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_SELF_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 7e6fd8acd56..ad219c7a4c0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -51,6 +51,7 @@ const char END[] = "end"; const char SIZE[] = "size"; const char USE_NESTEROV[] = "use_nesterov"; const char GROUP[] = "group"; +enum OperateType { ADD = 0, SUB, MUL, DIV, SQUARE, SQRT }; class CPUKernel : public kernel::KernelMod { public: diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc deleted file mode 100644 index f99c53d577b..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/kernel_compiler/cpu/sub_cpu_kernel.h" -#include -#include -#include "runtime/device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - if (shape.size() == 1) { - if (shape[0] != 1) { - MS_LOG(EXCEPTION) << "input 1 only support scalar"; - } - } else { - MS_LOG(EXCEPTION) << "input 1 only support scalar"; - } -} - -void sub_task(const int *in_addr, int *out_addr, size_t lens, int offset) { - for (size_t i = 0; i < lens; i++) { - out_addr[i] = in_addr[i] - offset; - } -} - -bool SubCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - offset_ = *reinterpret_cast(inputs[1]->addr); - MS_LOG(INFO) << "offset: " << offset_; - auto lens = inputs[0]->size / sizeof(int); - if (lens < 10000) { - for (size_t i = 0; i < lens; i++) { - output_addr[i] = input_addr[i] - offset_; - } - } else { - const size_t thread_num = 4; - std::thread threads[4]; - size_t process_lens = (lens + thread_num - 1) / thread_num; - size_t process_offset = 0; - for (size_t i = 0; i < thread_num; i++) { - threads[i] = - std::thread(sub_task, input_addr + process_offset, output_addr + process_offset, process_lens, offset_); - if (process_offset + process_lens > lens) { - process_lens = lens - process_offset; - process_offset = lens; - } else { - process_offset += process_lens; - } - } - for (size_t i = 0; i < thread_num; i++) { - threads[i].join(); - } - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "SubscaleCPUKernel, used time: " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "SubCPUKernel, used time: " << time << " us"; -#endif - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/tests/st/ops/cpu/test_arithmetic_op.py b/tests/st/ops/cpu/test_arithmetic_op.py new file mode 100644 index 00000000000..d3b77843b2a --- /dev/null +++ b/tests/st/ops/cpu/test_arithmetic_op.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class SubNet(nn.Cell): + def __init__(self): + super(SubNet, self).__init__() + self.sub = P.Sub() + + def construct(self, x, y): + return self.sub(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_sub(): + x = np.ones([2, 3, 4, 4]).astype(np.int32) + y = 1 + net = SubNet() + output = net(Tensor(x), Tensor(y, mindspore.int32)) + expect_output = np.zeros([2, 3, 4, 4]).astype(np.int) + print(output) + assert np.all(output.asnumpy() == expect_output) diff --git a/tests/st/ops/cpu/test_arithmetic_self_op.py b/tests/st/ops/cpu/test_arithmetic_self_op.py new file mode 100644 index 00000000000..81ae6a9bc45 --- /dev/null +++ b/tests/st/ops/cpu/test_arithmetic_self_op.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class SquareNet(nn.Cell): + def __init__(self): + super(SquareNet, self).__init__() + self.square = P.Square() + + def construct(self, x): + return self.square(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_square(): + x = np.array([1, 2, 3]).astype(np.float32) + net = SquareNet() + output = net(Tensor(x)) + expect_output = np.array([1, 4, 9]).astype(np.float32) + print(output) + assert np.all(output.asnumpy() == expect_output)