From 68f55e1e9349bfada3a0e5beb18fa5144a5bc280 Mon Sep 17 00:00:00 2001 From: looop5 Date: Fri, 28 May 2021 11:38:32 +0800 Subject: [PATCH] expand conv2d when input format is DefaultFormat but attr format is NHWC --- mindspore/_extends/graph_kernel/expanders/conv2d.py | 8 ++++++-- mindspore/_extends/graph_kernel/model/op_infer.py | 13 +++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mindspore/_extends/graph_kernel/expanders/conv2d.py b/mindspore/_extends/graph_kernel/expanders/conv2d.py index 4194718a6ae..ac93090aaaa 100644 --- a/mindspore/_extends/graph_kernel/expanders/conv2d.py +++ b/mindspore/_extends/graph_kernel/expanders/conv2d.py @@ -13,7 +13,7 @@ # limitations under the License. # =========================================================================== """generate json desc for Conv2D""" -from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad +from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from ._utils import Expander, ExpanderInfoValidator as VLD @@ -29,8 +29,9 @@ C_CHANNEL_ALIGN = 8 OUT_NHW_ALIGN = 128 +@VLD.add_format(DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NHWC, DF.NHWC) -@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') +@VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') class Conv2D(Expander): """ Conv2D expander @@ -73,6 +74,9 @@ class Conv2D(Expander): if type_0 != "float16" or type_1 != "float16": raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1)) + formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']] + check_format_any(formats, DF.NHWC) + groups = self.attrs['groups'] group = self.attrs['group'] if groups != 1 or group != 1: diff --git a/mindspore/_extends/graph_kernel/model/op_infer.py b/mindspore/_extends/graph_kernel/model/op_infer.py index 762d8179b95..eb59558c502 100644 --- a/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/mindspore/_extends/graph_kernel/model/op_infer.py @@ -239,6 +239,13 @@ class Select(_Elemwise): return self.inputs[1].dtype +def check_format_any(formats, checked_format): + if not isinstance(formats, (list, tuple)): + raise GKException("formats {} should be list or tuple, but got {}.".format(formats, type(formats))) + if checked_format not in formats: + raise GKException("Check {} failed in {}".format(checked_format, formats)) + + def check_nd(data, nd): if not isinstance(data, (list, tuple)) or len(data) != nd: raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data)) @@ -269,10 +276,8 @@ class Conv2D(OpInfer): check_nd(shape_0, 4) check_nd(shape_1, 4) - format_0 = self.inputs[0].data_format - format_1 = self.inputs[1].data_format - if format_0 != DF.NHWC or format_1 != DF.NHWC: - raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1)) + formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]] + check_format_any(formats, DF.NHWC) n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0] pad_list = self.attrs["pad_list"]