forked from OSchip/llvm-project
[mlir:PDL] Fix bugs in PDLPatternModule merging
* Constraints/Rewrites registered before a pattern was added were dropped * Constraints/Rewrites may be registered multiple times (if different pattern sets depend on them) * ModuleOp no longer has a terminator, so we shouldn't be removing the terminator from it Differential Revision: https://reviews.llvm.org/D114816
This commit is contained in:
parent
98f5bd3489
commit
06c3b9c7be
|
@ -157,22 +157,21 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
|
||||||
// Ignore the other module if it has no patterns.
|
// Ignore the other module if it has no patterns.
|
||||||
if (!other.pdlModule)
|
if (!other.pdlModule)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
// Steal the functions of the other module.
|
||||||
|
for (auto &it : other.constraintFunctions)
|
||||||
|
registerConstraintFunction(it.first(), std::move(it.second));
|
||||||
|
for (auto &it : other.rewriteFunctions)
|
||||||
|
registerRewriteFunction(it.first(), std::move(it.second));
|
||||||
|
|
||||||
// Steal the other state if we have no patterns.
|
// Steal the other state if we have no patterns.
|
||||||
if (!pdlModule) {
|
if (!pdlModule) {
|
||||||
constraintFunctions = std::move(other.constraintFunctions);
|
|
||||||
rewriteFunctions = std::move(other.rewriteFunctions);
|
|
||||||
pdlModule = std::move(other.pdlModule);
|
pdlModule = std::move(other.pdlModule);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Steal the functions of the other module.
|
|
||||||
for (auto &it : constraintFunctions)
|
|
||||||
registerConstraintFunction(it.first(), std::move(it.second));
|
|
||||||
for (auto &it : rewriteFunctions)
|
|
||||||
registerRewriteFunction(it.first(), std::move(it.second));
|
|
||||||
|
|
||||||
// Merge the pattern operations from the other module into this one.
|
// Merge the pattern operations from the other module into this one.
|
||||||
Block *block = pdlModule->getBody();
|
Block *block = pdlModule->getBody();
|
||||||
block->getTerminator()->erase();
|
|
||||||
block->getOperations().splice(block->end(),
|
block->getOperations().splice(block->end(),
|
||||||
other.pdlModule->getBody()->getOperations());
|
other.pdlModule->getBody()->getOperations());
|
||||||
}
|
}
|
||||||
|
@ -182,18 +181,20 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
|
||||||
|
|
||||||
void PDLPatternModule::registerConstraintFunction(
|
void PDLPatternModule::registerConstraintFunction(
|
||||||
StringRef name, PDLConstraintFunction constraintFn) {
|
StringRef name, PDLConstraintFunction constraintFn) {
|
||||||
auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
|
// TODO: Is it possible to diagnose when `name` is already registered to
|
||||||
(void)it;
|
// a function that is not equivalent to `constraintFn`?
|
||||||
assert(it.second &&
|
// Allow existing mappings in the case multiple patterns depend on the same
|
||||||
"constraint with the given name has already been registered");
|
// constraint.
|
||||||
|
constraintFunctions.try_emplace(name, std::move(constraintFn));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PDLPatternModule::registerRewriteFunction(StringRef name,
|
void PDLPatternModule::registerRewriteFunction(StringRef name,
|
||||||
PDLRewriteFunction rewriteFn) {
|
PDLRewriteFunction rewriteFn) {
|
||||||
auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
|
// TODO: Is it possible to diagnose when `name` is already registered to
|
||||||
(void)it;
|
// a function that is not equivalent to `rewriteFn`?
|
||||||
assert(it.second && "native rewrite function with the given name has "
|
// Allow existing mappings in the case multiple patterns depend on the same
|
||||||
"already been registered");
|
// rewrite.
|
||||||
|
rewriteFunctions.try_emplace(name, std::move(rewriteFn));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -87,13 +87,27 @@ struct TestPDLByteCodePass
|
||||||
if (!patternModule || !irModule)
|
if (!patternModule || !irModule)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
RewritePatternSet patternList(module->getContext());
|
||||||
|
|
||||||
|
// Register ahead of time to test when functions are registered without a
|
||||||
|
// pattern.
|
||||||
|
patternList.getPDLPatterns().registerConstraintFunction(
|
||||||
|
"multi_entity_constraint", customMultiEntityConstraint);
|
||||||
|
patternList.getPDLPatterns().registerConstraintFunction(
|
||||||
|
"single_entity_constraint", customSingleEntityConstraint);
|
||||||
|
|
||||||
// Process the pattern module.
|
// Process the pattern module.
|
||||||
patternModule.getOperation()->remove();
|
patternModule.getOperation()->remove();
|
||||||
PDLPatternModule pdlPattern(patternModule);
|
PDLPatternModule pdlPattern(patternModule);
|
||||||
|
|
||||||
|
// Note: This constraint was already registered, but we re-register here to
|
||||||
|
// ensure that duplication registration is allowed (the duplicate mapping
|
||||||
|
// will be ignored). This tests that we support separating the registration
|
||||||
|
// of library functions from the construction of patterns, and also that we
|
||||||
|
// allow multiple patterns to depend on the same library functions (without
|
||||||
|
// asserting/crashing).
|
||||||
pdlPattern.registerConstraintFunction("multi_entity_constraint",
|
pdlPattern.registerConstraintFunction("multi_entity_constraint",
|
||||||
customMultiEntityConstraint);
|
customMultiEntityConstraint);
|
||||||
pdlPattern.registerConstraintFunction("single_entity_constraint",
|
|
||||||
customSingleEntityConstraint);
|
|
||||||
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
|
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
|
||||||
customMultiEntityVariadicConstraint);
|
customMultiEntityVariadicConstraint);
|
||||||
pdlPattern.registerRewriteFunction("creator", customCreate);
|
pdlPattern.registerRewriteFunction("creator", customCreate);
|
||||||
|
@ -101,8 +115,7 @@ struct TestPDLByteCodePass
|
||||||
customVariadicResultCreate);
|
customVariadicResultCreate);
|
||||||
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
|
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
|
||||||
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
|
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
|
||||||
|
patternList.add(std::move(pdlPattern));
|
||||||
RewritePatternSet patternList(std::move(pdlPattern));
|
|
||||||
|
|
||||||
// Invoke the pattern driver with the provided patterns.
|
// Invoke the pattern driver with the provided patterns.
|
||||||
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
|
(void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
|
||||||
|
|
Loading…
Reference in New Issue