Refactor GraphKernelExpander (3rd submission)

Rewrited the OpInfer as a class with functions "infer_shape","infer_type"
and "infer_format". Used op name to find the subclass that have implemented
these functions. Two common class "_Elemwise" and "_Reduce" was provided.

Op BiasAddGrad supported "FractalNZ" format.
This commit is contained in:
dayschan 2021-02-23 19:23:15 +08:00
parent 29b68a82a4
commit 7beca18f3c
4 changed files with 295 additions and 144 deletions

View File

@ -20,24 +20,30 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT)
@VLD.add_format(DF.NHWC)
@VLD.add_format(DF.NCHW)
@VLD.add_format(DF.FRAC_NZ)
class BiasAddGrad(Expander):
"""BiasAddGrad expander"""
def _expand(self, graph_builder):
input_x = self.inputs[0]
x = self.inputs[0]
reduce_axis = ()
if input_x.data_format == 'NHWC':
if x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
elif input_x.data_format == 'NCHW':
elif x.data_format == DF.NCHW:
reduce_axis = (0, 2, 3)
# DefaultFormat shape's length should be from 2 to 4
elif x.data_format == DF.FRAC_NZ:
reduce_axis = (-2, -3)
else:
if len(input_x.shape) == 2:
# DefaultFormat shape's length should be from 2 to 4
if len(x.shape) == 2:
reduce_axis = (0,)
elif len(input_x.shape) == 3:
elif len(x.shape) == 3:
reduce_axis = (0, 1)
else:
reduce_axis = (0, 2, 3)
result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
if x.data_format == DF.FRAC_NZ:
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
return result

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for Tile"""
from mindspore._extends.graph_kernel.model import model_builder as builder
from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
@ -27,8 +27,9 @@ class Tile(Expander):
input_x = self.inputs[0]
multiples = self.attrs['multiples']
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples)
if shape_compatible:
tile_infer = TileInfer(self.name, self.inputs, self.attrs)
output_shape, _, _ = tile_infer.infer()
if tile_infer.broadcast_compatible:
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
else:
result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples})

View File

@ -15,139 +15,8 @@
"""GraphKernel model builder"""
import copy
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat
def get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""
if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")
shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)
shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)
return output_shape, input_reshape, output_reshape, shape_compatible
class OpInfer:
"""Op infer"""
@staticmethod
def default_reduce_infer(inputs, attrs):
"""Default reduce infer"""
shape = copy.deepcopy(inputs[0].shape)
if attrs['keep_dims']:
for i in attrs['reduce_axis']:
shape[i] = 1
return shape
real_shape = []
for i, _ in enumerate(shape):
if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']:
real_shape.append(shape[i])
return real_shape
@staticmethod
def default_elementwise_infer(inputs, attrs):
"""Default elementwise infer"""
shape = (1,)
max_flatten_shape = 1
for t in inputs:
flatten_shape = 1
for s in t.shape:
flatten_shape *= s
if flatten_shape >= max_flatten_shape:
max_flatten_shape = flatten_shape
shape = t.shape
return shape
default_infer_shape_func = [
None,
None,
default_elementwise_infer.__func__,
lambda inputs, attrs: max([t.shape for t in inputs]),
default_reduce_infer.__func__,
None,
lambda inputs, attrs: [1], # control op
]
@staticmethod
def default_infer_dtype_func(inputs, attrs):
"""Infer dtype"""
return inputs[0].dtype
@staticmethod
def default_infer_format_func(inputs, attrs):
"""Infer format"""
result = inputs[0].data_format
# default_format and other_format results in other_format
for input_tensor in inputs[1:]:
data_format = input_tensor.data_format
if data_format != DataFormat.DEFAULT:
if result not in [DataFormat.DEFAULT, data_format]:
raise RuntimeError("Incompatible data format %s and %s" % (data_format, result))
result = data_format
return result
infer_shape_func = {
# add special infer func here
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
'Reshape': lambda inputs, attrs: attrs["shape"],
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0],
'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1),
}
infer_dtype_func = {
# add special infer func here
'Cast': lambda inputs, attrs: attrs['dst_type'],
'Less': lambda inputs, attrs: "bool",
'LessEqual': lambda inputs, attrs: "bool",
'Equal': lambda inputs, attrs: "bool",
'Greater': lambda inputs, attrs: "bool",
'GreaterEqual': lambda inputs, attrs: "bool",
}
infer_format_func = {
# add special infer func here
'Reshape': lambda inputs, attrs: "DefaultFormat",
}
@classmethod
def infer(cls, prim_name, inputs, attrs):
prim = PrimLib.primtives[prim_name]
infer_shape = cls.infer_shape_func.get(
prim_name, cls.default_infer_shape_func[prim.iter_type])
infer_dtype = cls.infer_dtype_func.get(
prim_name, cls.default_infer_dtype_func)
infer_format = cls.infer_format_func.get(
prim_name, cls.default_infer_format_func)
return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs)
from . import op_infer
from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
class GraphBuilder:
@ -229,7 +98,7 @@ class GraphBuilder:
if isinstance(inputs, (Tensor, Value)):
inputs = [inputs]
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs)
out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
output = self.tensor(out_shape, out_dtype, out_format, name)
self.op(prim, output, inputs, attrs)
return output

View File

@ -0,0 +1,275 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""GraphKernel Op Infer"""
import copy
import sys
from functools import reduce
from .model import GraphKernelUnsupportedException as GKException
from .model import PrimLib, DataFormat as DF
def infer(op_name, inputs, attrs):
"""infer shape dtype and format"""
def _create_opinfer():
if hasattr(sys.modules[__name__], op_name):
op_cls = getattr(sys.modules[__name__], op_name)
return op_cls(op_name, inputs, attrs)
# common infer
class_name_map = {
PrimLib.ELEMWISE: "_Elemwise",
PrimLib.REDUCE: "_Reduce",
}
cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
if not cls_name:
raise GKException("OpInfo does not support op {}".format(op_name))
op_cls = getattr(sys.modules[__name__], cls_name)
return op_cls(op_name, inputs, attrs)
return _create_opinfer().infer()
class OpInfer:
"""
OpInfer is the base class for inferring operator info in GraphKernel model builder.
There are three methods should be overridden to define the infer logic of the operator:
_infer_shape(), _infer_type() and _infer_format().
"""
def __init__(self, name, inputs, attrs):
self.name = name
self.inputs = inputs
self.attrs = attrs
def infer(self):
"""Infer shape, type and format by op inputs"""
self._check()
return self._infer_shape(), self._infer_type(), self._infer_format()
def _infer_shape(self):
return self.inputs[0].shape
def _infer_type(self):
return self.inputs[0].dtype
def _infer_format(self):
return self.inputs[0].data_format
def _check(self):
self._check_shape()
self._check_type()
self._check_format()
def _check_shape(self):
pass
def _check_type(self):
"""check all dtypes are same"""
dtype = self.inputs[0].dtype
for i, t in enumerate(self.inputs[1:]):
if t.dtype != dtype:
raise GKException(
"Incompatible dtype between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
def _check_format(self):
"""check formats are compatible. only DefaultFormat is compatible with others"""
result = self.inputs[0].data_format
i = 0
for j, t in enumerate(self.inputs[1:]):
if t.data_format != result:
if DF.DEFAULT not in (result, t.data_format):
raise GKException("Incompatible format between input {}({}) and {}({})".format(
i, result, j + 1, t.data_format))
if result == DF.DEFAULT:
result = t.data_format
i = j + 1
class _Elemwise(OpInfer):
"""Common infer for elementwise operators"""
def _infer_shape(self):
"""returns the input shape with largest flatten size"""
shape = (1,)
max_flatten_size = 1
for t in self.inputs:
flatten_size = reduce(lambda x, y: x * y, t.shape)
if flatten_size >= max_flatten_size:
max_flatten_size = flatten_size
shape = t.shape
return shape
def _infer_format(self):
for tensor in self.inputs:
if tensor.data_format != DF.DEFAULT:
return tensor.data_format
return DF.DEFAULT
class _Reduce(OpInfer):
"""Common infer for reduction operators"""
def _check(self):
super()._check()
# check reduce axis in the range [-len, len)
shape_len = len(self.inputs[0].shape)
axis = self.attrs['reduce_axis']
if isinstance(axis, int):
axis = [axis]
if not all([(-shape_len <= i < shape_len) for i in axis]):
raise GKException(
"reduce_axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
def _infer_shape(self):
shape = copy.deepcopy(self.inputs[0].shape)
axis = self.attrs['reduce_axis']
if isinstance(axis, int):
axis = [axis]
if any([i < 0 for i in axis]):
# change the axis to non-negative number.
axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
self.attrs['reduce_axis'] = sorted(axis)
if self.attrs['keep_dims']:
for i in axis:
shape[i] = 1
return shape
real_shape = []
for i, s in enumerate(shape):
if i not in axis:
real_shape.append(s)
return real_shape
def _infer_format(self):
return DF.DEFAULT
class _Reshape(OpInfer):
"""Common infer for reshape operators, should not be instantiated"""
def _infer_shape(self):
raise GKException("_infer_shape should be implemented by subclass")
def _infer_format(self):
return DF.DEFAULT
class Reshape(_Reshape):
def _infer_shape(self):
return self.attrs["shape"]
class ExpandDims(_Reshape):
def _infer_shape(self):
return list(self.inputs[0].shape).insert(self.attrs["axis"], 1)
class Cast(_Elemwise):
def _infer_type(self):
return self.attrs["dst_type"]
class InplaceAssign(_Elemwise):
def _infer_shape(self):
return [1] if self.attrs["fake_output"] else self.inputs[2].shape
def _infer_type(self):
return self.inputs[2].dtype
def _infer_format(self):
return DF.DEFAULT if self.attrs["fake_output"] else self.inputs[2].data_format
class BroadcastTo(OpInfer):
def _infer_shape(self):
return self.attrs["shape"]
def _infer_format(self):
return self.inputs[0].data_format
class Tile(OpInfer):
"""Op Tile"""
def __init__(self, op_name, inputs, attrs):
super().__init__(op_name, inputs, attrs)
self.input_reshape = None
self.output_reshape = None
self.broadcast_compatible = True
def _infer_shape(self):
shape = self.inputs[0].shape
multiples = self.attrs["multiples"]
shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)
self.broadcast_compatible = True
output_shape = []
self.input_reshape = []
self.output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
self.input_reshape.append(sh)
self.output_reshape.append(dim)
else:
self.broadcast_compatible = False
self.input_reshape.append(1)
self.input_reshape.append(sh)
self.output_reshape.append(mul)
self.output_reshape.append(sh)
return output_shape
def _infer_format(self):
return DF.DEFAULT
class _CompareOp(_Elemwise):
"""Compare operators"""
def _infer_type(self):
return "bool"
class Less(_CompareOp):
pass
class LessEqual(_CompareOp):
pass
class Equal(_CompareOp):
pass
class Greater(_CompareOp):
pass
class GreaterEqual(_CompareOp):
pass