!11413 Add uint8 support for relu gpu

From: @yuan_shen_zhou
Reviewed-by: @liangchenghui,@linqingke
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-01-20 09:21:46 +08:00 committed by Gitee
commit d482053167
7 changed files with 22 additions and 5 deletions

View File

@ -37,3 +37,4 @@ template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStrea
template void CalReLUGrad(int size, int16_t *dy, int16_t *y, int16_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, uint8_t *dy, uint8_t *y, uint8_t *dx, cudaStream_t cuda_stream);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,6 +38,7 @@ template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStr
template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, uint8_t *input_addr, uint8_t *output_addr, cudaStream_t cuda_stream);
template <typename T>
__global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) {
@ -78,6 +79,7 @@ template void ReluV2(const size_t num, const int8_t *x, int8_t *y, uint32_t *mas
template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluV2(const size_t num, const uint8_t *x, uint8_t *y, uint32_t *mask, cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx,
cudaStream_t cuda_stream);
@ -91,3 +93,5 @@ template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *ma
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx,
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
cudaStream_t cuda_stream);

View File

@ -32,5 +32,7 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutpu
ReLUGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReLUGpuFwdKernel,
int8_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ReLUGpuFwdKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

View File

@ -42,5 +42,8 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ReluGradGpuFwdKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ReluGradGpuFwdKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,5 +45,9 @@ MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
ReluGradV2GpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
ReluGradV2,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8),
ReluGradV2GpuKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,5 +42,8 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32),
ReluV2GpuKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

View File

@ -416,13 +416,13 @@ class ReLUV2(PrimitiveWithInfer):
f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}")
for i in enumerate(input_shape):
if i[0] == 1:
if input_dtype == mstype.uint8 and input_dtype == mstype.int8:
if input_dtype in (mstype.uint8, mstype.int8):
mask_shape.append((input_shape[1] + 31) // 32)
else:
mask_shape.append((input_shape[1] + 15) // 16)
else:
mask_shape.append(i[1])
if input_dtype == mstype.uint8 and input_dtype == mstype.int8:
if input_dtype in (mstype.uint8, mstype.int8):
mask_shape.append(4)
else:
mask_shape.append(2)