diff --git a/akg b/akg index f3168164c45..4aac4d95750 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit f3168164c452316c21709f3293ef3b31a3688062 +Subproject commit 4aac4d95750a87e664f175c0fa946a069f8a0c2a diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 46865422012..363401992eb 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -14,7 +14,7 @@ # =========================================================================== """Cost model splitter""" import os -from functools import reduce +from functools import reduce as prod_reduce from mindspore import log as logger from .model import PrimLib, Graph, Tensor, Operator from .model import DataFormat as DF @@ -98,6 +98,7 @@ class GraphSplitByPattern: return str(self) def get_relation(self, op, i): + """Get op relation""" relation = PrimLib.UNKNOWN _, elem_relation = PrimLib.input_relation(op, i) for r in elem_relation: @@ -122,6 +123,7 @@ class GraphSplitByPattern: self.reach_tab.sync(self.unique_id, out.unique_id) def update_stitch_info(self, stitch_info): + """Update stitch info""" if stitch_info.stitch_ops: self.stitch_info.stitch_ops.update(stitch_info.stitch_ops) if stitch_info.stitch_atomic_ops: @@ -180,9 +182,11 @@ class GraphSplitByPattern: return True def dom_op(self): + """Get dom op""" return self.ops[0] def reduce_out_exclude(self, area): + """Check whether op is redcue_out_exclude """ if self.output_excluded: for op in self.output_excluded: if op in area.ops: @@ -260,6 +264,7 @@ class GraphSplitByPattern: self.area_map[op] = area def set_default_mode(self, area): + """Set default mode""" area.mode = self.get_default_mode(area.ops[0]) def limit_area_size(self, dominant, fuse_areas): @@ -267,7 +272,7 @@ class GraphSplitByPattern: limit_size = 200 # an experience number area_sizes = map(lambda area: len(area.ops), fuse_areas) dom_size = len(dominant.ops) - if dom_size + reduce(lambda x, y: x+y, area_sizes) <= limit_size: + if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size: return fuse_areas # fuse the smaller area in priority fuse_areas.sort(key=lambda area: len(area.ops)) @@ -358,8 +363,9 @@ class GraphSplitByPattern: with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f: f.write(subgraphs_str) - def pattern_fuse(self, select=None): + def pattern_fuse(self, fuse_func=None): """fuse Areas by pattern repeatedly""" + del fuse_func raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__)) def split(self): @@ -566,6 +572,7 @@ class GraphSplitGpu(GraphSplitByPattern): REDUCE_FUSE_DEPTH = 20 def get_default_mode(self, op): + """Get default mode in GPU""" if op.prim == "MatMul": return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" and op.attrs['Akg'] else \ self.Area.MODE_BASIC @@ -696,9 +703,14 @@ class GraphSplitGpu(GraphSplitByPattern): if any(["Reduce" in x.prim for x in dom.ops[1:]]): return False op = dom.ops[0] - reduce_axis = op.attrs["reduce_axis"] + if "reduce_axis" in op.attrs: + reduce_axis = op.attrs["reduce_axis"] + elif "axis" in op.attrs: + reduce_axis = [op.attrs["axis"]] + else: + raise Exception("the operator has no attr reduce_axis or axis") if len(op.inputs[0].shape) - 1 in reduce_axis: - reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) + reduce_size = prod_reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) return reduce_size >= 1024 return True @@ -753,7 +765,7 @@ class GraphSplitGpu(GraphSplitByPattern): if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): if _reduce_nums(a.ops) < 2: dom_outs = [op.output for op in dom.ops] - a_ins = [input for op in a.ops for input in op.inputs] + a_ins = [op_input for op in a.ops for op_input in op.inputs] a_outs = [op.output for op in a.ops] a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins] stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] @@ -832,7 +844,7 @@ class GraphSplitAscend(GraphSplitByPattern): REDUCE_FUSE_DEPTH = 10 def get_default_mode(self, op): - """Get efault mode for op""" + """Get efault mode for Ascend""" def _dtype_same(tensors): dtype = tensors[0].dtype for tensor_ in tensors: diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 2c0debb6729..4dcec3e1466 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -17,6 +17,9 @@ class Utils: """Model utils""" + def __init__(self): + pass + @staticmethod def get_attr_type(attr): """Get attr type""" @@ -54,6 +57,9 @@ class DataFormat: FRACTAL_Z_C04 = "FRACTAL_Z_C04" NDHWC = "NDHWC" + def __init__(self): + pass + class DataType: """Data Type""" @@ -73,11 +79,8 @@ class DataType: UINT64 = "uint64" BOOL = "bool" - -class Config: - R0 = 8.0 - UB_SIZE = 256 * 1024 - MAX_BLOCK = 32 + def __init__(self): + pass class PrimLib: @@ -90,6 +93,9 @@ class PrimLib: REDUCE = 4 OPAQUE = 5 + def __init__(self): + pass + class Prim: """Prim""" @@ -101,6 +107,7 @@ class PrimLib: self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x) def default_reshape_relation(self, op, input_idx): + """Process reshape relation""" axis_relation, elem_relation = self.unknown_relation(op, input_idx) elem_relation = [PrimLib.RESHAPE] * len(elem_relation) return axis_relation, elem_relation @@ -189,6 +196,8 @@ class PrimLib: 'ReduceSum': Prim(REDUCE), 'ReduceMax': Prim(REDUCE), 'ReduceMin': Prim(REDUCE), + 'Argmax': Prim(REDUCE), + 'Argmin': Prim(REDUCE), 'Assign': Prim(ELEMWISE), 'Sign': Prim(ELEMWISE), 'Sin': Prim(ELEMWISE), @@ -225,6 +234,7 @@ class PrimLib: @classmethod def get_prim(cls, op): + """Get op primtive""" prim = cls.primtives.get(op.prim, None) if prim is None: print('[WARN] primtive is not registered: ' + op.prim) @@ -233,22 +243,27 @@ class PrimLib: @classmethod def input_relation(cls, op, input_idx): + """Get op's input_relation according to input_idx""" return cls.get_prim(op).relation_func(op, input_idx) @classmethod def iter_type(cls, op): + """Get op's iter type""" return cls.get_prim(op).iter_type @classmethod def is_reduce(cls, op): + """Check whether op's iter type is reduce""" return cls.get_prim(op).iter_type == cls.REDUCE @classmethod def calibrate_iter_size(cls, op, iter_size): + """Get calibrate_iter_size""" return cls.get_prim(op).calibrate * iter_size @classmethod def dtype_bytes(cls, dtype): + """Get dtype bytes""" bits, unit = 1, 1 for i in range(len(dtype) - 1, 0, -1): if dtype[i].isdecimal(): @@ -260,6 +275,7 @@ class PrimLib: @classmethod def inplace_reuse(cls, op, input_idx, start_axis=0): + """Check whether op is inplace reuse""" if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype): return False _, elem_relation = cls.get_prim(op).relation_func(op, input_idx) @@ -277,6 +293,8 @@ class Tensor: PARA_OUTPUT = 2 class Buddy: + """Buddy""" + def __init__(self, leader): self.members = [leader] @@ -328,6 +346,7 @@ class Value: return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) def get_size(self): + """Get size""" return 1 @@ -365,6 +384,7 @@ class Graph: self.outputs = [] self.stitch_info = stitch_info self.recompute_ops = recompute_ops + self.processor = "" def set_processor(self, processor): """Set processor""" @@ -498,7 +518,7 @@ class AlignShape(GraphVisitor): """Align shape""" def __init__(self): - super().__init__() + super(AlignShape, self).__init__() def visit(self, op): """Visit op node""" @@ -517,7 +537,7 @@ class AddControlBuddy(GraphVisitor): """Add control buddy""" def __init__(self): - super().__init__() + super(AddControlBuddy, self).__init__() self.buddies = {} # {op : [ctrl_op]} def visit(self, op): @@ -536,13 +556,15 @@ class AddControlBuddy(GraphVisitor): def visit_graph(self, graph): """Visit graph nodes""" - super().visit_graph(graph) + super(AddControlBuddy, self).visit_graph(graph) for owner in self.buddies: for op in self.buddies[owner]: owner.add_buddy(op.output) class GraphKernelUnsupportedException(Exception): + """"GraphKernel Unsupported Exception""" + def __init__(self, message): - super().__init__() + super(GraphKernelUnsupportedException, self).__init__() self.message = message diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/axis_normalizer.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/axis_normalizer.cc index 73b5a76a752..632818b5d1c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/axis_normalizer.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/axis_normalizer.cc @@ -26,7 +26,8 @@ namespace opt { int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast(rank); } bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const { - std::vector node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; + std::vector node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin, + prim::kPrimArgMax, prim::kPrimArgMin}; return std::any_of(node_with_axis.begin(), node_with_axis.end(), [&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); }); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index 0204c9b6669..2d26a864548 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -68,6 +68,8 @@ std::vector GetClusterableOpList() { #elif ENABLE_GPU prim::kPrimACos, prim::kPrimAcosh, + prim::kPrimArgMax, + prim::kPrimArgMin, prim::kPrimAsin, prim::kPrimAsinh, prim::kPrimAssign, diff --git a/tests/st/ops/graph_kernel/test_argmax.py b/tests/st/ops/graph_kernel/test_argmax.py new file mode 100644 index 00000000000..458aebd45c2 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_argmax.py @@ -0,0 +1,59 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class ArgMax(nn.Cell): + def __init__(self, axis): + super(ArgMax, self).__init__() + self.arg_max = P.Argmax(axis=axis) + + def construct(self, x): + return self.arg_max(x) + + +def get_output(x, axis, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = ArgMax(axis) + output = net(x) + return output + + +def test_argmax(): + x0 = Tensor(np.random.normal(0, 1, [2, 3, 4, 4]).astype(np.float32)) + axis0 = 3 + expect = get_output(x0, axis0, False) + output = get_output(x0, axis0, True) + assert np.allclose(expect.asnumpy(), output.asnumpy(), 0.0001, 0.0001) + + x1 = Tensor(np.random.normal(0, 1, [2, 3, 1, 4]).astype(np.float32)) + axis1 = 2 + expect = get_output(x1, axis1, False) + output = get_output(x1, axis1, True) + assert np.allclose(expect.asnumpy(), output.asnumpy(), 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmax_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_argmax()