!33968 [assistant][ops]Add New GPU operator Sinh

Merge pull request !33968 from 顾月/Sinh
This commit is contained in:
i-robot 2022-06-16 22:17:50 +00:00 committed by Gitee
commit 1fb04cc1f0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 173 additions and 7 deletions

View File

@ -39,6 +39,7 @@ enum UnaryOptype {
UNARY_OP_SQRT,
UNARY_OP_RSQRT,
UNARY_OP_SIN,
UNARY_OP_SINH,
UNARY_OP_COS,
UNARY_OP_TAN,
UNARY_OP_COSH,
@ -77,7 +78,7 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
{"Real", UNARY_OP_REAL}, {"Imag", UNARY_OP_IMAG},
{"Sign", UNARY_OP_SIGN}, {"Conj", UNARY_OP_CONJ},
{"Atanh", UNARY_OP_ATANH}, {"Tan", UNARY_OP_TAN},
};
{"Sinh", UNARY_OP_SINH}};
template <typename T>
class UnaryHelperGpuKernel : public GpuKernelHelperBase {
@ -126,8 +127,7 @@ class UnaryHelperGpuKernel : public GpuKernelHelperBase {
{UNARY_OP_FLOOR, Floor<T>}, {UNARY_OP_CEIL, Ceil<T>},
{UNARY_OP_RINT, Rint<T>}, {UNARY_OP_ROUND, Round<T>},
{UNARY_OP_SIGN, Sign<T>}, {UNARY_OP_ATANH, Atanh<T>},
{UNARY_OP_TAN, Tan<T>},
};
{UNARY_OP_TAN, Tan<T>}, {UNARY_OP_SINH, Sinh<T>}};
auto iter = func_map.find(unary_op_type_);
if (iter != func_map.end()) {

View File

@ -242,6 +242,27 @@ __global__ void SinKernel(const half *input, half *output, const size_t count) {
return;
}
template <typename T>
__global__ void SinhKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
output[i] = sinhf(input[i]);
}
return;
}
template <>
__global__ void SinhKernel(const double *input, double *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
output[i] = sinh(input[i]);
}
return;
}
template <>
__global__ void SinhKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
output[i] = half(0.5) * (hexp(input[i]) - hexp(-input[i]));
}
return;
}
template <typename T>
__global__ void TanKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = tanf(input[i]);
@ -718,6 +739,11 @@ void Sin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream
return;
}
template <typename T>
void Sinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
SinhKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Tan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
TanKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@ -860,6 +886,8 @@ template CUDA_LIB_EXPORT void Sqrt<double>(const double *input, double *output,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sin<double>(const double *input, double *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<double>(const double *input, double *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cos<double>(const double *input, double *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<double>(const double *input, double *output, const size_t count,
@ -926,6 +954,8 @@ template CUDA_LIB_EXPORT void Sqrt<float>(const float *input, float *output, con
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sin<float>(const float *input, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<float>(const float *input, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cos<float>(const float *input, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<float>(const float *input, float *output, const size_t count,
@ -987,6 +1017,7 @@ template CUDA_LIB_EXPORT void Square<half>(const half *input, half *output, cons
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sin<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
@ -1037,6 +1068,7 @@ template CUDA_LIB_EXPORT void Sqrt<char>(const char *input, char *output, const
template CUDA_LIB_EXPORT void Sin<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cos<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ACos<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
@ -1092,6 +1124,8 @@ template CUDA_LIB_EXPORT void Cos<unsigned char>(const unsigned char *input, uns
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
@ -1147,6 +1181,7 @@ template CUDA_LIB_EXPORT void Sqrt<int>(const int *input, int *output, const siz
template CUDA_LIB_EXPORT void Sin<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cos<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ACos<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
@ -1196,6 +1231,8 @@ template CUDA_LIB_EXPORT void Cos<uint32_t>(const uint32_t *input, uint32_t *out
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
@ -1256,6 +1293,8 @@ template CUDA_LIB_EXPORT void Cos<int16_t>(const int16_t *input, int16_t *output
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<int16_t>(const int16_t *input, int16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<int16_t>(const int16_t *input, int16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<int16_t>(const int16_t *input, int16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<int16_t>(const int16_t *input, int16_t *output, const size_t count,
@ -1316,6 +1355,8 @@ template CUDA_LIB_EXPORT void Cos<uint16_t>(const uint16_t *input, uint16_t *out
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
@ -1376,6 +1417,8 @@ template CUDA_LIB_EXPORT void Cos<int64_t>(const int64_t *input, int64_t *output
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<int64_t>(const int64_t *input, int64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<int64_t>(const int64_t *input, int64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<int64_t>(const int64_t *input, int64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<int64_t>(const int64_t *input, int64_t *output, const size_t count,
@ -1436,6 +1479,8 @@ template CUDA_LIB_EXPORT void Cos<uint64_t>(const uint64_t *input, uint64_t *out
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Tan<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sinh<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Cosh<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Asin<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,

View File

@ -47,6 +47,8 @@ CUDA_LIB_EXPORT void Rsqrt(const T *input, T *output, const size_t count, cudaSt
template <typename T>
CUDA_LIB_EXPORT void Sin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Sinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Tan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Cos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -47,6 +47,7 @@ constexpr auto kRsqrt = "Rsqrt";
constexpr auto kSign = "Sign";
constexpr auto kSin = "Sin";
constexpr auto kTan = "Tan";
constexpr auto kSinh = "Sinh";
constexpr auto kSqrt = "Sqrt";
constexpr auto kSquare = "Square";
} // namespace
@ -169,6 +170,13 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
&UnaryOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
{kSinh,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&UnaryOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&UnaryOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
{kTan,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&UnaryOpGpuKernelMod::LaunchKernel<double>},
@ -346,8 +354,7 @@ bool UnaryOpGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
{kAsin, Asin<T>}, {kACos, ACos<T>}, {kAtan, Atan<T>}, {kAsinh, Asinh<T>},
{kAcosh, Acosh<T>}, {kAbs, Abs<T>}, {kFloor, Floor<T>}, {kCeil, Ceil<T>},
{kRint, Rint<T>}, {kRound, Round<T>}, {kSign, Sign<T>}, {kAtanh, Atanh<T>},
{kTan, Tan<T>},
};
{kTan, Tan<T>}, {kSinh, Sinh<T>}};
copy(func_map_normal.begin(), func_map_normal.end(), inserter(func_map, func_map.begin()));
}
@ -407,6 +414,8 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Rsqrt,
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sign,
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kSign); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sin, []() { return std::make_shared<UnaryOpGpuKernelMod>(kSin); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sinh,
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kSinh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Tan, []() { return std::make_shared<UnaryOpGpuKernelMod>(kTan); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sqrt,
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kSqrt); });

View File

@ -1257,7 +1257,7 @@ def sinh(x):
TypeError: If `x` is not a Tensor.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)

View File

@ -3350,7 +3350,7 @@ class Sinh(Primitive):
Refer to :func:`mindspore.ops.sinh` for more detail.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> sinh = ops.Sinh()

View File

@ -0,0 +1,110 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
class NetSinh(nn.Cell):
def __init__(self):
super(NetSinh, self).__init__()
self.sinh = P.Sinh()
def construct(self, x):
return self.sinh(x)
class SinhGradNet(nn.Cell):
def __init__(self, network):
super(SinhGradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, x, grad_np):
grad_out = self.grad(self.network)(x, grad_np)
return grad_out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sinh_fp16():
"""
Feature: Sinh
Description: test cases for Sinh of float16
Expectation: the results are as expected
"""
x_np = np.array([0.5, 1, 3, 3.2]).astype(np.float16)
input_x = Tensor(x_np)
net = NetSinh()
output_ms = net(input_x)
grad_np = np.array([0.5, 1, 3, 3.2]).astype(np.float16)
grad_net = SinhGradNet(net)
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
expect_output = np.array([0.521, 1.176, 10.016, 12.234]).astype(np.float16)
expect_grad_output = np.array([0.5635, 1.543, 30.19, 39.28]).astype(np.float16)
assert np.allclose(output_ms.asnumpy(), expect_output)
assert np.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sinh_fp32():
"""
Feature: Sinh
Description: test cases for Sinh of float32
Expectation: the results are as expected
"""
x_np = np.array([1, 2, 3, 4]).astype(np.float32)
input_x = Tensor(x_np)
net = NetSinh()
output_ms = net(input_x)
grad_np = np.array([1, 2, 3, 4]).astype(np.float32)
grad_net = SinhGradNet(net)
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
expect_output = np.array([1.1752012, 3.6268604, 10.017875, 27.289917]).astype(np.float32)
expect_grad_output = np.array([1.5430806, 7.5243917, 30.202988, 109.232925]).astype(np.float32)
assert np.allclose(output_ms.asnumpy(), expect_output)
assert np.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sinh_fp64():
"""
Feature: Sinh
Description: test cases for Sinh of float64
Expectation: the results are as expected
"""
x_np = np.array([0.2, 0.9, 2.4, 8.8]).astype(np.float64)
input_x = Tensor(x_np)
net = NetSinh()
output_ms = net(input_x)
grad_np = np.array([0.2, 0.9, 2.4, 8.8]).astype(np.float64)
grad_net = SinhGradNet(net)
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
expect_output = np.array([2.01336003e-01, 1.02651673e+00, 5.46622921e+00, 3.31712193e+03]).astype(np.float64)
expect_grad_output = np.array([2.04013351e-01, 1.28977775e+00, 1.33366732e+01, 2.91906743e+04]).astype(np.float64)
assert np.allclose(output_ms.asnumpy(), expect_output)
assert np.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)