[mlir] Updated depthwise conv to support kernel dilation

Depthwise convolution should support kernel dilation and non-dilation should
not be a special case. Updated op definition to include a dilation attribute.

This also adds a tosa.depthwise_conv2d lowering to linalg to support the new
linalg behavior.

Differential Revision: https://reviews.llvm.org/D103219
This commit is contained in:
Rob Suderman 2021-05-26 18:02:30 -07:00
parent aaac268285
commit 422c7036d5
5 changed files with 86 additions and 24 deletions

View File

@ -155,7 +155,7 @@ ods_def<DepthwiseConvInputNHWCFilterHWCFOp>:
def depthwise_conv_2d_input_nhwc_filter_hwcf
(I: f32(N, IH, IW, CI), K: f32(KH, KW, CI, CO))
-> (O: f32(N, OH, OW, CI, CO))
attr(strides: 2xi64)
attr(strides: 2xi64, dilations: 2xi64)
"""A general depth-wise 2-D convolution operation.
This operation performs depth-wise 2-D convolution over an input `I` and filter
@ -164,7 +164,7 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
```
O(n, oh, ow, ci, co) = AddFOp<kh, kw>(
O(n, oh, ow, ci, co),
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], ci),
K(kh, kw, ci, co)));
```
@ -186,7 +186,7 @@ Linalg reshape op which collapses `CI` and `CO` into one dimension.
{
O(n, oh, ow, ci, co) = AddFOp<kh, kw>(
O(n, oh, ow, ci, co),
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], ci),
K(kh, kw, ci, co)));
}
@ -194,7 +194,7 @@ ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
def depthwise_conv_2d_input_nhwc_filter_hwc
(I: f32(N, IH, IW, C), K: f32(KH, KW, C))
-> (O: f32(N, OH, OW, C))
attr(strides: 2xi64)
attr(strides: 2xi64, dilations: 2xi64)
"""A depth-wise 2-D convolution operation.
This operation performs depth-wise 2-D convolution over an input `I` and filter
@ -203,7 +203,7 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
```
O(n, oh, ow, c) = AddFOp<kh, kw>(
O(n, oh, ow, c),
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c),
K(kh, kw, c)));
```
@ -223,7 +223,7 @@ Note: this op only supports channel multiplier == 1.
{
O(n, oh, ow, c) = AddFOp<kh, kw>(
O(n, oh, ow, c),
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c),
K(kh, kw, c)));
}

View File

@ -956,9 +956,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
}
if (isa<tosa::DepthwiseConv2DOp>(op)) {
if (llvm::any_of(dilation, [](int64_t d) { return d > 1; }))
return failure();
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
@ -969,7 +966,7 @@ convolutionMatchAndRewriterHelper(Operation *op,
Value conv = rewriter
.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
loc, linalgConvTy, ValueRange{input, weight},
ValueRange{biasReshape}, strideAttr)
ValueRange{biasReshape}, dilationAttr, strideAttr)
.getResult(0);
Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);

View File

@ -1189,7 +1189,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
// CHECK: linalg.yield %arg3 : f32
// CHECK: } -> tensor<1x5x5x33xf32>
// CHECK: [[DBIAS:%.+]] = linalg.tensor_reshape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
// CHECK: linalg.tensor_reshape %3 {{\[}}[0], [1], [2], [3, 4]]
%2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>)
return

View File

@ -78,7 +78,7 @@ func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C:
func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
{ strides = dense<1> : tensor<2xi64> }
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x3x4x2x3xf32>)
return
@ -103,8 +103,35 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %fil
// -----
func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x2x3x2x3xf32>) {
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x2x3x2x3xf32>)
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5 * 2, d2 + d6 * 2, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
// CHECK: func @depthwise_conv_2d_input_nhwc_filter_hwcf
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x2x3x2x3xf32>)
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
// CHECK-NEXT: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
// -----
func @depthwise_conv_2d_input_nhwc_filter_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return

View File

@ -6,11 +6,11 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
%init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
%fill = linalg.fill(%init, %zero) : tensor<2x3x4x2x3xf32>, f32 -> tensor<2x3x4x2x3xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<2x3x4x2x3xf32>)
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
{ strides = dense<1> : tensor<2xi64> }
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
outs(%fill : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
return %0 : tensor<2x3x4x2x3xf32>
@ -19,11 +19,11 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
{ strides = dense<1> : tensor<2xi64> }
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x3x4x2x3xf32>)
return
@ -33,10 +33,10 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32
func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
%init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
// CHECK-SAME: {strides = dense<2> : vector<2xi64>}
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
outs(%init: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
return %0: tensor<1x56x56x96xf32>
@ -45,20 +45,58 @@ func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_memref
func @depthwise_conv_2d_input_nhwc_filter_hwc_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc
// CHECK-SAME: {strides = dense<2> : vector<2xi64>}
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>)
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
}
func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
%zero = constant 0.000000e+00 : f32
%init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32>
%fill = linalg.fill(%init, %zero) : tensor<2x6x7x2x3xf32>, f32 -> tensor<2x6x7x2x3xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<2x6x7x2x3xf32>)
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
outs(%fill : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32>
return %0 : tensor<2x6x7x2x3xf32>
}
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref_dilated
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref_dilated(%input: memref<2x8x9x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x6x7x2x3xf32>) {
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x6x7x2x3xf32>)
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x6x7x2x3xf32>)
return
}
// -----
func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{missing indexing map required attribute 'strides'}}
linalg.depthwise_conv_2d_input_nhwc_filter_hwc
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
}
// -----
func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{missing indexing map required attribute 'dilations'}}
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<1> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@ -68,7 +106,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x11
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2.0> : vector<2xf32>}
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@ -78,7 +116,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<3xi64> }
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return