forked from mindspore-Ecosystem/mindspore
!10580 Add new GPU operators: Asinh, AsinhGrad, Acosh, AcoshGrad, Atan and AtanGrad
From: @david-he91 Reviewed-by: @linqingke,@liangchenghui,@liangchenghui Signed-off-by: @liangchenghui,@liangchenghui
This commit is contained in:
commit
ee07285420
|
@ -76,6 +76,33 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
T one = 1;
|
||||
T divisor = one + input[i] * input[i];
|
||||
output[i] = dout[i] / divisor;
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T coshy = static_cast<T>(coshf(inputf));
|
||||
output[i] = dout[i] / coshy;
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T sinhy = static_cast<T>(sinhf(inputf));
|
||||
output[i] = dout[i] / sinhy;
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
|
@ -98,6 +125,24 @@ void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cuda
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
AtanGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
AsinhGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
AcoshGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
|
@ -106,6 +151,12 @@ template void AsinGrad<float>(const float *input, const float *dout, float *outp
|
|||
cudaStream_t cuda_stream);
|
||||
template void ACosGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AtanGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AsinhGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AcoshGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
|
@ -114,3 +165,9 @@ template void AsinGrad<half>(const half *input, const half *dout, half *output,
|
|||
cudaStream_t cuda_stream);
|
||||
template void ACosGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AtanGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AsinhGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AcoshGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -26,5 +26,12 @@ template <typename T>
|
|||
void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_
|
||||
|
|
|
@ -146,6 +146,15 @@ __global__ void AsinKernel(const T *input, T *output, const size_t count) {
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AsinhKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(asinhf(inputf));
|
||||
output[i] = res;
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void CosKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = cos(input[i]);
|
||||
|
@ -169,6 +178,24 @@ __global__ void ACosKernel(const T *input, T *output, const size_t count) {
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AcoshKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(acoshf(inputf));
|
||||
output[i] = res;
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AtanKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(atanf(inputf));
|
||||
output[i] = res;
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void ZeroslikeKernel(T *output, const size_t count) {
|
||||
T zero = 0.0;
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -281,6 +308,21 @@ void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_strea
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
AtanKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
AsinhKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
AcoshKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
RsqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -315,6 +357,9 @@ template void Sin<float>(const float *input, float *output, const size_t count,
|
|||
template void Cos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Asin<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void ACos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Atan<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Asinh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Acosh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -333,6 +378,9 @@ template void Sin<half>(const half *input, half *output, const size_t count, cud
|
|||
template void Cos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Asin<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void ACos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Atan<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Asinh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Acosh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Zeroslike<half>(half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -49,6 +49,12 @@ void Asin(const T *input, T *output, const size_t count, cudaStream_t cuda_strea
|
|||
template <typename T>
|
||||
void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -74,6 +74,10 @@ MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
@ -82,6 +86,14 @@ MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
|
|
@ -44,6 +44,9 @@ enum UnaryOptype {
|
|||
UNARY_OP_COS,
|
||||
UNARY_OP_ASIN,
|
||||
UNARY_OP_ACOS,
|
||||
UNARY_OP_ATAN,
|
||||
UNARY_OP_ASINH,
|
||||
UNARY_OP_ACOSH,
|
||||
UNARY_OP_ABS,
|
||||
UNARY_OP_FLOOR,
|
||||
UNARY_OP_INVALID_TYPE = 255
|
||||
|
@ -64,6 +67,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
|
|||
{"Cos", UNARY_OP_COS},
|
||||
{"Asin", UNARY_OP_ASIN},
|
||||
{"ACos", UNARY_OP_ACOS},
|
||||
{"Atan", UNARY_OP_ATAN},
|
||||
{"Asinh", UNARY_OP_ASINH},
|
||||
{"Acosh", UNARY_OP_ACOSH},
|
||||
{"Abs", UNARY_OP_ABS},
|
||||
{"Floor", UNARY_OP_FLOOR}};
|
||||
template <typename T>
|
||||
|
@ -142,6 +148,18 @@ class UnaryOpGpuKernel : public GpuKernel {
|
|||
ACos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ATAN: {
|
||||
Atan(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ASINH: {
|
||||
Asinh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ACOSH: {
|
||||
Acosh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ZEROSLIKE: {
|
||||
Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
|
|
|
@ -50,5 +50,29 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
ACosGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AtanGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryGradOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AtanGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AsinhGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryGradOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AsinhGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AcoshGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryGradOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AcoshGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,12 +32,16 @@ enum UnaryGradOptype {
|
|||
UNARY_OP_RSQRT_GRAD = 1,
|
||||
UNARY_OP_ASIN_GRAD = 2,
|
||||
UNARY_OP_ACOS_GRAD = 3,
|
||||
UNARY_OP_ATAN_GRAD = 4,
|
||||
UNARY_OP_ASINH_GRAD = 5,
|
||||
UNARY_OP_ACOSH_GRAD = 6,
|
||||
UNARY_OP_GRAD_INVALID_TYPE = 255
|
||||
};
|
||||
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD},
|
||||
{"RsqrtGrad", UNARY_OP_RSQRT_GRAD},
|
||||
{"AsinGrad", UNARY_OP_ASIN_GRAD},
|
||||
{"ACosGrad", UNARY_OP_ACOS_GRAD}};
|
||||
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {
|
||||
{"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD},
|
||||
{"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD},
|
||||
{"AcoshGrad", UNARY_OP_ACOSH_GRAD}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryGradOpGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -77,6 +81,21 @@ class UnaryGradOpGpuKernel : public GpuKernel {
|
|||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ATAN_GRAD: {
|
||||
AtanGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ASINH_GRAD: {
|
||||
AsinhGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ACOSH_GRAD: {
|
||||
AcoshGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_RSQRT_GRAD: {
|
||||
RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations._grad_ops as P
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_acoshgrad_fp32():
|
||||
y_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
dout_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np))
|
||||
output_np = dout_np / np.sinh(y_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_acoshgrad_fp16():
|
||||
y_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
dout_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np))
|
||||
output_np = dout_np.astype(np.float32) / np.sinh(y_np).astype(np.float32)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_acosh_fp32():
|
||||
x_np = np.random.rand(4, 2).astype(np.float32) * 10 + 1
|
||||
output_ms = P.Acosh()(Tensor(x_np))
|
||||
output_np = np.arccosh(x_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_acosh_fp16():
|
||||
x_np = np.random.rand(4, 2).astype(np.float16) * 10 + 1
|
||||
output_ms = P.Acosh()(Tensor(x_np))
|
||||
output_np = np.arccosh(x_np.astype(np.float32)).astype(np.float16)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations._grad_ops as P
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_asinhgrad_fp32():
|
||||
y_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
dout_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np))
|
||||
output_np = dout_np / np.cosh(y_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_asinhgrad_fp16():
|
||||
y_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
dout_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np))
|
||||
output_np = dout_np.astype(np.float32) / np.cosh(y_np).astype(np.float32)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_asinh_fp32():
|
||||
x_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
output_ms = P.Asinh()(Tensor(x_np))
|
||||
output_np = np.arcsinh(x_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_asinh_fp16():
|
||||
x_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
output_ms = P.Asinh()(Tensor(x_np))
|
||||
output_np = np.arcsinh(x_np.astype(np.float32)).astype(np.float16)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations._grad_ops as P
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atangrad_fp32():
|
||||
x_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
dout_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np))
|
||||
output_np = dout_np / (1 + np.square(x_np))
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atangrad_fp16():
|
||||
x_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
dout_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np))
|
||||
output_np = dout_np.astype(np.float32) / (1 + np.square(x_np.astype(np.float32)))
|
||||
assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atan_fp32():
|
||||
x_np = np.random.rand(4, 2).astype(np.float32) * 10
|
||||
output_ms = P.Atan()(Tensor(x_np))
|
||||
output_np = np.arctan(x_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_atan_fp16():
|
||||
x_np = np.random.rand(4, 2).astype(np.float16) * 10
|
||||
output_ms = P.Atan()(Tensor(x_np))
|
||||
output_np = np.arctan(x_np.astype(np.float32)).astype(np.float16)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
|
Loading…
Reference in New Issue