diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index 6d52662470d..b46b34039ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -93,6 +93,20 @@ void Gelu(const T *in, T *out, size_t start, size_t end) { out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; } } + +template +void Asin(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = asin(in[i]); + } +} + +template +void ACos(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = acos(in[i]); + } +} } // namespace void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { @@ -116,6 +130,10 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = RECIPROCAL; } else if (kernel_name == prim::kPrimGelu->name()) { operate_type_ = GELU; + } else if (kernel_name == prim::kPrimAsin->name()) { + operate_type_ = ASIN; + } else if (kernel_name == prim::kPrimACos->name()) { + operate_type_ = ACOS; } dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); @@ -216,6 +234,10 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs threads.emplace_back(std::thread(Reciprocal, input, output, start, end)); } else if (operate_type_ == GELU) { threads.emplace_back(std::thread(Gelu, input, output, start, end)); + } else if (operate_type_ == ASIN) { + threads.emplace_back(std::thread(Asin, input, output, start, end)); + } else if (operate_type_ == ACOS) { + threads.emplace_back(std::thread(ACos, input, output, start, end)); } start += once_compute_size; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 6a28da1033a..8c22ae69c03 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -70,6 +70,14 @@ MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ArithmeticSelfCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 58759712a59..e614a627357 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -93,6 +93,10 @@ enum OperateType { RECIPROCAL, GELU, GELUGRAD, + ASIN, + ACOS, + ASINGRAD, + ACOSGRAD, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc index ee7a47c91a5..24a758f3d02 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc @@ -90,6 +90,48 @@ void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, si } } +template +void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + T dividend = input2[i]; + T divisor = sqrt(1 - input1[i] * input1[i]); + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } + continue; + } + out[i] = dividend / divisor; + } +} + +template +void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + T dividend = -input2[i]; + T divisor = sqrt(1 - input1[i] * input1[i]); + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } + continue; + } + out[i] = dividend / divisor; + } +} + void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); @@ -107,6 +149,10 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = SQRTGRAD; } else if (kernel_name == "GeluGrad") { operate_type_ = GELUGRAD; + } else if (kernel_name == "AsinGrad") { + operate_type_ = ASINGRAD; + } else if (kernel_name == "ACosGrad") { + operate_type_ = ACOSGRAD; } else { MS_LOG(EXCEPTION) << "Not support " << kernel_name; } @@ -188,6 +234,10 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector &inputs, c threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad, this, input1, input2, output, start, end)); } else if (operate_type_ == GELUGRAD) { threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad, this, input1, input2, output, start, end)); + } else if (operate_type_ == ASINGRAD) { + threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad, this, input1, input2, output, start, end)); + } else if (operate_type_ == ACOSGRAD) { + threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad, this, input1, input2, output, start, end)); } else { MS_LOG(EXCEPTION) << "Not support " << operate_type_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h index b67c632654b..5669f1b0baf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELTWISE_GRAD_CPU_KERNEL_H_ #include #include +#include #include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" @@ -49,6 +50,10 @@ class EltWiseGradCPUKernel : public CPUKernel { void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); template void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); std::vector input_shape0_; std::vector input_shape1_; std::vector input_element_num0_; @@ -90,6 +95,20 @@ MS_REG_CPU_KERNEL(GeluGrad, .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), EltWiseGradCPUKernel); +MS_REG_CPU_KERNEL( + AsinGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EltWiseGradCPUKernel); +MS_REG_CPU_KERNEL( + AsinGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + EltWiseGradCPUKernel); +MS_REG_CPU_KERNEL( + ACosGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EltWiseGradCPUKernel); +MS_REG_CPU_KERNEL( + ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + EltWiseGradCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 25436e00a33..38b4de8eb40 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -278,6 +278,10 @@ inline const PrimitivePtr kPrimSplitV = std::make_shared("SplitV"); inline const PrimitivePtr kPrimLinSpace = std::make_shared("LinSpace"); inline const PrimitivePtr kPrimSign = std::make_shared("Sign"); inline const PrimitivePtr kPrimSquaredDifference = std::make_shared("SquaredDifference"); +inline const PrimitivePtr kPrimAsin = std::make_shared("Asin"); +inline const PrimitivePtr kPrimACos = std::make_shared("ACos"); +inline const PrimitivePtr kPrimAsinGrad = std::make_shared("AsinGrad"); +inline const PrimitivePtr kPrimACosGrad = std::make_shared("ACosGrad"); // Statements inline const PrimitivePtr kPrimReturn = std::make_shared("return"); @@ -351,7 +355,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_ inline const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); -// Other primitve not used by backend but used in core; +// Other primitive not used by backend but used in core; inline const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); inline const PrimitivePtr kPrimJ = std::make_shared("J"); diff --git a/tests/st/ops/cpu/test_acos_grad_op.py b/tests/st/ops/cpu/test_acos_grad_op.py new file mode 100644 index 00000000000..4891cee6c4e --- /dev/null +++ b/tests/st/ops/cpu/test_acos_grad_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetACosGrad(nn.Cell): + def __init__(self): + super(NetACosGrad, self).__init__() + self.acosGrad = G.ACosGrad() + + def construct(self, x, dy): + return self.acosGrad(x, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_acos_grad(): + x = np.array([-0.5, 0, 0.5]).astype('float32') + dy = np.array([1, 0, -1]).astype('float32') + acos_grad = NetACosGrad() + output = acos_grad(Tensor(x), Tensor(dy)) + print(output) + expect = -dy / np.sqrt(1 - x * x) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_acos_op.py b/tests/st/ops/cpu/test_acos_op.py new file mode 100644 index 00000000000..7558bc2e1a7 --- /dev/null +++ b/tests/st/ops/cpu/test_acos_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetACos(nn.Cell): + def __init__(self): + super(NetACos, self).__init__() + self.acos = P.ACos() + + def construct(self, x): + return self.acos(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_acos(): + np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') + input_x = Tensor(np_array) + net = NetACos() + output = net(input_x) + print(output) + expect = np.arccos(np_array) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_asin_grad_op.py b/tests/st/ops/cpu/test_asin_grad_op.py new file mode 100644 index 00000000000..8bdb294e03a --- /dev/null +++ b/tests/st/ops/cpu/test_asin_grad_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAsinGrad(nn.Cell): + def __init__(self): + super(NetAsinGrad, self).__init__() + self.asinGrad = G.AsinGrad() + + def construct(self, x, dy): + return self.asinGrad(x, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_asin_grad(): + x = np.array([-0.5, 0, 0.5]).astype('float32') + dy = np.array([1, 0, -1]).astype('float32') + asin_grad = NetAsinGrad() + output = asin_grad(Tensor(x), Tensor(dy)) + print(output) + expect = dy / np.sqrt(1 - x * x) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_asin_op.py b/tests/st/ops/cpu/test_asin_op.py new file mode 100644 index 00000000000..62e461bf3cb --- /dev/null +++ b/tests/st/ops/cpu/test_asin_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAsin(nn.Cell): + def __init__(self): + super(NetAsin, self).__init__() + self.asin = P.Asin() + + def construct(self, x): + return self.asin(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_asin(): + np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') + input_x = Tensor(np_array) + net = NetAsin() + output = net(input_x) + print(output) + expect = np.arcsin(np_array) + assert np.allclose(output.asnumpy(), expect)