Add testcases in Ascend back-end for graph kernel

This commit is contained in:
looop5 2020-11-16 11:14:39 +08:00
parent f885f6636f
commit f5f66abd06
11 changed files with 238 additions and 65 deletions

View File

@ -20,8 +20,6 @@ from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
@ -35,9 +33,6 @@ class Net(Cell):
return self.add(mul_res, square_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
mul_res = input_x * input_x
@ -49,3 +44,20 @@ def test_basic():
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_basic()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_basic()

View File

@ -21,8 +21,6 @@ from mindspore.nn import Cell
import mindspore.ops.operations as P
from mindspore.nn.graph_kernels import ReLU
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
@ -42,9 +40,6 @@ class Net(Cell):
return self.add(add1_res, add1_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
@ -61,3 +56,20 @@ def test_basic():
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_basic()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_basic()

View File

@ -25,8 +25,6 @@ from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_graph_kernel=True)
class Net(nn.Cell):
def __init__(self, decay_flag=True):
@ -76,9 +74,6 @@ def CalFusedAdam(beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, we
return param_expect, m_expect, v_expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adam():
np.random.seed(0)
beta1 = np.array([0.9]).astype(np.float32)
@ -105,9 +100,6 @@ def test_adam():
assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adam_weight_decay():
np.random.seed(0)
beta1 = np.array([0.9]).astype(np.float32)
@ -133,3 +125,37 @@ def test_adam_weight_decay():
assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adam_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_adam()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_adam_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_adam()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adam_weight_decay_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_adam_weight_decay()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_adam_weight_decay_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_adam_weight_decay()

View File

@ -21,8 +21,6 @@ from mindspore.nn import Cell
import mindspore.ops.operations as P
import mindspore.ops.operations._grad_ops as G
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class GeluNet(Cell):
def __init__(self):
@ -48,9 +46,6 @@ def CalGelu(x):
return expect
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gelu():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
@ -64,9 +59,6 @@ def test_gelu():
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gelu_grad():
np.random.seed(0)
input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
@ -83,3 +75,37 @@ def test_gelu_grad():
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gelu_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_gelu()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_gelu_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_gelu()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gelu_grad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_gelu_grad()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_gelu_grad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_gelu_grad()

View File

@ -20,7 +20,6 @@ from mindspore import Tensor, Parameter
from mindspore.nn import Cell
from mindspore.nn.graph_kernels import LambUpdateWithLR, LambNextMV
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class LambNet(Cell):
def __init__(self, i2, i5, x6):
@ -38,6 +37,7 @@ class LambNet(Cell):
ix1, ix2, ix3), \
self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my)
def LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my):
trust_ratio = np.where(np.greater(x2, gy),
np.where(np.greater(x1, gy), np.divide(x2, x3), se),
@ -47,6 +47,7 @@ def LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my):
next_param = x6 - np.reshape(update_with_lr, x6.shape)
return next_param
def LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3):
m_fp32 = i5.astype(np.float32)
v_fp32 = i2.astype(np.float32)
@ -63,10 +64,7 @@ def tensor_all(*args):
res = [Tensor(a) for a in args]
return res
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_graph_kernel_lamb():
shape = [1, 16]
oshape = [1]
@ -130,3 +128,17 @@ def test_graph_kernel_lamb():
wres, bres))
assert all(cmp_res) and np.allclose(ares, np_res, rtol, atol)
def test_graph_kernel_lamb_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_graph_kernel_lamb()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_graph_kernel_lamb_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_graph_kernel_lamb()

View File

@ -23,8 +23,6 @@ import mindspore.nn as nn
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
@ -72,9 +70,6 @@ def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_
return dx, dg, db, mean, var
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
@ -117,9 +112,6 @@ def test_basic():
assert False
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgrad():
np.random.seed(0)
begin_norm_axis = 1
@ -142,3 +134,37 @@ def test_layernormgrad():
assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_basic()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_basic()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_layernormgrad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_layernormgrad()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_layernormgrad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_layernormgrad()

View File

@ -20,8 +20,6 @@ from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations._grad_ops as G
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class MaxmumGradNet(Cell):
def __init__(self):
@ -32,9 +30,6 @@ class MaxmumGradNet(Cell):
return self.maximum_grad(x, y, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maximum_grad():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
@ -46,3 +41,20 @@ def test_maximum_grad():
dy = input_dout - dx
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maximum_grad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_maximum_grad()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_maximum_grad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_maximum_grad()

View File

@ -20,8 +20,6 @@ from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations._grad_ops as G
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class MinmumGradNet(Cell):
def __init__(self):
@ -32,9 +30,6 @@ class MinmumGradNet(Cell):
return self.minimum_grad(x, y, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_minimum_grad():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
@ -46,3 +41,20 @@ def test_minimum_grad():
dy = input_dout - dx
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_minimum_grad()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_minimum_grad()

View File

@ -20,8 +20,6 @@ from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
@ -32,9 +30,6 @@ class Net(Cell):
return self.reduce_mean(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reduce_mean():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
@ -43,3 +38,20 @@ def test_reduce_mean():
result = net(Tensor(input_x))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reduce_mean_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_reduce_mean()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_reduce_mean_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_reduce_mean()

View File

@ -20,9 +20,6 @@ from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE,
enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
@ -50,9 +47,6 @@ class Net(Cell):
return self.reducemin(resh_res, 1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
@ -72,3 +66,20 @@ def test_basic():
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4,
atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_basic()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_basic()

View File

@ -20,8 +20,6 @@ from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations._grad_ops as G
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class TanhGradNet(Cell):
def __init__(self):
@ -32,9 +30,6 @@ class TanhGradNet(Cell):
return self.tanh_grad(y, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_tanh_grad():
np.random.seed(0)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
@ -44,3 +39,20 @@ def test_tanh_grad():
expect = input_dy * (1.0 - input_y * input_y)
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_tanh_grad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_tanh_grad()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_tanh_grad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_tanh_grad()