forked from mindspore-Ecosystem/mindspore
add primitive operator to test_lamb
This commit is contained in:
parent
b052ecf43b
commit
793737ab62
|
@ -33,7 +33,8 @@ class LambNet(Cell):
|
|||
|
||||
def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3,
|
||||
x1, x2, x3, x4, x5, gy, se, my):
|
||||
return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0,
|
||||
i1_ = i1 + i3
|
||||
return self.lamb_next(i1_, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0,
|
||||
ix1, ix2, ix3), \
|
||||
self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my)
|
||||
|
||||
|
@ -113,7 +114,8 @@ def test_graph_kernel_lamb():
|
|||
|
||||
context.set_context(enable_graph_kernel=False)
|
||||
|
||||
a3, a0, a1, up = LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0,
|
||||
i1_ = i1 + i3
|
||||
a3, a0, a1, up = LambNextMVNumpy(i1_, i2, i3, i4, i5, i6, i7, i8, i9, ix0,
|
||||
ix1, ix2, ix3)
|
||||
|
||||
np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my)
|
||||
|
|
Loading…
Reference in New Issue