support expand fused_adam and fused_adam_weight_decay op

This commit is contained in:
zengzitao 2020-11-05 16:39:05 +08:00
parent 3070e9c78b
commit 53043ae18f
14 changed files with 238 additions and 6 deletions

2
akg

@ -1 +1 @@
Subproject commit f308919c39811c2c3e07fb0dcc8054a533c84cbc
Subproject commit 2956e64803cad9b84316cdf2b25d034c5f944ccc

View File

@ -21,3 +21,5 @@ from .softmax import expand_softmax
from .square import expand_square
from .bias_add import expand_biasadd
from .bias_add_grad import expand_biasaddgrad
from .fused_adam import expand_fusedadam
from .fused_adam_weight_decay import expand_fusedadamweightdecay

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for fused_adam"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_fusedadam(expand_info):
"""FusedAdma expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
input_desc_3 = expand_info['input_desc'][3]
input_desc_4 = expand_info['input_desc'][4]
input_desc_5 = expand_info['input_desc'][5]
input_desc_6 = expand_info['input_desc'][6]
input_desc_7 = expand_info['input_desc'][7]
input_desc_8 = expand_info['input_desc'][8]
input_desc_9 = expand_info['input_desc'][9]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format'])
lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format'])
param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format'])
m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format'])
v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format'])
gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format'])
graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient)
# compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad])
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
grad_square = graph_builder.emit('Mul', [gradient, gradient])
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps])
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr])
param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True})
v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True})
# set graph output.
graph_scope.set_output(param_result, m_result, v_result)
graph = graph_builder.get()[0]
return graph

View File

@ -0,0 +1,76 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for fused_adam_weight_decay"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_fusedadamweightdecay(expand_info):
"""FusedAdmaWeightDecay expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
input_desc_3 = expand_info['input_desc'][3]
input_desc_4 = expand_info['input_desc'][4]
input_desc_5 = expand_info['input_desc'][5]
input_desc_6 = expand_info['input_desc'][6]
input_desc_7 = expand_info['input_desc'][7]
input_desc_8 = expand_info['input_desc'][8]
input_desc_9 = expand_info['input_desc'][9]
input_desc_10 = expand_info['input_desc'][10]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format'])
lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format'])
param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format'])
m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format'])
v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format'])
gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format'])
weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format'])
graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2,
eps, lr, param, m, v, gradient, weight_decay)
# compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
next_m = graph_builder.emit('TensorAdd', [beta_1_mul_m, one_sub_beta_1_mul_grad])
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
grad_square = graph_builder.emit('Mul', [gradient, gradient])
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
next_v = graph_builder.emit('TensorAdd', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
sqrt_next_v_add_eps = graph_builder.emit('TensorAdd', [sqrt_next_v, eps])
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
update = graph_builder.emit('TensorAdd', [update, param_with_weight_decay])
update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr])
para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True})
v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True})
# set graph output.
graph_scope.set_output(para_result, m_result, v_result)
graph = graph_builder.get()[0]
return graph

View File

@ -154,6 +154,8 @@ class PrimLib:
'ControlDepend': Prim(CONTROL),
'Assign': Prim(ELEMWISE),
'Tanh': Prim(ELEMWISE),
'ExpandDims': Prim(ELEMWISE),
'InplaceAssign': Prim(ELEMWISE),
'@ReduceInit': Prim(ELEMWISE),
}

View File

@ -70,6 +70,7 @@ class OpInfer:
infer_shape_func = {
# add special infer func here
'InplaceAssign': lambda inputs, attrs: inputs[2].shape
}
infer_dtype_func = {
# add special infer func here

View File

@ -560,7 +560,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
auto shape = GetNodeShape(node);
if (shape.size() != 0 && shape.size() != 1) {
return node;
return nullptr;
} else {
auto tmp_node = node->cast<CNodePtr>();
auto transpose_node = tmp_node->input(1);
@ -635,7 +635,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
return new_cnode;
}
return node;
return nullptr;
};
auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr {
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node);

View File

@ -702,7 +702,8 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu, prim::kPrimGeluGrad,
prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu,
prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay,
};
return expand_ops;
}
@ -729,7 +730,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh};
prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape};
return fusible_basic_ops;
}

View File

@ -174,6 +174,8 @@ inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("SparseApplyFtrl");
inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad");
inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam");
inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

View File

@ -20,6 +20,7 @@ from .hsigmoid import _hsigmoid_akg
from .hsigmoid_grad import _hsigmoid_grad_akg
from .hswish import _hswish_akg
from .hswish_grad import _hswish_grad_akg
from .inplace_assign import _inplace_assign_akg
from .lessequal import _lessequal_akg
from .logical_and import _logical_and_akg
from .logical_not import _logical_not_akg

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InplaceAssign op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("InplaceAssign") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.input(2, "z") \
.output(0, "output") \
.attr("fake_output", "optional", "bool") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \
.dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \
.dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \
.dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \
.get_op_info()
@op_info_register(op_info)
def _inplace_assign_akg():
"""InplaceAssign Akg register"""
return

View File

@ -82,7 +82,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler)
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,

View File

@ -65,6 +65,36 @@ class Assign(PrimitiveWithCheck):
validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name)
class InplaceAssign(PrimitiveWithInfer):
"""
Inplace assign `Parameter` with a value.
This primitive can only use in graph kernel.
Inputs:
- **variable** (Parameter) - The `Parameter`.
- **value** (Tensor) - The value to be assigned.
- **depend** (Tensor) - The dependent tensor to keep this op connected in graph.
Outputs:
Tensor, has the same type as original `variable`.
Examples:
>>> def construct(self, x):
>>> val = x - 1.0
>>> ret = x + 2.0
>>> return InplaceAssign()(x, val, ret)
>>> x = Tensor([2.0], mindspore.float32)
>>> net = Net()
>>> net(x)
"""
@ prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
def infer_shape(self, x, y, z):
return z
def infer_dtype(self, x, y, z):
return z
class BoundingBoxEncode(PrimitiveWithInfer):
"""
Encodes bounding boxes locations.
@ -509,6 +539,7 @@ class PopulationCount(PrimitiveWithInfer):
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name)
return mstype.tensor_type(mstype.uint8)
class Push(PrimitiveWithInfer):
"""
Pushes the inputs of the corresponding optimizer to parameter server.
@ -539,6 +570,7 @@ class Push(PrimitiveWithInfer):
def infer_dtype(self, inputs, shapes):
return mstype.uint64
class Pull(PrimitiveWithInfer):
"""
Pulls weight from parameter server.
@ -563,6 +595,7 @@ class Pull(PrimitiveWithInfer):
def infer_dtype(self, key_dtype, weight_dtype):
return mstype.float32
class identity(Primitive):
"""
Makes a identify primitive, used for pynative mode.

View File

@ -52,6 +52,7 @@ def CalGelu(x):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gelu():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
net = GeluNet()
@ -67,6 +68,7 @@ def test_gelu():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gelu_grad():
np.random.seed(0)
input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = CalGelu(input_x)