From 0bdf08a90ecf90784c6beb849f63319f4d0ba41b Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 2 Aug 2021 15:33:06 +0800 Subject: [PATCH] develop softplus for cpu --- .../backend/kernel_compiler/cpu/cpu_kernel.h | 2 +- .../cpu/eltwise_grad_cpu_kernel.cc | 97 +++++++++------- .../cpu/eltwise_grad_cpu_kernel.h | 7 +- .../cpu/mkldnn/eltwise_cpu_kernel.cc | 52 +++++---- .../cpu/mkldnn/eltwise_cpu_kernel.h | 4 +- .../cpu/nnacl/fp32_grad/activation_grad.c | 27 ++++- .../cpu/nnacl/fp32_grad/activation_grad.h | 3 +- mindspore/core/base/core_ops.h | 4 + tests/st/ops/cpu/test_softplus_grad_op.py | 78 +++++++++++++ tests/st/ops/cpu/test_softplus_op.py | 107 ++++++++++++++++++ 10 files changed, 316 insertions(+), 65 deletions(-) create mode 100644 tests/st/ops/cpu/test_softplus_grad_op.py create mode 100644 tests/st/ops/cpu/test_softplus_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 7241f6163cf..b85568f505e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc index 394fcbbd786..926d8e172ef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include + #include "backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h" +#include +#include #include "common/thread_pool.h" #include "runtime/device/cpu/cpu_device_address.h" #include "nnacl/fp32_grad/activation_grad.h" @@ -25,50 +27,50 @@ namespace mindspore { namespace kernel { template void EltWiseGradCPUKernel::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { - if constexpr (std::is_same_v) { - int ret = ::ReluGrad(input1 + start, input2 + start, end - start, out + start); - if (ret == NNACL_ERR) { - MS_LOG(EXCEPTION) << "ReLUGrad failed."; - } - } else { + if constexpr (!std::is_same::value) { MS_LOG(EXCEPTION) << "ReLUGrad only support float"; } + + int ret = ::ReluGrad(input1 + start, input2 + start, end - start, out + start); + if (ret == NNACL_ERR) { + MS_LOG(EXCEPTION) << "ReLUGrad execute failed."; + } } template void EltWiseGradCPUKernel::ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { - if constexpr (std::is_same_v) { - int ret = ::Relu6Grad(input1 + start, input2 + start, end - start, out + start); - if (ret == NNACL_ERR) { - MS_LOG(EXCEPTION) << "ReLU6Grad failed."; - } - } else { + if constexpr (!std::is_same::value) { MS_LOG(EXCEPTION) << "ReLU6Grad only support float"; } + + int ret = ::Relu6Grad(input1 + start, input2 + start, end - start, out + start); + if (ret == NNACL_ERR) { + MS_LOG(EXCEPTION) << "ReLU6Grad execute failed."; + } } template void EltWiseGradCPUKernel::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { - if constexpr (std::is_same_v) { - int ret = ::ElementAbsGrad(input1 + start, input2 + start, out + start, end - start); - if (ret == NNACL_ERR) { - MS_LOG(EXCEPTION) << "AbsGrad failed."; - } - } else { + if constexpr (!std::is_same::value) { MS_LOG(EXCEPTION) << "AbsGrad only support float"; } + + int ret = ::ElementAbsGrad(input1 + start, input2 + start, out + start, end - start); + if (ret == NNACL_ERR) { + MS_LOG(EXCEPTION) << "AbsGrad execute failed."; + } } template void EltWiseGradCPUKernel::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { - if constexpr (std::is_same_v) { - int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start); - if (ret == NNACL_ERR) { - MS_LOG(EXCEPTION) << "SigmoidGrad failed."; - } - } else { + if constexpr (!std::is_same::value) { MS_LOG(EXCEPTION) << "SigmoidGrad only support float"; } + + int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start); + if (ret == NNACL_ERR) { + MS_LOG(EXCEPTION) << "SigmoidGrad execute failed."; + } } template @@ -80,14 +82,14 @@ void EltWiseGradCPUKernel::SqrtGrad(const T *input1, const T *input2, T *out, template void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { - if constexpr (std::is_same_v) { - int ret = ::TanhGrad(input2 + start, input1 + start, end - start, out + start); - if (ret == NNACL_ERR) { - MS_LOG(EXCEPTION) << "TanhGrad failed."; - } - } else { + if constexpr (!std::is_same::value) { MS_LOG(EXCEPTION) << "TanhGrad only support float"; } + + int ret = ::TanhGrad(input2 + start, input1 + start, end - start, out + start); + if (ret == NNACL_ERR) { + MS_LOG(EXCEPTION) << "TanhGrad execute failed."; + } } template @@ -207,6 +209,18 @@ void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out } } +template +void EltWiseGradCPUKernel::SoftplusGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { + if constexpr (!std::is_same::value) { + MS_LOG(EXCEPTION) << "SoftplusGrad only support float"; + } + + int ret = ::SoftplusGrad(input1 + start, input2 + start, end - start, out + start); + if (ret == NNACL_ERR) { + MS_LOG(EXCEPTION) << "SoftplusGrad execute failed."; + } +} + template void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); @@ -219,12 +233,19 @@ bool EltWiseGradCPUKernel::Launch(const std::vector &inpu const std::vector &outputs) { static const std::map> - elt_map{{"ReluGrad", &EltWiseGradCPUKernel::ReluGrad}, {"ReLU6Grad", &EltWiseGradCPUKernel::ReLU6Grad}, - {"SigmoidGrad", &EltWiseGradCPUKernel::SigmoidGrad}, {"AbsGrad", &EltWiseGradCPUKernel::AbsGrad}, - {"TanhGrad", &EltWiseGradCPUKernel::TanhGrad}, {"SqrtGrad", &EltWiseGradCPUKernel::SqrtGrad}, - {"GeLUGrad", &EltWiseGradCPUKernel::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel::AsinGrad}, - {"ACosGrad", &EltWiseGradCPUKernel::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel::AtanGrad}, - {"AsinhGrad", &EltWiseGradCPUKernel::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel::AcoshGrad}}; + elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCPUKernel::ReluGrad}, + {prim::kPrimRelu6Grad->name(), &EltWiseGradCPUKernel::ReLU6Grad}, + {prim::kPrimSigmoidGrad->name(), &EltWiseGradCPUKernel::SigmoidGrad}, + {prim::kPrimAbsGrad->name(), &EltWiseGradCPUKernel::AbsGrad}, + {prim::kPrimTanhGrad->name(), &EltWiseGradCPUKernel::TanhGrad}, + {prim::kPrimSqrtGrad->name(), &EltWiseGradCPUKernel::SqrtGrad}, + {prim::kPrimGeLUGrad->name(), &EltWiseGradCPUKernel::GeluGrad}, + {prim::kPrimAsinGrad->name(), &EltWiseGradCPUKernel::AsinGrad}, + {prim::kPrimACosGrad->name(), &EltWiseGradCPUKernel::ACosGrad}, + {prim::kPrimAtanGrad->name(), &EltWiseGradCPUKernel::AtanGrad}, + {prim::kPrimAsinhGrad->name(), &EltWiseGradCPUKernel::AsinhGrad}, + {prim::kPrimAcoshGrad->name(), &EltWiseGradCPUKernel::AcoshGrad}, + {prim::kPrimSoftplusGrad->name(), &EltWiseGradCPUKernel::SoftplusGrad}}; if (inputs.size() < 2 || outputs.size() != 1) { MS_LOG(ERROR) << kernel_name_ << " requires at least 2 inputs and 1 output, but got " << inputs.size() << " inputs and " << outputs.size() << " output."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h index f085a9a80d6..9f434981f75 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,7 @@ class EltWiseGradCPUKernel : public CPUKernel { void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; + void SoftplusGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; std::string kernel_name_ = ""; }; @@ -103,6 +104,10 @@ MS_REG_CPU_KERNEL_T( AcoshGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( + SoftplusGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EltWiseGradCPUKernel, float); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc index ecb66469d0d..0d76cff47a9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.cc @@ -13,39 +13,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h" +#include +#include #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" namespace mindspore { namespace kernel { +namespace { +struct DescParam { + dnnl::algorithm algorithm; + float alpha = 0.f; + float beta = 0.f; +}; +} // namespace + dnnl::eltwise_forward::desc EltWiseCPUKernel::GetForwardEltwiseDesc(const CNodePtr &kernel_node, const dnnl::memory::desc src_desc) { + static const std::unordered_map eltWiseOpDescMap{ + {prim::kPrimRelu->name(), DescParam{dnnl::algorithm::eltwise_relu}}, + {prim::kPrimRelu6->name(), DescParam{dnnl::algorithm::eltwise_clip, 0.f, 6.f}}, + {prim::kPrimAbs->name(), DescParam{dnnl::algorithm::eltwise_abs}}, + {prim::kPrimExp->name(), DescParam{dnnl::algorithm::eltwise_exp}}, + {prim::kPrimLog->name(), DescParam{dnnl::algorithm::eltwise_log}}, + {prim::kPrimSigmoid->name(), DescParam{dnnl::algorithm::eltwise_logistic}}, + {prim::kPrimSqrt->name(), DescParam{dnnl::algorithm::eltwise_sqrt}}, + {prim::kPrimSquare->name(), DescParam{dnnl::algorithm::eltwise_square}}, + {prim::kPrimTanh->name(), DescParam{dnnl::algorithm::eltwise_tanh}}, + {prim::kPrimElu->name(), DescParam{dnnl::algorithm::eltwise_elu, 1.f, 0.f}}, + {prim::kPrimSoftplus->name(), DescParam{dnnl::algorithm::eltwise_soft_relu}}, + }; + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "ReLU") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_relu, src_desc, 0.0); - } else if (kernel_name == "ReLU6") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_clip, src_desc, 0.0, 6.0); - } else if (kernel_name == "Abs") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_abs, src_desc); - } else if (kernel_name == "Exp") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_exp, src_desc); - } else if (kernel_name == "Log") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_log, src_desc); - } else if (kernel_name == "Sigmoid") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_logistic, src_desc); - } else if (kernel_name == "Sqrt") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_sqrt, src_desc); - } else if (kernel_name == "Square") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_square, src_desc); - } else if (kernel_name == "Tanh") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_tanh, src_desc); - } else if (kernel_name == "Elu") { - return dnnl::eltwise_forward::desc(DnnlForward, dnnl::algorithm::eltwise_elu, src_desc, 1.0); - } else { - MS_LOG(EXCEPTION) << "Eltwise operators don't support " << kernel_name; + const auto desc_pair = eltWiseOpDescMap.find(kernel_name); + if (desc_pair == eltWiseOpDescMap.end()) { + MS_LOG(EXCEPTION) << "EltWiseCPUKernel does not support " << kernel_name; } + return dnnl::eltwise_forward::desc(DnnlForward, desc_pair->second.algorithm, src_desc, desc_pair->second.alpha, + desc_pair->second.beta); } void EltWiseCPUKernel::InitKernel(const CNodePtr &kernel_node) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h index 18d0ae24548..cd695e2a9e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,6 +56,8 @@ MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutpu EltWiseCPUKernel); MS_REG_CPU_KERNEL(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseCPUKernel); +MS_REG_CPU_KERNEL(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EltWiseCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.c index 488d413727b..366d1a9cf6a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.c @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include #include "nnacl/op_base.h" #include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/exp_fp32.h" #include "nnacl/fp32_grad/activation_grad.h" #include "nnacl/errorcode.h" @@ -110,3 +111,27 @@ int GeluGrad(const float *src0, const float *src1, size_t length, float *dst) { } return NNACL_OK; } + +int SoftplusGrad(const float *src0, const float *src1, int length, float *dst) { + int i = 0; +#if defined(ENABLE_AVX) + for (; i <= length - C8NUM; i += C8NUM) { + simd_exp_avx(-(MS_LD256_F32(src1 + i)), dst + i); + MS_ST256_F32(dst + i, + MS_DIV256_F32(MS_LD256_F32(src0 + i), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i)))); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; i <= length - C4NUM; i += C4NUM) { + simd_exp(MS_SUBQ_F32(MS_MOVQ_F32(0.0f), MS_LDQ_F32(src1 + i)), dst + i); + MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_LDQ_F32(src0 + i), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i)))); + } +#endif + + for (; i < length; ++i) { + single_exp(-src1[i], dst + i); + dst[i] = src0[i] / (1.0f + dst[i]); + } + return NNACL_OK; +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.h index e88b27addb5..7f493215fe3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/activation_grad.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,6 +39,7 @@ int HSwishGrad(const float *src0, const float *src1, size_t length, float *dst); int HSigmoidGrad(const float *src0, const float *src1, size_t length, float *dst); int EluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha); int GeluGrad(const float *src0, const float *src1, size_t length, float *dst); +int SoftplusGrad(const float *src, const float *src1, int length, float *dst); #ifdef __cplusplus } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ab7c128ffbf..d7616e225dd 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -346,6 +346,7 @@ inline const PrimitivePtr kPrimRelu6 = std::make_shared(kReLU6); inline const PrimitivePtr kPrimReluV2 = std::make_shared(kReLUV2); inline const PrimitivePtr kPrimPRelu = std::make_shared("PReLU"); inline const PrimitivePtr kPrimSoftplus = std::make_shared("Softplus"); +inline const PrimitivePtr kPrimSoftplusGrad = std::make_shared("SoftplusGrad"); inline const PrimitivePtr kPrimZeros = std::make_shared("Zeros"); inline const PrimitivePtr kPrimZerosLike = std::make_shared(kZerosLike); inline const PrimitivePtr kPrimOnesLike = std::make_shared(kOnesLike); @@ -472,6 +473,7 @@ inline const PrimitivePtr kPrimSqrtGrad = std::make_shared("SqrtGrad" inline const PrimitivePtr kPrimReciprocal = std::make_shared(kReciprocal); inline const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); inline const PrimitivePtr kPrimAbs = std::make_shared("Abs"); +inline const PrimitivePtr kPrimAbsGrad = std::make_shared("AbsGrad"); inline const PrimitivePtr kPrimRint = std::make_shared("Rint"); inline const PrimitivePtr kPrimRound = std::make_shared("Round"); inline const PrimitivePtr kPrimExp = std::make_shared(kExp); @@ -487,6 +489,8 @@ inline const PrimitivePtr kPrimACos = std::make_shared("ACos"); inline const PrimitivePtr kPrimAsinGrad = std::make_shared("AsinGrad"); inline const PrimitivePtr kPrimACosGrad = std::make_shared("ACosGrad"); inline const PrimitivePtr kPrimAtanGrad = std::make_shared("AtanGrad"); +inline const PrimitivePtr kPrimAsinhGrad = std::make_shared("AsinhGrad"); +inline const PrimitivePtr kPrimAcoshGrad = std::make_shared("AcoshGrad"); inline const PrimitivePtr kPrimFloorMod = std::make_shared("FloorMod"); inline const PrimitivePtr kPrimWhere = std::make_shared("Where"); inline const PrimitivePtr kPrimIdentityMath = std::make_shared("Identity", kSideEffectPropagate); diff --git a/tests/st/ops/cpu/test_softplus_grad_op.py b/tests/st/ops/cpu/test_softplus_grad_op.py new file mode 100644 index 00000000000..76879689960 --- /dev/null +++ b/tests/st/ops/cpu/test_softplus_grad_op.py @@ -0,0 +1,78 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class SoftplusNet(nn.Cell): + def __init__(self): + super(SoftplusNet, self).__init__() + self.softplus = P.Softplus() + + def construct(self, x): + return self.softplus(x) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, input_data, sens): + gout = self.grad(self.network)(input_data, sens) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softplus_grad(): + x = np.array([0.58401114, 0.68800163, 0.9760397, 0.14702141, 0.46563736, 0.9607501, + 0.14567593, 0.12261796, 0.37054458, 0.46421242]).astype(np.float32) + dy = np.array([0.5559598, 0.96994054, 0.24770357, 0.34646875, 0.2984393, 0.03287048, + 0.55681044, 0.966908, 0.06015943, 0.6099489]).astype(np.float32) + x_ms = Tensor(x) + dy_ms = Tensor(dy) + + net = SoftplusNet() + grad = Grad(net) + + output = grad(x_ms, dy_ms) + expect = dy * np.exp(x) / (1 + np.exp(x)) + assert np.allclose(output[0].asnumpy(), expect, rtol=1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softplus_grad_fp16(): + np.random.seed(42) + x_np = np.random.randn(5, 3, 6).astype(np.float16) + dy_np = np.random.randn(5, 3, 6).astype(np.float16) + net = SoftplusNet() + grad = Grad(net) + output = grad(Tensor(x_np), Tensor(dy_np)) + expect = dy_np * np.exp(x_np) / (1 + np.exp(x_np)) + assert np.allclose(output[0].asnumpy(), expect, rtol=1e-2) diff --git a/tests/st/ops/cpu/test_softplus_op.py b/tests/st/ops/cpu/test_softplus_op.py new file mode 100644 index 00000000000..19af2a20762 --- /dev/null +++ b/tests/st/ops/cpu/test_softplus_op.py @@ -0,0 +1,107 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class SoftplusNet(nn.Cell): + def __init__(self): + super(SoftplusNet, self).__init__() + self.softplus = P.Softplus() + + def construct(self, x): + return self.softplus(x) + + +def SoftplusCompute(x): + return np.log(1 + np.exp(x)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softplus_1d(): + x_np = np.random.random((50,)).astype(np.float32) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softplus_2d(): + x_np = np.random.random((50, 40)).astype(np.float32) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softplus_4d(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softplus_neg(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) * -1 + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_softplus_4d_fp16(): + x_np = np.random.random((32, 3, 224, 224)).astype(np.float16) + y_np = SoftplusCompute(x_np) + + x_ms = Tensor(x_np) + net = SoftplusNet() + y_ms = net(x_ms) + + assert np.allclose(y_np, y_ms.asnumpy(), rtol=5e-3)