support asin and acos with dtype float on gpu

This commit is contained in:
zhouyuanshen 2020-10-19 18:31:46 +08:00
parent 92b1e7e2ba
commit f49bd92b88
6 changed files with 179 additions and 69 deletions

View File

@ -16,35 +16,35 @@
#include "unary_op_impl.cuh"
template <typename T>
__global__ void ExponentialKernel(T *input, T *output, size_t count) {
__global__ void ExponentialKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = exp(input[i]);
}
return;
}
template <>
__global__ void ExponentialKernel(half *input, half *output, size_t count) {
__global__ void ExponentialKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hexp(input[i]);
}
return;
}
template <typename T>
__global__ void LogarithmKernel(T *input, T *output, size_t count) {
__global__ void LogarithmKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = logf(input[i]);
}
return;
}
template <>
__global__ void LogarithmKernel(half *input, half *output, size_t count) {
__global__ void LogarithmKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hlog(input[i]);
}
return;
}
template <typename T>
__global__ void NegativeKernel(T *input, T *output, size_t count) {
__global__ void NegativeKernel(const T *input, T *output, const size_t count) {
T neg_one = -1;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = neg_one * input[i];
@ -52,7 +52,7 @@ __global__ void NegativeKernel(T *input, T *output, size_t count) {
return;
}
template <typename T>
__global__ void ReciprocalKernel(T *input, T *output, size_t count) {
__global__ void ReciprocalKernel(const T *input, T *output, const size_t count) {
T one = 1.0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = one / input[i];
@ -60,70 +60,84 @@ __global__ void ReciprocalKernel(T *input, T *output, size_t count) {
return;
}
template <typename T>
__global__ void SquareKernel(T *input, T *output, size_t count) {
__global__ void SquareKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = input[i] * input[i];
}
return;
}
template <typename T>
__global__ void SqrtKernel(T *input, T *output, size_t count) {
__global__ void SqrtKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = sqrt(input[i]);
}
return;
}
template <>
__global__ void SqrtKernel(half *input, half *output, size_t count) {
__global__ void SqrtKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hsqrt(input[i]);
}
return;
}
template <typename T>
__global__ void RsqrtKernel(T *input, T *output, size_t count) {
__global__ void RsqrtKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = rsqrt(input[i]);
}
return;
}
template <>
__global__ void RsqrtKernel(half *input, half *output, size_t count) {
__global__ void RsqrtKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hrsqrt(input[i]);
}
return;
}
template <typename T>
__global__ void SinKernel(T *input, T *output, size_t count) {
__global__ void SinKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = sin(input[i]);
}
return;
}
template <>
__global__ void SinKernel(half *input, half *output, size_t count) {
__global__ void SinKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hsin(input[i]);
}
return;
}
template <typename T>
__global__ void CosKernel(T *input, T *output, size_t count) {
__global__ void AsinKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = asinf(input[i]);
}
return;
}
template <typename T>
__global__ void CosKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = cos(input[i]);
}
return;
}
template <>
__global__ void CosKernel(half *input, half *output, size_t count) {
__global__ void CosKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hcos(input[i]);
}
return;
}
template <typename T>
__global__ void ZeroslikeKernel(T *output, size_t count) {
__global__ void ACosKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = acosf(input[i]);
}
return;
}
template <typename T>
__global__ void ZeroslikeKernel(T *output, const size_t count) {
T zero = 0.0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = zero;
@ -131,14 +145,14 @@ __global__ void ZeroslikeKernel(T *output, size_t count) {
return;
}
template <typename T>
__global__ void AbsKernel(T *input, T *output, size_t count) {
__global__ void AbsKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = abs(input[i]);
}
return;
}
template <>
__global__ void AbsKernel(half *input, half *output, size_t count) {
__global__ void AbsKernel(const half *input, half *output, const size_t count) {
half zero = 0.0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = input[i] < zero ? -input[i] : input[i];
@ -146,106 +160,120 @@ __global__ void AbsKernel(half *input, half *output, size_t count) {
return;
}
template <typename T>
__global__ void FloorKernel(T *input, T *output, size_t count) {
__global__ void FloorKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = floor(input[i]);
}
return;
}
template <>
__global__ void FloorKernel(half *input, half *output, size_t count) {
__global__ void FloorKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hfloor(input[i]);
}
return;
}
template <typename T>
void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Logarithm(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
LogarithmKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
NegativeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Reciprocal(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ReciprocalKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Square(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
SquareKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Pow(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Pow(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
PowKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Sqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
SqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Sin(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Sin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
SinKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Cos(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Cos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
CosKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Asin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
AsinKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ACosKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
RsqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) {
void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream) {
ZeroslikeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(output, count);
return;
}
template <typename T>
void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
AbsKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template void Exponential<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Logarithm<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Negative<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Reciprocal<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Square<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Sqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Sin<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Cos<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Rsqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Zeroslike<float>(float *output, size_t count, cudaStream_t cuda_stream);
template void Abs<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Floor<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
template void Exponential<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Logarithm<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Negative<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Reciprocal<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Square<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Sqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Sin<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Cos<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Rsqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Zeroslike<half>(half *output, size_t count, cudaStream_t cuda_stream);
template void Abs<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Floor<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Square<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Sin<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Cos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Asin<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void ACos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Square<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Sin<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Cos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Asin<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void ACos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<half>(half *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -19,28 +19,32 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Logarithm(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Logarithm(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Negative(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Reciprocal(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Square(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Sqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Sin(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Sin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Cos(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Cos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream);
void Asin(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream);
void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_

View File

@ -54,10 +54,14 @@ MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -38,6 +38,8 @@ enum UnaryOptype {
UNARY_OP_RSQRT,
UNARY_OP_SIN,
UNARY_OP_COS,
UNARY_OP_ASIN,
UNARY_OP_ACOS,
UNARY_OP_ABS,
UNARY_OP_FLOOR,
UNARY_OP_INVALID_TYPE = 255
@ -52,6 +54,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
{"Rsqrt", UNARY_OP_RSQRT},
{"Sin", UNARY_OP_SIN},
{"Cos", UNARY_OP_COS},
{"Asin", UNARY_OP_ASIN},
{"ACos", UNARY_OP_ACOS},
{"Abs", UNARY_OP_ABS},
{"Floor", UNARY_OP_FLOOR}};
template <typename T>
@ -112,6 +116,14 @@ class UnaryOpGpuKernel : public GpuKernel {
Cos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ASIN: {
Asin(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ACOS: {
ACos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ZEROSLIKE: {
Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acos_fp32():
x_np = np.array([0.74, 0.04, 0.30, 0.56]).astype(np.float32)
output_ms = P.ACos()(Tensor(x_np))
output_np = np.arccos(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asin_fp32():
x_np = np.array([0.74, 0.04, 0.30, 0.56]).astype(np.float32)
output_ms = P.Asin()(Tensor(x_np))
output_np = np.arcsin(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)