From a4ad6066b10fce18b8b3b1b4277e78791cdd51a4 Mon Sep 17 00:00:00 2001 From: wenfangpei Date: Thu, 4 Mar 2021 16:02:58 +0800 Subject: [PATCH] expander lamb_apply_weight_assign --- .../graph_kernel/expanders/__init__.py | 1 + .../expanders/lamb_apply_weight_assign.py | 56 ++++++++++++++++++ .../graph_kernel/graph_kernel_expander.cc | 7 ++- mindspore/core/base/core_ops.h | 2 + .../test_lamb_apply_weight_assign.py | 58 +++++++++++++++++++ 5 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py create mode 100644 tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index 084e9d676f5..5152dce6420 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -46,3 +46,4 @@ from .square import Square from .tanh_grad import TanhGrad from .tile import Tile from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign +from .lamb_apply_weight_assign import LambApplyWeightAssign diff --git a/mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py b/mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py new file mode 100644 index 00000000000..0e260f5d646 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/lamb_apply_weight_assign.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for LambApplyWeightAssign""" +from ._utils import Expander, ExpanderInfoValidator as VLD + +@VLD.check_all_formats_same +class LambApplyWeightAssign(Expander): + """LambApplyWeightAssign expander""" + + def _expand(self, graph_builder): + + w_norm, g_norm, input_lr, update, input_param = self.inputs + # ratio + const_zero = graph_builder.value(g_norm.dtype, 0) + const_one = graph_builder.value(g_norm.dtype, 1) + dtype = update.dtype + + g_norm_greater_res = graph_builder.emit('Greater', [g_norm, const_zero]) + g_norm_greater_res_float = graph_builder.emit('Cast', [g_norm_greater_res], attrs={'dst_type': dtype}) + + w_norm_g_norm = graph_builder.emit('RealDiv', [w_norm, g_norm]) + # select + g_norm_greater_res_neg = graph_builder.emit('Neg', [g_norm_greater_res_float]) + g_norm_greater_res_f = graph_builder.emit('Add', [g_norm_greater_res_neg, const_one]) + g_norm_value_1 = graph_builder.emit('Mul', [g_norm_greater_res_float, w_norm_g_norm]) + g_norm_value = graph_builder.emit('Add', [g_norm_value_1, g_norm_greater_res_f]) + + w_norm_greater_res = graph_builder.emit('Greater', [w_norm, const_zero]) + w_norm_greater_res_float = graph_builder.emit('Cast', [w_norm_greater_res], attrs={'dst_type': dtype}) + + # select + w_norm_greater_res_neg = graph_builder.emit('Neg', [w_norm_greater_res_float]) + w_norm_greater_res_f = graph_builder.emit('Add', [w_norm_greater_res_neg, const_one]) + w_norm_value_1 = graph_builder.emit('Mul', [w_norm_greater_res_float, g_norm_value]) + ratio = graph_builder.emit('Add', [w_norm_value_1, w_norm_greater_res_f]) + + # ratio * input_lr * update + update_with_ir = graph_builder.emit('Mul', [update, input_lr]) + ratio_update_with_ir = graph_builder.emit('Mul', [update_with_ir, ratio]) + + # input_param - ratio_update_with_ir + next_param = graph_builder.emit('Sub', [input_param, ratio_update_with_ir]) + + return [next_param] diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 74e325669b2..0bc97b6389c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -39,7 +39,8 @@ namespace mindspore { namespace opt { namespace { constexpr size_t kAssignInputIdx = 1; -constexpr size_t kLambInputIdx = 12; +constexpr size_t kLambOptimizerInputIdx = 12; +constexpr size_t kLambWeightInputIdx = 4; std::vector GetExpandOps() { std::vector expand_ops = { @@ -51,6 +52,7 @@ std::vector GetExpandOps() { prim::kPrimSqrtGrad, prim::kPrimClipByNormNoDivSum, prim::kLambApplyOptimizerAssign, + prim::kLambApplyWeightAssign, #elif ENABLE_GPU prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, @@ -176,7 +178,8 @@ ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) { {prim::kPrimDropout, std::make_shared()}, {prim::kPrimAssignAdd, std::make_shared(kAssignInputIdx)}, {prim::kPrimAssignSub, std::make_shared(kAssignInputIdx)}, - {prim::kLambApplyOptimizerAssign, std::make_shared(kLambInputIdx)}, + {prim::kLambApplyOptimizerAssign, std::make_shared(kLambOptimizerInputIdx)}, + {prim::kLambApplyWeightAssign, std::make_shared(kLambWeightInputIdx)}, }; for (auto &e : expanders) { diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 9f446640df0..6f4cb345f36 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -305,6 +305,8 @@ inline const PrimitivePtr kPrimTensorMove = std::make_shared("TensorM inline const PrimitivePtr kPrimL2Normalize = std::make_shared("L2Normalize"); inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared("CustomExtractFeatures"); inline const PrimitivePtr kLambApplyOptimizerAssign = std::make_shared("LambApplyOptimizerAssign"); +inline const PrimitivePtr kLambApplyWeightAssign = std::make_shared("LambApplyWeightAssign"); + // Comm ops inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared("_MirrorMiniStepOperator"); diff --git a/tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py b/tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py new file mode 100644 index 00000000000..92447037e89 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_lamb_apply_weight_assign.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.lamb_apply_weight_assign = P.LambApplyWeightAssign() + + def construct(self, w_norm, g_norm, lr, update, param): + return self.lamb_apply_weight_assign(w_norm, g_norm, lr, update, param) + +def get_output(w_norm, g_norm, lr, update, param, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + opt = Net() + output = opt(Tensor(w_norm), Tensor(g_norm), Tensor(lr), Tensor(update), Tensor(param)) + return output + +def lamb_apply_weight_assign(): + + w_norm = np.array([0.11]).astype(np.float32) + g_norm = np.array([1.2]).astype(np.float32) + lr = np.array([0.012]).astype(np.float32) + update = np.array([0.01, 0.03, 0.05]).astype(np.float32) + param = np.array([1, 3, 5]).astype(np.float32) + + expect = get_output(w_norm, g_norm, lr, update, param, False) + output = get_output(w_norm, g_norm, lr, update, param, True) + + assert np.allclose(output.asnumpy(), expect.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_lamb_apply_weight_assign_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + lamb_apply_weight_assign()