diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 42663685d990..0007feb4ccd3 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -558,7 +558,7 @@ void ConversionPatternRewriterImpl::undoBlockActions( case BlockActionKind::Split: { action.originalBlock->getOperations().splice( action.originalBlock->end(), action.block->getOperations()); - action.block->dropAllUses(); + action.block->dropAllDefinedValueUses(); action.block->erase(); break; } @@ -990,6 +990,21 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, } } + // Check all of the replacements to ensure that the pattern actually replaced + // the root operation. We also mark any other replaced ops as 'dead' so that + // we don't try to legalize them later. + bool replacedRoot = false; + for (unsigned i = curState.numReplacements, + e = rewriterImpl.replacements.size(); + i != e; ++i) { + Operation *replacedOp = rewriterImpl.replacements[i].op; + if (replacedOp == op) + replacedRoot = true; + else + rewriterImpl.deadOps.insert(replacedOp); + } + assert(replacedRoot && "expected pattern to replace the root operation"); + // Recursively legalize each of the new operations. for (unsigned i = curState.numCreatedOperations, e = rewriterImpl.createdOps.size(); diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 79494c798e1f..2cf981b0db9d 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -19,6 +19,13 @@ func @dropped_region_with_illegal_ops() { }) : () -> () "test.return"() : () -> () } +// CHECK-LABEL: func @replace_non_root_illegal_op +func @replace_non_root_illegal_op() { + // CHECK-NEXT: "test.legal_op_b" + // CHECK-NEXT: test.return + %result = "test.replace_non_root"() : () -> (i32) + "test.return"() : () -> () +} // ----- diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index b26360f3d2cc..73769e72a4a3 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -808,6 +808,7 @@ def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; +def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; // Check that smaller pattern depths are chosen, i.e. prioritize more direct // mappings. diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 83814eed11c0..2dde6a37675f 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -249,6 +249,26 @@ struct TestUpdateConsumerType : public ConversionPattern { } }; +//===----------------------------------------------------------------------===// +// Non-Root Replacement Rewrite Testing +/// This pattern generates an invalid operation, but replaces it before the +/// pattern is finished. This checks that we don't need to legalize the +/// temporary op. +struct TestNonRootReplacement : public RewritePattern { + TestNonRootReplacement(MLIRContext *ctx) + : RewritePattern("test.replace_non_root", 1, ctx) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + auto resultType = *op->result_type_begin(); + auto illegalOp = rewriter.create(op->getLoc(), resultType); + auto legalOp = rewriter.create(op->getLoc(), resultType); + + rewriter.replaceOp(illegalOp, {legalOp}); + rewriter.replaceOp(op, {illegalOp}); + return matchSuccess(); + } +}; } // namespace namespace { @@ -301,15 +321,15 @@ struct TestLegalizePatternDriver .insert( - &getContext()); + TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, + TestNonRootReplacement>(&getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](TestReturnOp op) { // Don't allow F32 operands.