forked from mindspore-Ecosystem/mindspore
!20966 Update graph kernel support for argmax/argmin
Merge pull request !20966 from zichun_ye/argmax_min_kernelgraph
This commit is contained in:
commit
7b20a5adf7
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit f3168164c452316c21709f3293ef3b31a3688062
|
||||
Subproject commit 4aac4d95750a87e664f175c0fa946a069f8a0c2a
|
|
@ -14,7 +14,7 @@
|
|||
# ===========================================================================
|
||||
"""Cost model splitter"""
|
||||
import os
|
||||
from functools import reduce
|
||||
from functools import reduce as prod_reduce
|
||||
from mindspore import log as logger
|
||||
from .model import PrimLib, Graph, Tensor, Operator
|
||||
from .model import DataFormat as DF
|
||||
|
@ -98,6 +98,7 @@ class GraphSplitByPattern:
|
|||
return str(self)
|
||||
|
||||
def get_relation(self, op, i):
|
||||
"""Get op relation"""
|
||||
relation = PrimLib.UNKNOWN
|
||||
_, elem_relation = PrimLib.input_relation(op, i)
|
||||
for r in elem_relation:
|
||||
|
@ -122,6 +123,7 @@ class GraphSplitByPattern:
|
|||
self.reach_tab.sync(self.unique_id, out.unique_id)
|
||||
|
||||
def update_stitch_info(self, stitch_info):
|
||||
"""Update stitch info"""
|
||||
if stitch_info.stitch_ops:
|
||||
self.stitch_info.stitch_ops.update(stitch_info.stitch_ops)
|
||||
if stitch_info.stitch_atomic_ops:
|
||||
|
@ -180,9 +182,11 @@ class GraphSplitByPattern:
|
|||
return True
|
||||
|
||||
def dom_op(self):
|
||||
"""Get dom op"""
|
||||
return self.ops[0]
|
||||
|
||||
def reduce_out_exclude(self, area):
|
||||
"""Check whether op is redcue_out_exclude """
|
||||
if self.output_excluded:
|
||||
for op in self.output_excluded:
|
||||
if op in area.ops:
|
||||
|
@ -260,6 +264,7 @@ class GraphSplitByPattern:
|
|||
self.area_map[op] = area
|
||||
|
||||
def set_default_mode(self, area):
|
||||
"""Set default mode"""
|
||||
area.mode = self.get_default_mode(area.ops[0])
|
||||
|
||||
def limit_area_size(self, dominant, fuse_areas):
|
||||
|
@ -267,7 +272,7 @@ class GraphSplitByPattern:
|
|||
limit_size = 200 # an experience number
|
||||
area_sizes = map(lambda area: len(area.ops), fuse_areas)
|
||||
dom_size = len(dominant.ops)
|
||||
if dom_size + reduce(lambda x, y: x+y, area_sizes) <= limit_size:
|
||||
if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size:
|
||||
return fuse_areas
|
||||
# fuse the smaller area in priority
|
||||
fuse_areas.sort(key=lambda area: len(area.ops))
|
||||
|
@ -358,8 +363,9 @@ class GraphSplitByPattern:
|
|||
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
|
||||
f.write(subgraphs_str)
|
||||
|
||||
def pattern_fuse(self, select=None):
|
||||
def pattern_fuse(self, fuse_func=None):
|
||||
"""fuse Areas by pattern repeatedly"""
|
||||
del fuse_func
|
||||
raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__))
|
||||
|
||||
def split(self):
|
||||
|
@ -566,6 +572,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
REDUCE_FUSE_DEPTH = 20
|
||||
|
||||
def get_default_mode(self, op):
|
||||
"""Get default mode in GPU"""
|
||||
if op.prim == "MatMul":
|
||||
return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" and op.attrs['Akg'] else \
|
||||
self.Area.MODE_BASIC
|
||||
|
@ -696,9 +703,14 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
||||
return False
|
||||
op = dom.ops[0]
|
||||
reduce_axis = op.attrs["reduce_axis"]
|
||||
if "reduce_axis" in op.attrs:
|
||||
reduce_axis = op.attrs["reduce_axis"]
|
||||
elif "axis" in op.attrs:
|
||||
reduce_axis = [op.attrs["axis"]]
|
||||
else:
|
||||
raise Exception("the operator has no attr reduce_axis or axis")
|
||||
if len(op.inputs[0].shape) - 1 in reduce_axis:
|
||||
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
|
||||
reduce_size = prod_reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
|
||||
return reduce_size >= 1024
|
||||
return True
|
||||
|
||||
|
@ -753,7 +765,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||
if _reduce_nums(a.ops) < 2:
|
||||
dom_outs = [op.output for op in dom.ops]
|
||||
a_ins = [input for op in a.ops for input in op.inputs]
|
||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||
a_outs = [op.output for op in a.ops]
|
||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||
|
@ -832,7 +844,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
REDUCE_FUSE_DEPTH = 10
|
||||
|
||||
def get_default_mode(self, op):
|
||||
"""Get efault mode for op"""
|
||||
"""Get efault mode for Ascend"""
|
||||
def _dtype_same(tensors):
|
||||
dtype = tensors[0].dtype
|
||||
for tensor_ in tensors:
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
|
||||
class Utils:
|
||||
"""Model utils"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_attr_type(attr):
|
||||
"""Get attr type"""
|
||||
|
@ -54,6 +57,9 @@ class DataFormat:
|
|||
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
|
||||
NDHWC = "NDHWC"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class DataType:
|
||||
"""Data Type"""
|
||||
|
@ -73,11 +79,8 @@ class DataType:
|
|||
UINT64 = "uint64"
|
||||
BOOL = "bool"
|
||||
|
||||
|
||||
class Config:
|
||||
R0 = 8.0
|
||||
UB_SIZE = 256 * 1024
|
||||
MAX_BLOCK = 32
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class PrimLib:
|
||||
|
@ -90,6 +93,9 @@ class PrimLib:
|
|||
REDUCE = 4
|
||||
OPAQUE = 5
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class Prim:
|
||||
"""Prim"""
|
||||
|
||||
|
@ -101,6 +107,7 @@ class PrimLib:
|
|||
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
|
||||
|
||||
def default_reshape_relation(self, op, input_idx):
|
||||
"""Process reshape relation"""
|
||||
axis_relation, elem_relation = self.unknown_relation(op, input_idx)
|
||||
elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
|
||||
return axis_relation, elem_relation
|
||||
|
@ -189,6 +196,8 @@ class PrimLib:
|
|||
'ReduceSum': Prim(REDUCE),
|
||||
'ReduceMax': Prim(REDUCE),
|
||||
'ReduceMin': Prim(REDUCE),
|
||||
'Argmax': Prim(REDUCE),
|
||||
'Argmin': Prim(REDUCE),
|
||||
'Assign': Prim(ELEMWISE),
|
||||
'Sign': Prim(ELEMWISE),
|
||||
'Sin': Prim(ELEMWISE),
|
||||
|
@ -225,6 +234,7 @@ class PrimLib:
|
|||
|
||||
@classmethod
|
||||
def get_prim(cls, op):
|
||||
"""Get op primtive"""
|
||||
prim = cls.primtives.get(op.prim, None)
|
||||
if prim is None:
|
||||
print('[WARN] primtive is not registered: ' + op.prim)
|
||||
|
@ -233,22 +243,27 @@ class PrimLib:
|
|||
|
||||
@classmethod
|
||||
def input_relation(cls, op, input_idx):
|
||||
"""Get op's input_relation according to input_idx"""
|
||||
return cls.get_prim(op).relation_func(op, input_idx)
|
||||
|
||||
@classmethod
|
||||
def iter_type(cls, op):
|
||||
"""Get op's iter type"""
|
||||
return cls.get_prim(op).iter_type
|
||||
|
||||
@classmethod
|
||||
def is_reduce(cls, op):
|
||||
"""Check whether op's iter type is reduce"""
|
||||
return cls.get_prim(op).iter_type == cls.REDUCE
|
||||
|
||||
@classmethod
|
||||
def calibrate_iter_size(cls, op, iter_size):
|
||||
"""Get calibrate_iter_size"""
|
||||
return cls.get_prim(op).calibrate * iter_size
|
||||
|
||||
@classmethod
|
||||
def dtype_bytes(cls, dtype):
|
||||
"""Get dtype bytes"""
|
||||
bits, unit = 1, 1
|
||||
for i in range(len(dtype) - 1, 0, -1):
|
||||
if dtype[i].isdecimal():
|
||||
|
@ -260,6 +275,7 @@ class PrimLib:
|
|||
|
||||
@classmethod
|
||||
def inplace_reuse(cls, op, input_idx, start_axis=0):
|
||||
"""Check whether op is inplace reuse"""
|
||||
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
|
||||
return False
|
||||
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
|
||||
|
@ -277,6 +293,8 @@ class Tensor:
|
|||
PARA_OUTPUT = 2
|
||||
|
||||
class Buddy:
|
||||
"""Buddy"""
|
||||
|
||||
def __init__(self, leader):
|
||||
self.members = [leader]
|
||||
|
||||
|
@ -328,6 +346,7 @@ class Value:
|
|||
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
|
||||
|
||||
def get_size(self):
|
||||
"""Get size"""
|
||||
return 1
|
||||
|
||||
|
||||
|
@ -365,6 +384,7 @@ class Graph:
|
|||
self.outputs = []
|
||||
self.stitch_info = stitch_info
|
||||
self.recompute_ops = recompute_ops
|
||||
self.processor = ""
|
||||
|
||||
def set_processor(self, processor):
|
||||
"""Set processor"""
|
||||
|
@ -498,7 +518,7 @@ class AlignShape(GraphVisitor):
|
|||
"""Align shape"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super(AlignShape, self).__init__()
|
||||
|
||||
def visit(self, op):
|
||||
"""Visit op node"""
|
||||
|
@ -517,7 +537,7 @@ class AddControlBuddy(GraphVisitor):
|
|||
"""Add control buddy"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super(AddControlBuddy, self).__init__()
|
||||
self.buddies = {} # {op : [ctrl_op]}
|
||||
|
||||
def visit(self, op):
|
||||
|
@ -536,13 +556,15 @@ class AddControlBuddy(GraphVisitor):
|
|||
|
||||
def visit_graph(self, graph):
|
||||
"""Visit graph nodes"""
|
||||
super().visit_graph(graph)
|
||||
super(AddControlBuddy, self).visit_graph(graph)
|
||||
for owner in self.buddies:
|
||||
for op in self.buddies[owner]:
|
||||
owner.add_buddy(op.output)
|
||||
|
||||
|
||||
class GraphKernelUnsupportedException(Exception):
|
||||
""""GraphKernel Unsupported Exception"""
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__()
|
||||
super(GraphKernelUnsupportedException, self).__init__()
|
||||
self.message = message
|
||||
|
|
|
@ -26,7 +26,8 @@ namespace opt {
|
|||
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
|
||||
|
||||
bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const {
|
||||
std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
|
||||
std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin,
|
||||
prim::kPrimArgMax, prim::kPrimArgMin};
|
||||
return std::any_of(node_with_axis.begin(), node_with_axis.end(),
|
||||
[&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
|
||||
}
|
||||
|
|
|
@ -68,6 +68,8 @@ std::vector<PrimitivePtr> GetClusterableOpList() {
|
|||
#elif ENABLE_GPU
|
||||
prim::kPrimACos,
|
||||
prim::kPrimAcosh,
|
||||
prim::kPrimArgMax,
|
||||
prim::kPrimArgMin,
|
||||
prim::kPrimAsin,
|
||||
prim::kPrimAsinh,
|
||||
prim::kPrimAssign,
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ArgMax(nn.Cell):
|
||||
def __init__(self, axis):
|
||||
super(ArgMax, self).__init__()
|
||||
self.arg_max = P.Argmax(axis=axis)
|
||||
|
||||
def construct(self, x):
|
||||
return self.arg_max(x)
|
||||
|
||||
|
||||
def get_output(x, axis, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net = ArgMax(axis)
|
||||
output = net(x)
|
||||
return output
|
||||
|
||||
|
||||
def test_argmax():
|
||||
x0 = Tensor(np.random.normal(0, 1, [2, 3, 4, 4]).astype(np.float32))
|
||||
axis0 = 3
|
||||
expect = get_output(x0, axis0, False)
|
||||
output = get_output(x0, axis0, True)
|
||||
assert np.allclose(expect.asnumpy(), output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
x1 = Tensor(np.random.normal(0, 1, [2, 3, 1, 4]).astype(np.float32))
|
||||
axis1 = 2
|
||||
expect = get_output(x1, axis1, False)
|
||||
output = get_output(x1, axis1, True)
|
||||
assert np.allclose(expect.asnumpy(), output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_argmax_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_argmax()
|
Loading…
Reference in New Issue