[mlir][PDL] Add a PDL Interpreter Dialect

The PDL Interpreter dialect provides a lower level abstraction compared to the PDL dialect, and is targeted towards low level optimization and interpreter code generation. The dialect operations encapsulates low-level pattern match and rewrite "primitives", such as navigating the IR (Operation::getOperand), creating new operations (OpBuilder::create), etc. Many of the operations within this dialect also fuse branching control flow with some form of a predicate comparison operation. This type of fusion reduces the amount of work that an interpreter must do when executing.

An example of this representation is shown below:

```mlir
// The following high level PDL pattern:
pdl.pattern : benefit(1) {
  %resultType = pdl.type
  %inputOperand = pdl.input
  %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
  pdl.rewrite %root {
    pdl.replace %root with (%inputOperand)
  }
}

// May be represented in the interpreter dialect as follows:
module {
  func @matcher(%arg0: !pdl.operation) {
    pdl_interp.check_operation_name of %arg0 is "foo.op" -> ^bb2, ^bb1
  ^bb1:
    pdl_interp.return
  ^bb2:
    pdl_interp.check_operand_count of %arg0 is 1 -> ^bb3, ^bb1
  ^bb3:
    pdl_interp.check_result_count of %arg0 is 1 -> ^bb4, ^bb1
  ^bb4:
    %0 = pdl_interp.get_operand 0 of %arg0
    pdl_interp.is_not_null %0 : !pdl.value -> ^bb5, ^bb1
  ^bb5:
    %1 = pdl_interp.get_result 0 of %arg0
    pdl_interp.is_not_null %1 : !pdl.value -> ^bb6, ^bb1
  ^bb6:
    pdl_interp.record_match @rewriters::@rewriter(%0, %arg0 : !pdl.value, !pdl.operation) : benefit(1), loc([%arg0]), root("foo.op") -> ^bb1
  }
  module @rewriters {
    func @rewriter(%arg0: !pdl.value, %arg1: !pdl.operation) {
      pdl_interp.replace %arg1 with(%arg0)
      pdl_interp.return
    }
  }
}
```

Differential Revision: https://reviews.llvm.org/D84579
This commit is contained in:
River Riddle 2020-08-26 05:12:07 -07:00
parent 92c527e5a2
commit d289a97f91
25 changed files with 1387 additions and 69 deletions

View File

@ -6,6 +6,7 @@ add_subdirectory(LLVMIR)
add_subdirectory(OpenACC) add_subdirectory(OpenACC)
add_subdirectory(OpenMP) add_subdirectory(OpenMP)
add_subdirectory(PDL) add_subdirectory(PDL)
add_subdirectory(PDLInterp)
add_subdirectory(Quant) add_subdirectory(Quant)
add_subdirectory(SCF) add_subdirectory(SCF)
add_subdirectory(Shape) add_subdirectory(Shape)

View File

@ -49,7 +49,7 @@ def PDL_Dialect : Dialect {
%resultType = pdl.type %resultType = pdl.type
%inputOperand = pdl.input %inputOperand = pdl.input
%root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
pdl.rewrite(%root) { pdl.rewrite %root {
pdl.replace %root with (%inputOperand) pdl.replace %root with (%inputOperand)
} }
} }

View File

@ -51,17 +51,18 @@ def PDL_ApplyConstraintOp
``` ```
}]; }];
let arguments = (ins Variadic<PDL_PositionalValue>:$args, let arguments = (ins StrAttr:$name,
ArrayAttr:$params, Variadic<PDL_PositionalValue>:$args,
StrAttr:$name); OptionalAttr<ArrayAttr>:$constParams);
let assemblyFormat = "$name $params `(` $args `:` type($args) `)` attr-dict"; let assemblyFormat = [{
$name ($constParams^)? `(` $args `:` type($args) `)` attr-dict
}];
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, " OpBuilder<"OpBuilder &builder, OperationState &state, StringRef name, "
"ValueRange args, ArrayRef<Attribute> params, " "ValueRange args = {}, ArrayRef<Attribute> params = {}", [{
"StringRef name", [{ build(builder, state, builder.getStringAttr(name), args,
build(builder, state, args, builder.getArrayAttr(params), params.empty() ? ArrayAttr() : builder.getArrayAttr(params));
builder.getStringAttr(name));
}]>, }]>,
]; ];
} }
@ -135,12 +136,13 @@ def PDL_CreateNativeOp
``` ```
}]; }];
let arguments = (ins StrAttr:$name, Variadic<PDL_PositionalValue>:$arguments, let arguments = (ins StrAttr:$name,
ArrayAttr:$constantParams); Variadic<PDL_PositionalValue>:$args,
OptionalAttr<ArrayAttr>:$constParams);
let results = (outs PDL_PositionalValue:$result); let results = (outs PDL_PositionalValue:$result);
let assemblyFormat = [{ let assemblyFormat = [{
$name $constantParams (`(` $arguments^ `:` type($arguments) `)`)? $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
`:` type($result) attr-dict attr-dict
}]; }];
let verifier = ?; let verifier = ?;
} }
@ -222,7 +224,7 @@ def PDL_OperationOp
`pdl.operation`s are composed of a name, and a set of attribute, operand, `pdl.operation`s are composed of a name, and a set of attribute, operand,
and result type values, that map to what those that would be on a and result type values, that map to what those that would be on a
constructed instance of that operation. The results of a `pdl.operation` are constructed instance of that operation. The results of a `pdl.operation` are
a handle to the operation itself, and a handle to each of the operation a handle to the operation itself, and a handle to each of the operation
result values. result values.
When used within a matching context, the name of the operation may be When used within a matching context, the name of the operation may be
@ -380,16 +382,18 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
rewrite is specified either via a string name (`name`) to an external rewrite is specified either via a string name (`name`) to an external
rewrite function, or via the region body. The rewrite region, if specified, rewrite function, or via the region body. The rewrite region, if specified,
must contain a single block and terminate via the `pdl.rewrite_end` must contain a single block and terminate via the `pdl.rewrite_end`
operation. operation. If the rewrite is external, it also takes a set of constant
parameters and a set of additional positional values defined within the
matcher as arguments.
Example: Example:
```mlir ```mlir
// Specify an external rewrite function: // Specify an external rewrite function:
pdl.rewrite "myExternalRewriter"(%root) pdl.rewrite %root with "myExternalRewriter"(%value : !pdl.value)
// Specify the rewrite inline using PDL: // Specify the rewrite inline using PDL:
pdl.rewrite(%root) { pdl.rewrite %root {
%op = pdl.operation "foo.op"(%arg0, %arg1) %op = pdl.operation "foo.op"(%arg0, %arg1)
pdl.replace %root with %op pdl.replace %root with %op
} }
@ -397,7 +401,9 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
}]; }];
let arguments = (ins PDL_Operation:$root, let arguments = (ins PDL_Operation:$root,
OptionalAttr<StrAttr>:$name); OptionalAttr<StrAttr>:$name,
Variadic<PDL_PositionalValue>:$externalArgs,
OptionalAttr<ArrayAttr>:$externalConstParams);
let regions = (region AnyRegion:$body); let regions = (region AnyRegion:$body);
} }

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,2 @@
add_mlir_dialect(PDLInterpOps pdl_interp)
add_mlir_doc(PDLInterpOps -gen-op-doc PDLInterpOps Dialects/)

View File

@ -0,0 +1,39 @@
//===- PDLInterp.h - PDL Interpreter dialect --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the interpreter dialect for the PDL pattern descriptor
// language.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
namespace pdl_interp {
//===----------------------------------------------------------------------===//
// PDLInterp Dialect
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// PDLInterp Dialect Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.h.inc"
} // end namespace pdl_interp
} // end namespace mlir
#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_

View File

@ -0,0 +1,926 @@
//===- PDLInterpOps.td - Pattern Interpreter Dialect -------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the PDL interpreter dialect ops.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
include "mlir/Dialect/PDL/IR/PDLBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// PDLInterp Dialect
//===----------------------------------------------------------------------===//
def PDLInterp_Dialect : Dialect {
let summary = "Interpreted pattern execution dialect";
let description = [{
The PDL Interpreter dialect provides a lower level abstraction compared to
the PDL dialect, and is targeted towards low level optimization and
interpreter code generation. The dialect operations encapsulates
low-level pattern match and rewrite "primitives", such as navigating the
IR (Operation::getOperand), creating new operations (OpBuilder::create),
etc. Many of the operations within this dialect also fuse branching control
flow with some form of a predicate comparison operation. This type of fusion
reduces the amount of work that an interpreter must do when executing.
}];
let name = "pdl_interp";
let cppNamespace = "mlir::pdl_interp";
let dependentDialects = ["pdl::PDLDialect"];
}
//===----------------------------------------------------------------------===//
// PDLInterp Operations
//===----------------------------------------------------------------------===//
// Generic interpreter operation.
class PDLInterp_Op<string mnemonic, list<OpTrait> traits = []> :
Op<PDLInterp_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// PDLInterp_PredicateOp
// Check operations evaluate a predicate on a positional value and then
// conditionally branch on the result.
class PDLInterp_PredicateOp<string mnemonic, list<OpTrait> traits = []> :
PDLInterp_Op<mnemonic, !listconcat([Terminator], traits)> {
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
}
//===----------------------------------------------------------------------===//
// PDLInterp_SwitchOp
// Switch operations evaluate a predicate on a positional value and then
// conditionally branch on the result.
class PDLInterp_SwitchOp<string mnemonic, list<OpTrait> traits = []> :
PDLInterp_Op<mnemonic, !listconcat([Terminator], traits)> {
let successors = (successor AnySuccessor:$defaultDest,
VariadicSuccessor<AnySuccessor>:$cases);
let verifier = [{
// Verify that the number of case destinations matches the number of case
// values.
size_t numDests = cases().size();
size_t numValues = caseValues().size();
if (numDests != numValues) {
return emitOpError("expected number of cases to match the number of case "
"values, got ")
<< numDests << " but expected " << numValues;
}
return success();
}];
}
//===----------------------------------------------------------------------===//
// pdl_interp::ApplyConstraintOp
//===----------------------------------------------------------------------===//
def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let summary = "Apply a constraint to a set of positional values";
let description = [{
`pdl_interp.apply_constraint` operations apply a generic constraint, that
has been registered with the interpreter, with a given set of positional
values. The constraint may have any number of constant parameters. On
success, this operation branches to the true destination, otherwise the
false destination is taken.
Example:
```mlir
// Apply `myConstraint` to the entities defined by `input`, `attr`, and
// `op`.
pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
```
}];
let arguments = (ins StrAttr:$name,
Variadic<PDL_PositionalValue>:$args,
OptionalAttr<ArrayAttr>:$constParams);
let assemblyFormat = [{
$name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->`
successors
}];
}
//===----------------------------------------------------------------------===//
// pdl_interp::ApplyRewriteOp
//===----------------------------------------------------------------------===//
def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
let summary = "Invoke and apply an externally registered rewrite method";
let description = [{
`pdl_interp.apply_rewrite` operations invoke an external rewriter that has
been registered with the interpreter to perform the rewrite after a
successful match. The rewrite is passed the root operation being matched, a
set of additional positional arguments generated within the matcher, and a
set of constant parameters.
Example:
```mlir
// Rewriter operating solely on the root operation.
pdl_interp.apply_rewrite "rewriter" on %root
// Rewriter operating on the root operation along with additional arguments
// from the matcher.
pdl_interp.apply_rewrite "rewriter"(%value : !pdl.value) on %root
// Rewriter operating on the root operation along with additional arguments
// and constant parameters.
pdl_interp.apply_rewrite "rewriter"[42](%value : !pdl.value) on %root
```
}];
let arguments = (ins StrAttr:$name,
PDL_Operation:$root,
Variadic<PDL_PositionalValue>:$args,
OptionalAttr<ArrayAttr>:$constParams);
let assemblyFormat = [{
$name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// pdl_interp::AreEqualOp
//===----------------------------------------------------------------------===//
def PDLInterp_AreEqualOp
: PDLInterp_PredicateOp<"are_equal", [NoSideEffect, SameTypeOperands]> {
let summary = "Check if two positional values are equivalent";
let description = [{
`pdl_interp.are_equal` operations compare two positional values for
equality. On success, this operation branches to the true destination,
otherwise the false destination is taken.
Example:
```mlir
pdl_interp.are_equal %result1, %result2 : !pdl.value -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_PositionalValue:$lhs,
PDL_PositionalValue:$rhs);
let assemblyFormat = "operands `:` type($lhs) attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::BranchOp
//===----------------------------------------------------------------------===//
def PDLInterp_BranchOp : PDLInterp_Op<"branch", [NoSideEffect, Terminator]> {
let summary = "General branch operation";
let description = [{
`pdl_interp.branch` operations expose general branch functionality to the
interpreter, and are generally used to branch from one pattern match
sequence to another.
Example:
```mlir
pdl_interp.branch ^dest
```
}];
let successors = (successor AnySuccessor:$dest);
let assemblyFormat = "$dest attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::CheckAttributeOp
//===----------------------------------------------------------------------===//
def PDLInterp_CheckAttributeOp
: PDLInterp_PredicateOp<"check_attribute", [NoSideEffect]> {
let summary = "Check the value of an `Attribute`";
let description = [{
`pdl_interp.check_attribute` operations compare the value of a given
attribute with a constant value. On success, this operation branches to the
true destination, otherwise the false destination is taken.
Example:
```mlir
pdl_interp.check_attribute %attr is 10 -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_Attribute:$attribute, AnyAttr:$constantValue);
let assemblyFormat = [{
$attribute `is` $constantValue attr-dict `->` successors
}];
}
//===----------------------------------------------------------------------===//
// pdl_interp::CheckOperandCountOp
//===----------------------------------------------------------------------===//
def PDLInterp_CheckOperandCountOp
: PDLInterp_PredicateOp<"check_operand_count", [NoSideEffect]> {
let summary = "Check the number of operands of an `Operation`";
let description = [{
`pdl_interp.check_operand_count` operations compare the number of operands
of a given operation value with a constant. On success, this operation
branches to the true destination, otherwise the false destination is taken.
Example:
```mlir
pdl_interp.check_operand_count of %op is 2 -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_Operation:$operation,
Confined<I32Attr, [IntNonNegative]>:$count);
let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::CheckOperationNameOp
//===----------------------------------------------------------------------===//
def PDLInterp_CheckOperationNameOp
: PDLInterp_PredicateOp<"check_operation_name", [NoSideEffect]> {
let summary = "Check the OperationName of an `Operation`";
let description = [{
`pdl_interp.check_operation_name` operations compare the name of a given
operation with a known name. On success, this operation branches to the true
destination, otherwise the false destination is taken.
Example:
```mlir
pdl_interp.check_operation_name of %op is "foo.op" -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_Operation:$operation, StrAttr:$name);
let assemblyFormat = "`of` $operation `is` $name attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::CheckResultCountOp
//===----------------------------------------------------------------------===//
def PDLInterp_CheckResultCountOp
: PDLInterp_PredicateOp<"check_result_count", [NoSideEffect]> {
let summary = "Check the number of results of an `Operation`";
let description = [{
`pdl_interp.check_result_count` operations compare the number of results
of a given operation value with a constant. On success, this operation
branches to the true destination, otherwise the false destination is taken.
Example:
```mlir
pdl_interp.check_result_count of %op is 0 -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_Operation:$operation,
Confined<I32Attr, [IntNonNegative]>:$count);
let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::CheckTypeOp
//===----------------------------------------------------------------------===//
def PDLInterp_CheckTypeOp
: PDLInterp_PredicateOp<"check_type", [NoSideEffect]> {
let summary = "Compare a type to a known value";
let description = [{
`pdl_interp.check_type` operations compare a type with a statically known
type. On success, this operation branches to the true destination, otherwise
the false destination is taken.
Example:
```mlir
pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_Type:$value, TypeAttr:$type);
let assemblyFormat = "$value `is` $type attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::CreateAttributeOp
//===----------------------------------------------------------------------===//
def PDLInterp_CreateAttributeOp
: PDLInterp_Op<"create_attribute", [NoSideEffect]> {
let summary = "Create an interpreter handle to a constant `Attribute`";
let description = [{
`pdl_interp.create_attribute` operations generate a handle within the
interpreter for a specific constant attribute value.
Example:
```mlir
pdl_interp.create_attribute 10 : i64
```
}];
let arguments = (ins AnyAttr:$value);
let results = (outs PDL_Attribute:$attribute);
let assemblyFormat = "$value attr-dict";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, "
"Attribute value", [{
build(builder, state, builder.getType<pdl::AttributeType>(), value);
}]>];
}
//===----------------------------------------------------------------------===//
// pdl_interp::CreateNativeOp
//===----------------------------------------------------------------------===//
def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
let summary = "Call a native creation method to construct an `Attribute`, "
"`Operation`, `Type`, or `Value`";
let description = [{
`pdl_interp.create_native` operations invoke a native C++ function, that has
been registered externally with the consumer of PDL, to create an
`Attribute`, `Operation`, `Type`, or `Value`. The native function must
produce a value of the specified return type, and may accept any number of
positional arguments and constant attribute parameters.
Example:
```mlir
%ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
```
}];
let arguments = (ins StrAttr:$name,
Variadic<PDL_PositionalValue>:$args,
OptionalAttr<ArrayAttr>:$constParams);
let results = (outs PDL_PositionalValue:$result);
let assemblyFormat = [{
$name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
attr-dict
}];
let verifier = ?;
}
//===----------------------------------------------------------------------===//
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
def PDLInterp_CreateOperationOp
: PDLInterp_Op<"create_operation", [AttrSizedOperandSegments]> {
let summary = "Create an instance of a specific `Operation`";
let description = [{
`pdl_interp.create_operation` operations create an `Operation` instance with
the specified attributes, operands, and result types.
Example:
```mlir
// Create an instance of a `foo.op` operation.
%op = pdl_interp.create_operation "foo.op"(%arg0) {"attrA" = %attr0} -> %type, %type
```
}];
let arguments = (ins StrAttr:$name,
Variadic<PDL_Value>:$operands,
Variadic<PDL_Attribute>:$attributes,
StrArrayAttr:$attributeNames,
Variadic<PDL_Type>:$types);
let results = (outs PDL_Operation:$operation);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, StringRef name, "
"ValueRange types, ValueRange operands, ValueRange attributes, "
"ArrayAttr attributeNames", [{
build(builder, state, builder.getType<pdl::OperationType>(), name,
operands, attributes, attributeNames, types);
}]>];
let parser = [{ return ::parseCreateOperationOp(parser, result); }];
let printer = [{ ::print(p, *this); }];
}
//===----------------------------------------------------------------------===//
// pdl_interp::CreateTypeOp
//===----------------------------------------------------------------------===//
def PDLInterp_CreateTypeOp : PDLInterp_Op<"create_type", [NoSideEffect]> {
let summary = "Create an interpreter handle to a constant `Type`";
let description = [{
`pdl_interp.create_type` operations generate a handle within the interpreter
for a specific constant type value.
Example:
```mlir
pdl_interp.create_type i64
```
}];
let arguments = (ins TypeAttr:$value);
let results = (outs PDL_Type:$result);
let assemblyFormat = "$value attr-dict";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, TypeAttr type", [{
build(builder, state, builder.getType<pdl::TypeType>(), type);
}]>
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::EraseOp
//===----------------------------------------------------------------------===//
def PDLInterp_EraseOp : PDLInterp_Op<"erase"> {
let summary = "Mark an operation as `erased`";
let description = [{
`pdl.erase` operations are used to specify that an operation should be
marked as erased. The semantics of this operation correspond with the
`eraseOp` method on a `PatternRewriter`.
Example:
```mlir
pdl_interp.erase %root
```
}];
let arguments = (ins PDL_Operation:$operation);
let assemblyFormat = "$operation attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::FinalizeOp
//===----------------------------------------------------------------------===//
def PDLInterp_FinalizeOp
: PDLInterp_Op<"finalize", [NoSideEffect, Terminator]> {
let summary = "Finalize a pattern match or rewrite sequence";
let description = [{
`pdl_interp.finalize` is used to denote the termination of a match or
rewrite sequence.
Example:
```mlir
pdl_interp.finalize
```
}];
let assemblyFormat = "attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeOp
//===----------------------------------------------------------------------===//
def PDLInterp_GetAttributeOp : PDLInterp_Op<"get_attribute", [NoSideEffect]> {
let summary = "Get a specified attribute value from an `Operation`";
let description = [{
`pdl_interp.get_attribute` operations try to get a specific attribute from
an operation. If the operation does not have that attribute, a null value is
returned.
Example:
```mlir
%attr = pdl_interp.get_attribute "attr" of %op
```
}];
let arguments = (ins PDL_Operation:$operation,
StrAttr:$name);
let results = (outs PDL_Attribute:$attribute);
let assemblyFormat = "$name `of` $operation attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeTypeOp
//===----------------------------------------------------------------------===//
def PDLInterp_GetAttributeTypeOp
: PDLInterp_Op<"get_attribute_type", [NoSideEffect]> {
let summary = "Get the result type of a specified `Attribute`";
let description = [{
`pdl_interp.get_attribute_type` operations get the resulting type of a
specific attribute.
Example:
```mlir
%type = pdl_interp.get_attribute_type of %attr
```
}];
let arguments = (ins PDL_Attribute:$value);
let results = (outs PDL_Type:$result);
let assemblyFormat = "`of` $value attr-dict";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{
build(builder, state, builder.getType<pdl::TypeType>(), value);
}]>
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetDefiningOpOp
//===----------------------------------------------------------------------===//
def PDLInterp_GetDefiningOpOp
: PDLInterp_Op<"get_defining_op", [NoSideEffect]> {
let summary = "Get the defining operation of a `Value`";
let description = [{
`pdl_interp.get_defining_op` operations try to get the defining operation
of a specific value. If the value is not an operation result, null is
returned.
Example:
```mlir
%op = pdl_interp.get_defining_op of %value
```
}];
let arguments = (ins PDL_Value:$value);
let results = (outs PDL_Operation:$operation);
let assemblyFormat = "`of` $value attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetOperandOp
//===----------------------------------------------------------------------===//
def PDLInterp_GetOperandOp : PDLInterp_Op<"get_operand", [NoSideEffect]> {
let summary = "Get a specified operand from an `Operation`";
let description = [{
`pdl_interp.get_operand` operations try to get a specific operand from an
operation If the operation does not have an operand for the given index, a
null value is returned.
Example:
```mlir
%operand = pdl_interp.get_operand 1 of %op
```
}];
let arguments = (ins PDL_Operation:$operation,
Confined<I32Attr, [IntNonNegative]>:$index);
let results = (outs PDL_Value:$value);
let assemblyFormat = "$index `of` $operation attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetResultOp
//===----------------------------------------------------------------------===//
def PDLInterp_GetResultOp : PDLInterp_Op<"get_result", [NoSideEffect]> {
let summary = "Get a specified result from an `Operation`";
let description = [{
`pdl_interp.get_result` operations try to get a specific result from an
operation. If the operation does not have a result for the given index, a
null value is returned.
Example:
```mlir
%result = pdl_interp.get_result 1 of %op
```
}];
let arguments = (ins PDL_Operation:$operation,
Confined<I32Attr, [IntNonNegative]>:$index);
let results = (outs PDL_Value:$value);
let assemblyFormat = "$index `of` $operation attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//
// Get a type from the root operation, held in the rewriter context.
def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect]> {
let summary = "Get the result type of a specified `Value`";
let description = [{
`pdl_interp.get_value_type` operations get the resulting type of a specific
value.
Example:
```mlir
%type = pdl_interp.get_value_type of %value
```
}];
let arguments = (ins PDL_Value:$value);
let results = (outs PDL_Type:$result);
let assemblyFormat = "`of` $value attr-dict";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{
build(builder, state, builder.getType<pdl::TypeType>(), value);
}]>
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::InferredTypeOp
//===----------------------------------------------------------------------===//
def PDLInterp_InferredTypeOp : PDLInterp_Op<"inferred_type"> {
let summary = "Generate a handle to a Type that is \"inferred\"";
let description = [{
`pdl_interp.inferred_type` operations generate a handle to a type that
should be inferred. This signals to other operations, such as
`pdl_interp.create_operation`, that this type should be inferred.
Example:
```mlir
pdl_interp.inferred_type
```
}];
let results = (outs PDL_Type:$type);
let assemblyFormat = "attr-dict";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state", [{
build(builder, state, builder.getType<pdl::TypeType>());
}]>,
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::IsNotNullOp
//===----------------------------------------------------------------------===//
def PDLInterp_IsNotNullOp
: PDLInterp_PredicateOp<"is_not_null", [NoSideEffect]> {
let summary = "Check if a positional value is non-null";
let description = [{
`pdl_interp.is_not_null` operations check that a positional value exists. On
success, this operation branches to the true destination. Otherwise, the
false destination is taken.
Example:
```mlir
pdl_interp.is_not_null %value : !pdl.value -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_PositionalValue:$value);
let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
}
//===----------------------------------------------------------------------===//
// pdl_interp::RecordMatchOp
//===----------------------------------------------------------------------===//
def PDLInterp_RecordMatchOp
: PDLInterp_Op<"record_match", [AttrSizedOperandSegments, Terminator]> {
let summary = "Record the metadata for a successful pattern match";
let description = [{
`pdl_interp.record_match` operations record a successful pattern match with
the interpreter and branch to the next part of the matcher. The metadata
recorded by these operations correspond to a specific `pdl.pattern`, as well
as what values were used during that match that should be propagated to the
rewriter.
Example:
```mlir
pdl_interp.record_match @rewriters::myRewriter(%root : !pdl.operation) : benefit(1), loc([%root, %op1]), root("foo.op") -> ^nextDest
```
}];
let arguments = (ins Variadic<PDL_PositionalValue>:$inputs,
Variadic<PDL_Operation>:$matchedOps,
SymbolRefAttr:$rewriter,
OptionalAttr<StrAttr>:$rootKind,
OptionalAttr<StrArrayAttr>:$generatedOps,
Confined<I16Attr, [IntNonNegative]>:$benefit);
let successors = (successor AnySuccessor:$dest);
let assemblyFormat = [{
$rewriter (`(` $inputs^ `:` type($inputs) `)`)? `:`
`benefit` `(` $benefit `)` `,`
(`generatedOps` `(` $generatedOps^ `)` `,`)?
`loc` `(` `[` $matchedOps `]` `)`
(`,` `root` `(` $rootKind^ `)`)? attr-dict `->` $dest
}];
}
//===----------------------------------------------------------------------===//
// pdl_interp::ReplaceOp
//===----------------------------------------------------------------------===//
def PDLInterp_ReplaceOp : PDLInterp_Op<"replace"> {
let summary = "Mark an operation as `replace`d";
let description = [{
`pdl_interp.replaced` operations are used to specify that an operation
should be marked as replaced. The semantics of this operation correspond
with the `replaceOp` method on a `PatternRewriter`. The set of replacement
values must match the number of results specified by the operation.
Example:
```mlir
// Replace root node with 2 values:
pdl_interp.replace %root with (%val0, %val1)
```
}];
let arguments = (ins PDL_Operation:$operation,
Variadic<PDL_Value>:$replValues);
let assemblyFormat = "$operation `with` `(` $replValues `)` attr-dict";
}
//===----------------------------------------------------------------------===//
// pdl_interp::SwitchAttributeOp
//===----------------------------------------------------------------------===//
def PDLInterp_SwitchAttributeOp
: PDLInterp_SwitchOp<"switch_attribute", [NoSideEffect]> {
let summary = "Switch on the value of an `Attribute`";
let description = [{
`pdl_interp.switch_attribute` operations compare the value of a given
attribute with a set of constant attributes. If the value matches one of the
provided case values the destination for that case value is taken, otherwise
the default destination is taken.
Example:
```mlir
pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest
```
}];
let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues);
let assemblyFormat = [{
$attribute `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value attribute,"
"ArrayRef<Attribute> caseValues,"
"Block *defaultDest, ArrayRef<Block *> dests", [{
build(builder, state, attribute, builder.getArrayAttr(caseValues),
defaultDest, dests);
}]>];
}
//===----------------------------------------------------------------------===//
// pdl_interp::SwitchOperandCountOp
//===----------------------------------------------------------------------===//
def PDLInterp_SwitchOperandCountOp
: PDLInterp_SwitchOp<"switch_operand_count", [NoSideEffect]> {
let summary = "Switch on the operand count of an `Operation`";
let description = [{
`pdl_interp.switch_operand_count` operations compare the operand count of a
given operation with a set of potential counts. If the value matches one of
the provided case values the destination for that case value is taken,
otherwise the default destination is taken.
Example:
```mlir
pdl_interp.switch_operand_count of %op to [10, 2] -> ^10Dest, ^2Dest, ^defaultDest
```
}];
let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues);
let assemblyFormat = [{
`of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
"ArrayRef<int32_t> counts, Block *defaultDest, "
"ArrayRef<Block *> dests", [{
build(builder, state, operation, builder.getI32VectorAttr(counts),
defaultDest, dests);
}]>];
}
//===----------------------------------------------------------------------===//
// pdl_interp::SwitchOperationNameOp
//===----------------------------------------------------------------------===//
def PDLInterp_SwitchOperationNameOp
: PDLInterp_SwitchOp<"switch_operation_name", [NoSideEffect]> {
let summary = "Switch on the OperationName of an `Operation`";
let description = [{
`pdl_interp.switch_operation_name` operations compare the name of a given
operation with a set of known names. If the value matches one of the
provided case values the destination for that case value is taken, otherwise
the default destination is taken.
Example:
```mlir
pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest
```
}];
let arguments = (ins PDL_Operation:$operation,
StrArrayAttr:$caseValues);
let assemblyFormat = [{
`of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
"ArrayRef<OperationName> names, "
"Block *defaultDest, ArrayRef<Block *> dests", [{
auto stringNames = llvm::to_vector<8>(llvm::map_range(names,
[](OperationName name) { return name.getStringRef(); }));
build(builder, state, operation, builder.getStrArrayAttr(stringNames),
defaultDest, dests);
}]>,
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::SwitchResultCountOp
//===----------------------------------------------------------------------===//
def PDLInterp_SwitchResultCountOp
: PDLInterp_SwitchOp<"switch_result_count", [NoSideEffect]> {
let summary = "Switch on the result count of an `Operation`";
let description = [{
`pdl_interp.switch_result_count` operations compare the result count of a
given operation with a set of potential counts. If the value matches one of
the provided case values the destination for that case value is taken,
otherwise the default destination is taken.
Example:
```mlir
pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest
```
}];
let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues);
let assemblyFormat = [{
`of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
"ArrayRef<int32_t> counts, Block *defaultDest, "
"ArrayRef<Block *> dests", [{
build(builder, state, operation, builder.getI32VectorAttr(counts),
defaultDest, dests);
}]>];
}
//===----------------------------------------------------------------------===//
// pdl_interp::SwitchTypeOp
//===----------------------------------------------------------------------===//
def PDLInterp_SwitchTypeOp : PDLInterp_SwitchOp<"switch_type", [NoSideEffect]> {
let summary = "Switch on a `Type` value";
let description = [{
`pdl_interp.switch_type` operations compare a type with a set of statically
known types. If the value matches one of the provided case values the
destination for that case value is taken, otherwise the default destination
is taken.
Example:
```mlir
pdl_interp.switch_type %type to [i32, i64] -> ^i32Dest, ^i64Dest, ^defaultDest
```
}];
let arguments = (ins PDL_Type:$value, TypeArrayAttr:$caseValues);
let assemblyFormat = [{
$value `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
}];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value edge, "
"TypeRange types, Block *defaultDest, ArrayRef<Block *> dests", [{
build(builder, state, edge, builder.getTypeArrayAttr(types), defaultDest,
dests);
}]>,
];
let extraClassDeclaration = [{
auto getCaseTypes() { return caseValues().getAsValueRange<TypeAttr>(); }
}];
}
#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS

View File

@ -217,12 +217,12 @@ private:
public: public:
template <typename AttrTy> template <typename AttrTy>
llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() { iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
return llvm::make_range(attr_value_iterator<AttrTy>(begin()), return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
attr_value_iterator<AttrTy>(end())); attr_value_iterator<AttrTy>(end()));
} }
template <typename AttrTy, typename UnderlyingTy> template <typename AttrTy, typename UnderlyingTy = typename AttrTy::ValueType>
auto getAsRange() { auto getAsValueRange() {
return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) { return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
return static_cast<UnderlyingTy>(attr.getValue()); return static_cast<UnderlyingTy>(attr.getValue());
}); });
@ -589,6 +589,9 @@ public:
/// Returns the number of elements held by this attribute. /// Returns the number of elements held by this attribute.
int64_t getNumElements() const; int64_t getNumElements() const;
/// Returns the number of elements held by this attribute.
int64_t size() const { return getNumElements(); }
/// Generates a new ElementsAttr by mapping each int value to a new /// Generates a new ElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either an integer or float. /// underlying APInt. The new values can represent either an integer or float.
/// This ElementsAttr should contain integers. /// This ElementsAttr should contain integers.

View File

@ -139,6 +139,7 @@ public:
ArrayAttr getF32ArrayAttr(ArrayRef<float> values); ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
ArrayAttr getF64ArrayAttr(ArrayRef<double> values); ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values); ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
ArrayAttr getTypeArrayAttr(TypeRange values);
// Affine expressions and affine maps. // Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position); AffineExpr getAffineDimExpr(unsigned position);

View File

@ -426,6 +426,12 @@ public:
return parseOptionalAttribute(result, Type(), attrName, attrs); return parseOptionalAttribute(result, Type(), attrName, attrs);
} }
/// Specialized variants of `parseOptionalAttribute` that remove potential
/// ambiguities in syntax.
virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
StringRef attrName,
NamedAttrList &attrs) = 0;
/// Parse an arbitrary attribute of a given type and return it in result. This /// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified /// also adds the attribute to the specified attribute list with the specified
/// name. /// name.

View File

@ -25,6 +25,7 @@
#include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h" #include "mlir/Dialect/SDBM/SDBMDialect.h"
@ -49,6 +50,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
scf::SCFDialect, scf::SCFDialect,
omp::OpenMPDialect, omp::OpenMPDialect,
pdl::PDLDialect, pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
quant::QuantizationDialect, quant::QuantizationDialect,
spirv::SPIRVDialect, spirv::SPIRVDialect,
StandardOpsDialect, StandardOpsDialect,

View File

@ -6,6 +6,7 @@ add_subdirectory(LLVMIR)
add_subdirectory(OpenACC) add_subdirectory(OpenACC)
add_subdirectory(OpenMP) add_subdirectory(OpenMP)
add_subdirectory(PDL) add_subdirectory(PDL)
add_subdirectory(PDLInterp)
add_subdirectory(Quant) add_subdirectory(Quant)
add_subdirectory(SCF) add_subdirectory(SCF)
add_subdirectory(SDBM) add_subdirectory(SDBM)

View File

@ -76,9 +76,7 @@ static LogicalResult isContraction(Operation *op) {
if (!genericOp) if (!genericOp)
return failure(); return failure();
auto mapRange = auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>();
genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
return success( return success(
genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
llvm::all_of(mapRange, llvm::all_of(mapRange,

View File

@ -446,20 +446,39 @@ static LogicalResult verify(ReplaceOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) { static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) {
// If the first token isn't a '(', this is an external rewrite.
StringAttr nameAttr;
if (failed(p.parseOptionalLParen())) {
if (p.parseAttribute(nameAttr, "name", state.attributes) || p.parseLParen())
return failure();
}
// Parse the root operand. // Parse the root operand.
OpAsmParser::OperandType rootOperand; OpAsmParser::OperandType rootOperand;
if (p.parseOperand(rootOperand) || p.parseRParen() || if (p.parseOperand(rootOperand) ||
p.resolveOperand(rootOperand, p.getBuilder().getType<OperationType>(), p.resolveOperand(rootOperand, p.getBuilder().getType<OperationType>(),
state.operands)) state.operands))
return failure(); return failure();
// Parse an external rewrite.
StringAttr nameAttr;
if (succeeded(p.parseOptionalKeyword("with"))) {
if (p.parseAttribute(nameAttr, "name", state.attributes))
return failure();
// Parse the optional set of constant parameters.
ArrayAttr constantParams;
OptionalParseResult constantParamResult = p.parseOptionalAttribute(
constantParams, "externalConstParams", state.attributes);
if (constantParamResult.hasValue() && failed(*constantParamResult))
return failure();
// Parse the optional additional arguments.
if (succeeded(p.parseOptionalLParen())) {
SmallVector<OpAsmParser::OperandType, 4> arguments;
SmallVector<Type, 4> argumentTypes;
llvm::SMLoc argumentLoc = p.getCurrentLocation();
if (p.parseOperandList(arguments) ||
p.parseColonTypeList(argumentTypes) || p.parseRParen() ||
p.resolveOperands(arguments, argumentTypes, argumentLoc,
state.operands))
return failure();
}
}
// If this isn't an external rewrite, parse the region body. // If this isn't an external rewrite, parse the region body.
Region &rewriteRegion = *state.addRegion(); Region &rewriteRegion = *state.addRegion();
if (!nameAttr) { if (!nameAttr) {
@ -468,27 +487,58 @@ static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) {
return failure(); return failure();
RewriteOp::ensureTerminator(rewriteRegion, p.getBuilder(), state.location); RewriteOp::ensureTerminator(rewriteRegion, p.getBuilder(), state.location);
} }
return success();
return p.parseOptionalAttrDictWithKeyword(state.attributes);
} }
static void print(OpAsmPrinter &p, RewriteOp op) { static void print(OpAsmPrinter &p, RewriteOp op) {
p << "pdl.rewrite"; p << "pdl.rewrite " << op.root();
if (Optional<StringRef> name = op.name()) { if (Optional<StringRef> name = op.name()) {
p << " \"" << *name << "\"(" << op.root() << ")"; p << " with \"" << *name << "\"";
return;
if (ArrayAttr constantParams = op.externalConstParamsAttr())
p << constantParams;
OperandRange externalArgs = op.externalArgs();
if (!externalArgs.empty())
p << "(" << externalArgs << " : " << externalArgs.getTypes() << ")";
} else {
p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
} }
p << "(" << op.root() << ")"; p.printOptionalAttrDictWithKeyword(op.getAttrs(),
p.printRegion(op.body(), /*printEntryBlockArgs=*/false, {"name", "externalConstParams"});
/*printBlockTerminators=*/false);
} }
static LogicalResult verify(RewriteOp op) { static LogicalResult verify(RewriteOp op) {
Region &rewriteRegion = op.body(); Region &rewriteRegion = op.body();
if (llvm::hasNItemsOrMore(rewriteRegion, 2)) {
return op.emitOpError() // Handle the case where the rewrite is external.
<< "expected rewrite region when specified to have a single block"; if (op.name()) {
if (!rewriteRegion.empty()) {
return op.emitOpError()
<< "expected rewrite region to be empty when rewrite is external";
}
return success();
} }
// Otherwise, check that the rewrite region only contains a single block.
if (rewriteRegion.empty()) {
return op.emitOpError() << "expected rewrite region to be non-empty if "
"external name is not specified";
}
// Check that no additional arguments were provided.
if (!op.externalArgs().empty()) {
return op.emitOpError() << "expected no external arguments when the "
"rewrite is specified inline";
}
if (op.externalConstParams()) {
return op.emitOpError() << "expected no external constant parameters when "
"the rewrite is specified inline";
}
return success(); return success();
} }

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRPDLInterp
PDLInterp.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/PDLInterp
DEPENDS
MLIRPDLInterpOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRPDL
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
)

View File

@ -0,0 +1,122 @@
//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
using namespace mlir::pdl_interp;
//===----------------------------------------------------------------------===//
// PDLInterp Dialect
//===----------------------------------------------------------------------===//
void PDLInterpDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
static ParseResult parseCreateOperationOp(OpAsmParser &p,
OperationState &state) {
if (p.parseOptionalAttrDict(state.attributes))
return failure();
Builder &builder = p.getBuilder();
// Parse the operation name.
StringAttr opName;
if (p.parseAttribute(opName, "name", state.attributes))
return failure();
// Parse the operands.
SmallVector<OpAsmParser::OperandType, 4> operands;
if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() ||
p.resolveOperands(operands, builder.getType<pdl::ValueType>(),
state.operands))
return failure();
// Parse the attributes.
SmallVector<Attribute, 4> attrNames;
if (succeeded(p.parseOptionalLBrace())) {
SmallVector<OpAsmParser::OperandType, 4> attrOps;
do {
StringAttr nameAttr;
OpAsmParser::OperandType operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
attrOps.push_back(operand);
} while (succeeded(p.parseOptionalComma()));
if (p.parseRBrace() ||
p.resolveOperands(attrOps, builder.getType<pdl::AttributeType>(),
state.operands))
return failure();
}
state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
state.addTypes(builder.getType<pdl::OperationType>());
// Parse the result types.
SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
if (p.parseArrow())
return failure();
if (succeeded(p.parseOptionalLParen())) {
if (p.parseRParen())
return failure();
} else if (p.parseOperandList(opResultTypes) ||
p.resolveOperands(opResultTypes, builder.getType<pdl::TypeType>(),
state.operands)) {
return failure();
}
int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
static_cast<int32_t>(attrNames.size()),
static_cast<int32_t>(opResultTypes.size())};
state.addAttribute("operand_segment_sizes",
builder.getI32VectorAttr(operandSegmentSizes));
return success();
}
static void print(OpAsmPrinter &p, CreateOperationOp op) {
p << "pdl_interp.create_operation ";
p.printOptionalAttrDict(op.getAttrs(),
{"attributeNames", "name", "operand_segment_sizes"});
p << '"' << op.name() << "\"(" << op.operands() << ')';
// Emit the optional attributes.
ArrayAttr attrNames = op.attributeNames();
if (!attrNames.empty()) {
Operation::operand_range attrArgs = op.attributes();
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}
// Print the result type constraints of the operation.
auto types = op.types();
if (types.empty())
p << " -> ()";
else
p << " -> " << op.types();
}
//===----------------------------------------------------------------------===//
// TableGen Auto-Generated Op and Interface Definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"

View File

@ -261,6 +261,12 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
return getArrayAttr(attrs); return getArrayAttr(attrs);
} }
ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
return getArrayAttr(attrs);
}
ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) { ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range( auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); })); values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));

View File

@ -221,6 +221,9 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
return result; return result;
} }
} }
OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute) {
return parseOptionalAttributeWithToken(Token::l_square, attribute);
}
/// Attribute dictionary. /// Attribute dictionary.
/// ///

View File

@ -1045,15 +1045,37 @@ public:
} }
/// Parse an optional attribute. /// Parse an optional attribute.
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type, /// Template utilities to simplify specifying multiple derived overloads.
StringRef attrName, template <typename AttrT>
NamedAttrList &attrs) override { OptionalParseResult
parseOptionalAttributeAndAddToList(AttrT &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
OptionalParseResult parseResult = OptionalParseResult parseResult =
parser.parseOptionalAttribute(result, type); parser.parseOptionalAttribute(result, type);
if (parseResult.hasValue() && succeeded(*parseResult)) if (parseResult.hasValue() && succeeded(*parseResult))
attrs.push_back(parser.builder.getNamedAttr(attrName, result)); attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult; return parseResult;
} }
template <typename AttrT>
OptionalParseResult parseOptionalAttributeAndAddToList(AttrT &result,
StringRef attrName,
NamedAttrList &attrs) {
OptionalParseResult parseResult = parser.parseOptionalAttribute(result);
if (parseResult.hasValue() && succeeded(*parseResult))
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, attrName, attrs);
}
/// Parse a named dictionary into 'result' if it is present. /// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override { ParseResult parseOptionalAttrDict(NamedAttrList &result) override {

View File

@ -187,6 +187,22 @@ public:
/// Parse an optional attribute with the provided type. /// Parse an optional attribute with the provided type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute, OptionalParseResult parseOptionalAttribute(Attribute &attribute,
Type type = {}); Type type = {});
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute);
/// Parse an optional attribute that is demarcated by a specific token.
template <typename AttributeT>
OptionalParseResult parseOptionalAttributeWithToken(Token::Kind kind,
AttributeT &attr,
Type type = {}) {
if (getToken().isNot(kind))
return llvm::None;
if (Attribute parsedAttr = parseAttribute()) {
attr = parsedAttr.cast<ArrayAttr>();
return success();
}
return failure();
}
/// Parse an attribute dictionary. /// Parse an attribute dictionary.
ParseResult parseAttributeDict(NamedAttrList &attributes); ParseResult parseAttributeDict(NamedAttrList &attributes);

View File

@ -9,7 +9,7 @@ pdl.pattern : benefit(1) {
// expected-error@below {{expected at least one argument}} // expected-error@below {{expected at least one argument}}
"pdl.apply_constraint"() {name = "foo", params = []} : () -> () "pdl.apply_constraint"() {name = "foo", params = []} : () -> ()
pdl.rewrite "rewriter"(%op) pdl.rewrite %op with "rewriter"
} }
// ----- // -----
@ -25,14 +25,14 @@ pdl.pattern : benefit(1) {
%attr = pdl.attribute : %type 10 %attr = pdl.attribute : %type 10
%op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type %op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type
pdl.rewrite "rewriter"(%op) pdl.rewrite %op with "rewriter"
} }
// ----- // -----
pdl.pattern : benefit(1) { pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op" %op = pdl.operation "foo.op"
pdl.rewrite(%op) { pdl.rewrite %op {
%type = pdl.type %type = pdl.type
// expected-error@below {{expected constant value when specified within a `pdl.rewrite`}} // expected-error@below {{expected constant value when specified within a `pdl.rewrite`}}
@ -44,7 +44,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) { pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op" %op = pdl.operation "foo.op"
pdl.rewrite(%op) { pdl.rewrite %op {
// expected-error@below {{expected constant value when specified within a `pdl.rewrite`}} // expected-error@below {{expected constant value when specified within a `pdl.rewrite`}}
%attr = pdl.attribute %attr = pdl.attribute
} }
@ -57,7 +57,7 @@ pdl.pattern : benefit(1) {
%unused = pdl.attribute %unused = pdl.attribute
%op = pdl.operation "foo.op" %op = pdl.operation "foo.op"
pdl.rewrite "rewriter"(%op) pdl.rewrite %op with "rewriter"
} }
// ----- // -----
@ -71,7 +71,7 @@ pdl.pattern : benefit(1) {
%unused = pdl.input %unused = pdl.input
%op = pdl.operation "foo.op" %op = pdl.operation "foo.op"
pdl.rewrite "rewriter"(%op) pdl.rewrite %op with "rewriter"
} }
// ----- // -----
@ -82,7 +82,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) { pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op" %op = pdl.operation "foo.op"
pdl.rewrite(%op) { pdl.rewrite %op {
// expected-error@below {{must have an operation name when nested within a `pdl.rewrite`}} // expected-error@below {{must have an operation name when nested within a `pdl.rewrite`}}
%newOp = pdl.operation %newOp = pdl.operation
} }
@ -96,14 +96,14 @@ pdl.pattern : benefit(1) {
attributeNames = ["attr"], attributeNames = ["attr"],
operand_segment_sizes = dense<0> : vector<3xi32> operand_segment_sizes = dense<0> : vector<3xi32>
} : () -> (!pdl.operation) } : () -> (!pdl.operation)
pdl.rewrite "rewriter"(%op) pdl.rewrite %op with "rewriter"
} }
// ----- // -----
pdl.pattern : benefit(1) { pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"() %op = pdl.operation "foo.op"()
pdl.rewrite (%op) { pdl.rewrite %op {
%type = pdl.type %type = pdl.type
// expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}} // expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}}
@ -119,7 +119,7 @@ pdl.pattern : benefit(1) {
%unused = pdl.operation "foo.op" %unused = pdl.operation "foo.op"
%op = pdl.operation "foo.op" %op = pdl.operation "foo.op"
pdl.rewrite "rewriter"(%op) pdl.rewrite %op with "rewriter"
} }
// ----- // -----
@ -142,7 +142,7 @@ pdl.pattern : benefit(1) {
"foo.other_op"() : () -> () "foo.other_op"() : () -> ()
%root = pdl.operation "foo.op" %root = pdl.operation "foo.op"
pdl.rewrite "foo"(%root) pdl.rewrite %root with "foo"
} }
// ----- // -----
@ -153,7 +153,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) { pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op" %root = pdl.operation "foo.op"
pdl.rewrite (%root) { pdl.rewrite %root {
%type = pdl.type : i32 %type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type %newOp, %newResult = pdl.operation "foo.op" -> %type
@ -167,7 +167,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) { pdl.pattern : benefit(1) {
%type = pdl.type : i32 %type = pdl.type : i32
%root, %oldResult = pdl.operation "foo.op" -> %type %root, %oldResult = pdl.operation "foo.op" -> %type
pdl.rewrite (%root) { pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type %newOp, %newResult = pdl.operation "foo.op" -> %type
// expected-error@below {{expected no replacement values to be provided when the replacement operation is present}} // expected-error@below {{expected no replacement values to be provided when the replacement operation is present}}
@ -181,7 +181,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) { pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op" %root = pdl.operation "foo.op"
pdl.rewrite (%root) { pdl.rewrite %root {
%type = pdl.type : i32 %type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type %newOp, %newResult = pdl.operation "foo.op" -> %type
@ -192,6 +192,55 @@ pdl.pattern : benefit(1) {
// ----- // -----
//===----------------------------------------------------------------------===//
// pdl::RewriteOp
//===----------------------------------------------------------------------===//
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
// expected-error@below {{expected rewrite region to be non-empty if external name is not specified}}
"pdl.rewrite"(%op) ({}) : (!pdl.operation) -> ()
}
// -----
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
// expected-error@below {{expected no external arguments when the rewrite is specified inline}}
"pdl.rewrite"(%op, %op) ({
^bb1:
pdl.rewrite_end
}) : (!pdl.operation, !pdl.operation) -> ()
}
// -----
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
// expected-error@below {{expected no external constant parameters when the rewrite is specified inline}}
"pdl.rewrite"(%op) ({
^bb1:
pdl.rewrite_end
}) {externalConstParams = []} : (!pdl.operation) -> ()
}
// -----
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
// expected-error@below {{expected rewrite region to be empty when rewrite is external}}
"pdl.rewrite"(%op) ({
^bb1:
pdl.rewrite_end
}) {name = "foo"} : (!pdl.operation) -> ()
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// pdl::TypeOp // pdl::TypeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -201,5 +250,5 @@ pdl.pattern : benefit(1) {
%unused = pdl.type %unused = pdl.type
%op = pdl.operation "foo.op" %op = pdl.operation "foo.op"
pdl.rewrite "rewriter"(%op) pdl.rewrite %op with "rewriter"
} }

View File

@ -1,8 +1,6 @@
// RUN: mlir-opt -split-input-file %s | mlir-opt // RUN: mlir-opt -split-input-file %s | mlir-opt
// Verify the printed output can be parsed.
// RUN: mlir-opt %s | mlir-opt
// Verify the generic form can be parsed. // Verify the generic form can be parsed.
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt // RUN: mlir-opt -split-input-file -mlir-print-op-generic %s | mlir-opt
// ----- // -----
@ -15,7 +13,30 @@ pdl.pattern @operations : benefit(1) {
// Operation with input. // Operation with input.
%input = pdl.input %input = pdl.input
%root = pdl.operation(%op0_result, %input) %root = pdl.operation(%op0_result, %input)
pdl.rewrite "rewriter"(%root) pdl.rewrite %root with "rewriter"
}
// -----
pdl.pattern @rewrite_with_args : benefit(1) {
%input = pdl.input
%root = pdl.operation(%input)
pdl.rewrite %root with "rewriter"(%input : !pdl.value)
}
// -----
pdl.pattern @rewrite_with_params : benefit(1) {
%root = pdl.operation
pdl.rewrite %root with "rewriter"["I am param"]
}
// -----
pdl.pattern @rewrite_with_args_and_params : benefit(1) {
%input = pdl.input
%root = pdl.operation(%input)
pdl.rewrite %root with "rewriter"["I am param"](%input : !pdl.value)
} }
// ----- // -----
@ -26,7 +47,7 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) {
%type1 = pdl.type : i32 %type1 = pdl.type : i32
%type2 = pdl.type %type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2 %root, %results:2 = pdl.operation -> %type1, %type2
pdl.rewrite(%root) { pdl.rewrite %root {
%type3 = pdl.type %type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3 %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
pdl.replace %root with %newOp pdl.replace %root with %newOp
@ -41,7 +62,7 @@ pdl.pattern @infer_type_from_result_replace : benefit(1) {
%type1 = pdl.type : i32 %type1 = pdl.type : i32
%type2 = pdl.type %type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2 %root, %results:2 = pdl.operation -> %type1, %type2
pdl.rewrite(%root) { pdl.rewrite %root {
%type3 = pdl.type %type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3 %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
pdl.replace %root with (%newResults#0, %newResults#1) pdl.replace %root with (%newResults#0, %newResults#1)
@ -56,7 +77,7 @@ pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%type1 = pdl.type : i32 %type1 = pdl.type : i32
%type2 = pdl.type %type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2 %root, %results:2 = pdl.operation -> %type1, %type2
pdl.rewrite(%root) { pdl.rewrite %root {
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2 %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2
} }
} }

View File

@ -0,0 +1,25 @@
// RUN: mlir-opt -split-input-file %s | mlir-opt
// Verify the printed output can be parsed.
// RUN: mlir-opt %s | mlir-opt
// Verify the generic form can be parsed.
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt
// -----
func @operations(%attribute: !pdl.attribute,
%input: !pdl.value,
%type: !pdl.type) {
// attributes, operands, and results
%op0 = pdl_interp.create_operation "foo.op"(%input) {"attr" = %attribute} -> %type
// attributes, and results
%op1 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute} -> %type
// attributes
%op2 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute, "attr1" = %attribute} -> ()
// operands, and results
%op3 = pdl_interp.create_operation "foo.op"(%input) -> %type
pdl_interp.finalize
}

View File

@ -226,7 +226,7 @@ bool LiteralElement::isValidLiteral(StringRef value) {
// If there is only one character, this must either be punctuation or a // If there is only one character, this must either be punctuation or a
// single character bare identifier. // single character bare identifier.
if (value.size() == 1) if (value.size() == 1)
return isalpha(front) || StringRef("_:,=<>()[]?").contains(front); return isalpha(front) || StringRef("_:,=<>()[]{}?").contains(front);
// Check the punctuation that are larger than a single character. // Check the punctuation that are larger than a single character.
if (value == "->") if (value == "->")
@ -583,6 +583,8 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
.Case("=", "Equal()") .Case("=", "Equal()")
.Case("<", "Less()") .Case("<", "Less()")
.Case(">", "Greater()") .Case(">", "Greater()")
.Case("{", "LBrace()")
.Case("}", "RBrace()")
.Case("(", "LParen()") .Case("(", "LParen()")
.Case(")", "RParen()") .Case(")", "RParen()")
.Case("[", "LSquare()") .Case("[", "LSquare()")