!28140 Fix aot test case

Merge pull request !28140 from jiaoy1224/aot
This commit is contained in:
i-robot 2021-12-24 10:02:03 +00:00 committed by Gitee
commit 0138b6594a
1 changed files with 6 additions and 6 deletions

View File

@ -178,8 +178,8 @@ class SquareGradNet(Cell):
self.square = ops.Custom(func, out_shapes, out_types, "aot", bprop, reg)
def construct(self, x):
res = self.square(x)[0]
res2 = self.square(res)[0]
res = self.square(x)
res2 = self.square(res)
return res2
@ -205,7 +205,7 @@ def test_square_py_bprop():
return (dx,)
try:
net = SquareGradNet(func_path + ":CustomSquare", ([3],), (mstype.float32,), bprop=bprop, reg=None)
net = SquareGradNet(func_path + ":CustomSquare", (3,), mstype.float32, bprop=bprop, reg=None)
dx = ops.GradOperation(sens_param=True)(net)(Tensor(x), Tensor(sens))
except Exception as e:
if os.path.exists(func_path):
@ -233,7 +233,7 @@ def test_square_aot_bprop():
check_exec_file(cmd_bprop, func_path_bprop, "square_bprop.cu", "square_bprop.so")
try:
aot_bprop = ops.Custom(func_path_bprop + ":CustomSquareBprop",
([3],), (mstype.float32,), "aot", reg_info=None)
(3,), mstype.float32, "aot", reg_info=None)
except Exception as e:
if os.path.exists(func_path_bprop):
os.remove(func_path_bprop)
@ -241,12 +241,12 @@ def test_square_aot_bprop():
def bprop(x, out, dout):
res = aot_bprop(x, out, dout)
return res
return (res,)
cmd, func_path = get_file_path_gpu("square.cu", "square.so")
check_exec_file(cmd, func_path, "square_bprop.cu", "square_bprop.so")
try:
net = SquareGradNet(func_path + ":CustomSquare", ([3],), (mstype.float32,), bprop=bprop, reg=None)
net = SquareGradNet(func_path + ":CustomSquare", (3,), mstype.float32, bprop=bprop, reg=None)
dx = ops.GradOperation(sens_param=True)(net)(Tensor(x), Tensor(sens))
except Exception as e:
if os.path.exists(func_path):