[mlir:PDL] Fix bugs in PDLPatternModule merging

* Constraints/Rewrites registered before a pattern was added were dropped
* Constraints/Rewrites may be registered multiple times (if different pattern sets depend on them)
* ModuleOp no longer has a terminator, so we shouldn't be removing the terminator from it

Differential Revision: https://reviews.llvm.org/D114816
This commit is contained in:
River Riddle 2021-12-10 19:36:07 +00:00
parent 98f5bd3489
commit 06c3b9c7be
2 changed files with 34 additions and 20 deletions

View File

@ -157,22 +157,21 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Ignore the other module if it has no patterns.
if (!other.pdlModule)
return;
// Steal the functions of the other module.
for (auto &it : other.constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : other.rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
// Steal the other state if we have no patterns.
if (!pdlModule) {
constraintFunctions = std::move(other.constraintFunctions);
rewriteFunctions = std::move(other.rewriteFunctions);
pdlModule = std::move(other.pdlModule);
return;
}
// Steal the functions of the other module.
for (auto &it : constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
// Merge the pattern operations from the other module into this one.
Block *block = pdlModule->getBody();
block->getTerminator()->erase();
block->getOperations().splice(block->end(),
other.pdlModule->getBody()->getOperations());
}
@ -182,18 +181,20 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
void PDLPatternModule::registerConstraintFunction(
StringRef name, PDLConstraintFunction constraintFn) {
auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
(void)it;
assert(it.second &&
"constraint with the given name has already been registered");
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `constraintFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// constraint.
constraintFunctions.try_emplace(name, std::move(constraintFn));
}
void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) {
auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
(void)it;
assert(it.second && "native rewrite function with the given name has "
"already been registered");
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `rewriteFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// rewrite.
rewriteFunctions.try_emplace(name, std::move(rewriteFn));
}
//===----------------------------------------------------------------------===//

View File

@ -87,13 +87,27 @@ struct TestPDLByteCodePass
if (!patternModule || !irModule)
return;
RewritePatternSet patternList(module->getContext());
// Register ahead of time to test when functions are registered without a
// pattern.
patternList.getPDLPatterns().registerConstraintFunction(
"multi_entity_constraint", customMultiEntityConstraint);
patternList.getPDLPatterns().registerConstraintFunction(
"single_entity_constraint", customSingleEntityConstraint);
// Process the pattern module.
patternModule.getOperation()->remove();
PDLPatternModule pdlPattern(patternModule);
// Note: This constraint was already registered, but we re-register here to
// ensure that duplication registration is allowed (the duplicate mapping
// will be ignored). This tests that we support separating the registration
// of library functions from the construction of patterns, and also that we
// allow multiple patterns to depend on the same library functions (without
// asserting/crashing).
pdlPattern.registerConstraintFunction("multi_entity_constraint",
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("single_entity_constraint",
customSingleEntityConstraint);
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
customMultiEntityVariadicConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
@ -101,8 +115,7 @@ struct TestPDLByteCodePass
customVariadicResultCreate);
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
RewritePatternSet patternList(std::move(pdlPattern));
patternList.add(std::move(pdlPattern));
// Invoke the pattern driver with the provided patterns.
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),