Add expander for AddN; update akg submodule

This commit is contained in:
dayschan 2021-06-23 10:47:30 +08:00
parent fc884e44b6
commit 149dab39c5
6 changed files with 111 additions and 11 deletions

2
akg

@ -1 +1 @@
Subproject commit 32af460cac1bb7d76bc1fd41f5866107cfffe1b9
Subproject commit 97dc7e96c2ffedf2e6e38310a903ffa205a6e656

View File

@ -14,43 +14,44 @@
# ============================================================================
"""expanders init"""
from .addn import AddN
from .assign_add import AssignAdd
from .batchnorm import BatchNorm
from .batchnorm_grad import BatchNormGrad
from .bias_add import BiasAdd
from .bias_add_grad import BiasAddGrad
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
from .conv2d import Conv2D
from .dropout_grad import DropoutGrad
from .expand_dims import ExpandDims
from .fused_adam import FusedAdam
from .fused_adam_weight_decay import FusedAdamWeightDecay
from .fused_mul_add import FusedMulAdd
from .gelu import GeLU
from .gelu_grad import GeLUGrad
from .gkdropout import GkDropout
from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign
from .lamb_apply_weight_assign import LambApplyWeightAssign
from .layernorm import LayerNorm
from .layernorm_grad import LayerNormGrad
from .logsoftmax import LogSoftmax
from .logsoftmax_grad import LogSoftmaxGrad
from .matmul import BatchMatMul, MatMul
from .maximum_grad import MaximumGrad
from .minimum_grad import MinimumGrad
from .reduce_mean import ReduceMean
from .relu import ReLU
from .relu_grad import ReluGrad
from .softmax import Softmax
from .sigmoid import Sigmoid
from .sigmoid_grad import SigmoidGrad
from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits
from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad
from .sigmoid_grad import SigmoidGrad
from .softmax import Softmax
from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits
from .softmax_grad_ext import SoftmaxGradExt
from .sqrt_grad import SqrtGrad
from .square import Square
from .square_sum_v1 import SquareSumV1
from .squeeze import Squeeze
from .tanh_grad import TanhGrad
from .tile import Tile
from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign
from .lamb_apply_weight_assign import LambApplyWeightAssign
from .softmax_grad_ext import SoftmaxGradExt
from .square_sum_v1 import SquareSumV1
from .fused_mul_add import FusedMulAdd
from .conv2d import Conv2D
from .matmul import MatMul, BatchMatMul

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for addn"""
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class AddN(Expander):
"""Expand AddN to multiple Adds"""
def _check(self):
if len(self.inputs) < 2:
raise GKException("Inputs num of AddN should be greater than 1 but got {}".format(len(self.inputs)))
def _expand(self, graph_builder):
result = self.inputs[0]
for inp in self.inputs[1:]:
result = graph_builder.emit('Add', [result, inp])
return result

View File

@ -53,7 +53,6 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
prim::kPrimSub,
prim::kPrimRsqrt,
prim::kPrimSqrt,
prim::kPrimAddN,
prim::kPrimReciprocal,
prim::kPrimTanh,
prim::kPrimReshape,

View File

@ -44,6 +44,7 @@ constexpr size_t kLambWeightInputIdx = 4;
std::vector<PrimitivePtr> GetExpandOps() {
std::vector<PrimitivePtr> expand_ops = {
prim::kPrimAddN,
prim::kPrimSquare,
prim::kPrimGeLUGrad,
prim::kPrimAssignAdd,

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.addn = P.AddN()
def construct(self, *args):
return self.addn(*args)
def get_output(*tensors):
net = Net()
output = net(tensors)
return output
def test_basic():
np.random.seed(0)
tensors = []
expect = np.array([0], np.float32)
for _ in range(10):
t = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
expect = t + expect
tensors.append(Tensor(t))
output = get_output(*tensors).asnumpy()
assert np.allclose(expect, output, 1.e-4, 1.e-7)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_graph_kernel=True)
test_basic()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_graph_kernel=True)
test_basic()