Add Asin, ACos, AsinGrad, ACosGrad for CPU

This commit is contained in:
wangrao 2021-01-22 15:22:35 +08:00
parent 007bd6c7d3
commit f726faee4f
10 changed files with 292 additions and 1 deletions

View File

@ -93,6 +93,20 @@ void Gelu(const T *in, T *out, size_t start, size_t end) {
out[i] = x * ((T)1.0 + tanh_res) / (T)2.0;
}
}
template <typename T>
void Asin(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = asin(in[i]);
}
}
template <typename T>
void ACos(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = acos(in[i]);
}
}
} // namespace
void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@ -116,6 +130,10 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = RECIPROCAL;
} else if (kernel_name == prim::kPrimGelu->name()) {
operate_type_ = GELU;
} else if (kernel_name == prim::kPrimAsin->name()) {
operate_type_ = ASIN;
} else if (kernel_name == prim::kPrimACos->name()) {
operate_type_ = ACOS;
}
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
@ -216,6 +234,10 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end));
} else if (operate_type_ == GELU) {
threads.emplace_back(std::thread(Gelu<T>, input, output, start, end));
} else if (operate_type_ == ASIN) {
threads.emplace_back(std::thread(Asin<T>, input, output, start, end));
} else if (operate_type_ == ACOS) {
threads.emplace_back(std::thread(ACos<T>, input, output, start, end));
}
start += once_compute_size;
}

View File

@ -70,6 +70,14 @@ MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -93,6 +93,10 @@ enum OperateType {
RECIPROCAL,
GELU,
GELUGRAD,
ASIN,
ACOS,
ASINGRAD,
ACOSGRAD,
};
class CPUKernel : public kernel::KernelMod {

View File

@ -90,6 +90,48 @@ void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, si
}
}
template <typename T>
void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = sqrt(1 - input1[i] * input1[i]);
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
out[i] = dividend / divisor;
}
}
template <typename T>
void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = -input2[i];
T divisor = sqrt(1 - input1[i] * input1[i]);
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
out[i] = dividend / divisor;
}
}
void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
@ -107,6 +149,10 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = SQRTGRAD;
} else if (kernel_name == "GeluGrad") {
operate_type_ = GELUGRAD;
} else if (kernel_name == "AsinGrad") {
operate_type_ = ASINGRAD;
} else if (kernel_name == "ACosGrad") {
operate_type_ = ACOSGRAD;
} else {
MS_LOG(EXCEPTION) << "Not support " << kernel_name;
}
@ -188,6 +234,10 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == GELUGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ASINGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ACOSGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}

View File

@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELTWISE_GRAD_CPU_KERNEL_H_
#include <memory>
#include <vector>
#include <limits>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
@ -49,6 +50,10 @@ class EltWiseGradCPUKernel : public CPUKernel {
void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;
@ -90,6 +95,20 @@ MS_REG_CPU_KERNEL(GeluGrad,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
AsinGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
AsinGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
ACosGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EltWiseGradCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -278,6 +278,10 @@ inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV");
inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace");
inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign");
inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference");
inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin");
inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos");
inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad");
inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad");
// Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
@ -351,7 +355,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
// Other primitve not used by backend but used in core;
// Other primitive not used by backend but used in core;
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetACosGrad(nn.Cell):
def __init__(self):
super(NetACosGrad, self).__init__()
self.acosGrad = G.ACosGrad()
def construct(self, x, dy):
return self.acosGrad(x, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_acos_grad():
x = np.array([-0.5, 0, 0.5]).astype('float32')
dy = np.array([1, 0, -1]).astype('float32')
acos_grad = NetACosGrad()
output = acos_grad(Tensor(x), Tensor(dy))
print(output)
expect = -dy / np.sqrt(1 - x * x)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetACos(nn.Cell):
def __init__(self):
super(NetACos, self).__init__()
self.acos = P.ACos()
def construct(self, x):
return self.acos(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_acos():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetACos()
output = net(input_x)
print(output)
expect = np.arccos(np_array)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAsinGrad(nn.Cell):
def __init__(self):
super(NetAsinGrad, self).__init__()
self.asinGrad = G.AsinGrad()
def construct(self, x, dy):
return self.asinGrad(x, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_asin_grad():
x = np.array([-0.5, 0, 0.5]).astype('float32')
dy = np.array([1, 0, -1]).astype('float32')
asin_grad = NetAsinGrad()
output = asin_grad(Tensor(x), Tensor(dy))
print(output)
expect = dy / np.sqrt(1 - x * x)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAsin(nn.Cell):
def __init__(self):
super(NetAsin, self).__init__()
self.asin = P.Asin()
def construct(self, x):
return self.asin(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_asin():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetAsin()
output = net(input_x)
print(output)
expect = np.arcsin(np_array)
assert np.allclose(output.asnumpy(), expect)