expand conv2d when input format is DefaultFormat but attr format is NHWC

This commit is contained in:
looop5 2021-05-28 11:38:32 +08:00
parent f5a23ddf26
commit 68f55e1e93
2 changed files with 15 additions and 6 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for Conv2D"""
from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad
from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
@ -29,8 +29,9 @@ C_CHANNEL_ALIGN = 8
OUT_NHW_ALIGN = 128
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NHWC, DF.NHWC)
@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
@VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
class Conv2D(Expander):
"""
Conv2D expander
@ -73,6 +74,9 @@ class Conv2D(Expander):
if type_0 != "float16" or type_1 != "float16":
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))
formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
check_format_any(formats, DF.NHWC)
groups = self.attrs['groups']
group = self.attrs['group']
if groups != 1 or group != 1:

View File

@ -239,6 +239,13 @@ class Select(_Elemwise):
return self.inputs[1].dtype
def check_format_any(formats, checked_format):
if not isinstance(formats, (list, tuple)):
raise GKException("formats {} should be list or tuple, but got {}.".format(formats, type(formats)))
if checked_format not in formats:
raise GKException("Check {} failed in {}".format(checked_format, formats))
def check_nd(data, nd):
if not isinstance(data, (list, tuple)) or len(data) != nd:
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
@ -269,10 +276,8 @@ class Conv2D(OpInfer):
check_nd(shape_0, 4)
check_nd(shape_1, 4)
format_0 = self.inputs[0].data_format
format_1 = self.inputs[1].data_format
if format_0 != DF.NHWC or format_1 != DF.NHWC:
raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1))
formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
check_format_any(formats, DF.NHWC)
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
pad_list = self.attrs["pad_list"]