!37686 Trunc op precision issue fixed
Merge pull request !37686 from Mina/trunc_issue-fix
This commit is contained in:
commit
4e3c39b572
|
@ -28,10 +28,15 @@ constexpr size_t kTruncOutputsNum = 1;
|
|||
template <typename T>
|
||||
void Trunc(const T *in0, T *out0, size_t start, size_t end) {
|
||||
for (size_t index = start; index < end; index++) {
|
||||
auto retp = floor(in0[index]);
|
||||
auto retn = ceil(in0[index]);
|
||||
int ind = static_cast<int>(in0[index]);
|
||||
out0[index] = (ind < 0) ? retn : retp;
|
||||
if (std::is_same_v<T, std::uint8_t>) {
|
||||
out0[index] = in0[index];
|
||||
} else {
|
||||
auto absvalue1 = (in0[index]) * (in0[index]);
|
||||
auto absvalue = sqrt(absvalue1);
|
||||
auto retp = floor(absvalue);
|
||||
out0[index] = (ind < 0) ? -retp : retp;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -56,6 +61,14 @@ bool TruncCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
ret = LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
ret = LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
ret = LaunchKernel<double>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt8) {
|
||||
ret = LaunchKernel<int8_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeUInt8) {
|
||||
ret = LaunchKernel<uint8_t>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
ret = LaunchKernel<int32_t>(inputs, outputs);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "Unsupported input data type for operator [" << kernel_name_
|
||||
<< "]: " << TypeIdToType(dtype_)->ToString();
|
||||
|
@ -75,8 +88,11 @@ bool TruncCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, cons
|
|||
std::vector<KernelAttr> TruncCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32)};
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
|
|
@ -233,6 +233,7 @@ from .math_func import (
|
|||
rad2deg,
|
||||
truncate_div,
|
||||
truncate_mod,
|
||||
trunc,
|
||||
gumbel_softmax,
|
||||
matmul,
|
||||
baddbmm,
|
||||
|
|
|
@ -149,6 +149,7 @@ log_matrix_determinant_ = P.LogMatrixDeterminant()
|
|||
exp2_ = P.Pow()
|
||||
truncate_div_ = P.TruncateDiv()
|
||||
truncate_mod_ = P.TruncateMod()
|
||||
trunc_ = P.Trunc()
|
||||
sparse_segment_mean_ = SparseSegmentMean()
|
||||
xlogy_ = P.Xlogy()
|
||||
|
||||
|
@ -2300,6 +2301,32 @@ def truncate_mod(x, y):
|
|||
return truncate_mod_(x, y)
|
||||
|
||||
|
||||
def trunc(x):
|
||||
r"""
|
||||
Returns a new tensor with the truncated integer values of the elements of input.
|
||||
|
||||
Args:
|
||||
- **x** (Tensor) - Input_x is a tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, the same shape and data type as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([3.4742, 0.5466, -0.8008, -3.9079]),mindspore.float32)
|
||||
>>> trunc = ops.Trunc()
|
||||
>>> output = trunc(x)
|
||||
>>> print(output)
|
||||
[ 3. 0. 0. -3.]
|
||||
"""
|
||||
return trunc_(x)
|
||||
|
||||
|
||||
def ldexp(x, other):
|
||||
"""
|
||||
Multiplies input by 2**:attr:other.
|
||||
|
@ -5241,6 +5268,7 @@ __all__ = [
|
|||
'rad2deg',
|
||||
'truncate_div',
|
||||
'truncate_mod',
|
||||
'trunc',
|
||||
'gumbel_softmax',
|
||||
'matmul',
|
||||
'baddbmm',
|
||||
|
|
|
@ -14,12 +14,15 @@
|
|||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import pytest
|
||||
|
||||
from mindspore import context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
|
@ -32,23 +35,61 @@ class Net(nn.Cell):
|
|||
return self.trunc(x0)
|
||||
|
||||
|
||||
def test32_net():
|
||||
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3]), mstype.float32)
|
||||
uniq = Net()
|
||||
output = uniq(x)
|
||||
print("x:\n", output)
|
||||
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1]
|
||||
print("expected_x:\n", expect_x_result)
|
||||
|
||||
assert (output.asnumpy() == expect_x_result).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test16_net():
|
||||
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3]), mstype.float16)
|
||||
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3, -0.4]), mstype.float16)
|
||||
uniq = Net()
|
||||
output = uniq(x)
|
||||
print("x:\n", output)
|
||||
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1]
|
||||
print("expected_x:\n", expect_x_result)
|
||||
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1, -0]
|
||||
|
||||
assert (output.asnumpy() == expect_x_result).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test32_net():
|
||||
x = Tensor(np.array([1.2, -2.6, 5.0, 2.8, 0.2, -1.0, 2, -1.3, -0.4]), mstype.float32)
|
||||
uniq = Net()
|
||||
output = uniq(x)
|
||||
expect_x_result = [1., -2., 5., 2., 0., -1., 2, -1, -0]
|
||||
|
||||
assert (output.asnumpy() == expect_x_result).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def testint8_net():
|
||||
x = Tensor(np.array([1, -2, 5, 2, 0, -1, 2, -1, -0]), mstype.int8)
|
||||
uniq = Net()
|
||||
output = uniq(x)
|
||||
expect_x_result = [1, -2, 5, 2, 0, -1, 2, -1, -0]
|
||||
|
||||
assert (output.asnumpy() == expect_x_result).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def testuint8_net():
|
||||
x = Tensor(np.array([1, 5, 2, 0]), mstype.uint8)
|
||||
uniq = Net()
|
||||
output = uniq(x)
|
||||
expect_x_result = [1, 5, 2, 0]
|
||||
|
||||
assert (output.asnumpy() == expect_x_result).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def testint32_net():
|
||||
x = Tensor(np.array([1, -2, 5, 2, 0, -1, 2, -1, -0]), mstype.int32)
|
||||
uniq = Net()
|
||||
output = uniq(x)
|
||||
expect_x_result = [1, -2, 5, 2, 0, -1, 2, -1, -0]
|
||||
|
||||
assert (output.asnumpy() == expect_x_result).all()
|
||||
|
|
Loading…
Reference in New Issue