initial commit

fix ci
This commit is contained in:
Peilin Wang 2021-02-16 18:14:47 -05:00
parent d6b87269a2
commit 073fc6fa93
7 changed files with 184 additions and 48 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -343,6 +343,31 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
return;
}
// double
template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Log1p<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Erf<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Erfc<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Square<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Sin<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Cos<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Asin<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void ACos<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Atan<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Asinh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Acosh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<double>(double *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
// float
template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
@ -364,6 +389,8 @@ template void Rsqrt<float>(const float *input, float *output, const size_t count
template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
// half
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,6 +42,8 @@ MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -58,6 +60,8 @@ MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -66,6 +70,8 @@ MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -78,6 +84,8 @@ MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -94,6 +102,8 @@ MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,14 +20,29 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cos():
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
def cos(nptype):
np.random.seed(0)
x_np = np.random.rand(2, 3, 4, 4).astype(nptype)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_ms = P.Cos()(Tensor(x_np))
output_np = np.cos(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cos_float16():
cos(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cos_float32():
cos(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cos_float64():
cos(np.float64)

View File

@ -14,15 +14,14 @@
# ============================================================================
""" test loss """
import numpy as np
import mindspore
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import L1Loss
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
class WeightedLoss(_Loss):
def __init__(self, reduction='mean', weights=1.0):
super(WeightedLoss, self).__init__(reduction)
@ -33,10 +32,13 @@ class WeightedLoss(_Loss):
x = self.abs(base - target)
return self.get_loss(x, self.weights)
def test_WeightedLoss():
def weighted_loss(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
loss = WeightedLoss()
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype))
output_data = loss(input_data, target_data)
error_range = np.ones(shape=output_data.shape) * 10e-6
@ -50,14 +52,26 @@ def test_WeightedLoss():
diff = test_output - output_data * 3
assert np.all(abs(diff.asnumpy()) < error_range)
loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]])))
y_true = Tensor(np.array([[0., 1.], [0., 0.]]), mindspore.float32)
y_pred = Tensor(np.array([[1., 1.], [1., 0.]]), mindspore.float32)
loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]).astype(nptype)))
y_true = Tensor(np.array([[0., 1.], [0., 0.]]).astype(nptype))
y_pred = Tensor(np.array([[1., 1.], [1., 0.]]).astype(nptype))
test_data = 0.35
output = loss(y_true, y_pred)
diff = test_data - output.asnumpy()
assert np.all(abs(diff) < error_range)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_weighted_loss_float32():
weighted_loss(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_weighted_loss_float64():
weighted_loss(np.float64)
class CustomLoss(_Loss):
def __init__(self, reduction='mean'):
super(CustomLoss, self).__init__(reduction)
@ -67,10 +81,10 @@ class CustomLoss(_Loss):
x = self.abs(base - target)
return self.get_loss(x, weights=2.0)
def test_CustomLoss():
def custom_loss(nptype):
loss = L1Loss()
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype))
output_data = loss(input_data, target_data)
error_range = np.ones(shape=output_data.shape) * 10e-6
@ -78,3 +92,21 @@ def test_CustomLoss():
test_output = customloss(input_data, target_data)
diff = test_output - output_data * 2.0
assert np.all(abs(diff.asnumpy()) < error_range)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_custom_loss_float16():
custom_loss(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_custom_loss_float32():
custom_loss(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_custom_loss_float64():
custom_loss(np.float64)

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -31,12 +31,9 @@ class NetNeg(nn.Cell):
return self.neg(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_neg():
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.uniform(-2, 2, 1).astype(np.float32)
def neg(nptype):
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
x1_np = np.random.uniform(-2, 2, 1).astype(nptype)
x0 = Tensor(x0_np)
x1 = Tensor(x1_np)
expect0 = np.negative(x0_np)
@ -45,23 +42,41 @@ def test_neg():
error1 = np.ones(shape=expect1.shape) * 1.0e-5
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
neg = NetNeg()
output0 = neg(x0)
neg_net = NetNeg()
output0 = neg_net(x0)
diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0)
assert output0.shape == expect0.shape
output1 = neg(x1)
output1 = neg_net(x1)
diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1)
assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
neg = NetNeg()
output0 = neg(x0)
neg_net = NetNeg()
output0 = neg_net(x0)
diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0)
assert output0.shape == expect0.shape
output1 = neg(x1)
output1 = neg_net(x1)
diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1)
assert output1.shape == expect1.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_neg_float16():
neg(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_neg_float32():
neg(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_neg_float64():
neg(np.float64)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,14 +20,29 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sin():
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
def sin(nptype):
np.random.seed(0)
x_np = np.random.rand(2, 3, 4, 4).astype(nptype)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_ms = P.Sin()(Tensor(x_np))
output_np = np.sin(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sin_float16():
sin(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sin_float32():
sin(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sin_float64():
sin(np.float64)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,18 +20,40 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sqrt():
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
def sqrt(nptype):
np.random.seed(0)
x_np = np.random.rand(2, 3, 4, 4).astype(nptype)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_ms = P.Sqrt()(Tensor(x_np))
output_np = np.sqrt(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sqrt_float16():
sqrt(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sqrt_float32():
sqrt(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sqrt_float64():
sqrt(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_rsqrt():
np.random.seed(0)
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
output_ms = P.Rsqrt()(Tensor(x_np))
output_np = 1 / np.sqrt(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)