forked from mindspore-Ecosystem/mindspore
Add round op for cpu and gpu
This commit is contained in:
parent
94ed3b89a3
commit
9f62227f99
|
@ -109,6 +109,16 @@ void Rint(const T *in, T *out, size_t size) {
|
|||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Round(const T *in, T *out, size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(nearbyint(in[i]));
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Reciprocal(const T *in, T *out, size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
|
@ -251,6 +261,7 @@ static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::k
|
|||
{prim::kPrimSign->name(), SIGN},
|
||||
{prim::kPrimFloor->name(), FLOOR},
|
||||
{prim::kPrimRint->name(), RINT},
|
||||
{prim::kPrimRound->name(), ROUND},
|
||||
{prim::kPrimReciprocal->name(), RECIPROCAL},
|
||||
{prim::kPrimGeLU->name(), GELU},
|
||||
{prim::kPrimAsin->name(), ASIN},
|
||||
|
@ -317,7 +328,7 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
{ATAN, Atan<T>}, {SINH, Sinh<T>},
|
||||
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
|
||||
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>},
|
||||
{RINT, Rint<T>}};
|
||||
{RINT, Rint<T>}, {ROUND, Round<T>}};
|
||||
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
|
||||
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
|
||||
} else {
|
||||
|
|
|
@ -67,6 +67,8 @@ MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput
|
|||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Round, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -114,6 +114,7 @@ enum OperateType {
|
|||
ACOSHGRAD,
|
||||
ATAN2,
|
||||
RINT,
|
||||
ROUND,
|
||||
};
|
||||
|
||||
class CPUKernel : public kernel::KernelMod {
|
||||
|
|
|
@ -17,6 +17,13 @@
|
|||
#include "unary_op_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void ExponentialKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = expf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void ExponentialKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = exp(input[i]);
|
||||
}
|
||||
|
@ -32,7 +39,14 @@ __global__ void ExponentialKernel(const half *input, half *output, const size_t
|
|||
template <typename T>
|
||||
__global__ void Expm1Kernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = static_cast<T>(expm1f(static_cast<float>(input[i])));
|
||||
output[i] = expm1f(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void Expm1Kernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = expm1(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -44,6 +58,13 @@ __global__ void LogarithmKernel(const T *input, T *output, const size_t count) {
|
|||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void LogarithmKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = log(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void LogarithmKernel(const half *input, half *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = hlog(input[i]);
|
||||
|
@ -53,21 +74,42 @@ __global__ void LogarithmKernel(const half *input, half *output, const size_t co
|
|||
template <typename T>
|
||||
__global__ void Log1pKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = static_cast<T>(log1pf(static_cast<float>(input[i])));
|
||||
output[i] = log1pf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void Log1pKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = log1p(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void ErfKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = static_cast<T>(erff(static_cast<float>(input[i])));
|
||||
output[i] = erff(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void ErfKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = erf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void ErfcKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = static_cast<T>(erfcf(static_cast<float>(input[i])));
|
||||
output[i] = erfcf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void ErfcKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = erfc(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -96,6 +138,13 @@ __global__ void SquareKernel(const T *input, T *output, const size_t count) {
|
|||
}
|
||||
template <typename T>
|
||||
__global__ void SqrtKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = sqrtf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void SqrtKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = sqrt(input[i]);
|
||||
}
|
||||
|
@ -110,6 +159,13 @@ __global__ void SqrtKernel(const half *input, half *output, const size_t count)
|
|||
}
|
||||
template <typename T>
|
||||
__global__ void RsqrtKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = rsqrtf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void RsqrtKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = rsqrt(input[i]);
|
||||
}
|
||||
|
@ -124,6 +180,13 @@ __global__ void RsqrtKernel(const half *input, half *output, const size_t count)
|
|||
}
|
||||
template <typename T>
|
||||
__global__ void SinKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = sinf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void SinKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = sin(input[i]);
|
||||
}
|
||||
|
@ -139,23 +202,40 @@ __global__ void SinKernel(const half *input, half *output, const size_t count) {
|
|||
template <typename T>
|
||||
__global__ void AsinKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(asinf(inputf));
|
||||
output[i] = res;
|
||||
output[i] = asinf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void AsinKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = asin(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AsinhKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(asinhf(inputf));
|
||||
output[i] = res;
|
||||
output[i] = asinhf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void AsinhKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = asinh(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void CosKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = cosf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void CosKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = cos(input[i]);
|
||||
}
|
||||
|
@ -171,27 +251,42 @@ __global__ void CosKernel(const half *input, half *output, const size_t count) {
|
|||
template <typename T>
|
||||
__global__ void ACosKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(acosf(inputf));
|
||||
output[i] = res;
|
||||
output[i] = acosf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void ACosKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = acos(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AcoshKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(acoshf(inputf));
|
||||
output[i] = res;
|
||||
output[i] = acoshf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void AcoshKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = acosh(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AtanKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
T res = static_cast<T>(atanf(inputf));
|
||||
output[i] = res;
|
||||
output[i] = atanf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void AtanKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = atan(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -212,6 +307,13 @@ __global__ void AbsKernel(const half *input, half *output, const size_t count) {
|
|||
}
|
||||
template <typename T>
|
||||
__global__ void FloorKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = floorf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void FloorKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = floor(input[i]);
|
||||
}
|
||||
|
@ -226,6 +328,13 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count)
|
|||
}
|
||||
template <typename T>
|
||||
__global__ void RintKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = rintf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void RintKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = rint(input[i]);
|
||||
}
|
||||
|
@ -239,6 +348,20 @@ __global__ void RintKernel(const half *input, half *output, const size_t count)
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void RoundKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = nearbyintf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void RoundKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = nearbyint(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -348,6 +471,11 @@ void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_strea
|
|||
RintKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
RoundKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
// double
|
||||
template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -371,6 +499,7 @@ template void Rsqrt<double>(const double *input, double *output, const size_t co
|
|||
template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
|
||||
// float
|
||||
|
@ -395,6 +524,7 @@ template void Rsqrt<float>(const float *input, float *output, const size_t count
|
|||
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
// half
|
||||
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -418,3 +548,28 @@ template void Rsqrt<half>(const half *input, half *output, const size_t count, c
|
|||
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
// int32
|
||||
template void Exponential<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Expm1<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Logarithm<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Log1p<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Erf<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Erfc<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Negative<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Reciprocal<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Square<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Sqrt<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Sin<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Cos<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Asin<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void ACos<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Atan<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Asinh<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Acosh<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rsqrt<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Round<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -60,5 +60,7 @@ template <typename T>
|
|||
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
|
||||
|
|
|
@ -114,5 +114,13 @@ MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnaryOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,6 +49,7 @@ enum UnaryOptype {
|
|||
UNARY_OP_ABS,
|
||||
UNARY_OP_FLOOR,
|
||||
UNARY_OP_RINT,
|
||||
UNARY_OP_ROUND,
|
||||
UNARY_OP_INVALID_TYPE = 255
|
||||
};
|
||||
|
||||
|
@ -63,7 +64,7 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
|
|||
{"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN},
|
||||
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
|
||||
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR},
|
||||
{"Rint", UNARY_OP_RINT}};
|
||||
{"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryOpGpuKernel : public GpuKernel {
|
||||
|
@ -165,6 +166,10 @@ class UnaryOpGpuKernel : public GpuKernel {
|
|||
Rint(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_ROUND: {
|
||||
Round(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
|
||||
}
|
||||
|
|
|
@ -3919,7 +3919,7 @@ class Round(PrimitiveWithInfer):
|
|||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([0.8, 1.5, 2.3, 2.5, -4.5]), mindspore.float32)
|
||||
|
|
|
@ -41,6 +41,15 @@ class FloorNet(nn.Cell):
|
|||
return self.floor(x)
|
||||
|
||||
|
||||
class RoundNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RoundNet, self).__init__()
|
||||
self.round = P.Round()
|
||||
|
||||
def construct(self, x):
|
||||
return self.round(x)
|
||||
|
||||
|
||||
class ReciprocalNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ReciprocalNet, self).__init__()
|
||||
|
@ -144,6 +153,20 @@ def test_rint():
|
|||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
def test_round():
|
||||
net = RoundNet()
|
||||
|
||||
x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(np.float16)
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.round(x).astype(np.float16)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(np.float32)
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.round(x).astype(np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.round = ops.Round()
|
||||
|
||||
def construct(self, x):
|
||||
return self.round(x)
|
||||
|
||||
|
||||
def generate_testcases(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(nptype)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = np.round(x).astype(nptype)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(nptype)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = np.round(x).astype(nptype)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sign_float32():
|
||||
generate_testcases(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sign_float16():
|
||||
generate_testcases(np.float16)
|
Loading…
Reference in New Issue