[mlir][linalg] Add attribute matcher to structured.match transform op

This is useful for building small test cases and will be utilized in a subsequent commit that adds a fusion example.

Differential Revision: https://reviews.llvm.org/D130344
This commit is contained in:
Matthias Springer 2022-07-22 13:26:48 +02:00
parent bc882ed21f
commit 32c6e0815a
3 changed files with 52 additions and 80 deletions

View File

@ -193,10 +193,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
The following constraints are supported:
- interface: an optional MatchInterfaceEnum specifying an enum
representation for an interface to target.
- ops: an optional StrArrayAttr specifying the concrete name of an op.
representation for an interface to target.
- ops: an optional StrArrayAttr specifying the concrete name of an op.
Multiple names can be specified. Matched ops must have one of specified
names.
- attribute: an optional Str specifying the name of an attribute that
matched ops must have.
Note: either `ops` or `interface` must be specified.
Note: Only ops that satisfy all specified constraints are matched.
TODO: Extend with regions to allow a limited form of constraints.
@ -214,12 +218,17 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
let arguments = (ins PDL_Operation:$target,
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface);
OptionalAttr<MatchInterfaceEnum>:$interface,
OptionalAttr<StrAttr>:$attribute);
// TODO: variadic results when needed.
let results = (outs PDL_Operation:$results);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let assemblyFormat = [{
(`ops` `{` $ops^ `}`)?
(`interface` `{` $interface^ `}`)?
(`attribute` `{` $attribute^ `}`)?
`in` $target attr-dict
}];
}
def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",

View File

@ -430,15 +430,6 @@ LogicalResult transform::InterchangeOp::verify() {
// MatchOp
//===---------------------------------------------------------------------===//
LogicalResult transform::MatchOp::verify() {
bool opXorIface = getOps().hasValue() ^ getInterface().hasValue();
if (!opXorIface)
return this->emitOpError(
"requires a either a match_op or a match_interface attribute (but not "
"both)");
return success();
}
DiagnosedSilenceableFailure
transform::MatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
@ -453,21 +444,28 @@ transform::MatchOp::apply(transform::TransformResults &results,
this->emitOpError("requires exactly one target handle"));
SmallVector<Operation *> res;
auto matchFun = [&](Operation *op) {
if (strs.contains(op->getName().getStringRef()))
res.push_back(op);
if (getOps().hasValue() && !strs.contains(op->getName().getStringRef()))
return WalkResult::advance();
// Interfaces cannot be matched by name, just by ID.
// So we specifically encode the interfaces we care about for this op.
if (getInterface().hasValue()) {
auto iface = getInterface().getValue();
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
isa<linalg::LinalgOp>(op))
res.push_back(op);
!isa<linalg::LinalgOp>(op))
return WalkResult::advance();
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
isa<TilingInterface>(op))
res.push_back(op);
return WalkResult::advance();
}
if (getAttribute().hasValue() && !op->hasAttr(getAttribute().getValue()))
return WalkResult::advance();
// All constraints are satisfied.
res.push_back(op);
return WalkResult::advance();
};
payloadOps.front()->walk(matchFun);
@ -475,65 +473,6 @@ transform::MatchOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure(success());
}
ParseResult transform::MatchOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse 'match_op' or 'interface' clause.
if (succeeded(parser.parseOptionalKeyword("ops"))) {
ArrayAttr opsAttr;
if (parser.parseLBrace() ||
parser.parseCustomAttributeWithFallback(
opsAttr, parser.getBuilder().getType<NoneType>(), "ops",
result.attributes) ||
parser.parseRBrace())
return failure();
} else if (succeeded(parser.parseOptionalKeyword("interface"))) {
if (parser.parseLBrace())
return failure();
StringRef attrStr;
auto loc = parser.getCurrentLocation();
if (parser.parseKeyword(&attrStr))
return failure();
auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr);
if (!interfaceEnum)
return parser.emitError(loc, "invalid ")
<< "match_interface attribute specification: \"" << attrStr << '"';
transform::MatchInterfaceEnumAttr match_interfaceAttr =
transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(),
interfaceEnum.value());
result.addAttribute("interface", match_interfaceAttr);
if (parser.parseRBrace())
return failure();
} else {
auto loc = parser.getCurrentLocation();
return parser.emitError(loc, "expected ops or interface");
}
OpAsmParser::UnresolvedOperand targetRawOperands[1];
ArrayRef<OpAsmParser::UnresolvedOperand> targetOperands(targetRawOperands);
if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
result.addTypes(pdlOpType);
if (parser.resolveOperands(targetOperands, pdlOpType, result.operands))
return failure();
return success();
}
void transform::MatchOp::print(OpAsmPrinter &p) {
if ((*this)->getAttr("ops")) {
p << " ops{";
p.printAttributeWithoutType(getOpsAttr());
p << "}";
}
if ((*this)->getAttr("interface")) {
p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}";
}
p << " in " << getTarget();
p.printOptionalAttrDict((*this)->getAttrs(),
/*elidedAttrs=*/{"ops", "interface"});
}
//===---------------------------------------------------------------------===//
// MultiTileSizesOp
//===---------------------------------------------------------------------===//

View File

@ -0,0 +1,24 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
func.func @bar() {
// expected-remark @below {{matched op name}}
// expected-remark @below {{matched attr name}}
%0 = arith.constant {my_attr} 0: i32
// expected-remark @below {{matched op name}}
%1 = arith.constant 1 : i32
return
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match ops{["arith.constant"]} in %arg1
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_consume_operand %match_name
%match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1
transform.test_print_remark_at_operand %match_attr, "matched attr name"
transform.test_consume_operand %match_attr
}
}