forked from mindspore-Ecosystem/mindspore
Support int32 as input type for Abs of GPU op, float64 as input type for IsFinite GPU op.
This commit is contained in:
parent
3cb39df189
commit
d61b089f6b
|
@ -128,9 +128,13 @@ void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t c
|
|||
|
||||
template void CalFloatStatus<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream);
|
||||
template void CalFloatStatus<half>(const size_t size, const half* input, float* output, cudaStream_t cuda_stream);
|
||||
template void CalFloatStatus<double>(const size_t size, const double* input, float* output, cudaStream_t cuda_stream);
|
||||
template void CalIsInf<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsInf<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsInf<double>(const size_t size, const double* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsNan<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsNan<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsNan<double>(const size_t size, const double* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsFinite<float>(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsFinite<half>(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream);
|
||||
template void CalIsFinite<double>(const size_t size, const double* input, bool* output, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -22,17 +22,25 @@ MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32)
|
|||
FloatStatusGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
|
||||
FloatStatusGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32),
|
||||
FloatStatusGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
FloatStatusGpuKernel, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -110,6 +110,8 @@ MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
|
|||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnaryOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
|
|
@ -4110,7 +4110,7 @@ class FloatStatus(PrimitiveWithInfer):
|
|||
Tensor, has the shape of `(1,)`, and the dtype is `mindspore.dtype.float32`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
TypeError: If dtype of `x` is not in [float16, float32, float64].
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
@ -4132,7 +4132,7 @@ class FloatStatus(PrimitiveWithInfer):
|
|||
return [1]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16, mstype.float64], self.name)
|
||||
return mstype.float32
|
||||
|
||||
|
||||
|
|
|
@ -67,17 +67,23 @@ x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32)
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_status():
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
|
||||
def test_status(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for FloatStatus
|
||||
Expectation: the result match to expectation
|
||||
"""
|
||||
ms_status = Net()
|
||||
output1 = ms_status(Tensor(x1))
|
||||
output1 = ms_status(Tensor(x1.astype(dtype)))
|
||||
expect1 = 1
|
||||
assert output1.asnumpy()[0] == expect1
|
||||
|
||||
output2 = ms_status(Tensor(x2))
|
||||
output2 = ms_status(Tensor(x2.astype(dtype)))
|
||||
expect2 = 1
|
||||
assert output2.asnumpy()[0] == expect2
|
||||
|
||||
output3 = ms_status(Tensor(x3))
|
||||
output3 = ms_status(Tensor(x3.astype(dtype)))
|
||||
expect3 = 0
|
||||
assert output3.asnumpy()[0] == expect3
|
||||
|
||||
|
@ -85,17 +91,23 @@ def test_status():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nan():
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
|
||||
def test_nan(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for IsNan
|
||||
Expectation: the result match to expectation
|
||||
"""
|
||||
ms_isnan = Netnan()
|
||||
output1 = ms_isnan(Tensor(x1))
|
||||
output1 = ms_isnan(Tensor(x1.astype(dtype)))
|
||||
expect1 = [[False, False, True, False]]
|
||||
assert (output1.asnumpy() == expect1).all()
|
||||
|
||||
output2 = ms_isnan(Tensor(x2))
|
||||
output2 = ms_isnan(Tensor(x2.astype(dtype)))
|
||||
expect2 = [[False, False, False, False]]
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
||||
output3 = ms_isnan(Tensor(x3))
|
||||
output3 = ms_isnan(Tensor(x3.astype(dtype)))
|
||||
expect3 = [[False, False], [False, False], [False, False]]
|
||||
assert (output3.asnumpy() == expect3).all()
|
||||
|
||||
|
@ -103,17 +115,23 @@ def test_nan():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_inf():
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
|
||||
def test_inf(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for IsInf
|
||||
Expectation: the result match to expectation
|
||||
"""
|
||||
ms_isinf = Netinf()
|
||||
output1 = ms_isinf(Tensor(x1))
|
||||
output1 = ms_isinf(Tensor(x1.astype(dtype)))
|
||||
expect1 = [[False, False, False, False]]
|
||||
assert (output1.asnumpy() == expect1).all()
|
||||
|
||||
output2 = ms_isinf(Tensor(x2))
|
||||
output2 = ms_isinf(Tensor(x2.astype(dtype)))
|
||||
expect2 = [[True, False, False, False]]
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
||||
output3 = ms_isinf(Tensor(x3))
|
||||
output3 = ms_isinf(Tensor(x3.astype(dtype)))
|
||||
expect3 = [[False, False], [False, False], [False, False]]
|
||||
assert (output3.asnumpy() == expect3).all()
|
||||
|
||||
|
@ -121,16 +139,22 @@ def test_inf():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_finite():
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
|
||||
def test_finite(dtype):
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for Netfinite
|
||||
Expectation: the result match to expectation
|
||||
"""
|
||||
ms_isfinite = Netfinite()
|
||||
output1 = ms_isfinite(Tensor(x1))
|
||||
output1 = ms_isfinite(Tensor(x1.astype(dtype)))
|
||||
expect1 = [[True, True, False, True]]
|
||||
assert (output1.asnumpy() == expect1).all()
|
||||
|
||||
output2 = ms_isfinite(Tensor(x2))
|
||||
output2 = ms_isfinite(Tensor(x2.astype(dtype)))
|
||||
expect2 = [[False, True, True, True]]
|
||||
assert (output2.asnumpy() == expect2).all()
|
||||
|
||||
output3 = ms_isfinite(Tensor(x3))
|
||||
output3 = ms_isfinite(Tensor(x3.astype(dtype)))
|
||||
expect3 = [[True, True], [True, True], [True, True]]
|
||||
assert (output3.asnumpy() == expect3).all()
|
||||
|
|
Loading…
Reference in New Issue