forked from OSchip/llvm-project
76 lines
3.2 KiB
MLIR
76 lines
3.2 KiB
MLIR
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
|
|
|
|
// CHECK-LABEL: @conv_2d_nhwc_hwcf
|
|
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
|
|
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
|
|
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
|
|
func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
|
|
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
|
|
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
|
|
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
|
|
// CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
|
|
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
|
|
%0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
|
|
strides = dense<1> : tensor<2xi64>}
|
|
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
|
|
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
|
|
// CHECK: return %[[RES]]
|
|
return %0 : tensor<?x1x?x?xf32>
|
|
}
|
|
|
|
transform.with_pdl_patterns {
|
|
^bb0(%arg0: !pdl.operation):
|
|
pdl.pattern @pdl_target : benefit(1) {
|
|
%args = operands
|
|
%results = types
|
|
%0 = pdl.operation "linalg.conv_2d_nhwc_hwcf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
|
|
// TODO: we don't want this, but it is the required terminator for pdl.pattern
|
|
rewrite %0 with "transform.dialect"
|
|
}
|
|
|
|
transform.sequence %arg0 {
|
|
^bb1(%arg1: !pdl.operation):
|
|
%0 = pdl_match @pdl_target in %arg1
|
|
%1 = transform.structured.decompose %0
|
|
}
|
|
}
|
|
|
|
// -----
|
|
|
|
// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
|
|
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
|
|
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
|
|
func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
|
|
// CHECK: %[[RES:.+]] = linalg.init_tensor
|
|
%init = linalg.init_tensor [1, 1, 56, 96] : tensor<1x1x56x96xf32>
|
|
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
|
|
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
|
|
// CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
|
|
// CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
|
|
// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
|
|
// CHECK-SAME: outs(%[[SLICERES]]
|
|
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
|
|
%0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
|
|
ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
|
|
outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
|
|
// CHECK: %[[INSERTED]]
|
|
return %0: tensor<1x1x56x96xf32>
|
|
}
|
|
|
|
transform.with_pdl_patterns {
|
|
^bb0(%arg0: !pdl.operation):
|
|
pdl.pattern @pdl_target : benefit(1) {
|
|
%args = operands
|
|
%results = types
|
|
%0 = pdl.operation "linalg.depthwise_conv_2d_nhwc_hwc"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
|
|
// TODO: we don't want this, but it is the required terminator for pdl.pattern
|
|
rewrite %0 with "transform.dialect"
|
|
}
|
|
|
|
transform.sequence %arg0 {
|
|
^bb1(%arg1: !pdl.operation):
|
|
%0 = pdl_match @pdl_target in %arg1
|
|
%1 = transform.structured.decompose %0
|
|
}
|
|
}
|