change float status output type to float32
This commit is contained in:
parent
048eba4460
commit
e48cb95cb7
|
@ -87,7 +87,7 @@ __global__ void IsFinite(const size_t size, const half* input, bool* out) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void FloatStatus(const size_t size, const T* input, T* out) {
|
__global__ void FloatStatus(const size_t size, const T* input, float* out) {
|
||||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||||
if (isinf(input[pos]) != 0 || isnan(input[pos])) {
|
if (isinf(input[pos]) != 0 || isnan(input[pos])) {
|
||||||
out[0] = 1;
|
out[0] = 1;
|
||||||
|
@ -96,7 +96,7 @@ __global__ void FloatStatus(const size_t size, const T* input, T* out) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
__global__ void FloatStatus(const size_t size, const half* input, half* out) {
|
__global__ void FloatStatus(const size_t size, const half* input, float* out) {
|
||||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||||
if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) {
|
if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) {
|
||||||
out[0] = 1;
|
out[0] = 1;
|
||||||
|
@ -106,7 +106,7 @@ __global__ void FloatStatus(const size_t size, const half* input, half* out) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) {
|
void CalFloatStatus(const size_t size, const T* input, float* output, cudaStream_t cuda_stream) {
|
||||||
FloatStatus<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
|
FloatStatus<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t c
|
||||||
}
|
}
|
||||||
|
|
||||||
template void CalFloatStatus<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream);
|
template void CalFloatStatus<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream);
|
||||||
template void CalFloatStatus<half>(const size_t size, const half* input, half* output, cudaStream_t cuda_stream);
|
template void CalFloatStatus<half>(const size_t size, const half* input, float* output, cudaStream_t cuda_stream);
|
||||||
template void CalIsInf<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
|
template void CalIsInf<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
|
||||||
template void CalIsInf<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
|
template void CalIsInf<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
|
||||||
template void CalIsNan<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
|
template void CalIsNan<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_
|
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_
|
||||||
#include "runtime/device/gpu/cuda_common.h"
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream);
|
void CalFloatStatus(const size_t size, const T *input, float *output, cudaStream_t stream);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream);
|
void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -20,7 +20,7 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
FloatStatusGpuKernel, float)
|
FloatStatusGpuKernel, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
|
||||||
FloatStatusGpuKernel, half)
|
FloatStatusGpuKernel, half)
|
||||||
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||||
FloatStatusGpuKernel, float)
|
FloatStatusGpuKernel, float)
|
||||||
|
|
|
@ -46,8 +46,8 @@ class FloatStatusGpuKernel : public GpuKernel {
|
||||||
|
|
||||||
switch (kernel_name_) {
|
switch (kernel_name_) {
|
||||||
case OP_STATUS: {
|
case OP_STATUS: {
|
||||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
float *output = GetDeviceAddress<float>(outputs, 0);
|
||||||
FillDeviceArray(outputs[0]->size / sizeof(T), output, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
FillDeviceArray(outputs[0]->size / sizeof(float), output, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -90,7 +90,7 @@ class FloatStatusGpuKernel : public GpuKernel {
|
||||||
kernel_name_ = iter->second;
|
kernel_name_ = iter->second;
|
||||||
}
|
}
|
||||||
if (kernel_name_ == OP_STATUS) {
|
if (kernel_name_ == OP_STATUS) {
|
||||||
output_size_ = sizeof(T);
|
output_size_ = sizeof(float);
|
||||||
} else {
|
} else {
|
||||||
output_size_ = input_size_ / sizeof(T) * sizeof(bool);
|
output_size_ = input_size_ / sizeof(T) * sizeof(bool);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3176,8 +3176,7 @@ class FloatStatus(PrimitiveWithInfer):
|
||||||
- **input_x** (Tensor) - The input tensor. The data type must be float16 or float32.
|
- **input_x** (Tensor) - The input tensor. The data type must be float16 or float32.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, has the shape of `(1,)`, and has the same dtype of input `mindspore.dtype.float32` or
|
Tensor, has the shape of `(1,)`, and the dtype is `mindspore.dtype.float32`.
|
||||||
`mindspore.dtype.float16`.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``GPU``
|
``GPU``
|
||||||
|
@ -3200,7 +3199,7 @@ class FloatStatus(PrimitiveWithInfer):
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
|
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
|
||||||
return x_dtype
|
return mstype.float32
|
||||||
|
|
||||||
|
|
||||||
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
class NPUAllocFloatStatus(PrimitiveWithInfer):
|
||||||
|
|
Loading…
Reference in New Issue