bug fix in lamb

This commit is contained in:
wenfangpei 2021-05-21 11:14:57 +08:00
parent 5ed742a893
commit fd84c20a6a
4 changed files with 32 additions and 18 deletions

View File

@ -71,6 +71,9 @@ class LambApplyOptimizerAssign(Expander):
do_use_weight_decay = graph_builder.emit('Mul', [do_use_weight_mul, do_use_weight])
update = graph_builder.emit('Add', [do_use_weight_decay, update])
next_v = graph_builder.emit('Assign', [inputv, next_v])
next_m = graph_builder.emit('Assign', [inputm, next_m])
res = [update, next_v, next_m]
return res

View File

@ -52,5 +52,6 @@ class LambApplyWeightAssign(Expander):
# input_param - ratio_update_with_ir
next_param = graph_builder.emit('Sub', [input_param, ratio_update_with_ir])
next_param = graph_builder.emit('Assign', [input_param, next_param])
return [next_param]

View File

@ -19,27 +19,32 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
class Net(nn.Cell):
def __init__(self):
def __init__(self, v, m):
super(Net, self).__init__()
self.lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign()
self.m = Parameter(m, name='m')
self.v = Parameter(v, name='v')
def construct(self, grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon,
def construct(self, grad, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon,
steps, do_use_weight, weight_decay_rate):
return self.lamb_apply_optimizer_assign(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2,
return self.lamb_apply_optimizer_assign(grad, self.v, self.m, input_param, beta_1, one_minus_beta_1, beta_2,
one_minus_beta_2, epsilon, steps, do_use_weight, weight_decay_rate)
def get_output(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps,
do_use_weight, weight_decay_rate, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
opt = Net()
output = opt(Tensor(grad), Tensor(inputv), Tensor(inputm), Tensor(input_param), Tensor(beta_1),
opt = Net(Tensor(inputv), Tensor(inputm))
output = opt(Tensor(grad), Tensor(input_param), Tensor(beta_1),
Tensor(one_minus_beta_1), Tensor(beta_2), Tensor(one_minus_beta_2), Tensor(epsilon), Tensor(steps),
Tensor(do_use_weight), Tensor(weight_decay_rate))
return output
return [output[0].asnumpy(), opt.v.data.asnumpy(), opt.m.data.asnumpy()]
def lamb_apply_optimizer_assign():
@ -64,9 +69,10 @@ def lamb_apply_optimizer_assign():
e1, e2, e3 = list(expect)
o1, o2, o3 = list(output)
assert np.allclose(o1.asnumpy(), e1.asnumpy())
assert np.allclose(o2.asnumpy(), e2.asnumpy())
assert np.allclose(o3.asnumpy(), e3.asnumpy())
assert np.allclose(o1, e1)
assert np.allclose(o2, e2)
assert np.allclose(o3, e3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training

View File

@ -19,22 +19,25 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
class Net(nn.Cell):
def __init__(self):
def __init__(self, param):
super(Net, self).__init__()
self.lamb_apply_weight_assign = P.LambApplyWeightAssign()
self.param = Parameter(param, name='param')
def construct(self, w_norm, g_norm, lr, update):
return self.lamb_apply_weight_assign(w_norm, g_norm, lr, update, self.param)
def construct(self, w_norm, g_norm, lr, update, param):
return self.lamb_apply_weight_assign(w_norm, g_norm, lr, update, param)
def get_output(w_norm, g_norm, lr, update, param, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
opt = Net()
output = opt(Tensor(w_norm), Tensor(g_norm), Tensor(lr), Tensor(update), Tensor(param))
return output
opt = Net(Tensor(param))
_ = opt(Tensor(w_norm), Tensor(g_norm), Tensor(lr), Tensor(update))
return opt.param.data.asnumpy()
def lamb_apply_weight_assign():
@ -47,7 +50,8 @@ def lamb_apply_weight_assign():
expect = get_output(w_norm, g_norm, lr, update, param, False)
output = get_output(w_norm, g_norm, lr, update, param, True)
assert np.allclose(output.asnumpy(), expect.asnumpy())
assert np.allclose(output, expect)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training