!12384 Add Sin, Cos, Tan, Atan, AtanGrad for CPU

From: @wangrao124
Reviewed-by: @wuxuejian,@kisnwang
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-02-22 10:17:46 +08:00 committed by Gitee
commit 90a56c9c23
12 changed files with 339 additions and 29 deletions

View File

@ -16,6 +16,7 @@
#include <cmath>
#include <string>
#include <thread>
#include <map>
#include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
@ -107,6 +108,34 @@ void ACos(const T *in, T *out, size_t start, size_t end) {
out[i] = acos(in[i]);
}
}
template <typename T>
void Atan(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = atan(in[i]);
}
}
template <typename T>
void Sin(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = sin(in[i]);
}
}
template <typename T>
void Cos(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = cos(in[i]);
}
}
template <typename T>
void Tan(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = tan(in[i]);
}
}
} // namespace
void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@ -134,6 +163,14 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = ASIN;
} else if (kernel_name == prim::kPrimACos->name()) {
operate_type_ = ACOS;
} else if (kernel_name == prim::kPrimAtan->name()) {
operate_type_ = ATAN;
} else if (kernel_name == prim::kPrimSin->name()) {
operate_type_ = SIN;
} else if (kernel_name == prim::kPrimCos->name()) {
operate_type_ = COS;
} else if (kernel_name == prim::kPrimTan->name()) {
operate_type_ = TAN;
}
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
@ -214,31 +251,18 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
static const std::map<OperateType, std::function<void(const T *in, T *out, size_t start, size_t end)>>
kArithmeticOpFuncMap = {{SQUARE, Square<T>}, {SIGN, Sign<T>},
{NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>},
{ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>},
{FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>},
{GELU, Gelu<T>}, {SIN, Sin<T>},
{COS, Cos<T>}, {TAN, Tan<T>},
{ASIN, Asin<T>}, {ACOS, ACos<T>},
{ATAN, Atan<T>}};
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
if (operate_type_ == SQUARE) {
threads.emplace_back(std::thread(Square<T>, input, output, start, end));
} else if (operate_type_ == NEG) {
threads.emplace_back(std::thread(Neg<T>, input, output, start, end));
} else if (operate_type_ == LOGICALNOT) {
threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end));
} else if (operate_type_ == ONESLIKE) {
threads.emplace_back(std::thread(OnesLike<T>, input, output, start, end));
} else if (operate_type_ == ZEROSLIKE) {
threads.emplace_back(std::thread(ZerosLike<T>, input, output, start, end));
} else if (operate_type_ == SIGN) {
threads.emplace_back(std::thread(Sign<T>, input, output, start, end));
} else if (operate_type_ == FLOOR) {
threads.emplace_back(std::thread(Floor<T>, input, output, start, end));
} else if (operate_type_ == RECIPROCAL) {
threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end));
} else if (operate_type_ == GELU) {
threads.emplace_back(std::thread(Gelu<T>, input, output, start, end));
} else if (operate_type_ == ASIN) {
threads.emplace_back(std::thread(Asin<T>, input, output, start, end));
} else if (operate_type_ == ACOS) {
threads.emplace_back(std::thread(ACos<T>, input, output, start, end));
}
threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end));
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {

View File

@ -78,6 +78,22 @@ MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -95,8 +95,13 @@ enum OperateType {
GELUGRAD,
ASIN,
ACOS,
ATAN,
ASINGRAD,
ACOSGRAD,
ATANGRAD,
SIN,
COS,
TAN,
};
class CPUKernel : public kernel::KernelMod {

View File

@ -132,6 +132,27 @@ void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, si
}
}
template <typename T>
void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = 1 + input1[i] * input1[i];
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
out[i] = dividend / divisor;
}
}
void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
@ -153,6 +174,8 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = ASINGRAD;
} else if (kernel_name == "ACosGrad") {
operate_type_ = ACOSGRAD;
} else if (kernel_name == "AtanGrad") {
operate_type_ = ATANGRAD;
} else {
MS_LOG(EXCEPTION) << "Not support " << kernel_name;
}
@ -238,6 +261,8 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ACOSGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ATANGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}

View File

@ -54,6 +54,8 @@ class EltWiseGradCPUKernel : public CPUKernel {
void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;
@ -109,6 +111,13 @@ MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL(
ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
AtanGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
AtanGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EltWiseGradCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -388,6 +388,7 @@ inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign");
inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos");
inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad");
inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad");
inline const PrimitivePtr kPrimAtanGrad = std::make_shared<Primitive>("AtanGrad");
inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where");

View File

@ -3342,7 +3342,7 @@ class Cos(PrimitiveWithInfer):
Tensor, has the same shape as `input_x`.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> cos = ops.Cos()
@ -3379,7 +3379,7 @@ class ACos(PrimitiveWithInfer):
Tensor, has the same shape as `input_x`.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> acos = ops.ACos()
@ -3412,7 +3412,7 @@ class Sin(PrimitiveWithInfer):
Tensor, has the same shape as `input_x`.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> sin = ops.Sin()
@ -3449,7 +3449,7 @@ class Asin(PrimitiveWithInfer):
Tensor, has the same shape as `input_x`.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> asin = ops.Asin()
@ -3666,7 +3666,7 @@ class Tan(PrimitiveWithInfer):
Tensor, has the same shape as `input_x`.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> tan = ops.Tan()
@ -3704,7 +3704,7 @@ class Atan(PrimitiveWithInfer):
A Tensor, has the same type as the input.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([1.0, 0.0]), mindspore.float32)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAtanGrad(nn.Cell):
def __init__(self):
super(NetAtanGrad, self).__init__()
self.atanGrad = G.AtanGrad()
def construct(self, x, dy):
return self.atanGrad(x, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atan_grad():
x = np.array([-0.5, 0, 0.5]).astype('float32')
dy = np.array([1, 0, -1]).astype('float32')
atan_grad = NetAtanGrad()
output = atan_grad(Tensor(x), Tensor(dy))
print(output)
expect = dy / (1 + x * x)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAtan(nn.Cell):
def __init__(self):
super(NetAtan, self).__init__()
self.atan = P.Atan()
def construct(self, x):
return self.atan(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atan():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetAtan()
output = net(input_x)
print(output)
expect = np.arctan(np_array)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetCos(nn.Cell):
def __init__(self):
super(NetCos, self).__init__()
self.cos = P.Cos()
def construct(self, x):
return self.cos(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cos():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetCos()
output = net(input_x)
print(output)
expect = np.cos(np_array)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetSin(nn.Cell):
def __init__(self):
super(NetSin, self).__init__()
self.sin = P.Sin()
def construct(self, x):
return self.sin(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sin():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetSin()
output = net(input_x)
print(output)
expect = np.sin(np_array)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetTan(nn.Cell):
def __init__(self):
super(NetTan, self).__init__()
self.tan = P.Tan()
def construct(self, x):
return self.tan(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tan():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetTan()
output = net(input_x)
print(output)
expect = np.tan(np_array)
assert np.allclose(output.asnumpy(), expect)