add ceil kernel

This commit is contained in:
zhangxuetong 2022-04-29 17:21:51 +08:00
parent 9226c57755
commit e727deddce
8 changed files with 189 additions and 7 deletions

View File

@ -21,6 +21,7 @@
#include <utility>
#include <algorithm>
#include <memory>
#include "nnacl/fp32/arithmetic_self_fp32.h"
namespace mindspore {
namespace kernel {
@ -28,6 +29,7 @@ namespace {
constexpr auto kReal = "Real";
constexpr auto kImag = "Imag";
constexpr auto kConj = "Conj";
constexpr auto kCeil = "Ceil";
template <typename T, typename S>
class UnaryOpCpuKernelFunc : public CpuKernelFunc {
@ -45,6 +47,12 @@ class UnaryOpCpuKernelFunc : public CpuKernelFunc {
std::string kernel_name_;
};
void Ceil(const float *input, float *output, size_t start, size_t end) {
auto input_s = input + start;
auto output_s = output + start;
ElementCeil(input_s, output_s, (end - start));
}
template <typename T, typename S>
void Real(const T *input, S *output, size_t start, size_t end) {
if (!std::is_same<S, complex64>::value && !std::is_same<S, complex128>::value) {
@ -81,6 +89,16 @@ void Conj(const T *input, S *output, size_t start, size_t end) {
template <typename T, typename S>
void UnaryOpCpuKernelFunc<T, S>::GetUnaryOpFunc() {
// only support float
if constexpr (std::is_same<T, float>::value) {
static std::map<std::string, UnaryOpFunc> kFloatSupportedMap = {{prim::kPrimCeil->name(), &Ceil}};
auto iter = kFloatSupportedMap.find(kernel_name_);
if (iter != kFloatSupportedMap.end()) {
unary_op_func_ = iter->second;
return;
}
}
if constexpr (std::is_same<T, complex64>::value || std::is_same<T, complex128>::value) {
static std::map<std::string, UnaryOpFunc> kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real<T, S>},
{prim::kPrimImag->name(), &Imag<T, S>},
@ -147,6 +165,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpCpuFuncCreator>>>
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeUnaryFunc<double, double>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc<bool, bool>}}},
{kCeil,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeUnaryFunc<float, float>}}},
{kImag,
{{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
SpecializeUnaryFunc<complex128, double>},
@ -229,5 +250,7 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Imag,
[]() { return std::make_shared<UnaryOpCpuKernelMod>(kImag); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Conj,
[]() { return std::make_shared<UnaryOpCpuKernelMod>(kConj); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Ceil,
[]() { return std::make_shared<UnaryOpCpuKernelMod>(kCeil); });
} // namespace kernel
} // namespace mindspore

View File

@ -48,6 +48,7 @@ enum UnaryOptype {
UNARY_OP_ACOSH,
UNARY_OP_ABS,
UNARY_OP_FLOOR,
UNARY_OP_CEIL,
UNARY_OP_RINT,
UNARY_OP_ROUND,
UNARY_OP_SIGN,
@ -69,10 +70,10 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
{"Asin", UNARY_OP_ASIN}, {"ACos", UNARY_OP_ACOS},
{"Atan", UNARY_OP_ATAN}, {"Asinh", UNARY_OP_ASINH},
{"Acosh", UNARY_OP_ACOSH}, {"Abs", UNARY_OP_ABS},
{"Floor", UNARY_OP_FLOOR}, {"Rint", UNARY_OP_RINT},
{"Round", UNARY_OP_ROUND}, {"Real", UNARY_OP_REAL},
{"Imag", UNARY_OP_IMAG}, {"Sign", UNARY_OP_SIGN},
{"Conj", UNARY_OP_CONJ}};
{"Floor", UNARY_OP_FLOOR}, {"Ceil", UNARY_OP_CEIL},
{"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND},
{"Real", UNARY_OP_REAL}, {"Imag", UNARY_OP_IMAG},
{"Sign", UNARY_OP_SIGN}, {"Conj", UNARY_OP_CONJ}};
template <typename T>
class UnaryHelperGpuKernel : public GpuKernelHelperBase {
@ -118,8 +119,9 @@ class UnaryHelperGpuKernel : public GpuKernelHelperBase {
{UNARY_OP_ASIN, Asin<T>}, {UNARY_OP_ACOS, ACos<T>},
{UNARY_OP_ATAN, Atan<T>}, {UNARY_OP_ASINH, Asinh<T>},
{UNARY_OP_ACOSH, Acosh<T>}, {UNARY_OP_ABS, Abs<T>},
{UNARY_OP_FLOOR, Floor<T>}, {UNARY_OP_RINT, Rint<T>},
{UNARY_OP_ROUND, Round<T>}, {UNARY_OP_SIGN, Sign<T>}};
{UNARY_OP_FLOOR, Floor<T>}, {UNARY_OP_CEIL, Ceil<T>},
{UNARY_OP_RINT, Rint<T>}, {UNARY_OP_ROUND, Round<T>},
{UNARY_OP_SIGN, Sign<T>}};
auto iter = func_map.find(unary_op_type_);
if (iter != func_map.end()) {

View File

@ -495,6 +495,28 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count)
}
return;
}
template <typename T>
__global__ void CeilKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = ceilf(input[i]);
}
return;
}
template <>
__global__ void CeilKernel(const double *input, double *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = ceil(input[i]);
}
return;
}
template <>
__global__ void CeilKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hceil(input[i]);
}
return;
}
template <typename T>
__global__ void RintKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@ -753,6 +775,11 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
return;
}
template <typename T>
void Ceil(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
CeilKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
RintKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@ -815,6 +842,8 @@ template CUDA_LIB_EXPORT void Abs<double>(const double *input, double *output, c
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<double>(const double *input, double *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<double>(const double *input, double *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<double>(const double *input, double *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<double>(const double *input, double *output, const size_t count,
@ -875,6 +904,8 @@ template CUDA_LIB_EXPORT void Abs<float>(const float *input, float *output, cons
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<float>(const float *input, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<float>(const float *input, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<float>(const float *input, float *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<float>(const float *input, float *output, const size_t count,
@ -924,6 +955,7 @@ template CUDA_LIB_EXPORT void Rsqrt<half>(const half *input, half *output, const
template CUDA_LIB_EXPORT void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<half>(const half *input, half *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<half>(const half *input, half *output, const size_t count,
cudaStream_t cuda_stream);
@ -968,6 +1000,7 @@ template CUDA_LIB_EXPORT void Rsqrt<char>(const char *input, char *output, const
template CUDA_LIB_EXPORT void Abs<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<char>(const char *input, char *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<char>(const char *input, char *output, const size_t count,
cudaStream_t cuda_stream);
@ -1023,6 +1056,8 @@ template CUDA_LIB_EXPORT void Abs<unsigned char>(const unsigned char *input, uns
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<unsigned char>(const unsigned char *input, unsigned char *output,
const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<unsigned char>(const unsigned char *input, unsigned char *output,
@ -1064,6 +1099,7 @@ template CUDA_LIB_EXPORT void Acosh<int>(const int *input, int *output, const si
template CUDA_LIB_EXPORT void Rsqrt<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Abs<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Sign<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
@ -1118,6 +1154,8 @@ template CUDA_LIB_EXPORT void Abs<uint32_t>(const uint32_t *input, uint32_t *out
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
@ -1172,6 +1210,8 @@ template CUDA_LIB_EXPORT void Abs<int16_t>(const int16_t *input, int16_t *output
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<int16_t>(const int16_t *input, int16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<int16_t>(const int16_t *input, int16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<int16_t>(const int16_t *input, int16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<int16_t>(const int16_t *input, int16_t *output, const size_t count,
@ -1226,6 +1266,8 @@ template CUDA_LIB_EXPORT void Abs<uint16_t>(const uint16_t *input, uint16_t *out
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
@ -1280,6 +1322,8 @@ template CUDA_LIB_EXPORT void Abs<int64_t>(const int64_t *input, int64_t *output
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<int64_t>(const int64_t *input, int64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<int64_t>(const int64_t *input, int64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<int64_t>(const int64_t *input, int64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<int64_t>(const int64_t *input, int64_t *output, const size_t count,
@ -1334,6 +1378,8 @@ template CUDA_LIB_EXPORT void Abs<uint64_t>(const uint64_t *input, uint64_t *out
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Floor<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Ceil<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Rint<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Round<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,

View File

@ -65,6 +65,8 @@ CUDA_LIB_EXPORT void Abs(const T *input, T *output, const size_t count, cudaStre
template <typename T>
CUDA_LIB_EXPORT void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Ceil(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -36,6 +36,7 @@ constexpr auto kErfc = "Erfc";
constexpr auto kExp = "Exp";
constexpr auto kExpm1 = "Expm1";
constexpr auto kFloor = "Floor";
constexpr auto kCeil = "Ceil";
constexpr auto kLog = "Log";
constexpr auto kLog1p = "Log1p";
constexpr auto kNeg = "Neg";
@ -147,6 +148,9 @@ const std::map<std::string, std::vector<std::pair<KernelAttr, UnaryPtrCreatorFun
{kFloor,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateUnaryKernelPtr<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateUnaryKernelPtr<half>}}},
{kCeil,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateUnaryKernelPtr<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateUnaryKernelPtr<half>}}},
{kRint,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateUnaryKernelPtr<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateUnaryKernelPtr<float>},
@ -226,6 +230,8 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Expm1,
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kExpm1); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Floor,
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kFloor); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Ceil,
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kCeil); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Log, []() { return std::make_shared<UnaryOpGpuKernelMod>(kLog); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Log1p,
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kLog1p); });

View File

@ -3081,7 +3081,7 @@ class Ceil(PrimitiveWithInfer):
TypeError: If dtype of x is not float16 or float32.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1.1, 2.5, -1.5]), mindspore.float32)

View File

@ -121,3 +121,40 @@ def test_pynative_imag():
output = P.Imag()(ms_x)
expect = np.imag(x)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_graph_ceil():
"""
Feature: ALL TO ALL
Description: test cases for ceil in graph mode cpu backend.
Expectation: the result match numpy ceil
"""
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array([1.1, -2.1]).astype(np.float32))
np_x = np.array([1.1, -2.1]).astype(np.float32)
output = P.Ceil()(x)
expect = np.ceil(np_x)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pynative_ceil():
"""
Feature: ALL TO ALL
Description: test cases for ceil in pynative mode cpu backend.
Expectation: the result match numpy ceil
"""
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array([1.1, -2.1]).astype(np.float32))
np_x = np.array([1.1, -2.1]).astype(np.float32)
output = P.Ceil()(x)
expect = np.ceil(np_x)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,66 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetCeil(nn.Cell):
def __init__(self):
super(NetCeil, self).__init__()
self.ceil = P.Ceil()
def construct(self, x):
return self.ceil(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ceil_fp32():
"""
Feature: Ceil gpu kernel
Description: test the ceil.
Expectation: match to np benchmark.
"""
ceil = NetCeil()
x = np.random.rand(3, 8).astype(np.float32)
output = ceil(Tensor(x, dtype=dtype.float32))
expect = np.ceil(x)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ceil_fp16():
"""
Feature: Ceil gpu kernel
Description: test the ceil.
Expectation: match to np benchmark.
"""
ceil = NetCeil()
x = np.random.rand(3, 8).astype(np.float16)
output = ceil(Tensor(x, dtype=dtype.float16))
expect = np.ceil(x)
assert np.allclose(output.asnumpy(), expect)