From fd84c20a6ae75597da44616b518781de4f899106 Mon Sep 17 00:00:00 2001 From: wenfangpei Date: Fri, 21 May 2021 11:14:57 +0800 Subject: [PATCH] bug fix in lamb --- .../expanders/lamb_apply_optimizer_assign.py | 3 +++ .../expanders/lamb_apply_weight_assign.py | 1 + .../test_lamb_apply_optimizer_assign.py | 26 ++++++++++++------- .../test_lamb_apply_weight_assign.py | 20 ++++++++------ 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/mindspore/_extends/graph_kernel/expanders/lamb_apply_optimizer_assign.py b/mindspore/_extends/graph_kernel/expanders/lamb_apply_optimizer_assign.py index 8820cbfb419..fd6e67ea052 100644 --- a/mindspore/_extends/graph_kernel/expanders/lamb_apply_optimizer_assign.py +++ b/mindspore/_extends/graph_kernel/expanders/lamb_apply_optimizer_assign.py @@ -71,6 +71,9 @@ class LambApplyOptimizerAssign(Expander): do_use_weight_decay = graph_builder.emit('Mul', [do_use_weight_mul, do_use_weight]) update = graph_builder.emit('Add', [do_use_weight_decay, update]) + next_v = graph_builder.emit('Assign', [inputv, next_v]) + next_m = graph_builder.emit('Assign', [inputm, next_m]) + res = [update, next_v, next_m] return res diff --git a/mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py b/mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py index 0e260f5d646..13fcbf3abfb 100644 --- a/mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py +++ b/mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py @@ -52,5 +52,6 @@ class LambApplyWeightAssign(Expander): # input_param - ratio_update_with_ir next_param = graph_builder.emit('Sub', [input_param, ratio_update_with_ir]) + next_param = graph_builder.emit('Assign', [input_param, next_param]) return [next_param] diff --git a/tests/st/ops/graph_kernel/test_lamb_apply_optimizer_assign.py b/tests/st/ops/graph_kernel/test_lamb_apply_optimizer_assign.py index 60ee32dcd59..b3ac5e0edfc 100644 --- a/tests/st/ops/graph_kernel/test_lamb_apply_optimizer_assign.py +++ b/tests/st/ops/graph_kernel/test_lamb_apply_optimizer_assign.py @@ -19,27 +19,32 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P - +from mindspore.common.parameter import Parameter class Net(nn.Cell): - def __init__(self): + def __init__(self, v, m): super(Net, self).__init__() self.lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign() + self.m = Parameter(m, name='m') + self.v = Parameter(v, name='v') - def construct(self, grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, + def construct(self, grad, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps, do_use_weight, weight_decay_rate): - return self.lamb_apply_optimizer_assign(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, + return self.lamb_apply_optimizer_assign(grad, self.v, self.m, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps, do_use_weight, weight_decay_rate) + def get_output(grad, inputv, inputm, input_param, beta_1, one_minus_beta_1, beta_2, one_minus_beta_2, epsilon, steps, do_use_weight, weight_decay_rate, enable_graph_kernel=False): context.set_context(enable_graph_kernel=enable_graph_kernel) - opt = Net() - output = opt(Tensor(grad), Tensor(inputv), Tensor(inputm), Tensor(input_param), Tensor(beta_1), + opt = Net(Tensor(inputv), Tensor(inputm)) + output = opt(Tensor(grad), Tensor(input_param), Tensor(beta_1), Tensor(one_minus_beta_1), Tensor(beta_2), Tensor(one_minus_beta_2), Tensor(epsilon), Tensor(steps), Tensor(do_use_weight), Tensor(weight_decay_rate)) - return output + + return [output[0].asnumpy(), opt.v.data.asnumpy(), opt.m.data.asnumpy()] + def lamb_apply_optimizer_assign(): @@ -64,9 +69,10 @@ def lamb_apply_optimizer_assign(): e1, e2, e3 = list(expect) o1, o2, o3 = list(output) - assert np.allclose(o1.asnumpy(), e1.asnumpy()) - assert np.allclose(o2.asnumpy(), e2.asnumpy()) - assert np.allclose(o3.asnumpy(), e3.asnumpy()) + assert np.allclose(o1, e1) + assert np.allclose(o2, e2) + assert np.allclose(o3, e3) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training diff --git a/tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py b/tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py index 92447037e89..8facbc9ac05 100644 --- a/tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py +++ b/tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py @@ -19,22 +19,25 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P - +from mindspore.common.parameter import Parameter class Net(nn.Cell): - def __init__(self): + def __init__(self, param): super(Net, self).__init__() self.lamb_apply_weight_assign = P.LambApplyWeightAssign() + self.param = Parameter(param, name='param') + + def construct(self, w_norm, g_norm, lr, update): + return self.lamb_apply_weight_assign(w_norm, g_norm, lr, update, self.param) - def construct(self, w_norm, g_norm, lr, update, param): - return self.lamb_apply_weight_assign(w_norm, g_norm, lr, update, param) def get_output(w_norm, g_norm, lr, update, param, enable_graph_kernel=False): context.set_context(enable_graph_kernel=enable_graph_kernel) - opt = Net() - output = opt(Tensor(w_norm), Tensor(g_norm), Tensor(lr), Tensor(update), Tensor(param)) - return output + opt = Net(Tensor(param)) + _ = opt(Tensor(w_norm), Tensor(g_norm), Tensor(lr), Tensor(update)) + return opt.param.data.asnumpy() + def lamb_apply_weight_assign(): @@ -47,7 +50,8 @@ def lamb_apply_weight_assign(): expect = get_output(w_norm, g_norm, lr, update, param, False) output = get_output(w_norm, g_norm, lr, update, param, True) - assert np.allclose(output.asnumpy(), expect.asnumpy()) + assert np.allclose(output, expect) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training