[mlir][Linalg] Add a Transform dialect NavigationOp op to match a list of ops or an interface.

This operation is a NavigationOp that simplifies the writing of transform IR.
Since there is no way of refering to an interface by name, the current implementation uses
an EnumAttr and depends on the interfaces it supports.
In the future, it would be worthwhile to remove this dependence and generalize.

Differential Revision: https://reviews.llvm.org/D130267
This commit is contained in:
Nicolas Vasilache 2022-07-21 06:44:43 -07:00
parent 4b9dbbdb09
commit 1f77f01c65
24 changed files with 224 additions and 427 deletions

View File

@ -1,6 +1,8 @@
set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td)
mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs)
mlir_tablegen(LinalgTransformOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LinalgTransformOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLinalgTransformOpsIncGen)
add_mlir_doc(LinalgTransformOps LinalgStructuredTransformOps Dialects/ -gen-op-doc)

View File

@ -25,6 +25,8 @@ class LinalgOp;
// Linalg Transform Operations
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"

View File

@ -14,6 +14,7 @@ include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
@ -127,6 +128,52 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}
def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match",
[
I32EnumAttrCase<"LinalgOp", 0>,
I32EnumAttrCase<"TilingInterface", 1>
]>{
let cppNamespace = "mlir::transform";
}
def MatchOp : Op<Transform_Dialect, "structured.match",
[MemoryEffectsOpInterface,
NavigationTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let description = [{
Match op with the specified constraints, within the target op.
The following constraints are supported:
- interface: an optional MatchInterfaceEnum specifying an enum
representation for an interface to target.
- ops: an optional StrArrayAttr specifying the concrete name of an op.
Note: either `ops` or `interface` must be specified.
TODO: Extend with regions to allow a limited form of constraints.
#### Return modes
This op traverses the ops nested under `target` and returns the handles to
all the operations that match the requirements.
This op fails if the target is not a handle to exactly one operation.
Otherwise it succeeds.
This operation does not consume the target handle and produces new handles:
it is a navigation op.
}];
let arguments = (ins PDL_Operation:$target,
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface);
// TODO: variadic results when needed.
let results = (outs PDL_Operation:$results);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface, TransformEachOpTrait]> {

View File

@ -19,6 +19,7 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
using namespace mlir::linalg;
@ -271,6 +272,114 @@ LogicalResult transform::InterchangeOp::verify() {
return success();
}
//===---------------------------------------------------------------------===//
// MatchOp
//===---------------------------------------------------------------------===//
LogicalResult transform::MatchOp::verify() {
bool opXorIface = getOps().hasValue() ^ getInterface().hasValue();
if (!opXorIface)
return this->emitOpError(
"requires a either a match_op or a match_interface attribute (but not "
"both)");
return success();
}
DiagnosedSilenceableFailure
transform::MatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
llvm::StringSet<> strs;
if (getOps().hasValue())
strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
getOps()->getAsValueRange<StringAttr>().end());
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
if (payloadOps.size() != 1)
return DiagnosedSilenceableFailure(
this->emitOpError("requires exactly one target handle"));
SmallVector<Operation *> res;
auto matchFun = [&](Operation *op) {
if (strs.contains(op->getName().getStringRef()))
res.push_back(op);
// Interfaces cannot be matched by name, just by ID.
// So we specifically encode the interfaces we care about for this op.
if (getInterface().hasValue()) {
auto iface = getInterface().getValue();
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
isa<linalg::LinalgOp>(op))
res.push_back(op);
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
isa<TilingInterface>(op))
res.push_back(op);
}
};
payloadOps.front()->walk(matchFun);
results.set(getResult().cast<OpResult>(), res);
return DiagnosedSilenceableFailure(success());
}
ParseResult transform::MatchOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse 'match_op' or 'interface' clause.
if (succeeded(parser.parseOptionalKeyword("ops"))) {
ArrayAttr opsAttr;
if (parser.parseLBrace() ||
parser.parseCustomAttributeWithFallback(
opsAttr, parser.getBuilder().getType<NoneType>(), "ops",
result.attributes) ||
parser.parseRBrace())
return failure();
} else if (succeeded(parser.parseOptionalKeyword("interface"))) {
if (parser.parseLBrace())
return failure();
StringRef attrStr;
auto loc = parser.getCurrentLocation();
if (parser.parseKeyword(&attrStr))
return failure();
auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr);
if (!interfaceEnum)
return parser.emitError(loc, "invalid ")
<< "match_interface attribute specification: \"" << attrStr << '"';
transform::MatchInterfaceEnumAttr match_interfaceAttr =
transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(),
interfaceEnum.value());
result.addAttribute("interface", match_interfaceAttr);
if (parser.parseRBrace())
return failure();
} else {
auto loc = parser.getCurrentLocation();
return parser.emitError(loc, "expected ops or interface");
}
OpAsmParser::UnresolvedOperand targetRawOperands[1];
ArrayRef<OpAsmParser::UnresolvedOperand> targetOperands(targetRawOperands);
if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
result.addTypes(pdlOpType);
if (parser.resolveOperands(targetOperands, pdlOpType, result.operands))
return failure();
return success();
}
void transform::MatchOp::print(OpAsmPrinter &p) {
if ((*this)->getAttr("ops")) {
p << " ops{";
p.printAttributeWithoutType(getOpsAttr());
p << "}";
}
if ((*this)->getAttr("interface")) {
p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}";
}
p << " in " << getTarget();
p.printOptionalAttrDict((*this)->getAttrs(),
/*elidedAttrs=*/{"ops", "interface"});
}
//===---------------------------------------------------------------------===//
// MultiTileSizesOp
//===---------------------------------------------------------------------===//
@ -873,6 +982,8 @@ public:
};
} // namespace
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

View File

@ -6,15 +6,10 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["func.func"]} in %arg1
transform.bufferization.one_shot_bufferize %0
{target_is_module = false}
}
pdl.pattern @pdl_target : benefit(1) {
%0 = operation "func.func"
rewrite %0 with "transform.dialect"
}
}
// CHECK-LABEL: func @test_function(
@ -43,15 +38,10 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["func.func"]} in %arg1
transform.bufferization.one_shot_bufferize %0
{target_is_module = false, test_analysis_only = true}
}
pdl.pattern @pdl_target : benefit(1) {
%0 = operation "func.func"
rewrite %0 with "transform.dialect"
}
}
// CHECK-LABEL: func @test_function_analysis(
@ -74,15 +64,10 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["func.func"]} in %arg1
// expected-error @+1 {{bufferization failed}}
transform.bufferization.one_shot_bufferize %0 {target_is_module = false}
}
pdl.pattern @pdl_target : benefit(1) {
%0 = operation "func.func"
rewrite %0 with "transform.dialect"
}
}
func.func @test_unknown_op_failure() -> (tensor<?xf32>) {

View File

@ -2,17 +2,10 @@
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @linalg_generic : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
// This implements a 2D multisize tiling with target sizes [3, 10].
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @linalg_generic in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3}
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10}
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 }

View File

@ -72,16 +72,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.promote %0 { use_alloca }
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// -----
@ -152,16 +145,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.promote %0
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
@ -212,14 +198,7 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match interface{LinalgOp} in %arg1
%1 = transform.structured.promote %0
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}

View File

@ -35,15 +35,8 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [16, 16, 16]
%2 = transform.structured.promote %1 { operands_to_promote = [0, 2], force_full_tiles = [false, false] }
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}

View File

@ -34,15 +34,9 @@ module {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_linalg_matmul : benefit(1) {
%0 = operands
%1 = types
%2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_linalg_matmul in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1:2 = transform.structured.tile_to_foreach_thread_op %0 [10, 20] (mapped to dims [1, 0])
}
}

View File

@ -18,25 +18,6 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x
return %0 : tensor<?x1x?x?xf32>
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.conv_2d_nhwc_hwcf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1 = transform.structured.decompose %0
}
}
// -----
// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
@ -59,17 +40,9 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.depthwise_conv_2d_nhwc_hwc"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match interface{LinalgOp} in %arg1
%1 = transform.structured.decompose %0
}
}

View File

@ -16,17 +16,9 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
}
}
@ -53,17 +45,9 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
transform.loop.peel %loops#0
}
@ -105,17 +89,9 @@ func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf3
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [5, 4, 7], tile_interchange = [0, 2, 1]}
}
}

View File

@ -12,17 +12,9 @@ func.func @generalize_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> t
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1
%1 = transform.structured.generalize %0
}
}

View File

@ -20,16 +20,9 @@ func.func @interchange_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_generic : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_generic in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
transform.structured.interchange %0 { iterator_interchange = [1, 0]}
}
}
@ -44,16 +37,9 @@ func.func @interchange_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %a
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_generic : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_generic in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-error @below {{transform applied to the wrong op kind}}
transform.structured.interchange %0 { iterator_interchange = [1, 0]}
}

View File

@ -6,16 +6,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 }
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// CHECK-LABEL: @multitile_sizes_static
@ -40,16 +33,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 }
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// CHECK: #[[$MAP_A:.+]] = affine_map<()[s0] -> ([[A_IMPL:s0 floordiv 2]])>

View File

@ -33,17 +33,9 @@ func.func @static_sizes_output_divisible(%arg0: tensor<24x12xf32>,
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
}
}
@ -60,17 +52,9 @@ func.func @pad(%arg0: tensor<24x12xf32>,
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-error @below {{op expects a padding value of type 'f32', got 0 : i32}}
%1 = transform.structured.pad %0 {padding_values=[0: i32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
}
@ -88,17 +72,9 @@ func.func @pad(%arg0: tensor<24x12xf32>,
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-error @below {{expects a padding that parses to 'f32', got "foo"}}
%1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
}
@ -117,17 +93,9 @@ func.func @pad(%arg0: tensor<24x12xf32>,
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// This error is silenceable and is not reported by this transform
// {{transform.structured.pad failed to apply}}
%1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}

View File

@ -12,17 +12,9 @@ func.func @scalarize(%arg0: tensor<24x12xf32>,
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops = transform.structured.tile %0 [10, 0, 0]
%2 = transform.structured.scalarize %1
}

View File

@ -20,17 +20,9 @@ func.func @matmul_split(%A : tensor<?x256xf32>, %B: tensor<256x32xf32>, %C: tens
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1:4 = transform.structured.split_reduction %0
{ split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc}
}

View File

@ -19,17 +19,9 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
}
}

View File

@ -3,16 +3,9 @@
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @linalg_generic : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @linalg_generic in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:2 = transform.structured.split %0 after 42 { dimension = 0 }
}
}
@ -108,23 +101,10 @@ func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @func_call : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
pdl.pattern @linalg_generic : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @linalg_generic in %arg1
%1 = transform.pdl_match @func_call in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
transform.structured.split %0 after %1 { dimension = 0 }
}
}
@ -171,16 +151,9 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @linalg_generic : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @linalg_generic in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:2 = transform.structured.split %0 after 4 { dimension = 0}
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 }
}
@ -244,23 +217,10 @@ transform.sequence {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @func_call : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
pdl.pattern @linalg_generic : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @linalg_generic in %arg1
%1 = transform.pdl_match @func_call in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
// expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
transform.structured.split %0 after %1 { dimension = 0 }
}
@ -286,23 +246,10 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @func_call : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
pdl.pattern @linalg_generic : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @linalg_generic in %arg1
%1 = transform.pdl_match @func_call in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
// expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
transform.structured.split %0 after %1 { dimension = 0 }
}
@ -335,7 +282,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @func_return in %arg1
%0 = transform.structured.match ops{["func.return"]} in %arg1
// expected-error @below {{only applies to structured ops}}
transform.structured.split %0 after 16 { dimension = 1 }
}
@ -359,7 +306,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @linalg_generic in %arg1
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
// expected-error @below {{dimension 1 does not exist in target op}}
transform.structured.split %0 after 16 { dimension = 1 }
}

View File

@ -4,16 +4,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// CHECK-LABEL: func @tile_linalg_matmul(
@ -50,23 +43,10 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1 = pdl_match @func_call in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
%2, %loops:3 = transform.structured.tile %0 [%1, %1, 4]
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
pdl.pattern @func_call : benefit(1) {
%args = operands
%results = types
%0 = operation "func.call"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
func.func private @get_dynamic_tile_size() -> index

View File

@ -28,7 +28,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = get_closest_isolated_parent %0
%2 = transform.structured.vectorize %1
}
@ -75,17 +75,9 @@ func.func @vectorize_keep_pad(
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = get_closest_isolated_parent %0
%2 = transform.structured.vectorize %1
}
@ -134,17 +126,9 @@ func.func @vectorize_pad(
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = get_closest_isolated_parent %0
%2 = transform.structured.vectorize %1 {vectorize_padding = true}
}
@ -162,17 +146,9 @@ func.func @vectorize(%arg0: tensor<24x12xf32>,
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-error @below {{op requires isolated-from-above targets}}
%2 = transform.structured.vectorize %0
}

View File

@ -66,16 +66,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.promote %0 { operands_to_promote = [0, 1, 2], use_full_tiles_by_default }
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// -----
@ -136,16 +129,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.promote %0 { operands_to_promote = [0], use_full_tiles_by_default }
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// -----
@ -176,16 +162,9 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1
%1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32}
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.fill"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}
// -----
@ -217,14 +196,7 @@ transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1
%1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32}
}
pdl.pattern @pdl_target : benefit(1) {
%args = operands
%results = types
%0 = operation "linalg.fill"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %0 with "transform.dialect"
}
}

View File

@ -17,16 +17,9 @@ func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_addi : benefit(1) {
%args = operands
%results = types
%op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %op with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_addi in %arg1
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// CHECK: = transform.loop.get_parent_for
%1 = transform.loop.get_parent_for %0
%2 = transform.loop.get_parent_for %0 { num_loops = 2 }
@ -47,16 +40,9 @@ func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_addi : benefit(1) {
%args = operands
%results = types
%op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %op with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_addi in %arg1
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// expected-error @below {{could not find an 'scf.for' parent}}
%1 = transform.loop.get_parent_for %0
}
@ -96,16 +82,9 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_addi : benefit(1) {
%args = operands
%results = types
%op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %op with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_addi in %arg1
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
// CHECK: = transform.loop.outline %{{.*}}
transform.loop.outline %1 {func_name = "foo"}
@ -132,16 +111,9 @@ func.func @loop_outline_op_multi_region() {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_while : benefit(1) {
%args = operands
%results = types
%op = operation "scf.while"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %op with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_while in %arg1
%0 = transform.structured.match ops{["scf.while"]} in %arg1
// expected-error @below {{failed to outline}}
transform.loop.outline %0 {func_name = "foo"}
}
@ -170,16 +142,9 @@ func.func @loop_peel_op() {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_addi : benefit(1) {
%args = operands
%results = types
%op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %op with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_addi in %arg1
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
transform.loop.peel %1
}
@ -213,16 +178,9 @@ func.func @loop_pipeline_op(%A: memref<?xf32>, %result: memref<?xf32>) {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_addf : benefit(1) {
%args = operands
%results = types
%op = operation "arith.addf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %op with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_addf in %arg1
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
%1 = transform.loop.get_parent_for %0
%2 = transform.loop.pipeline %1
// Verify that the returned handle is usable.
@ -247,16 +205,9 @@ func.func @loop_unroll_op() {
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @match_addi : benefit(1) {
%args = operands
%results = types
%op = operation "arith.addi"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
rewrite %op with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_addi in %arg1
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
transform.loop.unroll %1 { factor = 4 }
}

View File

@ -7127,6 +7127,14 @@ gentbl_cc_library(
["-gen-op-defs"],
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc",
),
(
["-gen-enum-decls"],
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.h.inc",
),
(
["-gen-enum-defs"],
"include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td",