Cast half to int by rounding down

This commit is contained in:
jonwe 2020-12-01 11:48:30 -05:00
parent 5835095e07
commit e2e6bf59b2
2 changed files with 30 additions and 8 deletions

View File

@ -28,35 +28,35 @@ __device__ __forceinline__ void CastBase(const S *input_addr, T *output_addr) {
// half --> integer
__device__ __forceinline__ void CastBase(const half *input_addr, uint64_t *output_addr) {
*output_addr = __half2ull_rn((*input_addr));
*output_addr = __half2ull_rd((*input_addr));
}
__device__ __forceinline__ void CastBase(const half *input_addr, int64_t *output_addr) {
*output_addr = __half2ll_rn((*input_addr));
*output_addr = __half2ll_rd((*input_addr));
}
__device__ __forceinline__ void CastBase(const half *input_addr, uint32_t *output_addr) {
*output_addr = __half2uint_rn((*input_addr));
*output_addr = __half2uint_rd((*input_addr));
}
__device__ __forceinline__ void CastBase(const half *input_addr, int32_t *output_addr) {
*output_addr = __half2int_rn((*input_addr));
*output_addr = __half2int_rd((*input_addr));
}
__device__ __forceinline__ void CastBase(const half *input_addr, uint16_t *output_addr) {
*output_addr = __half2ushort_rn((*input_addr));
*output_addr = __half2ushort_rd((*input_addr));
}
__device__ __forceinline__ void CastBase(const half *input_addr, int16_t *output_addr) {
*output_addr = __half2short_rn((*input_addr));
*output_addr = __half2short_rd((*input_addr));
}
__device__ __forceinline__ void CastBase(const half *input_addr, uint8_t *output_addr) {
*output_addr = static_cast<uint8_t>(__half2ushort_rn((*input_addr)));
*output_addr = static_cast<uint8_t>(__half2ushort_rd((*input_addr)));
}
__device__ __forceinline__ void CastBase(const half *input_addr, int8_t *output_addr) {
*output_addr = static_cast<int8_t>(__half2short_rn((*input_addr)));
*output_addr = static_cast<int8_t>(__half2short_rd((*input_addr)));
}
// integer --> half

View File

@ -597,3 +597,25 @@ def test_cast31():
assert type0 == 'uint16'
type1 = output[1].asnumpy().dtype
assert type1 == 'uint32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast32():
np.random.seed(10)
x = np.random.rand(*(3, 2)).astype(np.float16)
x0 = Tensor(x)
t0 = mstype.int32
x1 = Tensor(x)
t1 = mstype.float64
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int32'
expected = x.astype(np.int32)
assert (output[0].asnumpy() == expected).all()
type1 = output[1].asnumpy().dtype
assert type1 == 'float64'