!37686 Trunc op precision issue fixed

Merge pull request !37686 from Mina/trunc_issue-fix
This commit is contained in:
i-robot 2022-07-28 06:53:42 +00:00 committed by Gitee
commit 4e3c39b572
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 107 additions and 21 deletions

View File

@ -28,10 +28,15 @@ constexpr size_t kTruncOutputsNum = 1;
template <typename T>
void Trunc(const T *in0, T *out0, size_t start, size_t end) {
for (size_t index = start; index < end; index++) {
auto retp = floor(in0[index]);
auto retn = ceil(in0[index]);
int ind = static_cast<int>(in0[index]);
out0[index] = (ind < 0) ? retn : retp;
if (std::is_same_v<T, std::uint8_t>) {
out0[index] = in0[index];
} else {
auto absvalue1 = (in0[index]) * (in0[index]);
auto absvalue = sqrt(absvalue1);
auto retp = floor(absvalue);
out0[index] = (ind < 0) ? -retp : retp;
}
}
}
} // namespace
@ -56,6 +61,14 @@ bool TruncCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
ret = LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
ret = LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
ret = LaunchKernel<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt8) {
ret = LaunchKernel<int8_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeUInt8) {
ret = LaunchKernel<uint8_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32) {
ret = LaunchKernel<int32_t>(inputs, outputs);
} else {
MS_EXCEPTION(TypeError) << "Unsupported input data type for operator [" << kernel_name_
<< "]: " << TypeIdToType(dtype_)->ToString();
@ -75,8 +88,11 @@ bool TruncCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, cons
std::vector<KernelAttr> TruncCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32)};
return support_list;
}

View File

@ -233,6 +233,7 @@ from .math_func import (
rad2deg,
truncate_div,
truncate_mod,
trunc,
gumbel_softmax,
matmul,
baddbmm,

View File

@ -149,6 +149,7 @@ log_matrix_determinant_ = P.LogMatrixDeterminant()
exp2_ = P.Pow()
truncate_div_ = P.TruncateDiv()
truncate_mod_ = P.TruncateMod()
trunc_ = P.Trunc()
sparse_segment_mean_ = SparseSegmentMean()
xlogy_ = P.Xlogy()
@ -2300,6 +2301,32 @@ def truncate_mod(x, y):
return truncate_mod_(x, y)
def trunc(x):
r"""
Returns a new tensor with the truncated integer values of the elements of input.
Args:
- **x** (Tensor) - Input_x is a tensor.
Returns:
Tensor, the same shape and data type as the input.
Raises:
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32)
>>> trunc = ops.Trunc()
>>> output = trunc(x)
>>> print(output)
[ 3. 0. 0. -3.]
"""
return trunc_(x)
def ldexp(x, other):
"""
Multiplies input by 2**:attr:other.
@ -5241,6 +5268,7 @@ __all__ = [
'rad2deg',
'truncate_div',
'truncate_mod',
'trunc',
'gumbel_softmax',
'matmul',
'baddbmm',

View File

@ -14,12 +14,15 @@
# ============================================================================
import numpy as np
import mindspore.context as context
import pytest
from mindspore import context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -32,23 +35,61 @@ class Net(nn.Cell):
return self.trunc(x0)
def test32_net():
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3]), mstype.float32)
uniq = Net()
output = uniq(x)
print("x:\n", output)
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1]
print("expected_x:\n", expect_x_result)
assert (output.asnumpy() == expect_x_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test16_net():
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3]), mstype.float16)
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3, -0.4]), mstype.float16)
uniq = Net()
output = uniq(x)
print("x:\n", output)
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1]
print("expected_x:\n", expect_x_result)
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1, -0]
assert (output.asnumpy() == expect_x_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test32_net():
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3, -0.4]), mstype.float32)
uniq = Net()
output = uniq(x)
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1, -0]
assert (output.asnumpy() == expect_x_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def testint8_net():
x = Tensor(np.array([1, -2, 5, 2, 0, -1, 2, -1, -0]), mstype.int8)
uniq = Net()
output = uniq(x)
expect_x_result = [1, -2, 5, 2, 0, -1, 2, -1, -0]
assert (output.asnumpy() == expect_x_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def testuint8_net():
x = Tensor(np.array([1, 5, 2, 0]), mstype.uint8)
uniq = Net()
output = uniq(x)
expect_x_result = [1, 5, 2, 0]
assert (output.asnumpy() == expect_x_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def testint32_net():
x = Tensor(np.array([1, -2, 5, 2, 0, -1, 2, -1, -0]), mstype.int32)
uniq = Net()
output = uniq(x)
expect_x_result = [1, -2, 5, 2, 0, -1, 2, -1, -0]
assert (output.asnumpy() == expect_x_result).all()