forked from OSchip/llvm-project
[mlir] Updated depthwise conv to support kernel dilation
Depthwise convolution should support kernel dilation and non-dilation should not be a special case. Updated op definition to include a dilation attribute. This also adds a tosa.depthwise_conv2d lowering to linalg to support the new linalg behavior. Differential Revision: https://reviews.llvm.org/D103219
This commit is contained in:
parent
aaac268285
commit
422c7036d5
|
@ -155,7 +155,7 @@ ods_def<DepthwiseConvInputNHWCFilterHWCFOp>:
|
|||
def depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
(I: f32(N, IH, IW, CI), K: f32(KH, KW, CI, CO))
|
||||
-> (O: f32(N, OH, OW, CI, CO))
|
||||
attr(strides: 2xi64)
|
||||
attr(strides: 2xi64, dilations: 2xi64)
|
||||
"""A general depth-wise 2-D convolution operation.
|
||||
|
||||
This operation performs depth-wise 2-D convolution over an input `I` and filter
|
||||
|
@ -164,7 +164,7 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
|
|||
```
|
||||
O(n, oh, ow, ci, co) = AddFOp<kh, kw>(
|
||||
O(n, oh, ow, ci, co),
|
||||
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
|
||||
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], ci),
|
||||
K(kh, kw, ci, co)));
|
||||
```
|
||||
|
||||
|
@ -186,7 +186,7 @@ Linalg reshape op which collapses `CI` and `CO` into one dimension.
|
|||
{
|
||||
O(n, oh, ow, ci, co) = AddFOp<kh, kw>(
|
||||
O(n, oh, ow, ci, co),
|
||||
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
|
||||
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], ci),
|
||||
K(kh, kw, ci, co)));
|
||||
}
|
||||
|
||||
|
@ -194,7 +194,7 @@ ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
|
|||
def depthwise_conv_2d_input_nhwc_filter_hwc
|
||||
(I: f32(N, IH, IW, C), K: f32(KH, KW, C))
|
||||
-> (O: f32(N, OH, OW, C))
|
||||
attr(strides: 2xi64)
|
||||
attr(strides: 2xi64, dilations: 2xi64)
|
||||
"""A depth-wise 2-D convolution operation.
|
||||
|
||||
This operation performs depth-wise 2-D convolution over an input `I` and filter
|
||||
|
@ -203,7 +203,7 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
|
|||
```
|
||||
O(n, oh, ow, c) = AddFOp<kh, kw>(
|
||||
O(n, oh, ow, c),
|
||||
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
|
||||
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c),
|
||||
K(kh, kw, c)));
|
||||
```
|
||||
|
||||
|
@ -223,7 +223,7 @@ Note: this op only supports channel multiplier == 1.
|
|||
{
|
||||
O(n, oh, ow, c) = AddFOp<kh, kw>(
|
||||
O(n, oh, ow, c),
|
||||
MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
|
||||
MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c),
|
||||
K(kh, kw, c)));
|
||||
}
|
||||
|
||||
|
|
|
@ -956,9 +956,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
|
|||
}
|
||||
|
||||
if (isa<tosa::DepthwiseConv2DOp>(op)) {
|
||||
if (llvm::any_of(dilation, [](int64_t d) { return d > 1; }))
|
||||
return failure();
|
||||
|
||||
ShapedType linalgConvTy =
|
||||
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
|
||||
weightShape[2], weightShape[3]},
|
||||
|
@ -969,7 +966,7 @@ convolutionMatchAndRewriterHelper(Operation *op,
|
|||
Value conv = rewriter
|
||||
.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
|
||||
loc, linalgConvTy, ValueRange{input, weight},
|
||||
ValueRange{biasReshape}, strideAttr)
|
||||
ValueRange{biasReshape}, dilationAttr, strideAttr)
|
||||
.getResult(0);
|
||||
|
||||
Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
|
||||
|
|
|
@ -1189,7 +1189,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
|
|||
// CHECK: linalg.yield %arg3 : f32
|
||||
// CHECK: } -> tensor<1x5x5x33xf32>
|
||||
// CHECK: [[DBIAS:%.+]] = linalg.tensor_reshape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
|
||||
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
|
||||
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
|
||||
// CHECK: linalg.tensor_reshape %3 {{\[}}[0], [1], [2], [3, 4]]
|
||||
%2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>)
|
||||
return
|
||||
|
|
|
@ -78,7 +78,7 @@ func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C:
|
|||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
{ strides = dense<1> : tensor<2xi64> }
|
||||
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
|
||||
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
|
||||
outs(%output : memref<2x3x4x2x3xf32>)
|
||||
return
|
||||
|
@ -103,8 +103,35 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %fil
|
|||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x2x3x2x3xf32>) {
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
|
||||
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
|
||||
outs(%output : memref<2x2x3x2x3xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5 * 2, d2 + d6 * 2, d3)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
|
||||
|
||||
// CHECK: func @depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<2x2x3x2x3xf32>)
|
||||
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
|
||||
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
|
||||
// CHECK-NEXT: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
|
|
|
@ -6,11 +6,11 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
|
|||
%init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
|
||||
%fill = linalg.fill(%init, %zero) : tensor<2x3x4x2x3xf32>, f32 -> tensor<2x3x4x2x3xf32>
|
||||
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<2x3x4x2x3xf32>)
|
||||
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
{ strides = dense<1> : tensor<2xi64> }
|
||||
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
|
||||
ins(%input, %filter : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
|
||||
outs(%fill : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
|
||||
return %0 : tensor<2x3x4x2x3xf32>
|
||||
|
@ -19,11 +19,11 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
|
|||
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
|
||||
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
{ strides = dense<1> : tensor<2xi64> }
|
||||
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
|
||||
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
|
||||
outs(%output : memref<2x3x4x2x3xf32>)
|
||||
return
|
||||
|
@ -33,10 +33,10 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32
|
|||
func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
|
||||
%init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
|
||||
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
|
||||
// CHECK-SAME: {strides = dense<2> : vector<2xi64>}
|
||||
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
|
||||
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
|
||||
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
|
||||
ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
|
||||
outs(%init: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
|
||||
return %0: tensor<1x56x56x96xf32>
|
||||
|
@ -45,20 +45,58 @@ func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96
|
|||
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_memref
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwc_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc
|
||||
// CHECK-SAME: {strides = dense<2> : vector<2xi64>}
|
||||
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>)
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
|
||||
%zero = constant 0.000000e+00 : f32
|
||||
%init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32>
|
||||
%fill = linalg.fill(%init, %zero) : tensor<2x6x7x2x3xf32>, f32 -> tensor<2x6x7x2x3xf32>
|
||||
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<2x6x7x2x3xf32>)
|
||||
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
|
||||
ins(%input, %filter : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
|
||||
outs(%fill : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32>
|
||||
return %0 : tensor<2x6x7x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref_dilated
|
||||
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref_dilated(%input: memref<2x8x9x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x6x7x2x3xf32>) {
|
||||
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<2x6x7x2x3xf32>)
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
|
||||
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
|
||||
ins(%input, %filter : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
|
||||
outs(%output : memref<2x6x7x2x3xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{missing indexing map required attribute 'strides'}}
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>}
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{missing indexing map required attribute 'dilations'}}
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<1> : vector<2xi64>}
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
|
@ -68,7 +106,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x11
|
|||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2.0> : vector<2xf32>}
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
|
@ -78,7 +116,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
|
|||
|
||||
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
|
||||
// expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<3xi64> }
|
||||
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
|
||||
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
|
||||
outs(%output: memref<1x56x56x96xf32>)
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue