Add getRemappedValue to ConversionPatternRewriter

This method is needed for N->1 conversion patterns to retrieve remapped
Values used in the original N operations.

Closes tensorflow/mlir#237

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/237 from dcaballe:dcaballe/getRemappedValue 1f64fadcf2b203f7b336ff0c5838b116ae3625db
PiperOrigin-RevId: 281321881
This commit is contained in:
Diego Caballero 2019-11-19 10:15:36 -08:00 committed by A. Unique TensorFlower
parent 06fb797b40
commit dd5a7cb488
4 changed files with 86 additions and 0 deletions

View File

@ -338,6 +338,10 @@ public:
return cast<OpT>(cloneWithoutRegions(op.getOperation()));
}
/// Return the converted value that replaces 'key'. Return 'key' if there is
/// no such a converted value.
Value *getRemappedValue(Value *key);
//===--------------------------------------------------------------------===//
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//

View File

@ -803,6 +803,12 @@ Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
return newOp;
}
/// Return the converted value that replaces 'key'. Return 'key' if there is
/// no such a converted value.
Value *ConversionPatternRewriter::getRemappedValue(Value *key) {
return impl->mapping.lookupOrDefault(key);
}
/// PatternRewriter hook for splitting a block into two parts.
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {

View File

@ -0,0 +1,13 @@
// RUN: mlir-opt %s -test-remapped-value | FileCheck %s
// Simple test that exercises ConvertPatternRewriter::getRemappedValue.
func @remap_input_1_to_1(%arg0: i32) {
%0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> i32
%1 = "test.one_variadic_out_one_variadic_in1"(%0) : (i32) -> i32
"test.return"() : () -> ()
}
// CHECK-LABEL: func @remap_input_1_to_1
// CHECK-SAME: (%[[ARG:.*]]: i32)
// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]])
// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]])

View File

@ -435,3 +435,66 @@ static mlir::PassRegistration<TestLegalizePatternDriver>
return std::make_unique<TestLegalizePatternDriver>(
legalizerConversionMode);
});
//===----------------------------------------------------------------------===//
// ConversionPatternRewriter::getRemappedValue testing. This method is used
// to get the remapped value of a original value that was replaced using
// ConversionPatternRewriter.
namespace {
/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
/// operand twice.
///
/// Example:
/// %1 = test.one_variadic_out_one_variadic_in1"(%0)
/// is replaced with:
/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
struct OneVResOneVOperandOp1Converter
: public OpConversionPattern<OneVResOneVOperandOp1> {
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
PatternMatchResult
matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto origOps = op.getOperands();
assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
"One operand expected");
Value *origOp = *origOps.begin();
SmallVector<Value *, 2> remappedOperands;
// Replicate the remapped original operand twice. Note that we don't used
// the remapped 'operand' since the goal is testing 'getRemappedValue'.
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
SmallVector<Type, 1> resultTypes(op.getResultTypes());
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
remappedOperands);
return matchSuccess();
}
};
struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> {
void runOnFunction() override {
mlir::OwningRewritePatternList patterns;
patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
mlir::ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
// We make OneVResOneVOperandOp1 legal only when it has more that one
// operand. This will trigger the conversion that will replace one-operand
// OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
[](Operation *op) -> bool {
return std::distance(op->operand_begin(), op->operand_end()) > 1;
});
if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
signalPassFailure();
}
}
};
} // end anonymous namespace
static PassRegistration<TestRemappedValue> remapped_value_pass(
"test-remapped-value",
"Test public remapped value mechanism in ConversionPatternRewriter");