From e8794ce0d794152e60ebc7932ec8cb093156ffd1 Mon Sep 17 00:00:00 2001 From: TFBunny Date: Thu, 22 Apr 2021 16:59:23 -0400 Subject: [PATCH] add gpu kernel HSigmoid and HSigmoidGrad --- .../gpu/cuda_impl/hsigmoid_impl.cu | 51 ++++++++ .../gpu/cuda_impl/hsigmoid_impl.cuh | 29 +++++ .../gpu/nn/hsigmoid_gpu_kernel.cc | 26 ++++ .../gpu/nn/hsigmoid_gpu_kernel.h | 88 ++++++++++++++ .../gpu/nn/hsigmoid_grad_gpu_kernel.cc | 30 +++++ .../gpu/nn/hsigmoid_grad_gpu_kernel.h | 91 ++++++++++++++ mindspore/core/abstract/infer_functions.h | 4 + mindspore/core/abstract/prim_nn.cc | 14 +++ .../core/abstract/primitive_infer_map.cc | 2 + mindspore/core/base/core_ops.h | 2 + tests/st/ops/gpu/test_hsigmoid_op.py | 111 ++++++++++++++++++ 11 files changed, 448 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_hsigmoid_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu new file mode 100644 index 00000000000..70032336fb7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh" + +template +__global__ void HsigmoidKernel(size_t size, const T *input, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + T value = (input[pos] + static_cast(3)) / static_cast(6); + value = value > static_cast(1) ? static_cast(1) : value; + output[pos] = value > static_cast(0) ? value : static_cast(0); + } +} + +template +__global__ void HsigmoidGradKernel(size_t size, const T *dout, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + T value = dout[pos] / static_cast(6); + value = value > static_cast(1) ? static_cast(0) : value; + output[pos] = value > static_cast(0) ? value : static_cast(0); + } +} + +template +void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cuda_stream) { + HsigmoidKernel<<>>(size, input, output); +} + +template +void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream) { + HsigmoidGradKernel<<>>(size, dout, output); +} + +template void CalHSigmoid(const size_t &size, const half *input, half *output, cudaStream_t cuda_stream); +template void CalHSigmoid(const size_t &size, const float *input, float *output, cudaStream_t cuda_stream); + +template void CalHSigmoidGrad(const size_t &size, const half *dout, half *output, cudaStream_t cuda_stream); +template void CalHSigmoidGrad(const size_t &size, const float *dout, float *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh new file mode 100644 index 00000000000..99a3e377f5a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cuda_stream); + +template +void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.cc new file mode 100644 index 00000000000..c099ce66a01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + HSigmoidKernel, float) +MS_REG_GPU_KERNEL_ONE(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + HSigmoidKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.h new file mode 100644 index 00000000000..3b5c7f1771d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.h @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class HSigmoidKernel : public GpuKernel { + public: + HSigmoidKernel() { ResetResource(); } + ~HSigmoidKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + CalHSigmoid(input_size_, input, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSigmoid needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but HSigmoid has 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = 1; + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_size_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.cc new file mode 100644 index 00000000000..70f0cb5e758 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + HSigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + HSigmoidGradKernel, float) +MS_REG_GPU_KERNEL_ONE( + HSigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + HSigmoidGradKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h new file mode 100644 index 00000000000..75cd061496c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class HSigmoidGradKernel : public GpuKernel { + public: + HSigmoidGradKernel() { ResetResource(); } + ~HSigmoidGradKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + CalHSigmoidGrad(input_size_, input, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSigmoidGrad needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but HSigmoidGrad has 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = 1; + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_size_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + // though we are not using this mem, we still need to allocate + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 4a645007b23..6608fa6152a 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -59,6 +59,10 @@ AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitiveP const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index a3c589866c1..6a9b7b3511c 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -416,6 +416,20 @@ AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &pri return args_spec_list[0]->Broaden(); } +AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 2); + return args_spec_list[1]->Broaden(); +} + AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor. diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 2728f855e9e..46aa76460ed 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -132,6 +132,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}}, {prim::kPrimSGD, {InferImplSGD, nullptr, true}}, {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}}, + {prim::kPrimHSigmoid, {InferImplHSigmoid, nullptr, true}}, + {prim::kPrimHSigmoidGrad, {InferImplHSigmoidGrad, nullptr, true}}, // Others {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}}, {prim::kPrimLoad, {InferImplLoad, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ee022322957..e88d2c77216 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -521,6 +521,8 @@ inline const PrimitivePtr kPrimSubFusion = std::make_shared("SubFusio inline const PrimitivePtr kPrimMulFusion = std::make_shared("MulFusion"); inline const PrimitivePtr kPrimSigmoid = std::make_shared("Sigmoid"); inline const PrimitivePtr kPrimSigmoidGrad = std::make_shared("SigmoidGrad"); +inline const PrimitivePtr kPrimHSigmoid = std::make_shared("HSigmoid"); +inline const PrimitivePtr kPrimHSigmoidGrad = std::make_shared("HSigmoidGrad"); inline const PrimitivePtr kPrimClip = std::make_shared("Clip"); inline const PrimitivePtr kPrimHardTanh = std::make_shared("HardTanh"); inline const PrimitivePtr kPrimDepthWiseConv2DTransposeFusion = diff --git a/tests/st/ops/gpu/test_hsigmoid_op.py b/tests/st/ops/gpu/test_hsigmoid_op.py new file mode 100644 index 00000000000..435ca80d3c1 --- /dev/null +++ b/tests/st/ops/gpu/test_hsigmoid_op.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.composite import GradOperation +from mindspore.ops.operations import _inner_ops as inner + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, input_x, dout): + return self.grad(self.network)(input_x, dout) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.HSigmoid = P.HSigmoid() + + def construct(self, x): + return self.HSigmoid(x) + + +class DynamicNet(nn.Cell): + def __init__(self): + super(DynamicNet, self).__init__() + self.HSigmoid = P.HSigmoid() + self.d = inner.GpuConvertToDynamicShape() + + def construct(self, x): + x = self.d(x) + return self.HSigmoid(x) + + +def generate_testcases(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([-1, -2, 0, 2, 1]).astype(nptype) + net = Net() + output = net(Tensor(x)) + expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) + np.testing.assert_almost_equal(output.asnumpy(), expect) + + sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) + backward_net = Grad(Net()) + output = backward_net(Tensor(x), Tensor(sens)) + expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) + np.testing.assert_almost_equal(output[0].asnumpy(), expect) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = np.array([-1, -2, 0, 2, 1]).astype(nptype) + net = Net() + output = net(Tensor(x)) + expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) + np.testing.assert_almost_equal(output.asnumpy(), expect) + + sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) + backward_net = Grad(Net()) + output = backward_net(Tensor(x), Tensor(sens)) + expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) + np.testing.assert_almost_equal(output[0].asnumpy(), expect) + + +def generate_dynamic_testcase(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([-1, -2, 0, 2, 1]).astype(nptype) + net = DynamicNet() + output = net(Tensor(x)) + expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) + np.testing.assert_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_hsigmoid_dynamic_float32(): + generate_dynamic_testcase(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_hsigmoid_float32(): + generate_testcases(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_hsigmoid_float16(): + generate_testcases(np.float16)