!47505 Fix TruncateDiv
Merge pull request !47505 from zhanzhan/truncatediv
This commit is contained in:
commit
9106d90d81
|
@ -22,6 +22,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <cmath>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -90,7 +91,8 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
}
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<T>(dividend / divisor);
|
||||
double output_trunc = trunc(static_cast<double>(dividend / divisor));
|
||||
output_addr[i] = static_cast<T>(output_trunc);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_size, this, ¶llel_search_info_);
|
||||
|
@ -115,7 +117,8 @@ bool TruncateDivCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
}
|
||||
continue;
|
||||
}
|
||||
output_addr[i] = static_cast<T>(dividend / divisor);
|
||||
double output_trunc = trunc(static_cast<double>(dividend / divisor));
|
||||
output_addr[i] = static_cast<T>(output_trunc);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1106,7 +1106,10 @@ struct SquaredDifferenceFunc<Complex<float>> {
|
|||
template <typename T>
|
||||
struct TruncateDivFunc {
|
||||
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
|
||||
T res = static_cast<T>(static_cast<double>(lhs) / static_cast<double>(rhs));
|
||||
double lhs_d = static_cast<double>(lhs);
|
||||
double rhs_d = static_cast<double>(rhs);
|
||||
double res_d = trunc(lhs_d / rhs_d);
|
||||
T res = static_cast<T>(res_d);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
@ -1114,7 +1117,7 @@ struct TruncateDivFunc {
|
|||
template <>
|
||||
struct TruncateDivFunc<half> {
|
||||
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
|
||||
float res = __half2float(lhs) / __half2float(rhs);
|
||||
float res = truncf(__half2float(lhs) / __half2float(rhs));
|
||||
return __float2half_rn(res);
|
||||
}
|
||||
};
|
||||
|
@ -1125,8 +1128,8 @@ struct TruncateDivFunc<half2> {
|
|||
float2 l = __half22float2(lhs);
|
||||
float2 r = __half22float2(rhs);
|
||||
float2 res;
|
||||
res.x = l.x / r.x;
|
||||
res.y = l.y / r.y;
|
||||
res.x = truncf(l.x / r.x);
|
||||
res.y = truncf(l.y / r.y);
|
||||
return __float22half2_rn(res);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -58,14 +58,14 @@ def test_truncatediv_output_diff_types():
|
|||
|
||||
truncatediv_op = TruncateDiv()
|
||||
out = truncatediv_op(input_x, input_y).asnumpy()
|
||||
exp = np.array([0.33333334, 1.33333334, -1.4])
|
||||
exp = np.array([0., 1., -1.])
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
out_1 = truncatediv_op(input_x_1, input_y_1).asnumpy()
|
||||
exp = np.array([0.33333334, 1.33333334, -0.6])
|
||||
exp = np.array([0., 1., -0.])
|
||||
diff = np.abs(out_1 - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
|
|
Loading…
Reference in New Issue