forked from mindspore-Ecosystem/mindspore
commit
0138b6594a
|
@ -178,8 +178,8 @@ class SquareGradNet(Cell):
|
|||
self.square = ops.Custom(func, out_shapes, out_types, "aot", bprop, reg)
|
||||
|
||||
def construct(self, x):
|
||||
res = self.square(x)[0]
|
||||
res2 = self.square(res)[0]
|
||||
res = self.square(x)
|
||||
res2 = self.square(res)
|
||||
return res2
|
||||
|
||||
|
||||
|
@ -205,7 +205,7 @@ def test_square_py_bprop():
|
|||
return (dx,)
|
||||
|
||||
try:
|
||||
net = SquareGradNet(func_path + ":CustomSquare", ([3],), (mstype.float32,), bprop=bprop, reg=None)
|
||||
net = SquareGradNet(func_path + ":CustomSquare", (3,), mstype.float32, bprop=bprop, reg=None)
|
||||
dx = ops.GradOperation(sens_param=True)(net)(Tensor(x), Tensor(sens))
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
|
@ -233,7 +233,7 @@ def test_square_aot_bprop():
|
|||
check_exec_file(cmd_bprop, func_path_bprop, "square_bprop.cu", "square_bprop.so")
|
||||
try:
|
||||
aot_bprop = ops.Custom(func_path_bprop + ":CustomSquareBprop",
|
||||
([3],), (mstype.float32,), "aot", reg_info=None)
|
||||
(3,), mstype.float32, "aot", reg_info=None)
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path_bprop):
|
||||
os.remove(func_path_bprop)
|
||||
|
@ -241,12 +241,12 @@ def test_square_aot_bprop():
|
|||
|
||||
def bprop(x, out, dout):
|
||||
res = aot_bprop(x, out, dout)
|
||||
return res
|
||||
return (res,)
|
||||
|
||||
cmd, func_path = get_file_path_gpu("square.cu", "square.so")
|
||||
check_exec_file(cmd, func_path, "square_bprop.cu", "square_bprop.so")
|
||||
try:
|
||||
net = SquareGradNet(func_path + ":CustomSquare", ([3],), (mstype.float32,), bprop=bprop, reg=None)
|
||||
net = SquareGradNet(func_path + ":CustomSquare", (3,), mstype.float32, bprop=bprop, reg=None)
|
||||
dx = ops.GradOperation(sens_param=True)(net)(Tensor(x), Tensor(sens))
|
||||
except Exception as e:
|
||||
if os.path.exists(func_path):
|
||||
|
|
Loading…
Reference in New Issue