forked from OSchip/llvm-project
Add getRemappedValue to ConversionPatternRewriter
This method is needed for N->1 conversion patterns to retrieve remapped Values used in the original N operations. Closes tensorflow/mlir#237 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/237 from dcaballe:dcaballe/getRemappedValue 1f64fadcf2b203f7b336ff0c5838b116ae3625db PiperOrigin-RevId: 281321881
This commit is contained in:
parent
06fb797b40
commit
dd5a7cb488
|
@ -338,6 +338,10 @@ public:
|
|||
return cast<OpT>(cloneWithoutRegions(op.getOperation()));
|
||||
}
|
||||
|
||||
/// Return the converted value that replaces 'key'. Return 'key' if there is
|
||||
/// no such a converted value.
|
||||
Value *getRemappedValue(Value *key);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// PatternRewriter Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -803,6 +803,12 @@ Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
|
|||
return newOp;
|
||||
}
|
||||
|
||||
/// Return the converted value that replaces 'key'. Return 'key' if there is
|
||||
/// no such a converted value.
|
||||
Value *ConversionPatternRewriter::getRemappedValue(Value *key) {
|
||||
return impl->mapping.lookupOrDefault(key);
|
||||
}
|
||||
|
||||
/// PatternRewriter hook for splitting a block into two parts.
|
||||
Block *ConversionPatternRewriter::splitBlock(Block *block,
|
||||
Block::iterator before) {
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
// RUN: mlir-opt %s -test-remapped-value | FileCheck %s
|
||||
|
||||
// Simple test that exercises ConvertPatternRewriter::getRemappedValue.
|
||||
func @remap_input_1_to_1(%arg0: i32) {
|
||||
%0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> i32
|
||||
%1 = "test.one_variadic_out_one_variadic_in1"(%0) : (i32) -> i32
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
// CHECK-LABEL: func @remap_input_1_to_1
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32)
|
||||
// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]])
|
||||
// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]])
|
||||
|
|
@ -435,3 +435,66 @@ static mlir::PassRegistration<TestLegalizePatternDriver>
|
|||
return std::make_unique<TestLegalizePatternDriver>(
|
||||
legalizerConversionMode);
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConversionPatternRewriter::getRemappedValue testing. This method is used
|
||||
// to get the remapped value of a original value that was replaced using
|
||||
// ConversionPatternRewriter.
|
||||
namespace {
|
||||
/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
|
||||
/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
|
||||
/// operand twice.
|
||||
///
|
||||
/// Example:
|
||||
/// %1 = test.one_variadic_out_one_variadic_in1"(%0)
|
||||
/// is replaced with:
|
||||
/// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0)
|
||||
struct OneVResOneVOperandOp1Converter
|
||||
: public OpConversionPattern<OneVResOneVOperandOp1> {
|
||||
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto origOps = op.getOperands();
|
||||
assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
|
||||
"One operand expected");
|
||||
Value *origOp = *origOps.begin();
|
||||
SmallVector<Value *, 2> remappedOperands;
|
||||
// Replicate the remapped original operand twice. Note that we don't used
|
||||
// the remapped 'operand' since the goal is testing 'getRemappedValue'.
|
||||
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
|
||||
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
|
||||
|
||||
SmallVector<Type, 1> resultTypes(op.getResultTypes());
|
||||
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
|
||||
remappedOperands);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestRemappedValue : public mlir::FunctionPass<TestRemappedValue> {
|
||||
void runOnFunction() override {
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
|
||||
// We make OneVResOneVOperandOp1 legal only when it has more that one
|
||||
// operand. This will trigger the conversion that will replace one-operand
|
||||
// OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
|
||||
target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
|
||||
[](Operation *op) -> bool {
|
||||
return std::distance(op->operand_begin(), op->operand_end()) > 1;
|
||||
});
|
||||
|
||||
if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
static PassRegistration<TestRemappedValue> remapped_value_pass(
|
||||
"test-remapped-value",
|
||||
"Test public remapped value mechanism in ConversionPatternRewriter");
|
||||
|
|
Loading…
Reference in New Issue