[mlir:PDL] Fix bugs in PDLPatternModule merging

* Constraints/Rewrites registered before a pattern was added were dropped
* Constraints/Rewrites may be registered multiple times (if different pattern sets depend on them)
* ModuleOp no longer has a terminator, so we shouldn't be removing the terminator from it

Differential Revision: https://reviews.llvm.org/D114816
This commit is contained in:
River Riddle 2021-12-10 19:36:07 +00:00
parent 98f5bd3489
commit 06c3b9c7be
2 changed files with 34 additions and 20 deletions

View File

@ -157,22 +157,21 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Ignore the other module if it has no patterns. // Ignore the other module if it has no patterns.
if (!other.pdlModule) if (!other.pdlModule)
return; return;
// Steal the functions of the other module.
for (auto &it : other.constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : other.rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
// Steal the other state if we have no patterns. // Steal the other state if we have no patterns.
if (!pdlModule) { if (!pdlModule) {
constraintFunctions = std::move(other.constraintFunctions);
rewriteFunctions = std::move(other.rewriteFunctions);
pdlModule = std::move(other.pdlModule); pdlModule = std::move(other.pdlModule);
return; return;
} }
// Steal the functions of the other module.
for (auto &it : constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
// Merge the pattern operations from the other module into this one. // Merge the pattern operations from the other module into this one.
Block *block = pdlModule->getBody(); Block *block = pdlModule->getBody();
block->getTerminator()->erase();
block->getOperations().splice(block->end(), block->getOperations().splice(block->end(),
other.pdlModule->getBody()->getOperations()); other.pdlModule->getBody()->getOperations());
} }
@ -182,18 +181,20 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
void PDLPatternModule::registerConstraintFunction( void PDLPatternModule::registerConstraintFunction(
StringRef name, PDLConstraintFunction constraintFn) { StringRef name, PDLConstraintFunction constraintFn) {
auto it = constraintFunctions.try_emplace(name, std::move(constraintFn)); // TODO: Is it possible to diagnose when `name` is already registered to
(void)it; // a function that is not equivalent to `constraintFn`?
assert(it.second && // Allow existing mappings in the case multiple patterns depend on the same
"constraint with the given name has already been registered"); // constraint.
constraintFunctions.try_emplace(name, std::move(constraintFn));
} }
void PDLPatternModule::registerRewriteFunction(StringRef name, void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) { PDLRewriteFunction rewriteFn) {
auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn)); // TODO: Is it possible to diagnose when `name` is already registered to
(void)it; // a function that is not equivalent to `rewriteFn`?
assert(it.second && "native rewrite function with the given name has " // Allow existing mappings in the case multiple patterns depend on the same
"already been registered"); // rewrite.
rewriteFunctions.try_emplace(name, std::move(rewriteFn));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -87,13 +87,27 @@ struct TestPDLByteCodePass
if (!patternModule || !irModule) if (!patternModule || !irModule)
return; return;
RewritePatternSet patternList(module->getContext());
// Register ahead of time to test when functions are registered without a
// pattern.
patternList.getPDLPatterns().registerConstraintFunction(
"multi_entity_constraint", customMultiEntityConstraint);
patternList.getPDLPatterns().registerConstraintFunction(
"single_entity_constraint", customSingleEntityConstraint);
// Process the pattern module. // Process the pattern module.
patternModule.getOperation()->remove(); patternModule.getOperation()->remove();
PDLPatternModule pdlPattern(patternModule); PDLPatternModule pdlPattern(patternModule);
// Note: This constraint was already registered, but we re-register here to
// ensure that duplication registration is allowed (the duplicate mapping
// will be ignored). This tests that we support separating the registration
// of library functions from the construction of patterns, and also that we
// allow multiple patterns to depend on the same library functions (without
// asserting/crashing).
pdlPattern.registerConstraintFunction("multi_entity_constraint", pdlPattern.registerConstraintFunction("multi_entity_constraint",
customMultiEntityConstraint); customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("single_entity_constraint",
customSingleEntityConstraint);
pdlPattern.registerConstraintFunction("multi_entity_var_constraint", pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
customMultiEntityVariadicConstraint); customMultiEntityVariadicConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("creator", customCreate);
@ -101,8 +115,7 @@ struct TestPDLByteCodePass
customVariadicResultCreate); customVariadicResultCreate);
pdlPattern.registerRewriteFunction("type_creator", customCreateType); pdlPattern.registerRewriteFunction("type_creator", customCreateType);
pdlPattern.registerRewriteFunction("rewriter", customRewriter); pdlPattern.registerRewriteFunction("rewriter", customRewriter);
patternList.add(std::move(pdlPattern));
RewritePatternSet patternList(std::move(pdlPattern));
// Invoke the pattern driver with the provided patterns. // Invoke the pattern driver with the provided patterns.
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),