add conv2d expander

use data_format instead of format in op_infer

add pad and unpad inside conv2d expander

fix pad value

add limit in conv2d expander
This commit is contained in:
looop5 2021-04-27 11:42:03 +08:00
parent 35b2e40a72
commit dd81f47271
4 changed files with 284 additions and 2 deletions

View File

@ -52,3 +52,4 @@ from .lamb_apply_weight_assign import LambApplyWeightAssign
from .softmax_grad_ext import SoftmaxGradExt
from .square_sum_v1 import SquareSumV1
from .fused_mul_add import FusedMulAdd
from .conv2d import Conv2D

View File

@ -0,0 +1,174 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for Conv2D"""
from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
M_ALIGN = 16
N_ALIGN = 16
K_ALIGN = 8
K_LIMIT = 4096
MNK_LIMIT = 3 * (10 ** 10)
N0_CHANNEL_ALIGN = 16
N1_CHANNEL_ALIGN = 16
C_CHANNEL_ALIGN = 8
OUT_NHW_ALIGN = 128
@VLD.add_format(DF.NHWC, DF.NHWC)
@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
class Conv2D(Expander):
"""
Conv2D expander
Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped.
Conditions to expand:
inputs are NHWC format and float16.
attr groups and group are 1.
attr dilation are all 1.
N channel of inputs > 16.
C channel of inputs > 8.
output N*H*W are multiplies of 128.
"""
def __init__(self, expand_info):
super().__init__(expand_info)
self.dst_type = self.outputs[0]['data_type']
self.dst_format = self.outputs[0]['format']
self.has_pad = False
self.can_optimize_to_matmul = False
self.shape_0_pad = self.inputs[0]['shape']
self.shape_1_pad = self.inputs[1]['shape']
self.m = 0
self.n = 0
self.k = 0
def _optimize_to_matmul(self):
stride = self.attrs['stride']
dilation = self.attrs['dilation']
_, h, w, _ = self.inputs[1]['shape']
if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0 and \
self.k <= K_LIMIT and self.m * self.n * self.k < MNK_LIMIT:
return True
return False
def _check(self):
type_0 = self.inputs[0]['data_type']
type_1 = self.inputs[1]['data_type']
if type_0 != "float16" or type_1 != "float16":
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))
groups = self.attrs['groups']
group = self.attrs['group']
if groups != 1 or group != 1:
raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group))
dilation = self.attrs['dilation']
check_nd(dilation, 4)
if dilation != [1, 1, 1, 1]:
raise GKException("dilation should be all 1, but got {}".format(dilation))
pad_list = self.attrs['pad_list']
pad_mode = self.attrs['pad_mode']
check_nd(pad_list, 4)
self.has_pad = conv_had_pad(pad_list, pad_mode)
shape_0 = self.inputs[0]['shape']
shape_1 = self.inputs[1]['shape']
stride = self.attrs['stride']
check_nd(shape_0, 4)
check_nd(shape_1, 4)
check_nd(stride, 4)
n0, h0, w0, c0 = shape_0
n1, h1, w1, c1 = shape_1
if n0 < N0_CHANNEL_ALIGN:
raise GKException("N({}) channel of first input should >= {}".format(n0, N0_CHANNEL_ALIGN))
if n1 < N1_CHANNEL_ALIGN:
raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN))
if c0 != c1 or c0 < C_CHANNEL_ALIGN:
raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN))
# n0 pad
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
# h0, w0 pad
if self.has_pad:
h0 = h0 + pad_list[0] + pad_list[1]
w0 = w0 + pad_list[2] + pad_list[3]
# c0, c1 pad
c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN
c1 = c0
# n1 pad
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN
# check if can optimize to matmul
self.m, self.n, self.k = n0 * h0 * w0, n1, c1
self.can_optimize_to_matmul = self._optimize_to_matmul()
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
if not self.can_optimize_to_matmul and n0 * out_h * out_w % OUT_NHW_ALIGN != 0:
raise GKException("N({}) * H({}) * W({}) of Conv2d output should be multiplies of {}"
.format(n0, out_h, out_w, OUT_NHW_ALIGN))
self.shape_0_pad = [n0, h0, w0, c0]
self.shape_1_pad = [n1, h1, w1, c1]
def _expand(self, graph_builder):
input_0 = self.inputs[0]
input_1 = self.inputs[1]
n0, _, _, c0 = input_0.shape
n1, _, _, c1 = input_1.shape
n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
n1_p, _, _, c1_p = self.shape_1_pad
pad_value = 0
# input0 pad
input_0_pad_before = [0, 0, 0, 0]
input_0_pad_after = [0, 0, 0, 0]
if self.has_pad:
pad_list = self.attrs['pad_list']
input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
input_0_pad_after[0] = n0_p - n0
input_0_pad_after[3] = c0_p - c0
if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
'tail': input_0_pad_after,
'pad_val': pad_value})
# input1 pad
input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
if input_1_pad_after != [0, 0, 0, 0]:
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
'tail': input_1_pad_after,
'pad_val': pad_value})
if self.can_optimize_to_matmul:
a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
'transpose_b': True,
'dst_type': self.dst_type})
result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
'format': self.dst_format})
else:
attrs = self.attrs
attrs['pad_list'] = [0, 0, 0, 0]
attrs['dst_type'] = self.dst_type
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
# unpad
unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
if unpad_after != [0, 0, 0, 0]:
result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
return result

View File

@ -190,6 +190,7 @@ class PrimLib:
'BatchMatMul': Prim(OPAQUE),
'UnPadAkg': Prim(OPAQUE),
'PadAkg': Prim(OPAQUE),
'Conv2D': Prim(OPAQUE),
}
default_primtive = Prim(UNKNOWN)

View File

@ -14,7 +14,6 @@
# ===========================================================================
"""GraphKernel Op Infer"""
import copy
import sys
from functools import reduce
@ -24,6 +23,7 @@ from .model import PrimLib, DataFormat as DF
def infer(op_name, inputs, attrs):
"""infer shape dtype and format"""
def _create_opinfer():
if hasattr(sys.modules[__name__], op_name):
op_cls = getattr(sys.modules[__name__], op_name)
@ -38,6 +38,7 @@ def infer(op_name, inputs, attrs):
raise GKException("OpInfo does not support op {}".format(op_name))
op_cls = getattr(sys.modules[__name__], cls_name)
return op_cls(op_name, inputs, attrs)
return _create_opinfer().infer()
@ -168,7 +169,7 @@ class _Reshape(OpInfer):
raise GKException("_infer_shape should be implemented by subclass")
def _infer_format(self):
return DF.DEFAULT
return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
class Reshape(_Reshape):
@ -236,3 +237,108 @@ class Select(_Elemwise):
def _infer_type(self):
return self.inputs[1].dtype
def check_nd(data, nd):
if not isinstance(data, (list, tuple)) or len(data) != nd:
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
def conv_had_pad(pad_list, pad_mode):
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
return True
if pad_mode not in ["VALID", "valid"]:
for _, pad in enumerate(pad_list):
if pad != 0:
return True
return False
class Conv2D(OpInfer):
"""Conv2D infer"""
def _infer_type(self):
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
return self.inputs[0].dtype
def _infer_shape(self):
shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape)
check_nd(shape_0, 4)
check_nd(shape_1, 4)
format_0 = self.inputs[0].data_format
format_1 = self.inputs[1].data_format
if format_0 != DF.NHWC or format_1 != DF.NHWC:
raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1))
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
pad_list = self.attrs["pad_list"]
pad_mode = self.attrs["pad_mode"]
kernel_size = self.attrs["kernel_size"]
stride = self.attrs["stride"]
dilation = self.attrs["dilation"]
check_nd(pad_list, 4)
check_nd(kernel_size, 2)
check_nd(stride, 4)
check_nd(dilation, 4)
has_pad = conv_had_pad(pad_list, pad_mode)
if not has_pad:
pad_list = [0, 0, 0, 0]
k_h = (kernel_size[0] - 1) * dilation[-2] + 1
k_w = (kernel_size[1] - 1) * dilation[-1] + 1
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
return [n, out_h, out_w, out_channel]
class MatMul(OpInfer):
"""MatMul infer"""
def _infer_type(self):
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
return self.inputs[0].dtype
def _infer_shape(self):
shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape)
if len(shape_0) != 2 or len(shape_1) != 2:
raise GKException("MatMul's inputs shape must be 2D, but got {}, {}".format(len(shape_0), len(shape_1)))
transpose_a = self.attrs["transpose_a"]
transpose_b = self.attrs["transpose_b"]
m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
if k1 != k2:
raise GKException("MatMul's inputs have different k value: {} vs {}".format(k1, k2))
output_shape = [m, n]
return output_shape
class PadAkg(OpInfer):
"""PadAkg infer"""
def _infer_shape(self):
shape = list(self.inputs[0].shape)
n = len(shape)
pad_before = list(self.attrs["head"])
pad_after = list(self.attrs["tail"])
if len(pad_before) != n or len(pad_after) != n:
raise GKException("Input dimension and pad mismatch: {}d vs {}d vs {}d"
.format(n, len(pad_before), len(pad_after)))
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
return out_shape
class UnPadAkg(OpInfer):
"""UnPadAkg infer"""
def _infer_shape(self):
shape = list(self.inputs[0].shape)
n = len(shape)
unpad_after = list(self.attrs["tail"])
if len(unpad_after) != n:
raise GKException("Input dimension and pad mismatch: {}d vs {}d".format(n, len(unpad_after)))
out_shape = [shape[i] - unpad_after[i] for i in range(n)]
return out_shape