forked from mindspore-Ecosystem/mindspore
add conv2d expander
use data_format instead of format in op_infer add pad and unpad inside conv2d expander fix pad value add limit in conv2d expander
This commit is contained in:
parent
35b2e40a72
commit
dd81f47271
|
@ -52,3 +52,4 @@ from .lamb_apply_weight_assign import LambApplyWeightAssign
|
|||
from .softmax_grad_ext import SoftmaxGradExt
|
||||
from .square_sum_v1 import SquareSumV1
|
||||
from .fused_mul_add import FusedMulAdd
|
||||
from .conv2d import Conv2D
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for Conv2D"""
|
||||
from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
M_ALIGN = 16
|
||||
N_ALIGN = 16
|
||||
K_ALIGN = 8
|
||||
K_LIMIT = 4096
|
||||
MNK_LIMIT = 3 * (10 ** 10)
|
||||
N0_CHANNEL_ALIGN = 16
|
||||
N1_CHANNEL_ALIGN = 16
|
||||
C_CHANNEL_ALIGN = 8
|
||||
OUT_NHW_ALIGN = 128
|
||||
|
||||
|
||||
@VLD.add_format(DF.NHWC, DF.NHWC)
|
||||
@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
|
||||
class Conv2D(Expander):
|
||||
"""
|
||||
Conv2D expander
|
||||
|
||||
Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped.
|
||||
Conditions to expand:
|
||||
inputs are NHWC format and float16.
|
||||
attr groups and group are 1.
|
||||
attr dilation are all 1.
|
||||
N channel of inputs > 16.
|
||||
C channel of inputs > 8.
|
||||
output N*H*W are multiplies of 128.
|
||||
"""
|
||||
|
||||
def __init__(self, expand_info):
|
||||
super().__init__(expand_info)
|
||||
self.dst_type = self.outputs[0]['data_type']
|
||||
self.dst_format = self.outputs[0]['format']
|
||||
self.has_pad = False
|
||||
self.can_optimize_to_matmul = False
|
||||
self.shape_0_pad = self.inputs[0]['shape']
|
||||
self.shape_1_pad = self.inputs[1]['shape']
|
||||
self.m = 0
|
||||
self.n = 0
|
||||
self.k = 0
|
||||
|
||||
def _optimize_to_matmul(self):
|
||||
stride = self.attrs['stride']
|
||||
dilation = self.attrs['dilation']
|
||||
_, h, w, _ = self.inputs[1]['shape']
|
||||
if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
|
||||
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0 and \
|
||||
self.k <= K_LIMIT and self.m * self.n * self.k < MNK_LIMIT:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check(self):
|
||||
type_0 = self.inputs[0]['data_type']
|
||||
type_1 = self.inputs[1]['data_type']
|
||||
if type_0 != "float16" or type_1 != "float16":
|
||||
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))
|
||||
|
||||
groups = self.attrs['groups']
|
||||
group = self.attrs['group']
|
||||
if groups != 1 or group != 1:
|
||||
raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group))
|
||||
|
||||
dilation = self.attrs['dilation']
|
||||
check_nd(dilation, 4)
|
||||
if dilation != [1, 1, 1, 1]:
|
||||
raise GKException("dilation should be all 1, but got {}".format(dilation))
|
||||
|
||||
pad_list = self.attrs['pad_list']
|
||||
pad_mode = self.attrs['pad_mode']
|
||||
check_nd(pad_list, 4)
|
||||
self.has_pad = conv_had_pad(pad_list, pad_mode)
|
||||
|
||||
shape_0 = self.inputs[0]['shape']
|
||||
shape_1 = self.inputs[1]['shape']
|
||||
stride = self.attrs['stride']
|
||||
check_nd(shape_0, 4)
|
||||
check_nd(shape_1, 4)
|
||||
check_nd(stride, 4)
|
||||
n0, h0, w0, c0 = shape_0
|
||||
n1, h1, w1, c1 = shape_1
|
||||
if n0 < N0_CHANNEL_ALIGN:
|
||||
raise GKException("N({}) channel of first input should >= {}".format(n0, N0_CHANNEL_ALIGN))
|
||||
if n1 < N1_CHANNEL_ALIGN:
|
||||
raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN))
|
||||
if c0 != c1 or c0 < C_CHANNEL_ALIGN:
|
||||
raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN))
|
||||
# n0 pad
|
||||
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
|
||||
# h0, w0 pad
|
||||
if self.has_pad:
|
||||
h0 = h0 + pad_list[0] + pad_list[1]
|
||||
w0 = w0 + pad_list[2] + pad_list[3]
|
||||
# c0, c1 pad
|
||||
c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN
|
||||
c1 = c0
|
||||
# n1 pad
|
||||
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN
|
||||
|
||||
# check if can optimize to matmul
|
||||
self.m, self.n, self.k = n0 * h0 * w0, n1, c1
|
||||
self.can_optimize_to_matmul = self._optimize_to_matmul()
|
||||
|
||||
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
|
||||
if not self.can_optimize_to_matmul and n0 * out_h * out_w % OUT_NHW_ALIGN != 0:
|
||||
raise GKException("N({}) * H({}) * W({}) of Conv2d output should be multiplies of {}"
|
||||
.format(n0, out_h, out_w, OUT_NHW_ALIGN))
|
||||
self.shape_0_pad = [n0, h0, w0, c0]
|
||||
self.shape_1_pad = [n1, h1, w1, c1]
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_0 = self.inputs[0]
|
||||
input_1 = self.inputs[1]
|
||||
n0, _, _, c0 = input_0.shape
|
||||
n1, _, _, c1 = input_1.shape
|
||||
n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
|
||||
n1_p, _, _, c1_p = self.shape_1_pad
|
||||
|
||||
pad_value = 0
|
||||
# input0 pad
|
||||
input_0_pad_before = [0, 0, 0, 0]
|
||||
input_0_pad_after = [0, 0, 0, 0]
|
||||
if self.has_pad:
|
||||
pad_list = self.attrs['pad_list']
|
||||
input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
|
||||
input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
|
||||
input_0_pad_after[0] = n0_p - n0
|
||||
input_0_pad_after[3] = c0_p - c0
|
||||
if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
|
||||
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
|
||||
'tail': input_0_pad_after,
|
||||
'pad_val': pad_value})
|
||||
# input1 pad
|
||||
input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
|
||||
if input_1_pad_after != [0, 0, 0, 0]:
|
||||
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
|
||||
'tail': input_1_pad_after,
|
||||
'pad_val': pad_value})
|
||||
if self.can_optimize_to_matmul:
|
||||
a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
|
||||
b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
|
||||
c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
|
||||
'transpose_b': True,
|
||||
'dst_type': self.dst_type})
|
||||
result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
|
||||
'format': self.dst_format})
|
||||
else:
|
||||
attrs = self.attrs
|
||||
attrs['pad_list'] = [0, 0, 0, 0]
|
||||
attrs['dst_type'] = self.dst_type
|
||||
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
|
||||
# unpad
|
||||
unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
|
||||
if unpad_after != [0, 0, 0, 0]:
|
||||
result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
|
||||
|
||||
return result
|
|
@ -190,6 +190,7 @@ class PrimLib:
|
|||
'BatchMatMul': Prim(OPAQUE),
|
||||
'UnPadAkg': Prim(OPAQUE),
|
||||
'PadAkg': Prim(OPAQUE),
|
||||
'Conv2D': Prim(OPAQUE),
|
||||
}
|
||||
|
||||
default_primtive = Prim(UNKNOWN)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ===========================================================================
|
||||
"""GraphKernel Op Infer"""
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from functools import reduce
|
||||
|
@ -24,6 +23,7 @@ from .model import PrimLib, DataFormat as DF
|
|||
|
||||
def infer(op_name, inputs, attrs):
|
||||
"""infer shape dtype and format"""
|
||||
|
||||
def _create_opinfer():
|
||||
if hasattr(sys.modules[__name__], op_name):
|
||||
op_cls = getattr(sys.modules[__name__], op_name)
|
||||
|
@ -38,6 +38,7 @@ def infer(op_name, inputs, attrs):
|
|||
raise GKException("OpInfo does not support op {}".format(op_name))
|
||||
op_cls = getattr(sys.modules[__name__], cls_name)
|
||||
return op_cls(op_name, inputs, attrs)
|
||||
|
||||
return _create_opinfer().infer()
|
||||
|
||||
|
||||
|
@ -168,7 +169,7 @@ class _Reshape(OpInfer):
|
|||
raise GKException("_infer_shape should be implemented by subclass")
|
||||
|
||||
def _infer_format(self):
|
||||
return DF.DEFAULT
|
||||
return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
|
||||
|
||||
|
||||
class Reshape(_Reshape):
|
||||
|
@ -236,3 +237,108 @@ class Select(_Elemwise):
|
|||
|
||||
def _infer_type(self):
|
||||
return self.inputs[1].dtype
|
||||
|
||||
|
||||
def check_nd(data, nd):
|
||||
if not isinstance(data, (list, tuple)) or len(data) != nd:
|
||||
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
|
||||
|
||||
|
||||
def conv_had_pad(pad_list, pad_mode):
|
||||
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
|
||||
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
|
||||
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
|
||||
return True
|
||||
if pad_mode not in ["VALID", "valid"]:
|
||||
for _, pad in enumerate(pad_list):
|
||||
if pad != 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Conv2D(OpInfer):
|
||||
"""Conv2D infer"""
|
||||
def _infer_type(self):
|
||||
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
||||
return self.attrs["dst_type"]
|
||||
return self.inputs[0].dtype
|
||||
|
||||
def _infer_shape(self):
|
||||
shape_0 = list(self.inputs[0].shape)
|
||||
shape_1 = list(self.inputs[1].shape)
|
||||
check_nd(shape_0, 4)
|
||||
check_nd(shape_1, 4)
|
||||
|
||||
format_0 = self.inputs[0].data_format
|
||||
format_1 = self.inputs[1].data_format
|
||||
if format_0 != DF.NHWC or format_1 != DF.NHWC:
|
||||
raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1))
|
||||
|
||||
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
|
||||
pad_list = self.attrs["pad_list"]
|
||||
pad_mode = self.attrs["pad_mode"]
|
||||
kernel_size = self.attrs["kernel_size"]
|
||||
stride = self.attrs["stride"]
|
||||
dilation = self.attrs["dilation"]
|
||||
check_nd(pad_list, 4)
|
||||
check_nd(kernel_size, 2)
|
||||
check_nd(stride, 4)
|
||||
check_nd(dilation, 4)
|
||||
|
||||
has_pad = conv_had_pad(pad_list, pad_mode)
|
||||
if not has_pad:
|
||||
pad_list = [0, 0, 0, 0]
|
||||
|
||||
k_h = (kernel_size[0] - 1) * dilation[-2] + 1
|
||||
k_w = (kernel_size[1] - 1) * dilation[-1] + 1
|
||||
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
|
||||
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
|
||||
return [n, out_h, out_w, out_channel]
|
||||
|
||||
|
||||
class MatMul(OpInfer):
|
||||
"""MatMul infer"""
|
||||
def _infer_type(self):
|
||||
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
||||
return self.attrs["dst_type"]
|
||||
return self.inputs[0].dtype
|
||||
|
||||
def _infer_shape(self):
|
||||
shape_0 = list(self.inputs[0].shape)
|
||||
shape_1 = list(self.inputs[1].shape)
|
||||
if len(shape_0) != 2 or len(shape_1) != 2:
|
||||
raise GKException("MatMul's inputs shape must be 2D, but got {}, {}".format(len(shape_0), len(shape_1)))
|
||||
transpose_a = self.attrs["transpose_a"]
|
||||
transpose_b = self.attrs["transpose_b"]
|
||||
m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
|
||||
k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
|
||||
if k1 != k2:
|
||||
raise GKException("MatMul's inputs have different k value: {} vs {}".format(k1, k2))
|
||||
output_shape = [m, n]
|
||||
return output_shape
|
||||
|
||||
|
||||
class PadAkg(OpInfer):
|
||||
"""PadAkg infer"""
|
||||
def _infer_shape(self):
|
||||
shape = list(self.inputs[0].shape)
|
||||
n = len(shape)
|
||||
pad_before = list(self.attrs["head"])
|
||||
pad_after = list(self.attrs["tail"])
|
||||
if len(pad_before) != n or len(pad_after) != n:
|
||||
raise GKException("Input dimension and pad mismatch: {}d vs {}d vs {}d"
|
||||
.format(n, len(pad_before), len(pad_after)))
|
||||
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
|
||||
return out_shape
|
||||
|
||||
|
||||
class UnPadAkg(OpInfer):
|
||||
"""UnPadAkg infer"""
|
||||
def _infer_shape(self):
|
||||
shape = list(self.inputs[0].shape)
|
||||
n = len(shape)
|
||||
unpad_after = list(self.attrs["tail"])
|
||||
if len(unpad_after) != n:
|
||||
raise GKException("Input dimension and pad mismatch: {}d vs {}d".format(n, len(unpad_after)))
|
||||
out_shape = [shape[i] - unpad_after[i] for i in range(n)]
|
||||
return out_shape
|
||||
|
|
Loading…
Reference in New Issue