!33968 [assistant][ops]Add New GPU operator Sinh
Merge pull request !33968 from 顾月/Sinh
This commit is contained in:
commit
1fb04cc1f0
|
@ -39,6 +39,7 @@ enum UnaryOptype {
|
|||
UNARY_OP_SQRT,
|
||||
UNARY_OP_RSQRT,
|
||||
UNARY_OP_SIN,
|
||||
UNARY_OP_SINH,
|
||||
UNARY_OP_COS,
|
||||
UNARY_OP_TAN,
|
||||
UNARY_OP_COSH,
|
||||
|
@ -77,7 +78,7 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
|
|||
{"Real", UNARY_OP_REAL}, {"Imag", UNARY_OP_IMAG},
|
||||
{"Sign", UNARY_OP_SIGN}, {"Conj", UNARY_OP_CONJ},
|
||||
{"Atanh", UNARY_OP_ATANH}, {"Tan", UNARY_OP_TAN},
|
||||
};
|
||||
{"Sinh", UNARY_OP_SINH}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryHelperGpuKernel : public GpuKernelHelperBase {
|
||||
|
@ -126,8 +127,7 @@ class UnaryHelperGpuKernel : public GpuKernelHelperBase {
|
|||
{UNARY_OP_FLOOR, Floor<T>}, {UNARY_OP_CEIL, Ceil<T>},
|
||||
{UNARY_OP_RINT, Rint<T>}, {UNARY_OP_ROUND, Round<T>},
|
||||
{UNARY_OP_SIGN, Sign<T>}, {UNARY_OP_ATANH, Atanh<T>},
|
||||
{UNARY_OP_TAN, Tan<T>},
|
||||
};
|
||||
{UNARY_OP_TAN, Tan<T>}, {UNARY_OP_SINH, Sinh<T>}};
|
||||
|
||||
auto iter = func_map.find(unary_op_type_);
|
||||
if (iter != func_map.end()) {
|
||||
|
|
|
@ -242,6 +242,27 @@ __global__ void SinKernel(const half *input, half *output, const size_t count) {
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void SinhKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
output[i] = sinhf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void SinhKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
output[i] = sinh(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void SinhKernel(const half *input, half *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
output[i] = half(0.5) * (hexp(input[i]) - hexp(-input[i]));
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void TanKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = tanf(input[i]);
|
||||
|
@ -718,6 +739,11 @@ void Sin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Sinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
SinhKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Tan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
TanKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -860,6 +886,8 @@ template CUDA_LIB_EXPORT void Sqrt<double>(const double *input, double *output,
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sin<double>(const double *input, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<double>(const double *input, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cos<double>(const double *input, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<double>(const double *input, double *output, const size_t count,
|
||||
|
@ -926,6 +954,8 @@ template CUDA_LIB_EXPORT void Sqrt<float>(const float *input, float *output, con
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sin<float>(const float *input, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<float>(const float *input, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cos<float>(const float *input, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<float>(const float *input, float *output, const size_t count,
|
||||
|
@ -987,6 +1017,7 @@ template CUDA_LIB_EXPORT void Square<half>(const half *input, half *output, cons
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sin<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -1037,6 +1068,7 @@ template CUDA_LIB_EXPORT void Sqrt<char>(const char *input, char *output, const
|
|||
template CUDA_LIB_EXPORT void Sin<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cos<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ACos<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -1092,6 +1124,8 @@ template CUDA_LIB_EXPORT void Cos<unsigned char>(const unsigned char *input, uns
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
|
||||
|
@ -1147,6 +1181,7 @@ template CUDA_LIB_EXPORT void Sqrt<int>(const int *input, int *output, const siz
|
|||
template CUDA_LIB_EXPORT void Sin<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cos<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void ACos<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -1196,6 +1231,8 @@ template CUDA_LIB_EXPORT void Cos<uint32_t>(const uint32_t *input, uint32_t *out
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
|
@ -1256,6 +1293,8 @@ template CUDA_LIB_EXPORT void Cos<int16_t>(const int16_t *input, int16_t *output
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
|
@ -1316,6 +1355,8 @@ template CUDA_LIB_EXPORT void Cos<uint16_t>(const uint16_t *input, uint16_t *out
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
|
@ -1376,6 +1417,8 @@ template CUDA_LIB_EXPORT void Cos<int64_t>(const int64_t *input, int64_t *output
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
|
@ -1436,6 +1479,8 @@ template CUDA_LIB_EXPORT void Cos<uint64_t>(const uint64_t *input, uint64_t *out
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Tan<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sinh<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Cosh<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Asin<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
|
|
|
@ -47,6 +47,8 @@ CUDA_LIB_EXPORT void Rsqrt(const T *input, T *output, const size_t count, cudaSt
|
|||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Sin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Sinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Tan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Cos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -47,6 +47,7 @@ constexpr auto kRsqrt = "Rsqrt";
|
|||
constexpr auto kSign = "Sign";
|
||||
constexpr auto kSin = "Sin";
|
||||
constexpr auto kTan = "Tan";
|
||||
constexpr auto kSinh = "Sinh";
|
||||
constexpr auto kSqrt = "Sqrt";
|
||||
constexpr auto kSquare = "Square";
|
||||
} // namespace
|
||||
|
@ -169,6 +170,13 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::Una
|
|||
&UnaryOpGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
|
||||
{kSinh,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<half>}}},
|
||||
{kTan,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&UnaryOpGpuKernelMod::LaunchKernel<double>},
|
||||
|
@ -346,8 +354,7 @@ bool UnaryOpGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
|||
{kAsin, Asin<T>}, {kACos, ACos<T>}, {kAtan, Atan<T>}, {kAsinh, Asinh<T>},
|
||||
{kAcosh, Acosh<T>}, {kAbs, Abs<T>}, {kFloor, Floor<T>}, {kCeil, Ceil<T>},
|
||||
{kRint, Rint<T>}, {kRound, Round<T>}, {kSign, Sign<T>}, {kAtanh, Atanh<T>},
|
||||
{kTan, Tan<T>},
|
||||
};
|
||||
{kTan, Tan<T>}, {kSinh, Sinh<T>}};
|
||||
copy(func_map_normal.begin(), func_map_normal.end(), inserter(func_map, func_map.begin()));
|
||||
}
|
||||
|
||||
|
@ -407,6 +414,8 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Rsqrt,
|
|||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sign,
|
||||
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kSign); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sin, []() { return std::make_shared<UnaryOpGpuKernelMod>(kSin); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sinh,
|
||||
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kSinh); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Tan, []() { return std::make_shared<UnaryOpGpuKernelMod>(kTan); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sqrt,
|
||||
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kSqrt); });
|
||||
|
|
|
@ -1257,7 +1257,7 @@ def sinh(x):
|
|||
TypeError: If `x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)
|
||||
|
|
|
@ -3350,7 +3350,7 @@ class Sinh(Primitive):
|
|||
Refer to :func:`mindspore.ops.sinh` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> sinh = ops.Sinh()
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetSinh(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetSinh, self).__init__()
|
||||
self.sinh = P.Sinh()
|
||||
|
||||
def construct(self, x):
|
||||
return self.sinh(x)
|
||||
|
||||
|
||||
class SinhGradNet(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(SinhGradNet, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, grad_np):
|
||||
grad_out = self.grad(self.network)(x, grad_np)
|
||||
return grad_out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sinh_fp16():
|
||||
"""
|
||||
Feature: Sinh
|
||||
Description: test cases for Sinh of float16
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
x_np = np.array([0.5, 1, 3, 3.2]).astype(np.float16)
|
||||
input_x = Tensor(x_np)
|
||||
net = NetSinh()
|
||||
output_ms = net(input_x)
|
||||
grad_np = np.array([0.5, 1, 3, 3.2]).astype(np.float16)
|
||||
grad_net = SinhGradNet(net)
|
||||
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
|
||||
expect_output = np.array([0.521, 1.176, 10.016, 12.234]).astype(np.float16)
|
||||
expect_grad_output = np.array([0.5635, 1.543, 30.19, 39.28]).astype(np.float16)
|
||||
assert np.allclose(output_ms.asnumpy(), expect_output)
|
||||
assert np.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sinh_fp32():
|
||||
"""
|
||||
Feature: Sinh
|
||||
Description: test cases for Sinh of float32
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
x_np = np.array([1, 2, 3, 4]).astype(np.float32)
|
||||
input_x = Tensor(x_np)
|
||||
net = NetSinh()
|
||||
output_ms = net(input_x)
|
||||
grad_np = np.array([1, 2, 3, 4]).astype(np.float32)
|
||||
grad_net = SinhGradNet(net)
|
||||
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
|
||||
expect_output = np.array([1.1752012, 3.6268604, 10.017875, 27.289917]).astype(np.float32)
|
||||
expect_grad_output = np.array([1.5430806, 7.5243917, 30.202988, 109.232925]).astype(np.float32)
|
||||
assert np.allclose(output_ms.asnumpy(), expect_output)
|
||||
assert np.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sinh_fp64():
|
||||
"""
|
||||
Feature: Sinh
|
||||
Description: test cases for Sinh of float64
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
x_np = np.array([0.2, 0.9, 2.4, 8.8]).astype(np.float64)
|
||||
input_x = Tensor(x_np)
|
||||
net = NetSinh()
|
||||
output_ms = net(input_x)
|
||||
grad_np = np.array([0.2, 0.9, 2.4, 8.8]).astype(np.float64)
|
||||
grad_net = SinhGradNet(net)
|
||||
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
|
||||
expect_output = np.array([2.01336003e-01, 1.02651673e+00, 5.46622921e+00, 3.31712193e+03]).astype(np.float64)
|
||||
expect_grad_output = np.array([2.04013351e-01, 1.28977775e+00, 1.33366732e+01, 2.91906743e+04]).astype(np.float64)
|
||||
assert np.allclose(output_ms.asnumpy(), expect_output)
|
||||
assert np.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)
|
Loading…
Reference in New Issue