!18880 [GraphKernel] Enable TensorCore for Bert-Base, SSD on GPU

Merge pull request !18880 from lishanni/bert_bmm
This commit is contained in:
i-robot 2021-06-29 06:15:08 +00:00 committed by Gitee
commit 1983ded03f
3 changed files with 46 additions and 27 deletions

View File

@ -18,14 +18,14 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
M_ALIGN = 16
N_ALIGN = 16
K_ALIGN = 8
M_ALIGN = 32
N_ALIGN = 32
K_ALIGN = 16
K_LIMIT = 800
MNK_LIMIT = 3 * (10 ** 10)
N0_CHANNEL_ALIGN = 16
N1_CHANNEL_ALIGN = 16
C_CHANNEL_ALIGN = 8
N0_CHANNEL_ALIGN = 32
N1_CHANNEL_ALIGN = 32
C_CHANNEL_ALIGN = 16
OUT_NHW_ALIGN = 128
@ -63,8 +63,7 @@ class Conv2D(Expander):
dilation = self.attrs['dilation']
_, h, w, _ = self.inputs[1]['shape']
if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0 and \
self.k <= K_LIMIT and self.m * self.n * self.k < MNK_LIMIT:
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0:
return True
return False
@ -72,7 +71,8 @@ class Conv2D(Expander):
type_0 = self.inputs[0]['data_type']
type_1 = self.inputs[1]['data_type']
if type_0 != "float16" or type_1 != "float16":
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))
raise GKException(
"inputs type should be float16, but got {} and {}".format(type_0, type_1))
formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
check_format_any(formats, DF.NHWC)
@ -80,12 +80,14 @@ class Conv2D(Expander):
groups = self.attrs['groups']
group = self.attrs['group']
if groups != 1 or group != 1:
raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group))
raise GKException(
"groups and group should be both 1, but got {} and {}.".format(groups, group))
dilation = self.attrs['dilation']
check_nd(dilation, 4)
if dilation != [1, 1, 1, 1]:
raise GKException("dilation should be all 1, but got {}".format(dilation))
raise GKException(
"dilation should be all 1, but got {}".format(dilation))
pad_list = self.attrs['pad_list']
pad_mode = self.attrs['pad_mode']
@ -100,16 +102,16 @@ class Conv2D(Expander):
check_nd(stride, 4)
n0, h0, w0, c0 = shape_0
n1, h1, w1, c1 = shape_1
if n0 <= N0_CHANNEL_ALIGN:
raise GKException("N({}) channel of first input should > {}".format(n0, N0_CHANNEL_ALIGN))
if n1 < N1_CHANNEL_ALIGN:
raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN))
if c0 != c1 or c0 < C_CHANNEL_ALIGN:
raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN))
if stride != [1, 1, 2, 2]:
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3]))
if (n0 % N0_CHANNEL_ALIGN) != 0:
raise GKException("N({}) channel of first input should be multiples of {}".format(n0, N0_CHANNEL_ALIGN))
if (n1 % N1_CHANNEL_ALIGN) != 0:
raise GKException("O({}) channel of second input should be multiples of {}".format(n1, N1_CHANNEL_ALIGN))
if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0:
raise GKException("C channel of inputs({}, {}) should be same and also be multiples of {}".format(
c0, c1, C_CHANNEL_ALIGN))
# n0 pad
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) //
N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
# h0, w0 pad
if self.has_pad:
h0 = h0 + pad_list[0] + pad_list[1]
@ -118,16 +120,29 @@ class Conv2D(Expander):
c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN
c1 = c0
# n1 pad
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) //
N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN
# check if can optimize to matmul
self.m, self.n, self.k = n0 * h0 * w0, n1, c1
self.can_optimize_to_matmul = self._optimize_to_matmul()
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
if not self.can_optimize_to_matmul and n0 * out_h * out_w % OUT_NHW_ALIGN != 0:
raise GKException("N({}) * H({}) * W({}) of Conv2d output should be multiplies of {}"
.format(n0, out_h, out_w, OUT_NHW_ALIGN))
# requirements
if self.can_optimize_to_matmul:
if self.k > K_LIMIT:
raise GKException(
"If transformed to MatMul, C0({}) should not be larger than {}".format(self.k, K_LIMIT))
if self.m * self.n * self.k >= MNK_LIMIT:
raise GKException("If transformed to MatMul, The total size({}) should not be larger than {}".format(
self.m * self.n * self.k, MNK_LIMIT))
else:
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0:
raise GKException("N({}) * H({}) * W({}) of output should be multiplies of {}".format(
n0, out_h, out_w, OUT_NHW_ALIGN))
if stride != [1, 1, 2, 2]:
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3]))
self.shape_0_pad = [n0, h0, w0, c0]
self.shape_1_pad = [n1, h1, w1, c1]

View File

@ -81,7 +81,8 @@ def ssd_model_build():
def set_graph_kernel_context(device_target, model):
if device_target == "GPU" and model == "ssd300":
# Enable graph kernel for default model ssd300 on GPU back-end.
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
context.set_context(enable_graph_kernel=True,
graph_kernel_flags="--enable_parallel_fusion --enable_expand_ops=Conv2D")
def set_parameter(model_name):
if model_name == "ssd_resnet50_fpn":

View File

@ -127,11 +127,14 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
"""Add suitable graph kernel context for different configs."""
if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
if device_target == 'GPU':
if cfg.bert_network == 'base':
context.set_context(enable_graph_kernel=True,
graph_kernel_flags="--enable_stitch_fusion=true --enable_parallel_fusion=true")
graph_kernel_flags="--enable_stitch_fusion=true "
"--enable_parallel_fusion=true "
"--enable_cluster_ops=BatchMatMul")
else:
context.set_context(enable_graph_kernel=True)
else: