forked from mindspore-Ecosystem/mindspore
Add dtype float16 that erf and erfc should support
This commit is contained in:
parent
6f5be6b876
commit
8132e56417
|
@ -18,7 +18,7 @@
|
|||
template <typename T>
|
||||
__global__ void ErfKernel(T *input, T *output, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = (T)erf(input[i]);
|
||||
output[i] = static_cast<T>(erf(static_cast<float>(input[i])));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -30,3 +30,4 @@ void Erf(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
|
|||
}
|
||||
|
||||
template void Erf<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Erf<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
template <typename T>
|
||||
__global__ void ErfcKernel(T *input, T *output, size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
output[i] = (T)erfc(input[i]);
|
||||
output[i] = static_cast<T>(erfc(static_cast<float>(input[i])));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -30,3 +30,4 @@ void Erfc(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
|
|||
}
|
||||
|
||||
template void Erfc<float>(float *input, float *output, size_t count, cudaStream_t cuda_stream);
|
||||
template void Erfc<half>(half *input, half *output, size_t count, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -20,5 +20,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ErfGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ErfGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,5 +20,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ErfcGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ErfcGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,10 +37,21 @@ class NetErf(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_exp():
|
||||
def test_erf_fp32():
|
||||
erf = NetErf()
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
|
||||
x = np.random.rand(3, 8).astype(np.float32)
|
||||
output = erf(Tensor(x, dtype=dtype.float32))
|
||||
expect = special.erf(x)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect) < tol).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_erf_fp16():
|
||||
erf = NetErf()
|
||||
x = np.random.rand(3, 8).astype(np.float16)
|
||||
output = erf(Tensor(x, dtype=dtype.float16))
|
||||
expect = special.erf(x)
|
||||
tol = 1e-3
|
||||
assert (np.abs(output.asnumpy() - expect) < tol).all()
|
||||
|
|
|
@ -37,10 +37,21 @@ class NetErfc(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_exp():
|
||||
def test_erfc_fp32():
|
||||
erfc = NetErfc()
|
||||
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
|
||||
x = np.random.rand(3, 8).astype(np.float32)
|
||||
output = erfc(Tensor(x, dtype=dtype.float32))
|
||||
expect = special.erfc(x)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect) < tol).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_erfc_fp16():
|
||||
erfc = NetErfc()
|
||||
x = np.random.rand(3, 8).astype(np.float16)
|
||||
output = erfc(Tensor(x, dtype=dtype.float16))
|
||||
expect = special.erfc(x)
|
||||
tol = 1e-3
|
||||
assert (np.abs(output.asnumpy() - expect) < tol).all()
|
||||
|
|
Loading…
Reference in New Issue