forked from OSchip/llvm-project
Allowing replacing non-root operations in DialectConversion.
When dealing with regions, or other patterns that need to generate temporary operations, it is useful to be able to replace other operations than the root op being matched. Before this PR, these operations would still be considered for legalization meaning that the conversion would either fail, erroneously need to mark these ops as legal, or add unnecessary patterns. PiperOrigin-RevId: 274598513
This commit is contained in:
parent
24c392f21c
commit
96de7091bc
|
@ -558,7 +558,7 @@ void ConversionPatternRewriterImpl::undoBlockActions(
|
|||
case BlockActionKind::Split: {
|
||||
action.originalBlock->getOperations().splice(
|
||||
action.originalBlock->end(), action.block->getOperations());
|
||||
action.block->dropAllUses();
|
||||
action.block->dropAllDefinedValueUses();
|
||||
action.block->erase();
|
||||
break;
|
||||
}
|
||||
|
@ -990,6 +990,21 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
|
|||
}
|
||||
}
|
||||
|
||||
// Check all of the replacements to ensure that the pattern actually replaced
|
||||
// the root operation. We also mark any other replaced ops as 'dead' so that
|
||||
// we don't try to legalize them later.
|
||||
bool replacedRoot = false;
|
||||
for (unsigned i = curState.numReplacements,
|
||||
e = rewriterImpl.replacements.size();
|
||||
i != e; ++i) {
|
||||
Operation *replacedOp = rewriterImpl.replacements[i].op;
|
||||
if (replacedOp == op)
|
||||
replacedRoot = true;
|
||||
else
|
||||
rewriterImpl.deadOps.insert(replacedOp);
|
||||
}
|
||||
assert(replacedRoot && "expected pattern to replace the root operation");
|
||||
|
||||
// Recursively legalize each of the new operations.
|
||||
for (unsigned i = curState.numCreatedOperations,
|
||||
e = rewriterImpl.createdOps.size();
|
||||
|
|
|
@ -19,6 +19,13 @@ func @dropped_region_with_illegal_ops() {
|
|||
}) : () -> ()
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
// CHECK-LABEL: func @replace_non_root_illegal_op
|
||||
func @replace_non_root_illegal_op() {
|
||||
// CHECK-NEXT: "test.legal_op_b"
|
||||
// CHECK-NEXT: test.return
|
||||
%result = "test.replace_non_root"() : () -> (i32)
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -808,6 +808,7 @@ def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
|
|||
def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
|
||||
def LegalOpA : TEST_Op<"legal_op_a">,
|
||||
Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
|
||||
def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
|
||||
|
||||
// Check that smaller pattern depths are chosen, i.e. prioritize more direct
|
||||
// mappings.
|
||||
|
|
|
@ -249,6 +249,26 @@ struct TestUpdateConsumerType : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Non-Root Replacement Rewrite Testing
|
||||
/// This pattern generates an invalid operation, but replaces it before the
|
||||
/// pattern is finished. This checks that we don't need to legalize the
|
||||
/// temporary op.
|
||||
struct TestNonRootReplacement : public RewritePattern {
|
||||
TestNonRootReplacement(MLIRContext *ctx)
|
||||
: RewritePattern("test.replace_non_root", 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto resultType = *op->result_type_begin();
|
||||
auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
|
||||
auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
|
||||
|
||||
rewriter.replaceOp(illegalOp, {legalOp});
|
||||
rewriter.replaceOp(op, {illegalOp});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -301,15 +321,15 @@ struct TestLegalizePatternDriver
|
|||
.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType,
|
||||
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType>(
|
||||
&getContext());
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
|
||||
TestNonRootReplacement>(&getContext());
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||
converter);
|
||||
|
||||
// Define the conversion target used for the test.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
|
||||
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
|
||||
target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
|
||||
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
|
||||
// Don't allow F32 operands.
|
||||
|
|
Loading…
Reference in New Issue