llvm-project/mlir/test/Dialect/Linalg/transform-op-decompose.mlir

76 lines
3.2 KiB
MLIR

// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
// CHECK-LABEL: @conv_2d_nhwc_hwcf
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
// CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
// CHECK: return %[[RES]]
return %0 : tensor<?x1x?x?xf32>
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.conv_2d_nhwc_hwcf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1 = transform.structured.decompose %0
}
}
// -----
// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
// CHECK: %[[RES:.+]] = linalg.init_tensor
%init = linalg.init_tensor [1, 1, 56, 96] : tensor<1x1x56x96xf32>
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
// CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
// CHECK-SAME: outs(%[[SLICERES]]
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
%0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
// CHECK: %[[INSERTED]]
return %0: tensor<1x1x56x96xf32>
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.depthwise_conv_2d_nhwc_hwc"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1 = transform.structured.decompose %0
}
}