From e48cb95cb7ac7c1da59bb64fda648e7323552dde Mon Sep 17 00:00:00 2001 From: chenfei Date: Thu, 18 Feb 2021 16:08:51 +0800 Subject: [PATCH] change float status output type to float32 --- .../kernel_compiler/gpu/cuda_impl/float_status_impl.cu | 8 ++++---- .../kernel_compiler/gpu/cuda_impl/float_status_impl.cuh | 2 +- .../kernel_compiler/gpu/math/float_status_gpu_kernel.cc | 2 +- .../kernel_compiler/gpu/math/float_status_gpu_kernel.h | 6 +++--- mindspore/ops/operations/math_ops.py | 5 ++--- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu index 081754c87a..78d6b998cc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu @@ -87,7 +87,7 @@ __global__ void IsFinite(const size_t size, const half* input, bool* out) { } template -__global__ void FloatStatus(const size_t size, const T* input, T* out) { +__global__ void FloatStatus(const size_t size, const T* input, float* out) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { if (isinf(input[pos]) != 0 || isnan(input[pos])) { out[0] = 1; @@ -96,7 +96,7 @@ __global__ void FloatStatus(const size_t size, const T* input, T* out) { return; } template <> -__global__ void FloatStatus(const size_t size, const half* input, half* out) { +__global__ void FloatStatus(const size_t size, const half* input, float* out) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { out[0] = 1; @@ -106,7 +106,7 @@ __global__ void FloatStatus(const size_t size, const half* input, half* out) { } template -void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { +void CalFloatStatus(const size_t size, const T* input, float* output, cudaStream_t cuda_stream) { FloatStatus<<>>(size, input, output); return; } @@ -127,7 +127,7 @@ void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t c } template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); -template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalFloatStatus(const size_t size, const half* input, float* output, cudaStream_t cuda_stream); template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh index fbe063e72a..dbf0cff316 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh @@ -18,7 +18,7 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ #include "runtime/device/gpu/cuda_common.h" template -void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); +void CalFloatStatus(const size_t size, const T *input, float *output, cudaStream_t stream); template void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); template diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc index 313669a647..adcbb40185 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), FloatStatusGpuKernel, half) MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), FloatStatusGpuKernel, float) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h index b1e7e4fd0d..3d8870310a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h @@ -46,8 +46,8 @@ class FloatStatusGpuKernel : public GpuKernel { switch (kernel_name_) { case OP_STATUS: { - T *output = GetDeviceAddress(outputs, 0); - FillDeviceArray(outputs[0]->size / sizeof(T), output, 0.0f, reinterpret_cast(stream_ptr)); + float *output = GetDeviceAddress(outputs, 0); + FillDeviceArray(outputs[0]->size / sizeof(float), output, 0.0f, reinterpret_cast(stream_ptr)); CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); break; } @@ -90,7 +90,7 @@ class FloatStatusGpuKernel : public GpuKernel { kernel_name_ = iter->second; } if (kernel_name_ == OP_STATUS) { - output_size_ = sizeof(T); + output_size_ = sizeof(float); } else { output_size_ = input_size_ / sizeof(T) * sizeof(bool); } diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index e05714ad8c..7cfb2db0e9 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -3176,8 +3176,7 @@ class FloatStatus(PrimitiveWithInfer): - **input_x** (Tensor) - The input tensor. The data type must be float16 or float32. Outputs: - Tensor, has the shape of `(1,)`, and has the same dtype of input `mindspore.dtype.float32` or - `mindspore.dtype.float16`. + Tensor, has the shape of `(1,)`, and the dtype is `mindspore.dtype.float32`. Supported Platforms: ``GPU`` @@ -3200,7 +3199,7 @@ class FloatStatus(PrimitiveWithInfer): def infer_dtype(self, x_dtype): validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name) - return x_dtype + return mstype.float32 class NPUAllocFloatStatus(PrimitiveWithInfer):