Support int32 as input type for Abs of GPU op, float64 as input type for IsFinite GPU op.

This commit is contained in:
hezhenhao1 2021-11-04 11:10:02 +08:00
parent 3cb39df189
commit d61b089f6b
5 changed files with 56 additions and 18 deletions

View File

@ -128,9 +128,13 @@ void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t c
template void CalFloatStatus<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream);
template void CalFloatStatus<half>(const size_t size, const half* input, float* output, cudaStream_t cuda_stream);
template void CalFloatStatus<double>(const size_t size, const double* input, float* output, cudaStream_t cuda_stream);
template void CalIsInf<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
template void CalIsInf<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
template void CalIsInf<double>(const size_t size, const double* input, bool* output, cudaStream_t cuda_stream);
template void CalIsNan<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
template void CalIsNan<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
template void CalIsNan<double>(const size_t size, const double* input, bool* output, cudaStream_t cuda_stream);
template void CalIsFinite<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
template void CalIsFinite<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
template void CalIsFinite<double>(const size_t size, const double* input, bool* output, cudaStream_t cuda_stream);

View File

@ -22,17 +22,25 @@ MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32)
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
FloatStatusGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32),
FloatStatusGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
FloatStatusGpuKernel, double)
} // namespace kernel
} // namespace mindspore

View File

@ -110,6 +110,8 @@ MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
UnaryOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -4110,7 +4110,7 @@ class FloatStatus(PrimitiveWithInfer):
Tensor, has the shape of `(1,)`, and the dtype is `mindspore.dtype.float32`.
Raises:
TypeError: If dtype of `x` is neither float16 nor float32.
TypeError: If dtype of `x` is not in [float16, float32, float64].
Supported Platforms:
``GPU``
@ -4132,7 +4132,7 @@ class FloatStatus(PrimitiveWithInfer):
return [1]
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16, mstype.float64], self.name)
return mstype.float32

View File

@ -67,17 +67,23 @@ x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_status():
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_status(dtype):
"""
Feature: ALL To ALL
Description: test cases for FloatStatus
Expectation: the result match to expectation
"""
ms_status = Net()
output1 = ms_status(Tensor(x1))
output1 = ms_status(Tensor(x1.astype(dtype)))
expect1 = 1
assert output1.asnumpy()[0] == expect1
output2 = ms_status(Tensor(x2))
output2 = ms_status(Tensor(x2.astype(dtype)))
expect2 = 1
assert output2.asnumpy()[0] == expect2
output3 = ms_status(Tensor(x3))
output3 = ms_status(Tensor(x3.astype(dtype)))
expect3 = 0
assert output3.asnumpy()[0] == expect3
@ -85,17 +91,23 @@ def test_status():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nan():
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_nan(dtype):
"""
Feature: ALL To ALL
Description: test cases for IsNan
Expectation: the result match to expectation
"""
ms_isnan = Netnan()
output1 = ms_isnan(Tensor(x1))
output1 = ms_isnan(Tensor(x1.astype(dtype)))
expect1 = [[False, False, True, False]]
assert (output1.asnumpy() == expect1).all()
output2 = ms_isnan(Tensor(x2))
output2 = ms_isnan(Tensor(x2.astype(dtype)))
expect2 = [[False, False, False, False]]
assert (output2.asnumpy() == expect2).all()
output3 = ms_isnan(Tensor(x3))
output3 = ms_isnan(Tensor(x3.astype(dtype)))
expect3 = [[False, False], [False, False], [False, False]]
assert (output3.asnumpy() == expect3).all()
@ -103,17 +115,23 @@ def test_nan():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_inf():
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_inf(dtype):
"""
Feature: ALL To ALL
Description: test cases for IsInf
Expectation: the result match to expectation
"""
ms_isinf = Netinf()
output1 = ms_isinf(Tensor(x1))
output1 = ms_isinf(Tensor(x1.astype(dtype)))
expect1 = [[False, False, False, False]]
assert (output1.asnumpy() == expect1).all()
output2 = ms_isinf(Tensor(x2))
output2 = ms_isinf(Tensor(x2.astype(dtype)))
expect2 = [[True, False, False, False]]
assert (output2.asnumpy() == expect2).all()
output3 = ms_isinf(Tensor(x3))
output3 = ms_isinf(Tensor(x3.astype(dtype)))
expect3 = [[False, False], [False, False], [False, False]]
assert (output3.asnumpy() == expect3).all()
@ -121,16 +139,22 @@ def test_inf():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_finite():
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_finite(dtype):
"""
Feature: ALL To ALL
Description: test cases for Netfinite
Expectation: the result match to expectation
"""
ms_isfinite = Netfinite()
output1 = ms_isfinite(Tensor(x1))
output1 = ms_isfinite(Tensor(x1.astype(dtype)))
expect1 = [[True, True, False, True]]
assert (output1.asnumpy() == expect1).all()
output2 = ms_isfinite(Tensor(x2))
output2 = ms_isfinite(Tensor(x2.astype(dtype)))
expect2 = [[False, True, True, True]]
assert (output2.asnumpy() == expect2).all()
output3 = ms_isfinite(Tensor(x3))
output3 = ms_isfinite(Tensor(x3.astype(dtype)))
expect3 = [[True, True], [True, True], [True, True]]
assert (output3.asnumpy() == expect3).all()