diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 2deb0c9c048d..88669505e233 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -338,6 +338,10 @@ public: return cast(cloneWithoutRegions(op.getOperation())); } + /// Return the converted value that replaces 'key'. Return 'key' if there is + /// no such a converted value. + Value *getRemappedValue(Value *key); + //===--------------------------------------------------------------------===// // PatternRewriter Hooks //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index a2065f16a213..7931932a7894 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -803,6 +803,12 @@ Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) { return newOp; } +/// Return the converted value that replaces 'key'. Return 'key' if there is +/// no such a converted value. +Value *ConversionPatternRewriter::getRemappedValue(Value *key) { + return impl->mapping.lookupOrDefault(key); +} + /// PatternRewriter hook for splitting a block into two parts. Block *ConversionPatternRewriter::splitBlock(Block *block, Block::iterator before) { diff --git a/mlir/test/Transforms/test-legalize-remapped-value.mlir b/mlir/test/Transforms/test-legalize-remapped-value.mlir new file mode 100644 index 000000000000..ff571c93f938 --- /dev/null +++ b/mlir/test/Transforms/test-legalize-remapped-value.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -test-remapped-value | FileCheck %s + +// Simple test that exercises ConvertPatternRewriter::getRemappedValue. +func @remap_input_1_to_1(%arg0: i32) { + %0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> i32 + %1 = "test.one_variadic_out_one_variadic_in1"(%0) : (i32) -> i32 + "test.return"() : () -> () +} +// CHECK-LABEL: func @remap_input_1_to_1 +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]]) +// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]]) + diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 936d76329679..5ef03606dbec 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -435,3 +435,66 @@ static mlir::PassRegistration return std::make_unique( legalizerConversionMode); }); + +//===----------------------------------------------------------------------===// +// ConversionPatternRewriter::getRemappedValue testing. This method is used +// to get the remapped value of a original value that was replaced using +// ConversionPatternRewriter. +namespace { +/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with +/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original +/// operand twice. +/// +/// Example: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0) +/// is replaced with: +/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) +struct OneVResOneVOperandOp1Converter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult + matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto origOps = op.getOperands(); + assert(std::distance(origOps.begin(), origOps.end()) == 1 && + "One operand expected"); + Value *origOp = *origOps.begin(); + SmallVector remappedOperands; + // Replicate the remapped original operand twice. Note that we don't used + // the remapped 'operand' since the goal is testing 'getRemappedValue'. + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + remappedOperands.push_back(rewriter.getRemappedValue(origOp)); + + SmallVector resultTypes(op.getResultTypes()); + rewriter.replaceOpWithNewOp(op, resultTypes, + remappedOperands); + return matchSuccess(); + } +}; + +struct TestRemappedValue : public mlir::FunctionPass { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + mlir::ConversionTarget target(getContext()); + target.addLegalOp(); + // We make OneVResOneVOperandOp1 legal only when it has more that one + // operand. This will trigger the conversion that will replace one-operand + // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. + target.addDynamicallyLegalOp( + [](Operation *op) -> bool { + return std::distance(op->operand_begin(), op->operand_end()) > 1; + }); + + if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { + signalPassFailure(); + } + } +}; +} // end anonymous namespace + +static PassRegistration remapped_value_pass( + "test-remapped-value", + "Test public remapped value mechanism in ConversionPatternRewriter");