!18880 [GraphKernel] Enable TensorCore for Bert-Base, SSD on GPU
Merge pull request !18880 from lishanni/bert_bmm
This commit is contained in:
commit
1983ded03f
|
@ -18,14 +18,14 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
M_ALIGN = 16
|
||||
N_ALIGN = 16
|
||||
K_ALIGN = 8
|
||||
M_ALIGN = 32
|
||||
N_ALIGN = 32
|
||||
K_ALIGN = 16
|
||||
K_LIMIT = 800
|
||||
MNK_LIMIT = 3 * (10 ** 10)
|
||||
N0_CHANNEL_ALIGN = 16
|
||||
N1_CHANNEL_ALIGN = 16
|
||||
C_CHANNEL_ALIGN = 8
|
||||
N0_CHANNEL_ALIGN = 32
|
||||
N1_CHANNEL_ALIGN = 32
|
||||
C_CHANNEL_ALIGN = 16
|
||||
OUT_NHW_ALIGN = 128
|
||||
|
||||
|
||||
|
@ -63,8 +63,7 @@ class Conv2D(Expander):
|
|||
dilation = self.attrs['dilation']
|
||||
_, h, w, _ = self.inputs[1]['shape']
|
||||
if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
|
||||
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0 and \
|
||||
self.k <= K_LIMIT and self.m * self.n * self.k < MNK_LIMIT:
|
||||
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -72,7 +71,8 @@ class Conv2D(Expander):
|
|||
type_0 = self.inputs[0]['data_type']
|
||||
type_1 = self.inputs[1]['data_type']
|
||||
if type_0 != "float16" or type_1 != "float16":
|
||||
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))
|
||||
raise GKException(
|
||||
"inputs type should be float16, but got {} and {}".format(type_0, type_1))
|
||||
|
||||
formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
|
||||
check_format_any(formats, DF.NHWC)
|
||||
|
@ -80,12 +80,14 @@ class Conv2D(Expander):
|
|||
groups = self.attrs['groups']
|
||||
group = self.attrs['group']
|
||||
if groups != 1 or group != 1:
|
||||
raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group))
|
||||
raise GKException(
|
||||
"groups and group should be both 1, but got {} and {}.".format(groups, group))
|
||||
|
||||
dilation = self.attrs['dilation']
|
||||
check_nd(dilation, 4)
|
||||
if dilation != [1, 1, 1, 1]:
|
||||
raise GKException("dilation should be all 1, but got {}".format(dilation))
|
||||
raise GKException(
|
||||
"dilation should be all 1, but got {}".format(dilation))
|
||||
|
||||
pad_list = self.attrs['pad_list']
|
||||
pad_mode = self.attrs['pad_mode']
|
||||
|
@ -100,16 +102,16 @@ class Conv2D(Expander):
|
|||
check_nd(stride, 4)
|
||||
n0, h0, w0, c0 = shape_0
|
||||
n1, h1, w1, c1 = shape_1
|
||||
if n0 <= N0_CHANNEL_ALIGN:
|
||||
raise GKException("N({}) channel of first input should > {}".format(n0, N0_CHANNEL_ALIGN))
|
||||
if n1 < N1_CHANNEL_ALIGN:
|
||||
raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN))
|
||||
if c0 != c1 or c0 < C_CHANNEL_ALIGN:
|
||||
raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN))
|
||||
if stride != [1, 1, 2, 2]:
|
||||
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3]))
|
||||
if (n0 % N0_CHANNEL_ALIGN) != 0:
|
||||
raise GKException("N({}) channel of first input should be multiples of {}".format(n0, N0_CHANNEL_ALIGN))
|
||||
if (n1 % N1_CHANNEL_ALIGN) != 0:
|
||||
raise GKException("O({}) channel of second input should be multiples of {}".format(n1, N1_CHANNEL_ALIGN))
|
||||
if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0:
|
||||
raise GKException("C channel of inputs({}, {}) should be same and also be multiples of {}".format(
|
||||
c0, c1, C_CHANNEL_ALIGN))
|
||||
# n0 pad
|
||||
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
|
||||
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) //
|
||||
N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN
|
||||
# h0, w0 pad
|
||||
if self.has_pad:
|
||||
h0 = h0 + pad_list[0] + pad_list[1]
|
||||
|
@ -118,16 +120,29 @@ class Conv2D(Expander):
|
|||
c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN
|
||||
c1 = c0
|
||||
# n1 pad
|
||||
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN
|
||||
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) //
|
||||
N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN
|
||||
|
||||
# check if can optimize to matmul
|
||||
self.m, self.n, self.k = n0 * h0 * w0, n1, c1
|
||||
self.can_optimize_to_matmul = self._optimize_to_matmul()
|
||||
|
||||
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
|
||||
if not self.can_optimize_to_matmul and n0 * out_h * out_w % OUT_NHW_ALIGN != 0:
|
||||
raise GKException("N({}) * H({}) * W({}) of Conv2d output should be multiplies of {}"
|
||||
.format(n0, out_h, out_w, OUT_NHW_ALIGN))
|
||||
# requirements
|
||||
if self.can_optimize_to_matmul:
|
||||
if self.k > K_LIMIT:
|
||||
raise GKException(
|
||||
"If transformed to MatMul, C0({}) should not be larger than {}".format(self.k, K_LIMIT))
|
||||
if self.m * self.n * self.k >= MNK_LIMIT:
|
||||
raise GKException("If transformed to MatMul, The total size({}) should not be larger than {}".format(
|
||||
self.m * self.n * self.k, MNK_LIMIT))
|
||||
else:
|
||||
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
|
||||
if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0:
|
||||
raise GKException("N({}) * H({}) * W({}) of output should be multiplies of {}".format(
|
||||
n0, out_h, out_w, OUT_NHW_ALIGN))
|
||||
if stride != [1, 1, 2, 2]:
|
||||
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3]))
|
||||
|
||||
self.shape_0_pad = [n0, h0, w0, c0]
|
||||
self.shape_1_pad = [n1, h1, w1, c1]
|
||||
|
||||
|
|
|
@ -81,7 +81,8 @@ def ssd_model_build():
|
|||
def set_graph_kernel_context(device_target, model):
|
||||
if device_target == "GPU" and model == "ssd300":
|
||||
# Enable graph kernel for default model ssd300 on GPU back-end.
|
||||
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
|
||||
context.set_context(enable_graph_kernel=True,
|
||||
graph_kernel_flags="--enable_parallel_fusion --enable_expand_ops=Conv2D")
|
||||
|
||||
def set_parameter(model_name):
|
||||
if model_name == "ssd_resnet50_fpn":
|
||||
|
|
|
@ -127,11 +127,14 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
|
|||
|
||||
|
||||
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
|
||||
"""Add suitable graph kernel context for different configs."""
|
||||
if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
|
||||
if device_target == 'GPU':
|
||||
if cfg.bert_network == 'base':
|
||||
context.set_context(enable_graph_kernel=True,
|
||||
graph_kernel_flags="--enable_stitch_fusion=true --enable_parallel_fusion=true")
|
||||
graph_kernel_flags="--enable_stitch_fusion=true "
|
||||
"--enable_parallel_fusion=true "
|
||||
"--enable_cluster_ops=BatchMatMul")
|
||||
else:
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue