add ceil kernel
This commit is contained in:
parent
9226c57755
commit
e727deddce
|
@ -21,6 +21,7 @@
|
|||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "nnacl/fp32/arithmetic_self_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -28,6 +29,7 @@ namespace {
|
|||
constexpr auto kReal = "Real";
|
||||
constexpr auto kImag = "Imag";
|
||||
constexpr auto kConj = "Conj";
|
||||
constexpr auto kCeil = "Ceil";
|
||||
|
||||
template <typename T, typename S>
|
||||
class UnaryOpCpuKernelFunc : public CpuKernelFunc {
|
||||
|
@ -45,6 +47,12 @@ class UnaryOpCpuKernelFunc : public CpuKernelFunc {
|
|||
std::string kernel_name_;
|
||||
};
|
||||
|
||||
void Ceil(const float *input, float *output, size_t start, size_t end) {
|
||||
auto input_s = input + start;
|
||||
auto output_s = output + start;
|
||||
ElementCeil(input_s, output_s, (end - start));
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Real(const T *input, S *output, size_t start, size_t end) {
|
||||
if (!std::is_same<S, complex64>::value && !std::is_same<S, complex128>::value) {
|
||||
|
@ -81,6 +89,16 @@ void Conj(const T *input, S *output, size_t start, size_t end) {
|
|||
|
||||
template <typename T, typename S>
|
||||
void UnaryOpCpuKernelFunc<T, S>::GetUnaryOpFunc() {
|
||||
// only support float
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
static std::map<std::string, UnaryOpFunc> kFloatSupportedMap = {{prim::kPrimCeil->name(), &Ceil}};
|
||||
auto iter = kFloatSupportedMap.find(kernel_name_);
|
||||
if (iter != kFloatSupportedMap.end()) {
|
||||
unary_op_func_ = iter->second;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<T, complex64>::value || std::is_same<T, complex128>::value) {
|
||||
static std::map<std::string, UnaryOpFunc> kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real<T, S>},
|
||||
{prim::kPrimImag->name(), &Imag<T, S>},
|
||||
|
@ -147,6 +165,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpCpuFuncCreator>>>
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeUnaryFunc<double, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc<bool, bool>}}},
|
||||
{kCeil,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeUnaryFunc<float, float>}}},
|
||||
{kImag,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeUnaryFunc<complex128, double>},
|
||||
|
@ -229,5 +250,7 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Imag,
|
|||
[]() { return std::make_shared<UnaryOpCpuKernelMod>(kImag); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Conj,
|
||||
[]() { return std::make_shared<UnaryOpCpuKernelMod>(kConj); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Ceil,
|
||||
[]() { return std::make_shared<UnaryOpCpuKernelMod>(kCeil); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,7 @@ enum UnaryOptype {
|
|||
UNARY_OP_ACOSH,
|
||||
UNARY_OP_ABS,
|
||||
UNARY_OP_FLOOR,
|
||||
UNARY_OP_CEIL,
|
||||
UNARY_OP_RINT,
|
||||
UNARY_OP_ROUND,
|
||||
UNARY_OP_SIGN,
|
||||
|
@ -69,10 +70,10 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
|
|||
{"Asin", UNARY_OP_ASIN}, {"ACos", UNARY_OP_ACOS},
|
||||
{"Atan", UNARY_OP_ATAN}, {"Asinh", UNARY_OP_ASINH},
|
||||
{"Acosh", UNARY_OP_ACOSH}, {"Abs", UNARY_OP_ABS},
|
||||
{"Floor", UNARY_OP_FLOOR}, {"Rint", UNARY_OP_RINT},
|
||||
{"Round", UNARY_OP_ROUND}, {"Real", UNARY_OP_REAL},
|
||||
{"Imag", UNARY_OP_IMAG}, {"Sign", UNARY_OP_SIGN},
|
||||
{"Conj", UNARY_OP_CONJ}};
|
||||
{"Floor", UNARY_OP_FLOOR}, {"Ceil", UNARY_OP_CEIL},
|
||||
{"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND},
|
||||
{"Real", UNARY_OP_REAL}, {"Imag", UNARY_OP_IMAG},
|
||||
{"Sign", UNARY_OP_SIGN}, {"Conj", UNARY_OP_CONJ}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryHelperGpuKernel : public GpuKernelHelperBase {
|
||||
|
@ -118,8 +119,9 @@ class UnaryHelperGpuKernel : public GpuKernelHelperBase {
|
|||
{UNARY_OP_ASIN, Asin<T>}, {UNARY_OP_ACOS, ACos<T>},
|
||||
{UNARY_OP_ATAN, Atan<T>}, {UNARY_OP_ASINH, Asinh<T>},
|
||||
{UNARY_OP_ACOSH, Acosh<T>}, {UNARY_OP_ABS, Abs<T>},
|
||||
{UNARY_OP_FLOOR, Floor<T>}, {UNARY_OP_RINT, Rint<T>},
|
||||
{UNARY_OP_ROUND, Round<T>}, {UNARY_OP_SIGN, Sign<T>}};
|
||||
{UNARY_OP_FLOOR, Floor<T>}, {UNARY_OP_CEIL, Ceil<T>},
|
||||
{UNARY_OP_RINT, Rint<T>}, {UNARY_OP_ROUND, Round<T>},
|
||||
{UNARY_OP_SIGN, Sign<T>}};
|
||||
|
||||
auto iter = func_map.find(unary_op_type_);
|
||||
if (iter != func_map.end()) {
|
||||
|
|
|
@ -495,6 +495,28 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count)
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CeilKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = ceilf(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void CeilKernel(const double *input, double *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = ceil(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void CeilKernel(const half *input, half *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = hceil(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void RintKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -753,6 +775,11 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Ceil(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
CeilKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
RintKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -815,6 +842,8 @@ template CUDA_LIB_EXPORT void Abs<double>(const double *input, double *output, c
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<double>(const double *input, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<double>(const double *input, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<double>(const double *input, double *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<double>(const double *input, double *output, const size_t count,
|
||||
|
@ -875,6 +904,8 @@ template CUDA_LIB_EXPORT void Abs<float>(const float *input, float *output, cons
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<float>(const float *input, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<float>(const float *input, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<float>(const float *input, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<float>(const float *input, float *output, const size_t count,
|
||||
|
@ -924,6 +955,7 @@ template CUDA_LIB_EXPORT void Rsqrt<half>(const half *input, half *output, const
|
|||
template CUDA_LIB_EXPORT void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<half>(const half *input, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<half>(const half *input, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
@ -968,6 +1000,7 @@ template CUDA_LIB_EXPORT void Rsqrt<char>(const char *input, char *output, const
|
|||
template CUDA_LIB_EXPORT void Abs<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<char>(const char *input, char *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<char>(const char *input, char *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<char>(const char *input, char *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
@ -1023,6 +1056,8 @@ template CUDA_LIB_EXPORT void Abs<unsigned char>(const unsigned char *input, uns
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<unsigned char>(const unsigned char *input, unsigned char *output,
|
||||
const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<unsigned char>(const unsigned char *input, unsigned char *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<unsigned char>(const unsigned char *input, unsigned char *output,
|
||||
|
@ -1064,6 +1099,7 @@ template CUDA_LIB_EXPORT void Acosh<int>(const int *input, int *output, const si
|
|||
template CUDA_LIB_EXPORT void Rsqrt<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Abs<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Sign<int>(const int *input, int *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -1118,6 +1154,8 @@ template CUDA_LIB_EXPORT void Abs<uint32_t>(const uint32_t *input, uint32_t *out
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<uint32_t>(const uint32_t *input, uint32_t *output, const size_t count,
|
||||
|
@ -1172,6 +1210,8 @@ template CUDA_LIB_EXPORT void Abs<int16_t>(const int16_t *input, int16_t *output
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<int16_t>(const int16_t *input, int16_t *output, const size_t count,
|
||||
|
@ -1226,6 +1266,8 @@ template CUDA_LIB_EXPORT void Abs<uint16_t>(const uint16_t *input, uint16_t *out
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<uint16_t>(const uint16_t *input, uint16_t *output, const size_t count,
|
||||
|
@ -1280,6 +1322,8 @@ template CUDA_LIB_EXPORT void Abs<int64_t>(const int64_t *input, int64_t *output
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<int64_t>(const int64_t *input, int64_t *output, const size_t count,
|
||||
|
@ -1334,6 +1378,8 @@ template CUDA_LIB_EXPORT void Abs<uint64_t>(const uint64_t *input, uint64_t *out
|
|||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Floor<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Ceil<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Rint<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Round<uint64_t>(const uint64_t *input, uint64_t *output, const size_t count,
|
||||
|
|
|
@ -65,6 +65,8 @@ CUDA_LIB_EXPORT void Abs(const T *input, T *output, const size_t count, cudaStre
|
|||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Ceil(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -36,6 +36,7 @@ constexpr auto kErfc = "Erfc";
|
|||
constexpr auto kExp = "Exp";
|
||||
constexpr auto kExpm1 = "Expm1";
|
||||
constexpr auto kFloor = "Floor";
|
||||
constexpr auto kCeil = "Ceil";
|
||||
constexpr auto kLog = "Log";
|
||||
constexpr auto kLog1p = "Log1p";
|
||||
constexpr auto kNeg = "Neg";
|
||||
|
@ -147,6 +148,9 @@ const std::map<std::string, std::vector<std::pair<KernelAttr, UnaryPtrCreatorFun
|
|||
{kFloor,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateUnaryKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateUnaryKernelPtr<half>}}},
|
||||
{kCeil,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateUnaryKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateUnaryKernelPtr<half>}}},
|
||||
{kRint,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateUnaryKernelPtr<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateUnaryKernelPtr<float>},
|
||||
|
@ -226,6 +230,8 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Expm1,
|
|||
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kExpm1); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Floor,
|
||||
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kFloor); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Ceil,
|
||||
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kCeil); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Log, []() { return std::make_shared<UnaryOpGpuKernelMod>(kLog); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Log1p,
|
||||
[]() { return std::make_shared<UnaryOpGpuKernelMod>(kLog1p); });
|
||||
|
|
|
@ -3081,7 +3081,7 @@ class Ceil(PrimitiveWithInfer):
|
|||
TypeError: If dtype of x is not float16 or float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1.1, 2.5, -1.5]), mindspore.float32)
|
||||
|
|
|
@ -121,3 +121,40 @@ def test_pynative_imag():
|
|||
output = P.Imag()(ms_x)
|
||||
expect = np.imag(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_graph_ceil():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for ceil in graph mode cpu backend.
|
||||
Expectation: the result match numpy ceil
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.array([1.1, -2.1]).astype(np.float32))
|
||||
np_x = np.array([1.1, -2.1]).astype(np.float32)
|
||||
output = P.Ceil()(x)
|
||||
expect = np.ceil(np_x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ceil():
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for ceil in pynative mode cpu backend.
|
||||
Expectation: the result match numpy ceil
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
x = Tensor(np.array([1.1, -2.1]).astype(np.float32))
|
||||
np_x = np.array([1.1, -2.1]).astype(np.float32)
|
||||
output = P.Ceil()(x)
|
||||
expect = np.ceil(np_x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetCeil(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetCeil, self).__init__()
|
||||
self.ceil = P.Ceil()
|
||||
|
||||
def construct(self, x):
|
||||
return self.ceil(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ceil_fp32():
|
||||
"""
|
||||
Feature: Ceil gpu kernel
|
||||
Description: test the ceil.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
ceil = NetCeil()
|
||||
x = np.random.rand(3, 8).astype(np.float32)
|
||||
output = ceil(Tensor(x, dtype=dtype.float32))
|
||||
expect = np.ceil(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ceil_fp16():
|
||||
"""
|
||||
Feature: Ceil gpu kernel
|
||||
Description: test the ceil.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
ceil = NetCeil()
|
||||
x = np.random.rand(3, 8).astype(np.float16)
|
||||
output = ceil(Tensor(x, dtype=dtype.float16))
|
||||
expect = np.ceil(x)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
Loading…
Reference in New Issue