!47505 Fix TruncateDiv

Merge pull request !47505 from zhanzhan/truncatediv
This commit is contained in:
i-robot 2023-01-06 03:52:01 +00:00 committed by Gitee
commit 9106d90d81
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 14 additions and 8 deletions

View File

@ -22,6 +22,7 @@
#include <utility>
#include <vector>
#include <map>
#include <cmath>
namespace mindspore {
namespace kernel {
@ -90,7 +91,8 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
}
continue;
}
output_addr[i] = static_cast<T>(dividend / divisor);
double output_trunc = trunc(static_cast<double>(dividend / divisor));
output_addr[i] = static_cast<T>(output_trunc);
}
};
ParallelLaunchAutoSearch(task, output_size, this, &parallel_search_info_);
@ -115,7 +117,8 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
}
continue;
}
output_addr[i] = static_cast<T>(dividend / divisor);
double output_trunc = trunc(static_cast<double>(dividend / divisor));
output_addr[i] = static_cast<T>(output_trunc);
iter.GenNextPos();
}
};

View File

@ -1106,7 +1106,10 @@ struct SquaredDifferenceFunc<Complex<float>> {
template <typename T>
struct TruncateDivFunc {
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
T res = static_cast<T>(static_cast<double>(lhs) / static_cast<double>(rhs));
double lhs_d = static_cast<double>(lhs);
double rhs_d = static_cast<double>(rhs);
double res_d = trunc(lhs_d / rhs_d);
T res = static_cast<T>(res_d);
return res;
}
};
@ -1114,7 +1117,7 @@ struct TruncateDivFunc {
template <>
struct TruncateDivFunc<half> {
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
float res = __half2float(lhs) / __half2float(rhs);
float res = truncf(__half2float(lhs) / __half2float(rhs));
return __float2half_rn(res);
}
};
@ -1125,8 +1128,8 @@ struct TruncateDivFunc<half2> {
float2 l = __half22float2(lhs);
float2 r = __half22float2(rhs);
float2 res;
res.x = l.x / r.x;
res.y = l.y / r.y;
res.x = truncf(l.x / r.x);
res.y = truncf(l.y / r.y);
return __float22half2_rn(res);
}
};

View File

@ -58,14 +58,14 @@ def test_truncatediv_output_diff_types():
truncatediv_op = TruncateDiv()
out = truncatediv_op(input_x, input_y).asnumpy()
exp = np.array([0.33333334, 1.33333334, -1.4])
exp = np.array([0., 1., -1.])
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
out_1 = truncatediv_op(input_x_1, input_y_1).asnumpy()
exp = np.array([0.33333334, 1.33333334, -0.6])
exp = np.array([0., 1., -0.])
diff = np.abs(out_1 - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)