forked from mindspore-Ecosystem/mindspore
Refactor GraphKernelExpander (3rd submission)
Rewrited the OpInfer as a class with functions "infer_shape","infer_type" and "infer_format". Used op name to find the subclass that have implemented these functions. Two common class "_Elemwise" and "_Reduce" was provided. Op BiasAddGrad supported "FractalNZ" format.
This commit is contained in:
parent
29b68a82a4
commit
7beca18f3c
|
@ -20,24 +20,30 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
|
|||
@VLD.add_format(DF.DEFAULT)
|
||||
@VLD.add_format(DF.NHWC)
|
||||
@VLD.add_format(DF.NCHW)
|
||||
@VLD.add_format(DF.FRAC_NZ)
|
||||
class BiasAddGrad(Expander):
|
||||
"""BiasAddGrad expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
x = self.inputs[0]
|
||||
|
||||
reduce_axis = ()
|
||||
if input_x.data_format == 'NHWC':
|
||||
if x.data_format == DF.NHWC:
|
||||
reduce_axis = (0, 1, 2)
|
||||
elif input_x.data_format == 'NCHW':
|
||||
elif x.data_format == DF.NCHW:
|
||||
reduce_axis = (0, 2, 3)
|
||||
# DefaultFormat shape's length should be from 2 to 4
|
||||
elif x.data_format == DF.FRAC_NZ:
|
||||
reduce_axis = (-2, -3)
|
||||
else:
|
||||
if len(input_x.shape) == 2:
|
||||
# DefaultFormat shape's length should be from 2 to 4
|
||||
if len(x.shape) == 2:
|
||||
reduce_axis = (0,)
|
||||
elif len(input_x.shape) == 3:
|
||||
elif len(x.shape) == 3:
|
||||
reduce_axis = (0, 1)
|
||||
else:
|
||||
reduce_axis = (0, 2, 3)
|
||||
result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
if x.data_format == DF.FRAC_NZ:
|
||||
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
|
||||
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
|
||||
return result
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for Tile"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
@ -27,8 +27,9 @@ class Tile(Expander):
|
|||
input_x = self.inputs[0]
|
||||
multiples = self.attrs['multiples']
|
||||
|
||||
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples)
|
||||
if shape_compatible:
|
||||
tile_infer = TileInfer(self.name, self.inputs, self.attrs)
|
||||
output_shape, _, _ = tile_infer.infer()
|
||||
if tile_infer.broadcast_compatible:
|
||||
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
|
||||
else:
|
||||
result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples})
|
||||
|
|
|
@ -15,139 +15,8 @@
|
|||
"""GraphKernel model builder"""
|
||||
|
||||
import copy
|
||||
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat
|
||||
|
||||
|
||||
def get_tile_output_shape(shape, multiples):
|
||||
"""compute output shape of tile"""
|
||||
|
||||
if multiples is None:
|
||||
return shape
|
||||
if not isinstance(shape, (list, tuple)):
|
||||
raise TypeError("Input shape of Tile must be of type list or tuple")
|
||||
if not isinstance(multiples, (list, tuple)):
|
||||
raise TypeError("multiples of Tile must be of type list or tuple")
|
||||
|
||||
shape = list(shape)
|
||||
multiples = list(multiples)
|
||||
diff_len = len(multiples) - len(shape)
|
||||
if diff_len < 0:
|
||||
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
|
||||
if diff_len > 0:
|
||||
for _ in range(diff_len):
|
||||
shape.insert(0, 1)
|
||||
|
||||
shape_compatible = True
|
||||
output_shape = []
|
||||
input_reshape = []
|
||||
output_reshape = []
|
||||
for sh, mul in list(zip(shape, multiples)):
|
||||
dim = sh * mul
|
||||
output_shape.append(dim)
|
||||
if sh == 1 or mul == 1:
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(dim)
|
||||
else:
|
||||
shape_compatible = False
|
||||
input_reshape.append(1)
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(mul)
|
||||
output_reshape.append(sh)
|
||||
|
||||
return output_shape, input_reshape, output_reshape, shape_compatible
|
||||
|
||||
|
||||
class OpInfer:
|
||||
"""Op infer"""
|
||||
@staticmethod
|
||||
def default_reduce_infer(inputs, attrs):
|
||||
"""Default reduce infer"""
|
||||
shape = copy.deepcopy(inputs[0].shape)
|
||||
if attrs['keep_dims']:
|
||||
for i in attrs['reduce_axis']:
|
||||
shape[i] = 1
|
||||
return shape
|
||||
|
||||
real_shape = []
|
||||
for i, _ in enumerate(shape):
|
||||
if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']:
|
||||
real_shape.append(shape[i])
|
||||
return real_shape
|
||||
|
||||
@staticmethod
|
||||
def default_elementwise_infer(inputs, attrs):
|
||||
"""Default elementwise infer"""
|
||||
shape = (1,)
|
||||
max_flatten_shape = 1
|
||||
for t in inputs:
|
||||
flatten_shape = 1
|
||||
for s in t.shape:
|
||||
flatten_shape *= s
|
||||
if flatten_shape >= max_flatten_shape:
|
||||
max_flatten_shape = flatten_shape
|
||||
shape = t.shape
|
||||
return shape
|
||||
|
||||
default_infer_shape_func = [
|
||||
None,
|
||||
None,
|
||||
default_elementwise_infer.__func__,
|
||||
lambda inputs, attrs: max([t.shape for t in inputs]),
|
||||
default_reduce_infer.__func__,
|
||||
None,
|
||||
lambda inputs, attrs: [1], # control op
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def default_infer_dtype_func(inputs, attrs):
|
||||
"""Infer dtype"""
|
||||
return inputs[0].dtype
|
||||
|
||||
@staticmethod
|
||||
def default_infer_format_func(inputs, attrs):
|
||||
"""Infer format"""
|
||||
result = inputs[0].data_format
|
||||
# default_format and other_format results in other_format
|
||||
for input_tensor in inputs[1:]:
|
||||
data_format = input_tensor.data_format
|
||||
if data_format != DataFormat.DEFAULT:
|
||||
if result not in [DataFormat.DEFAULT, data_format]:
|
||||
raise RuntimeError("Incompatible data format %s and %s" % (data_format, result))
|
||||
result = data_format
|
||||
return result
|
||||
|
||||
infer_shape_func = {
|
||||
# add special infer func here
|
||||
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
|
||||
'Reshape': lambda inputs, attrs: attrs["shape"],
|
||||
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
|
||||
'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0],
|
||||
'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1),
|
||||
}
|
||||
infer_dtype_func = {
|
||||
# add special infer func here
|
||||
'Cast': lambda inputs, attrs: attrs['dst_type'],
|
||||
'Less': lambda inputs, attrs: "bool",
|
||||
'LessEqual': lambda inputs, attrs: "bool",
|
||||
'Equal': lambda inputs, attrs: "bool",
|
||||
'Greater': lambda inputs, attrs: "bool",
|
||||
'GreaterEqual': lambda inputs, attrs: "bool",
|
||||
}
|
||||
infer_format_func = {
|
||||
# add special infer func here
|
||||
'Reshape': lambda inputs, attrs: "DefaultFormat",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def infer(cls, prim_name, inputs, attrs):
|
||||
prim = PrimLib.primtives[prim_name]
|
||||
infer_shape = cls.infer_shape_func.get(
|
||||
prim_name, cls.default_infer_shape_func[prim.iter_type])
|
||||
infer_dtype = cls.infer_dtype_func.get(
|
||||
prim_name, cls.default_infer_dtype_func)
|
||||
infer_format = cls.infer_format_func.get(
|
||||
prim_name, cls.default_infer_format_func)
|
||||
return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs)
|
||||
from . import op_infer
|
||||
from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
|
@ -229,7 +98,7 @@ class GraphBuilder:
|
|||
if isinstance(inputs, (Tensor, Value)):
|
||||
inputs = [inputs]
|
||||
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
|
||||
out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs)
|
||||
out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
|
||||
output = self.tensor(out_shape, out_dtype, out_format, name)
|
||||
self.op(prim, output, inputs, attrs)
|
||||
return output
|
||||
|
|
|
@ -0,0 +1,275 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""GraphKernel Op Infer"""
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from functools import reduce
|
||||
from .model import GraphKernelUnsupportedException as GKException
|
||||
from .model import PrimLib, DataFormat as DF
|
||||
|
||||
|
||||
def infer(op_name, inputs, attrs):
|
||||
"""infer shape dtype and format"""
|
||||
def _create_opinfer():
|
||||
if hasattr(sys.modules[__name__], op_name):
|
||||
op_cls = getattr(sys.modules[__name__], op_name)
|
||||
return op_cls(op_name, inputs, attrs)
|
||||
# common infer
|
||||
class_name_map = {
|
||||
PrimLib.ELEMWISE: "_Elemwise",
|
||||
PrimLib.REDUCE: "_Reduce",
|
||||
}
|
||||
cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
|
||||
if not cls_name:
|
||||
raise GKException("OpInfo does not support op {}".format(op_name))
|
||||
op_cls = getattr(sys.modules[__name__], cls_name)
|
||||
return op_cls(op_name, inputs, attrs)
|
||||
return _create_opinfer().infer()
|
||||
|
||||
|
||||
class OpInfer:
|
||||
"""
|
||||
OpInfer is the base class for inferring operator info in GraphKernel model builder.
|
||||
|
||||
There are three methods should be overridden to define the infer logic of the operator:
|
||||
_infer_shape(), _infer_type() and _infer_format().
|
||||
"""
|
||||
|
||||
def __init__(self, name, inputs, attrs):
|
||||
self.name = name
|
||||
self.inputs = inputs
|
||||
self.attrs = attrs
|
||||
|
||||
def infer(self):
|
||||
"""Infer shape, type and format by op inputs"""
|
||||
self._check()
|
||||
return self._infer_shape(), self._infer_type(), self._infer_format()
|
||||
|
||||
def _infer_shape(self):
|
||||
return self.inputs[0].shape
|
||||
|
||||
def _infer_type(self):
|
||||
return self.inputs[0].dtype
|
||||
|
||||
def _infer_format(self):
|
||||
return self.inputs[0].data_format
|
||||
|
||||
def _check(self):
|
||||
self._check_shape()
|
||||
self._check_type()
|
||||
self._check_format()
|
||||
|
||||
def _check_shape(self):
|
||||
pass
|
||||
|
||||
def _check_type(self):
|
||||
"""check all dtypes are same"""
|
||||
dtype = self.inputs[0].dtype
|
||||
for i, t in enumerate(self.inputs[1:]):
|
||||
if t.dtype != dtype:
|
||||
raise GKException(
|
||||
"Incompatible dtype between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
|
||||
|
||||
def _check_format(self):
|
||||
"""check formats are compatible. only DefaultFormat is compatible with others"""
|
||||
result = self.inputs[0].data_format
|
||||
i = 0
|
||||
for j, t in enumerate(self.inputs[1:]):
|
||||
if t.data_format != result:
|
||||
if DF.DEFAULT not in (result, t.data_format):
|
||||
raise GKException("Incompatible format between input {}({}) and {}({})".format(
|
||||
i, result, j + 1, t.data_format))
|
||||
if result == DF.DEFAULT:
|
||||
result = t.data_format
|
||||
i = j + 1
|
||||
|
||||
|
||||
class _Elemwise(OpInfer):
|
||||
"""Common infer for elementwise operators"""
|
||||
|
||||
def _infer_shape(self):
|
||||
"""returns the input shape with largest flatten size"""
|
||||
shape = (1,)
|
||||
max_flatten_size = 1
|
||||
for t in self.inputs:
|
||||
flatten_size = reduce(lambda x, y: x * y, t.shape)
|
||||
if flatten_size >= max_flatten_size:
|
||||
max_flatten_size = flatten_size
|
||||
shape = t.shape
|
||||
return shape
|
||||
|
||||
def _infer_format(self):
|
||||
for tensor in self.inputs:
|
||||
if tensor.data_format != DF.DEFAULT:
|
||||
return tensor.data_format
|
||||
return DF.DEFAULT
|
||||
|
||||
|
||||
class _Reduce(OpInfer):
|
||||
"""Common infer for reduction operators"""
|
||||
|
||||
def _check(self):
|
||||
super()._check()
|
||||
# check reduce axis in the range [-len, len)
|
||||
shape_len = len(self.inputs[0].shape)
|
||||
axis = self.attrs['reduce_axis']
|
||||
if isinstance(axis, int):
|
||||
axis = [axis]
|
||||
if not all([(-shape_len <= i < shape_len) for i in axis]):
|
||||
raise GKException(
|
||||
"reduce_axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
|
||||
|
||||
def _infer_shape(self):
|
||||
shape = copy.deepcopy(self.inputs[0].shape)
|
||||
axis = self.attrs['reduce_axis']
|
||||
|
||||
if isinstance(axis, int):
|
||||
axis = [axis]
|
||||
if any([i < 0 for i in axis]):
|
||||
# change the axis to non-negative number.
|
||||
axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
|
||||
self.attrs['reduce_axis'] = sorted(axis)
|
||||
|
||||
if self.attrs['keep_dims']:
|
||||
for i in axis:
|
||||
shape[i] = 1
|
||||
return shape
|
||||
|
||||
real_shape = []
|
||||
for i, s in enumerate(shape):
|
||||
if i not in axis:
|
||||
real_shape.append(s)
|
||||
return real_shape
|
||||
|
||||
def _infer_format(self):
|
||||
return DF.DEFAULT
|
||||
|
||||
|
||||
class _Reshape(OpInfer):
|
||||
"""Common infer for reshape operators, should not be instantiated"""
|
||||
|
||||
def _infer_shape(self):
|
||||
raise GKException("_infer_shape should be implemented by subclass")
|
||||
|
||||
def _infer_format(self):
|
||||
return DF.DEFAULT
|
||||
|
||||
|
||||
class Reshape(_Reshape):
|
||||
def _infer_shape(self):
|
||||
return self.attrs["shape"]
|
||||
|
||||
|
||||
class ExpandDims(_Reshape):
|
||||
def _infer_shape(self):
|
||||
return list(self.inputs[0].shape).insert(self.attrs["axis"], 1)
|
||||
|
||||
|
||||
class Cast(_Elemwise):
|
||||
def _infer_type(self):
|
||||
return self.attrs["dst_type"]
|
||||
|
||||
|
||||
class InplaceAssign(_Elemwise):
|
||||
def _infer_shape(self):
|
||||
return [1] if self.attrs["fake_output"] else self.inputs[2].shape
|
||||
|
||||
def _infer_type(self):
|
||||
return self.inputs[2].dtype
|
||||
|
||||
def _infer_format(self):
|
||||
return DF.DEFAULT if self.attrs["fake_output"] else self.inputs[2].data_format
|
||||
|
||||
|
||||
class BroadcastTo(OpInfer):
|
||||
def _infer_shape(self):
|
||||
return self.attrs["shape"]
|
||||
|
||||
def _infer_format(self):
|
||||
return self.inputs[0].data_format
|
||||
|
||||
|
||||
class Tile(OpInfer):
|
||||
"""Op Tile"""
|
||||
|
||||
def __init__(self, op_name, inputs, attrs):
|
||||
super().__init__(op_name, inputs, attrs)
|
||||
self.input_reshape = None
|
||||
self.output_reshape = None
|
||||
self.broadcast_compatible = True
|
||||
|
||||
def _infer_shape(self):
|
||||
shape = self.inputs[0].shape
|
||||
multiples = self.attrs["multiples"]
|
||||
|
||||
shape = list(shape)
|
||||
multiples = list(multiples)
|
||||
diff_len = len(multiples) - len(shape)
|
||||
if diff_len < 0:
|
||||
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
|
||||
if diff_len > 0:
|
||||
for _ in range(diff_len):
|
||||
shape.insert(0, 1)
|
||||
|
||||
self.broadcast_compatible = True
|
||||
output_shape = []
|
||||
self.input_reshape = []
|
||||
self.output_reshape = []
|
||||
for sh, mul in list(zip(shape, multiples)):
|
||||
dim = sh * mul
|
||||
output_shape.append(dim)
|
||||
if sh == 1 or mul == 1:
|
||||
self.input_reshape.append(sh)
|
||||
self.output_reshape.append(dim)
|
||||
else:
|
||||
self.broadcast_compatible = False
|
||||
self.input_reshape.append(1)
|
||||
self.input_reshape.append(sh)
|
||||
self.output_reshape.append(mul)
|
||||
self.output_reshape.append(sh)
|
||||
|
||||
return output_shape
|
||||
|
||||
def _infer_format(self):
|
||||
return DF.DEFAULT
|
||||
|
||||
|
||||
class _CompareOp(_Elemwise):
|
||||
"""Compare operators"""
|
||||
|
||||
def _infer_type(self):
|
||||
return "bool"
|
||||
|
||||
|
||||
class Less(_CompareOp):
|
||||
pass
|
||||
|
||||
|
||||
class LessEqual(_CompareOp):
|
||||
pass
|
||||
|
||||
|
||||
class Equal(_CompareOp):
|
||||
pass
|
||||
|
||||
|
||||
class Greater(_CompareOp):
|
||||
pass
|
||||
|
||||
|
||||
class GreaterEqual(_CompareOp):
|
||||
pass
|
Loading…
Reference in New Issue