forked from OSchip/llvm-project
[mlir] Support dialect-wide canonicalization pattern registration
* Add `hasCanonicalizer` option to Dialect. * Initialize canonicalizer with dialect-wide canonicalization patterns. * Add test case to TestDialect. Dialect-wide canonicalization patterns are useful if a canonicalization pattern does not conceptually associate with any single operation, i.e., it should not be registered as part of an operation's `getCanonicalizationPatterns` function. E.g., this is the case for canonicalization patterns that match an op interface. Differential Revision: https://reviews.llvm.org/D103226
This commit is contained in:
parent
7d418dadf6
commit
108ca7a7e7
|
@ -68,6 +68,13 @@ public:
|
|||
/// These are represented with OpaqueType.
|
||||
bool allowsUnknownTypes() const { return unknownTypesAllowed; }
|
||||
|
||||
/// Register dialect-wide canonicalization patterns. This method should only
|
||||
/// be used to register canonicalization patterns that do not conceptually
|
||||
/// belong to any single operation in the dialect. (In that case, use the op's
|
||||
/// canonicalizer.) E.g., canonicalization patterns for op interfaces should
|
||||
/// be registered here.
|
||||
virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
|
||||
|
||||
/// Registered hook to materialize a single constant operation from a given
|
||||
/// attribute value with the desired resultant type. This method should use
|
||||
/// the provided builder to create the operation without changing the
|
||||
|
|
|
@ -275,6 +275,9 @@ class Dialect {
|
|||
|
||||
// If this dialect overrides the hook for op interface fallback.
|
||||
bit hasOperationInterfaceFallback = 0;
|
||||
|
||||
// If this dialect overrides the hook for canonicalization patterns.
|
||||
bit hasCanonicalizer = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -51,6 +51,9 @@ public:
|
|||
// Returns the dialects extra class declaration code.
|
||||
llvm::Optional<StringRef> getExtraClassDeclaration() const;
|
||||
|
||||
/// Returns true if this dialect has a canonicalizer.
|
||||
bool hasCanonicalizer() const;
|
||||
|
||||
// Returns true if this dialect has a constant materializer.
|
||||
bool hasConstantMaterializer() const;
|
||||
|
||||
|
|
|
@ -61,6 +61,10 @@ llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
|
|||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
bool Dialect::hasCanonicalizer() const {
|
||||
return def->getValueAsBit("hasCanonicalizer");
|
||||
}
|
||||
|
||||
bool Dialect::hasConstantMaterializer() const {
|
||||
return def->getValueAsBit("hasConstantMaterializer");
|
||||
}
|
||||
|
|
|
@ -35,6 +35,8 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
|
|||
/// execution.
|
||||
LogicalResult initialize(MLIRContext *context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto *dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (auto *op : context->getRegisteredOperations())
|
||||
op->getCanonicalizationPatterns(owningPatterns, context);
|
||||
patterns = std::move(owningPatterns);
|
||||
|
|
|
@ -104,4 +104,12 @@ func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
|
|||
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
|
||||
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
|
||||
return %1, %2, %3, %4, %5 : index, index, index, index, index
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_dialect_canonicalizer
|
||||
func @test_dialect_canonicalizer() -> (i32) {
|
||||
%0 = "test.dialect_canonicalizable"() : () -> (i32)
|
||||
// CHECK: %[[CST:.*]] = constant 42 : i32
|
||||
// CHECK: return %[[CST]]
|
||||
return %0 : i32
|
||||
}
|
||||
|
|
|
@ -287,6 +287,23 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
|
|||
return targetOperandsMutable();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestDialectCanonicalizerOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult
|
||||
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, rewriter.getI32Type(),
|
||||
rewriter.getI32IntegerAttr(42));
|
||||
return success();
|
||||
}
|
||||
|
||||
void TestDialect::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results) const {
|
||||
results.add(&dialectCanonicalizationPattern);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestFoldToCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -25,6 +25,7 @@ include "TestInterfaces.td"
|
|||
def Test_Dialect : Dialect {
|
||||
let name = "test";
|
||||
let cppNamespace = "::mlir::test";
|
||||
let hasCanonicalizer = 1;
|
||||
let hasConstantMaterializer = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
let hasRegionArgAttrVerify = 1;
|
||||
|
@ -966,6 +967,11 @@ def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
|
||||
let arguments = (ins);
|
||||
let results = (outs I32);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Patterns (Symbol Binding)
|
||||
|
||||
|
|
|
@ -107,6 +107,13 @@ static const char *const typeParserDecl = R"(
|
|||
::mlir::DialectAsmPrinter &os) const override;
|
||||
)";
|
||||
|
||||
/// The code block for the canonicalization pattern registration hook.
|
||||
static const char *const canonicalizerDecl = R"(
|
||||
/// Register canonicalization patterns.
|
||||
void getCanonicalizationPatterns(
|
||||
::mlir::RewritePatternSet &results) const override;
|
||||
)";
|
||||
|
||||
/// The code block for the constant materializer hook.
|
||||
static const char *const constantMaterializerDecl = R"(
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
|
@ -180,6 +187,8 @@ static void emitDialectDecl(Dialect &dialect,
|
|||
os << typeParserDecl;
|
||||
|
||||
// Add the decls for the various features of the dialect.
|
||||
if (dialect.hasCanonicalizer())
|
||||
os << canonicalizerDecl;
|
||||
if (dialect.hasConstantMaterializer())
|
||||
os << constantMaterializerDecl;
|
||||
if (dialect.hasOperationAttrVerify())
|
||||
|
|
Loading…
Reference in New Issue