!15647 Add rint op for cpu and gpu

From: @xcnick
Reviewed-by: @liangchenghui,@tom__chen
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-04-26 23:22:24 +08:00 committed by Gitee
commit 94ed3b89a3
11 changed files with 141 additions and 7 deletions

View File

@ -99,6 +99,16 @@ void Floor(const T *in, T *out, size_t size) {
CPUKernelUtils::ParallelFor(task, size);
}
template <typename T>
void Rint(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(rint(in[i]));
}
};
CPUKernelUtils::ParallelFor(task, size);
}
template <typename T>
void Reciprocal(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
@ -240,6 +250,7 @@ static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::k
{prim::kPrimLogicalNot->name(), LOGICALNOT},
{prim::kPrimSign->name(), SIGN},
{prim::kPrimFloor->name(), FLOOR},
{prim::kPrimRint->name(), RINT},
{prim::kPrimReciprocal->name(), RECIPROCAL},
{prim::kPrimGeLU->name(), GELU},
{prim::kPrimAsin->name(), ASIN},
@ -305,7 +316,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
{ASIN, Asin<T>}, {ACOS, ACos<T>},
{ATAN, Atan<T>}, {SINH, Sinh<T>},
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>}};
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>},
{RINT, Rint<T>}};
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
} else {

View File

@ -65,6 +65,8 @@ MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -113,6 +113,7 @@ enum OperateType {
ASINHGRAD,
ACOSHGRAD,
ATAN2,
RINT,
};
class CPUKernel : public kernel::KernelMod {

View File

@ -225,6 +225,20 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count)
return;
}
template <typename T>
__global__ void RintKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = rint(input[i]);
}
return;
}
template <>
__global__ void RintKernel(const half *input, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = hrint(input[i]);
}
return;
}
template <typename T>
void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
ExponentialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@ -329,6 +343,11 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
FloorKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
RintKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
// double
template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
@ -351,6 +370,7 @@ template void Acosh<double>(const double *input, double *output, const size_t co
template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Rint<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
// float
@ -374,6 +394,7 @@ template void Acosh<float>(const float *input, float *output, const size_t count
template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Rint<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
// half
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
@ -396,3 +417,4 @@ template void Acosh<half>(const half *input, half *output, const size_t count, c
template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Rint<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -58,5 +58,7 @@ template <typename T>
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_

View File

@ -108,5 +108,11 @@ MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -48,6 +48,7 @@ enum UnaryOptype {
UNARY_OP_ACOSH,
UNARY_OP_ABS,
UNARY_OP_FLOOR,
UNARY_OP_RINT,
UNARY_OP_INVALID_TYPE = 255
};
@ -61,7 +62,8 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {
{"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN},
{"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN},
{"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH},
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}};
{"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR},
{"Rint", UNARY_OP_RINT}};
template <typename T>
class UnaryOpGpuKernel : public GpuKernel {
@ -159,6 +161,10 @@ class UnaryOpGpuKernel : public GpuKernel {
Floor(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_RINT: {
Rint(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
default: {
MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported.";
}

View File

@ -395,6 +395,7 @@ inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs");
inline const PrimitivePtr kPrimRint = std::make_shared<Primitive>("Rint");
inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");

View File

@ -2705,7 +2705,7 @@ class Rint(PrimitiveWithInfer):
TypeError: If dtype of `input_x` is neither float16 nor float32.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)

View File

@ -50,6 +50,15 @@ class ReciprocalNet(nn.Cell):
return self.reciprocal(x)
class RintNet(nn.Cell):
def __init__(self):
super(RintNet, self).__init__()
self.rint = P.Rint()
def construct(self, x):
return self.rint(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -118,6 +127,23 @@ def test_floor():
assert np.all(output.asnumpy() == expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_rint():
net = RintNet()
prop = 100 if np.random.random() > 0.5 else -100
x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop
output = net(Tensor(x))
expect_output = np.rint(x).astype(np.float16)
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
x = np.random.randn(3, 4, 5, 6).astype(np.float32) * prop
output = net(Tensor(x))
expect_output = np.rint(x).astype(np.float32)
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -137,7 +163,3 @@ def test_reciprocal():
diff = output.asnumpy() - expect_output
error = np.ones(shape=expect_output.shape) * 1.0e-5
assert np.all(np.abs(diff) < error)
test_square()
test_floor()
test_reciprocal()

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, ops
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.rint = ops.Rint()
def construct(self, x):
return self.rint(x)
def generate_testcases(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype)
net = Net()
output = net(Tensor(x))
expect = np.rint(x).astype(nptype)
np.testing.assert_almost_equal(output.asnumpy(), expect)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]).astype(nptype)
net = Net()
output = net(Tensor(x))
expect = np.rint(x).astype(nptype)
np.testing.assert_almost_equal(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sign_float32():
generate_testcases(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sign_float16():
generate_testcases(np.float16)