forked from mindspore-Ecosystem/mindspore
expand conv2d when input format is DefaultFormat but attr format is NHWC
This commit is contained in:
parent
f5a23ddf26
commit
68f55e1e93
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for Conv2D"""
|
||||
from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad
|
||||
from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
@ -29,8 +29,9 @@ C_CHANNEL_ALIGN = 8
|
|||
OUT_NHW_ALIGN = 128
|
||||
|
||||
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.NHWC, DF.NHWC)
|
||||
@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
|
||||
@VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
|
||||
class Conv2D(Expander):
|
||||
"""
|
||||
Conv2D expander
|
||||
|
@ -73,6 +74,9 @@ class Conv2D(Expander):
|
|||
if type_0 != "float16" or type_1 != "float16":
|
||||
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))
|
||||
|
||||
formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
|
||||
check_format_any(formats, DF.NHWC)
|
||||
|
||||
groups = self.attrs['groups']
|
||||
group = self.attrs['group']
|
||||
if groups != 1 or group != 1:
|
||||
|
|
|
@ -239,6 +239,13 @@ class Select(_Elemwise):
|
|||
return self.inputs[1].dtype
|
||||
|
||||
|
||||
def check_format_any(formats, checked_format):
|
||||
if not isinstance(formats, (list, tuple)):
|
||||
raise GKException("formats {} should be list or tuple, but got {}.".format(formats, type(formats)))
|
||||
if checked_format not in formats:
|
||||
raise GKException("Check {} failed in {}".format(checked_format, formats))
|
||||
|
||||
|
||||
def check_nd(data, nd):
|
||||
if not isinstance(data, (list, tuple)) or len(data) != nd:
|
||||
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
|
||||
|
@ -269,10 +276,8 @@ class Conv2D(OpInfer):
|
|||
check_nd(shape_0, 4)
|
||||
check_nd(shape_1, 4)
|
||||
|
||||
format_0 = self.inputs[0].data_format
|
||||
format_1 = self.inputs[1].data_format
|
||||
if format_0 != DF.NHWC or format_1 != DF.NHWC:
|
||||
raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1))
|
||||
formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
|
||||
check_format_any(formats, DF.NHWC)
|
||||
|
||||
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
|
||||
pad_list = self.attrs["pad_list"]
|
||||
|
|
Loading…
Reference in New Issue