forked from mindspore-Ecosystem/mindspore
!16695 bug fix in lamb
From: @wenfangpei Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
84859aba23
|
@ -71,6 +71,9 @@ class LambApplyOptimizerAssign(Expander):
|
|||
do_use_weight_decay = graph_builder.emit('Mul', [do_use_weight_mul, do_use_weight])
|
||||
update = graph_builder.emit('Add', [do_use_weight_decay, update])
|
||||
|
||||
next_v = graph_builder.emit('Assign', [inputv, next_v])
|
||||
next_m = graph_builder.emit('Assign', [inputm, next_m])
|
||||
|
||||
res = [update, next_v, next_m]
|
||||
|
||||
return res
|
||||
|
|
|
@ -52,5 +52,6 @@ class LambApplyWeightAssign(Expander):
|
|||
|
||||
# input_param - ratio_update_with_ir
|
||||
next_param = graph_builder.emit('Sub', [input_param, ratio_update_with_ir])
|
||||
next_param = graph_builder.emit('Assign', [input_param, next_param])
|
||||
|
||||
return [next_param]
|
||||
|
|
|
@ -19,27 +19,32 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
def __init__(self, v, m):
|
||||
super(Net, self).__init__()
|
||||
self.lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign()
|
||||
self.m = Parameter(m, name='m')
|
||||
self.v = Parameter(v, name='v')
|
||||
|
||||
def construct(self, grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon,
|
||||
def construct(self, grad, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon,
|
||||
steps, do_use_weight, weight_decay_rate):
|
||||
return self.lamb_apply_optimizer_assign(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2,
|
||||
return self.lamb_apply_optimizer_assign(grad, self.v, self.m, input_param, beta_1, one_minus_beta_1, beta_2,
|
||||
one_minus_beta_2, epsilon, steps, do_use_weight, weight_decay_rate)
|
||||
|
||||
|
||||
def get_output(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps,
|
||||
do_use_weight, weight_decay_rate, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
opt = Net()
|
||||
output = opt(Tensor(grad), Tensor(inputv), Tensor(inputm), Tensor(input_param), Tensor(beta_1),
|
||||
opt = Net(Tensor(inputv), Tensor(inputm))
|
||||
output = opt(Tensor(grad), Tensor(input_param), Tensor(beta_1),
|
||||
Tensor(one_minus_beta_1), Tensor(beta_2), Tensor(one_minus_beta_2), Tensor(epsilon), Tensor(steps),
|
||||
Tensor(do_use_weight), Tensor(weight_decay_rate))
|
||||
return output
|
||||
|
||||
return [output[0].asnumpy(), opt.v.data.asnumpy(), opt.m.data.asnumpy()]
|
||||
|
||||
|
||||
def lamb_apply_optimizer_assign():
|
||||
|
||||
|
@ -64,9 +69,10 @@ def lamb_apply_optimizer_assign():
|
|||
e1, e2, e3 = list(expect)
|
||||
o1, o2, o3 = list(output)
|
||||
|
||||
assert np.allclose(o1.asnumpy(), e1.asnumpy())
|
||||
assert np.allclose(o2.asnumpy(), e2.asnumpy())
|
||||
assert np.allclose(o3.asnumpy(), e3.asnumpy())
|
||||
assert np.allclose(o1, e1)
|
||||
assert np.allclose(o2, e2)
|
||||
assert np.allclose(o3, e3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
|
@ -19,22 +19,25 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
def __init__(self, param):
|
||||
super(Net, self).__init__()
|
||||
self.lamb_apply_weight_assign = P.LambApplyWeightAssign()
|
||||
self.param = Parameter(param, name='param')
|
||||
|
||||
def construct(self, w_norm, g_norm, lr, update):
|
||||
return self.lamb_apply_weight_assign(w_norm, g_norm, lr, update, self.param)
|
||||
|
||||
def construct(self, w_norm, g_norm, lr, update, param):
|
||||
return self.lamb_apply_weight_assign(w_norm, g_norm, lr, update, param)
|
||||
|
||||
def get_output(w_norm, g_norm, lr, update, param, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
opt = Net()
|
||||
output = opt(Tensor(w_norm), Tensor(g_norm), Tensor(lr), Tensor(update), Tensor(param))
|
||||
return output
|
||||
opt = Net(Tensor(param))
|
||||
_ = opt(Tensor(w_norm), Tensor(g_norm), Tensor(lr), Tensor(update))
|
||||
return opt.param.data.asnumpy()
|
||||
|
||||
|
||||
def lamb_apply_weight_assign():
|
||||
|
||||
|
@ -47,7 +50,8 @@ def lamb_apply_weight_assign():
|
|||
expect = get_output(w_norm, g_norm, lr, update, param, False)
|
||||
output = get_output(w_norm, g_norm, lr, update, param, True)
|
||||
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy())
|
||||
assert np.allclose(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue