forked from OSchip/llvm-project
[mlir][linalg] Add attribute matcher to structured.match transform op
This is useful for building small test cases and will be utilized in a subsequent commit that adds a fusion example. Differential Revision: https://reviews.llvm.org/D130344
This commit is contained in:
parent
bc882ed21f
commit
32c6e0815a
|
@ -193,10 +193,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
|
|||
|
||||
The following constraints are supported:
|
||||
- interface: an optional MatchInterfaceEnum specifying an enum
|
||||
representation for an interface to target.
|
||||
- ops: an optional StrArrayAttr specifying the concrete name of an op.
|
||||
representation for an interface to target.
|
||||
- ops: an optional StrArrayAttr specifying the concrete name of an op.
|
||||
Multiple names can be specified. Matched ops must have one of specified
|
||||
names.
|
||||
- attribute: an optional Str specifying the name of an attribute that
|
||||
matched ops must have.
|
||||
|
||||
Note: either `ops` or `interface` must be specified.
|
||||
Note: Only ops that satisfy all specified constraints are matched.
|
||||
|
||||
TODO: Extend with regions to allow a limited form of constraints.
|
||||
|
||||
|
@ -214,12 +218,17 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
|
|||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
OptionalAttr<StrArrayAttr>:$ops,
|
||||
OptionalAttr<MatchInterfaceEnum>:$interface);
|
||||
OptionalAttr<MatchInterfaceEnum>:$interface,
|
||||
OptionalAttr<StrAttr>:$attribute);
|
||||
// TODO: variadic results when needed.
|
||||
let results = (outs PDL_Operation:$results);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
(`ops` `{` $ops^ `}`)?
|
||||
(`interface` `{` $interface^ `}`)?
|
||||
(`attribute` `{` $attribute^ `}`)?
|
||||
`in` $target attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
|
||||
|
|
|
@ -430,15 +430,6 @@ LogicalResult transform::InterchangeOp::verify() {
|
|||
// MatchOp
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult transform::MatchOp::verify() {
|
||||
bool opXorIface = getOps().hasValue() ^ getInterface().hasValue();
|
||||
if (!opXorIface)
|
||||
return this->emitOpError(
|
||||
"requires a either a match_op or a match_interface attribute (but not "
|
||||
"both)");
|
||||
return success();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::MatchOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
|
@ -453,21 +444,28 @@ transform::MatchOp::apply(transform::TransformResults &results,
|
|||
this->emitOpError("requires exactly one target handle"));
|
||||
|
||||
SmallVector<Operation *> res;
|
||||
|
||||
auto matchFun = [&](Operation *op) {
|
||||
if (strs.contains(op->getName().getStringRef()))
|
||||
res.push_back(op);
|
||||
if (getOps().hasValue() && !strs.contains(op->getName().getStringRef()))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Interfaces cannot be matched by name, just by ID.
|
||||
// So we specifically encode the interfaces we care about for this op.
|
||||
if (getInterface().hasValue()) {
|
||||
auto iface = getInterface().getValue();
|
||||
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
|
||||
isa<linalg::LinalgOp>(op))
|
||||
res.push_back(op);
|
||||
!isa<linalg::LinalgOp>(op))
|
||||
return WalkResult::advance();
|
||||
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
|
||||
isa<TilingInterface>(op))
|
||||
res.push_back(op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
if (getAttribute().hasValue() && !op->hasAttr(getAttribute().getValue()))
|
||||
return WalkResult::advance();
|
||||
|
||||
// All constraints are satisfied.
|
||||
res.push_back(op);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
|
||||
payloadOps.front()->walk(matchFun);
|
||||
|
@ -475,65 +473,6 @@ transform::MatchOp::apply(transform::TransformResults &results,
|
|||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
ParseResult transform::MatchOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
// Parse 'match_op' or 'interface' clause.
|
||||
if (succeeded(parser.parseOptionalKeyword("ops"))) {
|
||||
ArrayAttr opsAttr;
|
||||
if (parser.parseLBrace() ||
|
||||
parser.parseCustomAttributeWithFallback(
|
||||
opsAttr, parser.getBuilder().getType<NoneType>(), "ops",
|
||||
result.attributes) ||
|
||||
parser.parseRBrace())
|
||||
return failure();
|
||||
} else if (succeeded(parser.parseOptionalKeyword("interface"))) {
|
||||
if (parser.parseLBrace())
|
||||
return failure();
|
||||
StringRef attrStr;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseKeyword(&attrStr))
|
||||
return failure();
|
||||
auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr);
|
||||
if (!interfaceEnum)
|
||||
return parser.emitError(loc, "invalid ")
|
||||
<< "match_interface attribute specification: \"" << attrStr << '"';
|
||||
transform::MatchInterfaceEnumAttr match_interfaceAttr =
|
||||
transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(),
|
||||
interfaceEnum.value());
|
||||
result.addAttribute("interface", match_interfaceAttr);
|
||||
if (parser.parseRBrace())
|
||||
return failure();
|
||||
} else {
|
||||
auto loc = parser.getCurrentLocation();
|
||||
return parser.emitError(loc, "expected ops or interface");
|
||||
}
|
||||
|
||||
OpAsmParser::UnresolvedOperand targetRawOperands[1];
|
||||
ArrayRef<OpAsmParser::UnresolvedOperand> targetOperands(targetRawOperands);
|
||||
if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
|
||||
result.addTypes(pdlOpType);
|
||||
if (parser.resolveOperands(targetOperands, pdlOpType, result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void transform::MatchOp::print(OpAsmPrinter &p) {
|
||||
if ((*this)->getAttr("ops")) {
|
||||
p << " ops{";
|
||||
p.printAttributeWithoutType(getOpsAttr());
|
||||
p << "}";
|
||||
}
|
||||
if ((*this)->getAttr("interface")) {
|
||||
p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}";
|
||||
}
|
||||
p << " in " << getTarget();
|
||||
p.printOptionalAttrDict((*this)->getAttrs(),
|
||||
/*elidedAttrs=*/{"ops", "interface"});
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// MultiTileSizesOp
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
|
||||
|
||||
func.func @bar() {
|
||||
// expected-remark @below {{matched op name}}
|
||||
// expected-remark @below {{matched attr name}}
|
||||
%0 = arith.constant {my_attr} 0: i32
|
||||
// expected-remark @below {{matched op name}}
|
||||
%1 = arith.constant 1 : i32
|
||||
return
|
||||
}
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%match_name = transform.structured.match ops{["arith.constant"]} in %arg1
|
||||
transform.test_print_remark_at_operand %match_name, "matched op name"
|
||||
transform.test_consume_operand %match_name
|
||||
|
||||
%match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1
|
||||
transform.test_print_remark_at_operand %match_attr, "matched attr name"
|
||||
transform.test_consume_operand %match_attr
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue