forked from mindspore-Ecosystem/mindspore
changed cast to round to zero if casting from float to integral
This commit is contained in:
parent
2b8737a908
commit
299509dbfc
|
@ -28,35 +28,35 @@ __device__ __forceinline__ void CastBase(const S *input_addr, T *output_addr) {
|
|||
|
||||
// half --> integer
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, uint64_t *output_addr) {
|
||||
*output_addr = __half2ull_rd((*input_addr));
|
||||
*output_addr = __half2ull_rz((*input_addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, int64_t *output_addr) {
|
||||
*output_addr = __half2ll_rd((*input_addr));
|
||||
*output_addr = __half2ll_rz((*input_addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, uint32_t *output_addr) {
|
||||
*output_addr = __half2uint_rd((*input_addr));
|
||||
*output_addr = __half2uint_rz((*input_addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, int32_t *output_addr) {
|
||||
*output_addr = __half2int_rd((*input_addr));
|
||||
*output_addr = __half2int_rz((*input_addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, uint16_t *output_addr) {
|
||||
*output_addr = __half2ushort_rd((*input_addr));
|
||||
*output_addr = __half2ushort_rz((*input_addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, int16_t *output_addr) {
|
||||
*output_addr = __half2short_rd((*input_addr));
|
||||
*output_addr = __half2short_rz((*input_addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, uint8_t *output_addr) {
|
||||
*output_addr = static_cast<uint8_t>(__half2ushort_rd((*input_addr)));
|
||||
*output_addr = static_cast<uint8_t>(__half2ushort_rz((*input_addr)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void CastBase(const half *input_addr, int8_t *output_addr) {
|
||||
*output_addr = static_cast<int8_t>(__half2short_rd((*input_addr)));
|
||||
*output_addr = static_cast<int8_t>(__half2short_rz((*input_addr)));
|
||||
}
|
||||
|
||||
// integer --> half
|
||||
|
|
|
@ -604,7 +604,7 @@ def test_cast31():
|
|||
@pytest.mark.env_onecard
|
||||
def test_cast32():
|
||||
np.random.seed(10)
|
||||
x = np.random.rand(*(3, 2)).astype(np.float16)
|
||||
x = np.random.uniform(-5, 5, (3, 2)).astype(np.float16)
|
||||
x0 = Tensor(x)
|
||||
t0 = mstype.int32
|
||||
x1 = Tensor(x)
|
||||
|
|
Loading…
Reference in New Issue