Allowing replacing non-root operations in DialectConversion.

When dealing with regions, or other patterns that need to generate temporary operations, it is useful to be able to replace other operations than the root op being matched. Before this PR, these operations would still be considered for legalization meaning that the conversion would either fail, erroneously need to mark these ops as legal, or add unnecessary patterns.

PiperOrigin-RevId: 274598513
This commit is contained in:
River Riddle 2019-10-14 09:50:54 -07:00 committed by A. Unique TensorFlower
parent 24c392f21c
commit 96de7091bc
4 changed files with 47 additions and 4 deletions

View File

@ -558,7 +558,7 @@ void ConversionPatternRewriterImpl::undoBlockActions(
case BlockActionKind::Split: {
action.originalBlock->getOperations().splice(
action.originalBlock->end(), action.block->getOperations());
action.block->dropAllUses();
action.block->dropAllDefinedValueUses();
action.block->erase();
break;
}
@ -990,6 +990,21 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
}
}
// Check all of the replacements to ensure that the pattern actually replaced
// the root operation. We also mark any other replaced ops as 'dead' so that
// we don't try to legalize them later.
bool replacedRoot = false;
for (unsigned i = curState.numReplacements,
e = rewriterImpl.replacements.size();
i != e; ++i) {
Operation *replacedOp = rewriterImpl.replacements[i].op;
if (replacedOp == op)
replacedRoot = true;
else
rewriterImpl.deadOps.insert(replacedOp);
}
assert(replacedRoot && "expected pattern to replace the root operation");
// Recursively legalize each of the new operations.
for (unsigned i = curState.numCreatedOperations,
e = rewriterImpl.createdOps.size();

View File

@ -19,6 +19,13 @@ func @dropped_region_with_illegal_ops() {
}) : () -> ()
"test.return"() : () -> ()
}
// CHECK-LABEL: func @replace_non_root_illegal_op
func @replace_non_root_illegal_op() {
// CHECK-NEXT: "test.legal_op_b"
// CHECK-NEXT: test.return
%result = "test.replace_non_root"() : () -> (i32)
"test.return"() : () -> ()
}
// -----

View File

@ -808,6 +808,7 @@ def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
def LegalOpA : TEST_Op<"legal_op_a">,
Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
// Check that smaller pattern depths are chosen, i.e. prioritize more direct
// mappings.

View File

@ -249,6 +249,26 @@ struct TestUpdateConsumerType : public ConversionPattern {
}
};
//===----------------------------------------------------------------------===//
// Non-Root Replacement Rewrite Testing
/// This pattern generates an invalid operation, but replaces it before the
/// pattern is finished. This checks that we don't need to legalize the
/// temporary op.
struct TestNonRootReplacement : public RewritePattern {
TestNonRootReplacement(MLIRContext *ctx)
: RewritePattern("test.replace_non_root", 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
auto resultType = *op->result_type_begin();
auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
rewriter.replaceOp(illegalOp, {legalOp});
rewriter.replaceOp(op, {illegalOp});
return matchSuccess();
}
};
} // namespace
namespace {
@ -301,15 +321,15 @@ struct TestLegalizePatternDriver
.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType>(
&getContext());
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement>(&getContext());
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
// Define the conversion target used for the test.
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
// Don't allow F32 operands.