From 3e2a0959609ab4e54e9fa9a9e2e31395e87cce35 Mon Sep 17 00:00:00 2001 From: minara Date: Fri, 8 Jul 2022 18:15:17 +0000 Subject: [PATCH] precision issue-Trunc --- .../device/cpu/kernel/trunc_cpu_kernel.cc | 26 +++++-- mindspore/core/ops/trunc.cc | 2 +- .../python/mindspore/ops/function/__init__.py | 1 + .../mindspore/ops/function/math_func.py | 28 +++++++ tests/st/ops/cpu/test_trunc_op.py | 73 +++++++++++++++---- 5 files changed, 108 insertions(+), 22 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/trunc_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/trunc_cpu_kernel.cc index c514f122d73..cc6d1487f05 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/trunc_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/trunc_cpu_kernel.cc @@ -28,10 +28,15 @@ constexpr size_t kTruncOutputsNum = 1; template void Trunc(const T *in0, T *out0, size_t start, size_t end) { for (size_t index = start; index < end; index++) { - auto retp = floor(in0[index]); - auto retn = ceil(in0[index]); int ind = static_cast(in0[index]); - out0[index] = (ind < 0) ? retn : retp; + if (std::is_same_v) { + out0[index] = in0[index]; + } else { + auto absvalue1 = (in0[index]) * (in0[index]); + auto absvalue = sqrt(absvalue1); + auto retp = floor(absvalue); + out0[index] = (ind < 0) ? -retp : retp; + } } } } // namespace @@ -56,6 +61,14 @@ bool TruncCpuKernelMod::Launch(const std::vector &inputs, ret = LaunchKernel(inputs, outputs); } else if (dtype_ == kNumberTypeFloat32) { ret = LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + ret = LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt8) { + ret = LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeUInt8) { + ret = LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt32) { + ret = LaunchKernel(inputs, outputs); } else { MS_EXCEPTION(TypeError) << "Unsupported input data type for operator [" << kernel_name_ << "]: " << TypeIdToType(dtype_)->ToString(); @@ -75,8 +88,11 @@ bool TruncCpuKernelMod::LaunchKernel(const std::vector &inputs, cons std::vector TruncCpuKernelMod::GetOpSupport() { static std::vector support_list = { KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}; + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32)}; return support_list; } diff --git a/mindspore/core/ops/trunc.cc b/mindspore/core/ops/trunc.cc index 7f98dc95f0c..4ca683e7c24 100644 --- a/mindspore/core/ops/trunc.cc +++ b/mindspore/core/ops/trunc.cc @@ -38,7 +38,7 @@ abstract::ShapePtr TruncInferShape(const PrimitivePtr &primitive, const std::vec TypePtr TruncInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim->name()); - std::set check_list = {kFloat16, kFloat32, kInt8, kInt32, kUInt8}; + std::set check_list = {kFloat16, kFloat32, kInt8, kInt32, kUInt8, kFloat64}; auto input_type = input_args[0]->BuildType(); CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_type, check_list, prim->name()); return input_type; diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 5f48640ccbc..d9601154409 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -225,6 +225,7 @@ from .math_func import ( rad2deg, truncate_div, truncate_mod, + trunc, gumbel_softmax, matmul, baddbmm, diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 835085250cc..c11e9a8aa0c 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -148,6 +148,7 @@ log_matrix_determinant_ = P.LogMatrixDeterminant() exp2_ = P.Pow() truncate_div_ = P.TruncateDiv() truncate_mod_ = P.TruncateMod() +trunc_ = P.Trunc() sparse_segment_mean_ = SparseSegmentMean() xlogy_ = P.Xlogy() @@ -2220,6 +2221,32 @@ def truncate_mod(x, y): return truncate_mod_(x, y) +def trunc(x): + r""" + Returns a new tensor with the truncated integer values of the elements of input. + + Args: + - **x** (Tensor) - Input_x is a tensor. + + Returns: + Tensor, the same shape and data type as the input. + + Raises: + TypeError: If `input_x` is not a Tensor. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32) + >>> trunc = ops.Trunc() + >>> output = trunc(x) + >>> print(output) + [ 3. 0. 0. -3.] + """ + return trunc_(x) + + def ldexp(x, other): """ Multiplies input by 2**:attr:other. @@ -4797,6 +4824,7 @@ __all__ = [ 'rad2deg', 'truncate_div', 'truncate_mod', + 'trunc', 'gumbel_softmax', 'matmul', 'baddbmm', diff --git a/tests/st/ops/cpu/test_trunc_op.py b/tests/st/ops/cpu/test_trunc_op.py index ef556ed7e0c..62448b120df 100644 --- a/tests/st/ops/cpu/test_trunc_op.py +++ b/tests/st/ops/cpu/test_trunc_op.py @@ -14,12 +14,15 @@ # ============================================================================ import numpy as np -import mindspore.context as context +import pytest + +from mindspore import context import mindspore.nn as nn from mindspore import Tensor from mindspore.common import dtype as mstype from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") @@ -32,23 +35,61 @@ class Net(nn.Cell): return self.trunc(x0) -def test32_net(): - x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3]), mstype.float32) - uniq = Net() - output = uniq(x) - print("x:\n", output) - expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1] - print("expected_x:\n", expect_x_result) - - assert (output.asnumpy() == expect_x_result).all() - - +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test16_net(): - x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3]), mstype.float16) + x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3, -0.4]), mstype.float16) uniq = Net() output = uniq(x) - print("x:\n", output) - expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1] - print("expected_x:\n", expect_x_result) + expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1, -0] + + assert (output.asnumpy() == expect_x_result).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test32_net(): + x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3, -0.4]), mstype.float32) + uniq = Net() + output = uniq(x) + expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1, -0] + + assert (output.asnumpy() == expect_x_result).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def testint8_net(): + x = Tensor(np.array([1, -2, 5, 2, 0, -1, 2, -1, -0]), mstype.int8) + uniq = Net() + output = uniq(x) + expect_x_result = [1, -2, 5, 2, 0, -1, 2, -1, -0] + + assert (output.asnumpy() == expect_x_result).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def testuint8_net(): + x = Tensor(np.array([1, 5, 2, 0]), mstype.uint8) + uniq = Net() + output = uniq(x) + expect_x_result = [1, 5, 2, 0] + + assert (output.asnumpy() == expect_x_result).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def testint32_net(): + x = Tensor(np.array([1, -2, 5, 2, 0, -1, 2, -1, -0]), mstype.int32) + uniq = Net() + output = uniq(x) + expect_x_result = [1, -2, 5, 2, 0, -1, 2, -1, -0] assert (output.asnumpy() == expect_x_result).all()