forked from mindspore-Ecosystem/mindspore
!15647 Add rint op for cpu and gpu
From: @xcnick Reviewed-by: @liangchenghui,@tom__chen Signed-off-by: @liangchenghui
This commit is contained in:
commit
94ed3b89a3
|
@ -99,6 +99,16 @@ void Floor(const T *in, T *out, size_t size) {
|
|||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Rint(const T *in, T *out, size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(rint(in[i]));
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Reciprocal(const T *in, T *out, size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
|
@ -240,6 +250,7 @@ static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::k
|
|||
{prim::kPrimLogicalNot->name(), LOGICALNOT},
|
||||
{prim::kPrimSign->name(), SIGN},
|
||||
{prim::kPrimFloor->name(), FLOOR},
|
||||
{prim::kPrimRint->name(), RINT},
|
||||
{prim::kPrimReciprocal->name(), RECIPROCAL},
|
||||
{prim::kPrimGeLU->name(), GELU},
|
||||
{prim::kPrimAsin->name(), ASIN},
|
||||
|
@ -305,7 +316,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
{ASIN, Asin<T>}, {ACOS, ACos<T>},
|
||||
{ATAN, Atan<T>}, {SINH, Sinh<T>},
|
||||
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
|
||||
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>}};
|
||||
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>},
|
||||
{RINT, Rint<T>}};
|
||||
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
|
||||
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
|
||||
} else {
|
||||
|
|
|
@ -65,6 +65,8 @@ MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt
|
|||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCPUKernel);
|
||||
MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -113,6 +113,7 @@ enum OperateType {
|
|||
ASINHGRAD,
|
||||
ACOSHGRAD,
|
||||
ATAN2,
|
||||
RINT,
|
||||
};
|
||||
|
||||
class CPUKernel : public kernel::KernelMod {
|
||||
|
|
|
@ -225,6 +225,20 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count)
|
|||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void RintKernel(const T *input, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = rint(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <>
|
||||
__global__ void RintKernel(const half *input, half *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = hrint(input[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
|
@ -329,6 +343,11 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
|
|||
FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
RintKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
// double
|
||||
template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -351,6 +370,7 @@ template void Acosh<double>(const double *input, double *output, const size_t co
|
|||
template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
|
||||
// float
|
||||
|
@ -374,6 +394,7 @@ template void Acosh<float>(const float *input, float *output, const size_t count
|
|||
template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
// half
|
||||
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
@ -396,3 +417,4 @@ template void Acosh<half>(const half *input, half *output, const size_t count, c
|
|||
template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -58,5 +58,7 @@ template <typename T>
|
|||
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
|
||||
|
|
|
@ -108,5 +108,11 @@ MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
UnaryOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,7 @@ enum UnaryOptype {
|
|||
UNARY_OP_ACOSH,
|
||||
UNARY_OP_ABS,
|
||||
UNARY_OP_FLOOR,
|
||||
UNARY_OP_RINT,
|
||||
UNARY_OP_INVALID_TYPE = 255
|
||||
};
|
||||
|
||||
|
@ -61,7 +62,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
|
|||
{"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN},
|
||||
{"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN},
|
||||
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
|
||||
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}};
|
||||
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR},
|
||||
{"Rint", UNARY_OP_RINT}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryOpGpuKernel : public GpuKernel {
|
||||
|
@ -159,6 +161,10 @@ class UnaryOpGpuKernel : public GpuKernel {
|
|||
Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_RINT: {
|
||||
Rint(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
|
||||
}
|
||||
|
|
|
@ -395,6 +395,7 @@ inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"
|
|||
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
|
||||
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
|
||||
inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs");
|
||||
inline const PrimitivePtr kPrimRint = std::make_shared<Primitive>("Rint");
|
||||
inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
|
||||
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
|
||||
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
|
||||
|
|
|
@ -2705,7 +2705,7 @@ class Rint(PrimitiveWithInfer):
|
|||
TypeError: If dtype of `input_x` is neither float16 nor float32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)
|
||||
|
|
|
@ -50,6 +50,15 @@ class ReciprocalNet(nn.Cell):
|
|||
return self.reciprocal(x)
|
||||
|
||||
|
||||
class RintNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RintNet, self).__init__()
|
||||
self.rint = P.Rint()
|
||||
|
||||
def construct(self, x):
|
||||
return self.rint(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -118,6 +127,23 @@ def test_floor():
|
|||
assert np.all(output.asnumpy() == expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_rint():
|
||||
net = RintNet()
|
||||
prop = 100 if np.random.random() > 0.5 else -100
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.rint(x).astype(np.float16)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
x = np.random.randn(3, 4, 5, 6).astype(np.float32) * prop
|
||||
output = net(Tensor(x))
|
||||
expect_output = np.rint(x).astype(np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -137,7 +163,3 @@ def test_reciprocal():
|
|||
diff = output.asnumpy() - expect_output
|
||||
error = np.ones(shape=expect_output.shape) * 1.0e-5
|
||||
assert np.all(np.abs(diff) < error)
|
||||
|
||||
test_square()
|
||||
test_floor()
|
||||
test_reciprocal()
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.rint = ops.Rint()
|
||||
|
||||
def construct(self, x):
|
||||
return self.rint(x)
|
||||
|
||||
|
||||
def generate_testcases(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = np.rint(x).astype(nptype)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = np.rint(x).astype(nptype)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sign_float32():
|
||||
generate_testcases(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sign_float16():
|
||||
generate_testcases(np.float16)
|
Loading…
Reference in New Issue