forked from mindspore-Ecosystem/mindspore
add GPU operator: abs and floor
This commit is contained in:
parent
7233d650f0
commit
f3f9fc958a
|
@ -103,6 +103,35 @@ __global__ void ZeroslikeKernel(T *output, size_t count) {
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void AbsKernel(T *input, T *output, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = abs(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void AbsKernel(half *input, half *output, size_t count) {
|
||||
half zero = 0.0;
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = input[i] < zero ? -input[i] : input[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void FloorKernel(T *input, T *output, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = floor(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void FloorKernel(half *input, half *output, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = hfloor(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
|
||||
ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -147,6 +176,16 @@ void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) {
|
|||
ZeroslikeKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
|
||||
AbsKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
|
||||
FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template void Exponential<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Logarithm<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -156,6 +195,8 @@ template void Square<float>(float *input, float *output, size_t count, cudaStrea
|
|||
template void Sqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Rsqrt<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Zeroslike<float>(float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Exponential<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Logarithm<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Negative<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -164,3 +205,5 @@ template void Square<half>(half *input, half *output, size_t count, cudaStream_t
|
|||
template void Sqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Rsqrt<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Zeroslike<half>(half *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -34,5 +34,9 @@ template <typename T>
|
|||
void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Abs(T *input, T *output, size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Floor(T *input, T *output, size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
|
||||
|
|
|
@ -46,5 +46,13 @@ MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,6 +36,8 @@ enum UnaryOptype {
|
|||
UNARY_OP_SQUARE,
|
||||
UNARY_OP_SQRT,
|
||||
UNARY_OP_RSQRT,
|
||||
UNARY_OP_ABS,
|
||||
UNARY_OP_FLOOR,
|
||||
UNARY_OP_INVALID_TYPE = 255
|
||||
};
|
||||
static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP},
|
||||
|
@ -45,7 +47,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
|
|||
{"ZerosLike", UNARY_OP_ZEROSLIKE},
|
||||
{"Square", UNARY_OP_SQUARE},
|
||||
{"Sqrt", UNARY_OP_SQRT},
|
||||
{"Rsqrt", UNARY_OP_RSQRT}};
|
||||
{"Rsqrt", UNARY_OP_RSQRT},
|
||||
{"Abs", UNARY_OP_ABS},
|
||||
{"Floor", UNARY_OP_FLOOR}};
|
||||
template <typename T>
|
||||
class UnaryOpGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -100,6 +104,14 @@ class UnaryOpGpuKernel : public GpuKernel {
|
|||
Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
case UNARY_OP_ABS: {
|
||||
Abs(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_FLOOR: {
|
||||
Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue