[mlir] Support dialect-wide canonicalization pattern registration

* Add `hasCanonicalizer` option to Dialect.
* Initialize canonicalizer with dialect-wide canonicalization patterns.
* Add test case to TestDialect.

Dialect-wide canonicalization patterns are useful if a canonicalization pattern does not conceptually associate with any single operation, i.e., it should not be registered as part of an operation's `getCanonicalizationPatterns` function. E.g., this is the case for canonicalization patterns that match an op interface.

Differential Revision: https://reviews.llvm.org/D103226
This commit is contained in:
Matthias Springer 2021-05-27 17:26:45 +09:00
parent 7d418dadf6
commit 108ca7a7e7
9 changed files with 60 additions and 1 deletions

View File

@ -68,6 +68,13 @@ public:
/// These are represented with OpaqueType.
bool allowsUnknownTypes() const { return unknownTypesAllowed; }
/// Register dialect-wide canonicalization patterns. This method should only
/// be used to register canonicalization patterns that do not conceptually
/// belong to any single operation in the dialect. (In that case, use the op's
/// canonicalizer.) E.g., canonicalization patterns for op interfaces should
/// be registered here.
virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
/// Registered hook to materialize a single constant operation from a given
/// attribute value with the desired resultant type. This method should use
/// the provided builder to create the operation without changing the

View File

@ -275,6 +275,9 @@ class Dialect {
// If this dialect overrides the hook for op interface fallback.
bit hasOperationInterfaceFallback = 0;
// If this dialect overrides the hook for canonicalization patterns.
bit hasCanonicalizer = 0;
}
//===----------------------------------------------------------------------===//

View File

@ -51,6 +51,9 @@ public:
// Returns the dialects extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
/// Returns true if this dialect has a canonicalizer.
bool hasCanonicalizer() const;
// Returns true if this dialect has a constant materializer.
bool hasConstantMaterializer() const;

View File

@ -61,6 +61,10 @@ llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
bool Dialect::hasCanonicalizer() const {
return def->getValueAsBit("hasCanonicalizer");
}
bool Dialect::hasConstantMaterializer() const {
return def->getValueAsBit("hasConstantMaterializer");
}

View File

@ -35,6 +35,8 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
/// execution.
LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
patterns = std::move(owningPatterns);

View File

@ -104,4 +104,12 @@ func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
}
// CHECK-LABEL: test_dialect_canonicalizer
func @test_dialect_canonicalizer() -> (i32) {
%0 = "test.dialect_canonicalizable"() : () -> (i32)
// CHECK: %[[CST:.*]] = constant 42 : i32
// CHECK: return %[[CST]]
return %0 : i32
}

View File

@ -287,6 +287,23 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
return targetOperandsMutable();
}
//===----------------------------------------------------------------------===//
// TestDialectCanonicalizerOp
//===----------------------------------------------------------------------===//
static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
PatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(42));
return success();
}
void TestDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add(&dialectCanonicalizationPattern);
}
//===----------------------------------------------------------------------===//
// TestFoldToCallOp
//===----------------------------------------------------------------------===//

View File

@ -25,6 +25,7 @@ include "TestInterfaces.td"
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "::mlir::test";
let hasCanonicalizer = 1;
let hasConstantMaterializer = 1;
let hasOperationAttrVerify = 1;
let hasRegionArgAttrVerify = 1;
@ -966,6 +967,11 @@ def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
let hasFolder = 1;
}
def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
let arguments = (ins);
let results = (outs I32);
}
//===----------------------------------------------------------------------===//
// Test Patterns (Symbol Binding)

View File

@ -107,6 +107,13 @@ static const char *const typeParserDecl = R"(
::mlir::DialectAsmPrinter &os) const override;
)";
/// The code block for the canonicalization pattern registration hook.
static const char *const canonicalizerDecl = R"(
/// Register canonicalization patterns.
void getCanonicalizationPatterns(
::mlir::RewritePatternSet &results) const override;
)";
/// The code block for the constant materializer hook.
static const char *const constantMaterializerDecl = R"(
/// Materialize a single constant operation from a given attribute value with
@ -180,6 +187,8 @@ static void emitDialectDecl(Dialect &dialect,
os << typeParserDecl;
// Add the decls for the various features of the dialect.
if (dialect.hasCanonicalizer())
os << canonicalizerDecl;
if (dialect.hasConstantMaterializer())
os << constantMaterializerDecl;
if (dialect.hasOperationAttrVerify())