Defines new PDLInterp operations needed for multi-root matching in PDL.

This is commit 1 of 4 for the multi-root matching in PDL, discussed in https://llvm.discourse.group/t/rfc-multi-root-pdl-patterns-for-kernel-matching/4148 (topic flagged for review).

These operations are:
* pdl.get_accepting_ops: Returns a list of operations accepting the given value or a range of values at the specified position. Thus if there are two operations `%op1 = "foo"(%val)` and `%op2 = "bar"(%val)` accepting a value at position 0, `%ops = pdl_interp.get_accepting_ops of %val : !pdl.value at 0` will return both of them. This allows us to traverse upwards from a value to operations accepting the value.
* pdl.choose_op: Iteratively chooses one operation from a range of operations. Therefore, writing `%op = pdl_interp.choose_op from %ops` in the example above will select either `%op1`or `%op2`.

Testing: Added the corresponding test cases to mlir/test/Dialect/PDLInterp/ops.mlir.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D108543
This commit is contained in:
Stanislav Funiak 2021-11-26 17:57:30 +05:30 committed by Uday Bondhugula
parent 632acec737
commit 842b6861c0
3 changed files with 259 additions and 0 deletions

View File

@ -370,6 +370,29 @@ def PDLInterp_CheckTypesOp
let assemblyFormat = "$value `are` $types attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::ContinueOp
//===----------------------------------------------------------------------===//
def PDLInterp_ContinueOp
: PDLInterp_Op<"continue", [NoSideEffect, HasParent<"ForEachOp">,
Terminator]> {
let summary = "Breaks the current iteration";
let description = [{
`pdl_interp.continue` operation breaks the current iteration within the
`pdl_interp.foreach` region and continues with the next iteration from
the beginning of the region.
Example:
```mlir
pdl_interp.continue
```
}];
let assemblyFormat = "attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::CreateAttributeOp
//===----------------------------------------------------------------------===//
@ -513,6 +536,42 @@ def PDLInterp_EraseOp : PDLInterp_Op<"erase"> {
let assemblyFormat = "$operation attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::ExtractOp
//===----------------------------------------------------------------------===//
def PDLInterp_ExtractOp
: PDLInterp_Op<"extract", [NoSideEffect,
TypesMatchWith<
"`range` is a PDL range whose element type matches type of `result`",
"result", "range", "pdl::RangeType::get($_self)">]> {
let summary = "Extract the item at the specified index in a range";
let description = [{
`pdl_interp.extract` operations are used to extract an item from a range
at the specified index. If the index is out of range, returns null.
Example:
```mlir
// Extract the value at index 1 from a range of values.
%ops = pdl_interp.extract 1 of %values : !pdl.value
```
}];
let arguments = (ins PDL_RangeOf<PDL_AnyType>:$range,
Confined<I32Attr, [IntNonNegative]>:$index);
let results = (outs PDL_AnyType:$result);
let assemblyFormat = "$index `of` $range `:` type($result) attr-dict";
let builders = [
OpBuilder<(ins "Value":$range, "unsigned":$index), [{
build($_builder, $_state,
range.getType().cast<pdl::RangeType>().getElementType(),
range, index);
}]>,
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::FinalizeOp
//===----------------------------------------------------------------------===//
@ -533,6 +592,48 @@ def PDLInterp_FinalizeOp
let assemblyFormat = "attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::ForEachOp
//===----------------------------------------------------------------------===//
def PDLInterp_ForEachOp
: PDLInterp_Op<"foreach", [Terminator]> {
let summary = "Iterates over a range of values or ranges";
let description = [{
`pdl_interp.foreach` iteratively selects an element from a range of values
and executes the region until pdl.continue is reached.
In the bytecode interpreter, this operation is implemented by looping over
the values and, for each selection, running the bytecode until we reach
pdl.continue. This may result in multiple matches being reported. Note
that the input range is mutated (popped from).
Example:
```mlir
pdl_interp.foreach %op : !pdl.operation in %ops {
pdl_interp.continue
} -> ^next
```
}];
let arguments = (ins PDL_RangeOf<PDL_AnyType>:$values);
let regions = (region AnyRegion:$region);
let successors = (successor AnySuccessor:$successor);
let builders = [
OpBuilder<(ins "Value":$range, "Block *":$successor, "bool":$initLoop)>
];
let extraClassDeclaration = [{
/// Returns the loop variable.
BlockArgument getLoopVariable() { return region().getArgument(0); }
}];
let parser = [{ return ::parseForEachOp(parser, result); }];
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeOp
//===----------------------------------------------------------------------===//
@ -750,6 +851,42 @@ def PDLInterp_GetResultsOp : PDLInterp_Op<"get_results", [NoSideEffect]> {
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetUsersOp
//===----------------------------------------------------------------------===//
def PDLInterp_GetUsersOp
: PDLInterp_Op<"get_users", [NoSideEffect]> {
let summary = "Get the users of a `Value`";
let description = [{
`pdl_interp.get_users` extracts the users that accept this value. In the
case of a range, the union of users of the all the values are returned,
similarly to ResultRange::getUsers.
Example:
```mlir
// Get all the users of a single value.
%ops = pdl_interp.get_users of %value : !pdl.value
// Get all the users of the first value in a range.
%ops = pdl_interp.get_users of %values : !pdl.range<value>
```
}];
let arguments = (ins PDL_InstOrRangeOf<PDL_Value>:$value);
let results = (outs PDL_RangeOf<PDL_Operation>:$operations);
let assemblyFormat = "`of` $value `:` type($value) attr-dict";
let builders = [
OpBuilder<(ins "Value":$value), [{
build($_builder, $_state,
pdl::RangeType::get($_builder.getType<pdl::OperationType>()),
value);
}]>,
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//

View File

@ -65,6 +65,85 @@ static void printCreateOperationOpAttributes(OpAsmPrinter &p,
p << '}';
}
//===----------------------------------------------------------------------===//
// pdl_interp::ForEachOp
//===----------------------------------------------------------------------===//
void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
Value range, Block *successor, bool initLoop) {
build(builder, state, range, successor);
if (initLoop) {
// Create the block and the loop variable.
auto range_type = range.getType().cast<pdl::RangeType>();
state.regions.front()->emplaceBlock();
state.regions.front()->addArgument(range_type.getElementType());
}
}
static ParseResult parseForEachOp(OpAsmParser &parser, OperationState &result) {
// Parse the loop variable followed by type.
OpAsmParser::OperandType loopVariable;
Type loopVariableType;
if (parser.parseRegionArgument(loopVariable) ||
parser.parseColonType(loopVariableType))
return failure();
// Parse the "in" keyword.
if (parser.parseKeyword("in", " after loop variable"))
return failure();
// Parse the operand (value range).
OpAsmParser::OperandType operandInfo;
if (parser.parseOperand(operandInfo))
return failure();
// Resolve the operand.
Type rangeType = pdl::RangeType::get(loopVariableType);
if (parser.resolveOperand(operandInfo, rangeType, result.operands))
return failure();
// Parse the body region.
Region *body = result.addRegion();
if (parser.parseRegion(*body, {loopVariable}, {loopVariableType}))
return failure();
// Parse the attribute dictionary.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the successor.
Block *successor;
if (parser.parseArrow() || parser.parseSuccessor(successor))
return failure();
result.addSuccessors(successor);
return success();
}
static void print(OpAsmPrinter &p, ForEachOp op) {
BlockArgument arg = op.getLoopVariable();
p << ' ' << arg << " : " << arg.getType() << " in " << op.values();
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op->getAttrs());
p << " -> ";
p.printSuccessor(op.successor());
}
static LogicalResult verify(ForEachOp op) {
// Verify that the operation has exactly one argument.
if (op.region().getNumArguments() != 1)
return op.emitOpError("requires exactly one argument");
// Verify that the loop variable and the operand (value range)
// have compatible types.
BlockArgument arg = op.getLoopVariable();
Type rangeType = pdl::RangeType::get(arg.getType());
if (rangeType != op.values().getType())
return op.emitOpError("operand must be a range of loop variable type");
return success();
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//

View File

@ -23,3 +23,46 @@ func @operations(%attribute: !pdl.attribute,
pdl_interp.finalize
}
// -----
func @extract(%attrs : !pdl.range<attribute>, %ops : !pdl.range<operation>, %types : !pdl.range<type>, %vals: !pdl.range<value>) {
// attribute at index 0
%attr = pdl_interp.extract 0 of %attrs : !pdl.attribute
// operation at index 1
%op = pdl_interp.extract 1 of %ops : !pdl.operation
// type at index 2
%type = pdl_interp.extract 2 of %types : !pdl.type
// value at index 3
%val = pdl_interp.extract 3 of %vals : !pdl.value
pdl_interp.finalize
}
// -----
func @foreach(%ops: !pdl.range<operation>) {
// iterate over a range of operations
pdl_interp.foreach %op : !pdl.operation in %ops {
%val = pdl_interp.get_result 0 of %op
pdl_interp.continue
} -> ^end
^end:
pdl_interp.finalize
}
// -----
func @users(%value: !pdl.value, %values: !pdl.range<value>) {
// all the users of a single value
%ops1 = pdl_interp.get_users of %value : !pdl.value
// all the users of all the values in a range
%ops2 = pdl_interp.get_users of %values : !pdl.range<value>
pdl_interp.finalize
}