forked from mindspore-Ecosystem/mindspore
DinNoNan gpu kernel supports int8/uint8
This commit is contained in:
parent
f894fa5b86
commit
9a6ced3cc7
|
@ -199,6 +199,10 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x
|
|||
cudaStream_t stream);
|
||||
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, bool *y,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Element-wise ArithMetic
|
||||
template <typename T, typename Func>
|
||||
|
@ -261,6 +265,10 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half
|
|||
cudaStream_t stream);
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, int *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int8_t *x0, const int8_t *x1, int8_t *y,
|
||||
cudaStream_t stream);
|
||||
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint8_t *x0, const uint8_t *x1, uint8_t *y,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Broadcast comparation
|
||||
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
|
||||
|
@ -333,6 +341,12 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector
|
|||
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
|
||||
bool *y, cudaStream_t stream);
|
||||
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int8_t *x0,
|
||||
const int8_t *x1, bool *y, cudaStream_t stream);
|
||||
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0,
|
||||
const uint8_t *x1, bool *y, cudaStream_t stream);
|
||||
|
||||
// Broadcast Arithmetic
|
||||
template <typename T, typename Func>
|
||||
|
@ -448,6 +462,12 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect
|
|||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
|
||||
int *y, cudaStream_t stream);
|
||||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int8_t *x0,
|
||||
const int8_t *x1, int8_t *y, cudaStream_t stream);
|
||||
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
|
||||
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint8_t *x0,
|
||||
const uint8_t *x1, uint8_t *y, cudaStream_t stream);
|
||||
|
||||
// BroadcastTo
|
||||
template <typename T>
|
||||
|
|
|
@ -147,5 +147,15 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int)
|
||||
|
||||
// int8
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
BroadcastOpGpuKernel, int8_t)
|
||||
|
||||
// uint8
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
BroadcastOpGpuKernel, uint8_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -297,3 +297,43 @@ def test_broadcast_fp16():
|
|||
x2_np_zero = np.zeros_like(x2_np)
|
||||
output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero))
|
||||
assert np.allclose(output_ms.asnumpy(), x2_np_zero)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_divnonan_int8():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
np.random.seed(42)
|
||||
x1_np_int8 = np.random.randint(1, 100, (10, 20)).astype(np.int8)
|
||||
x2_np_int8 = np.random.randint(1, 100, (10, 20)).astype(np.int8)
|
||||
|
||||
output_ms = P.DivNoNan()(Tensor(x1_np_int8), Tensor(x2_np_int8))
|
||||
output_np = x1_np_int8 // x2_np_int8
|
||||
print(output_ms.asnumpy(), output_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
x2_np_zero = np.zeros_like(x2_np_int8)
|
||||
output_ms = P.DivNoNan()(Tensor(x1_np_int8), Tensor(x2_np_zero))
|
||||
assert np.allclose(output_ms.asnumpy(), x2_np_zero)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_divnonan_uint8():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
np.random.seed(42)
|
||||
x1_np_uint8 = np.random.randint(1, 100, (10, 20)).astype(np.uint8)
|
||||
x2_np_uint8 = np.random.randint(1, 100, (10, 20)).astype(np.uint8)
|
||||
|
||||
output_ms = P.DivNoNan()(Tensor(x1_np_uint8), Tensor(x2_np_uint8))
|
||||
output_np = x1_np_uint8 // x2_np_uint8
|
||||
print(output_ms.asnumpy(), output_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
x2_np_zero = np.zeros_like(x2_np_uint8)
|
||||
output_ms = P.DivNoNan()(Tensor(x1_np_uint8), Tensor(x2_np_zero))
|
||||
assert np.allclose(output_ms.asnumpy(), x2_np_zero)
|
||||
|
|
Loading…
Reference in New Issue