Add support for multi-level value mapping to DialectConversion.

When performing A->B->C conversion, an operation may still refer to an operand of A. This makes it necessary to unmap through multiple levels of replacement for a specific value.

PiperOrigin-RevId: 269367859
This commit is contained in:
River Riddle 2019-09-16 10:37:48 -07:00 committed by A. Unique TensorFlower
parent 6934a337f0
commit 9619ba10d4
5 changed files with 148 additions and 14 deletions

View File

@ -32,6 +32,40 @@ using namespace mlir::detail;
#define DEBUG_TYPE "dialect-conversion"
//===----------------------------------------------------------------------===//
// Multi-Level Value Mapper
//===----------------------------------------------------------------------===//
namespace {
/// This class wraps a BlockAndValueMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
struct ConversionValueMapping {
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return the provided value.
Value *lookupOrDefault(Value *from) const;
/// Map a value to the one provided.
void map(Value *oldVal, Value *newVal) { mapping.map(oldVal, newVal); }
/// Drop the last mapping for the given value.
void erase(Value *value) { mapping.erase(value); }
private:
/// Current value mappings.
BlockAndValueMapping mapping;
};
} // end anonymous namespace
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return the provided value.
Value *ConversionValueMapping::lookupOrDefault(Value *from) const {
// If this value had a valid mapping, unmap that value as well in the case
// that it was also replaced.
while (auto *mappedValue = mapping.lookupOrNull(from))
from = mappedValue;
return from;
}
//===----------------------------------------------------------------------===//
// ArgConverter
//===----------------------------------------------------------------------===//
@ -62,19 +96,19 @@ struct ArgConverter {
bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
/// Attempt to convert the signature of the given block.
LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
LogicalResult convertSignature(Block *block, ConversionValueMapping &mapping);
/// Apply the given signature conversion on the given block.
void applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping);
ConversionValueMapping &mapping);
/// Convert the given block argument given the provided set of new argument
/// values that are to replace it. This function returns the operation used
/// to perform the conversion.
Operation *convertArgument(BlockArgument *origArg,
ArrayRef<Value *> newValues,
BlockAndValueMapping &mapping);
ConversionValueMapping &mapping);
/// A utility function used to create a conversion cast operation with the
/// given input and result types.
@ -195,7 +229,7 @@ void ArgConverter::applyRewrites() {
/// Converts the signature of the given entry block.
LogicalResult ArgConverter::convertSignature(Block *block,
BlockAndValueMapping &mapping) {
ConversionValueMapping &mapping) {
if (auto conversion = typeConverter->convertBlockSignature(block))
return applySignatureConversion(block, *conversion, mapping), success();
return failure();
@ -204,7 +238,7 @@ LogicalResult ArgConverter::convertSignature(Block *block,
/// Apply the given signature conversion on the given block.
void ArgConverter::applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping) {
ConversionValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments();
auto convertedTypes = signatureConversion.getConvertedTypes();
if (origArgCount == 0 && convertedTypes.empty())
@ -236,7 +270,7 @@ void ArgConverter::applySignatureConversion(
/// to perform the conversion.
Operation *ArgConverter::convertArgument(BlockArgument *origArg,
ArrayRef<Value *> newValues,
BlockAndValueMapping &mapping) {
ConversionValueMapping &mapping) {
// Handle the cases of 1->0 or 1->1 mappings.
if (newValues.size() < 2) {
// Create a temporary producer for the argument during the conversion
@ -394,7 +428,7 @@ struct ConversionPatternRewriterImpl {
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
BlockAndValueMapping mapping;
ConversionValueMapping mapping;
/// Utility used to convert block arguments.
ArgConverter argConverter;
@ -440,6 +474,7 @@ void ConversionPatternRewriterImpl::undoBlockActions(
case BlockActionKind::Split: {
action.originalBlock->getOperations().splice(
action.originalBlock->end(), action.block->getOperations());
action.block->dropAllUses();
action.block->erase();
break;
}
@ -465,10 +500,8 @@ void ConversionPatternRewriterImpl::discardRewrites() {
undoBlockActions();
// Remove any newly created ops.
for (auto *op : createdOps) {
op->dropAllDefinedValueUses();
for (auto *op : llvm::reverse(createdOps))
op->erase();
}
}
void ConversionPatternRewriterImpl::applyRewrites() {
@ -574,6 +607,8 @@ ConversionPatternRewriter::~ConversionPatternRewriter() {}
void ConversionPatternRewriter::replaceOp(
Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) {
LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName()
<< "\n");
impl->replaceOp(op, newValues, valuesToRemoveIfDead);
}
@ -609,6 +644,7 @@ void ConversionPatternRewriter::inlineRegionBefore(Region &region,
/// PatternRewriter hook for creating a new operation.
Operation *
ConversionPatternRewriter::createOperation(const OperationState &state) {
LLVM_DEBUG(llvm::dbgs() << "** Creating operation : " << state.name << "\n");
auto *result = OpBuilder::createOperation(state);
impl->createdOps.push_back(result);
return result;

View File

@ -1,4 +1,6 @@
// RUN: mlir-opt -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
// expected-remark@-2 {{op 'module' is legalizable}}
// expected-remark@-3 {{op 'module_terminator' is legalizable}}
// expected-remark@+1 {{op 'func' is legalizable}}
func @test(%arg0: f32) {

View File

@ -0,0 +1,10 @@
// RUN: mlir-opt -test-legalize-patterns -test-legalize-mode=full %s | FileCheck %s
// CHECK-LABEL: func @multi_level_mapping
func @multi_level_mapping() {
// CHECK: "test.type_producer"() : () -> f64
// CHECK: "test.type_consumer"(%{{.*}}) : (f64) -> ()
%result = "test.type_producer"() : () -> i32
"test.type_consumer"(%result) : (i32) -> ()
"test.return"() : () -> ()
}

View File

@ -760,6 +760,10 @@ def TestCastOp : TEST_Op<"cast">,
Arguments<(ins Variadic<AnyType>:$inputs)>, Results<(outs AnyType:$res)>;
def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
Arguments<(ins Variadic<AnyType>:$inputs)>;
def TestTypeProducerOp : TEST_Op<"type_producer">,
Results<(outs AnyType:$output)>;
def TestTypeConsumerOp : TEST_Op<"type_consumer">,
Arguments<(ins AnyType:$input)>;
def TestValidOp : TEST_Op<"valid", [Terminator]>,
Arguments<(ins Variadic<AnyType>:$inputs)>;

View File

@ -56,6 +56,9 @@ static mlir::PassRegistration<TestPatternDriver>
//===----------------------------------------------------------------------===//
namespace {
//===----------------------------------------------------------------------===//
// Region-Block Rewrite Testing
/// This pattern is a simple pattern that inlines the first region of a given
/// operation into the parent region.
struct TestRegionRewriteBlockMovement : public ConversionPattern {
@ -99,6 +102,10 @@ struct TestRegionRewriteUndo : public RewritePattern {
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// Type-Conversion Rewrite Testing
/// This pattern simply erases the given operation.
struct TestDropOp : public ConversionPattern {
TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
@ -145,6 +152,62 @@ struct TestSplitReturnType : public ConversionPattern {
return matchFailure();
}
};
//===----------------------------------------------------------------------===//
// Multi-Level Type-Conversion Rewrite Testing
struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is I32, change the type to F32.
if (!(*op->result_type_begin()).isInteger(32))
return matchFailure();
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
return matchSuccess();
}
};
struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is F32, change the type to F64.
if (!(*op->result_type_begin()).isF32())
return matchFailure();
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
return matchSuccess();
}
};
struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 10, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// Always convert to B16, even though it is not a legal type. This tests
// that values are unmapped correctly.
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
return matchSuccess();
}
};
struct TestUpdateConsumerType : public ConversionPattern {
TestUpdateConsumerType(MLIRContext *ctx)
: ConversionPattern("test.type_consumer", 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
// Verify that the the incoming operand has been successfully remapped to
// F64.
if (!operands[0]->getType().isF64())
return matchFailure();
rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
return matchSuccess();
}
};
} // namespace
namespace {
@ -185,7 +248,7 @@ struct TestTypeConverter : public TypeConverter {
struct TestLegalizePatternDriver
: public ModulePass<TestLegalizePatternDriver> {
/// The mode of conversion to use with the driver.
enum class ConversionMode { Analysis, Partial };
enum class ConversionMode { Analysis, Full, Partial };
TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
@ -193,14 +256,18 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
&getContext());
patterns
.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType>(
&getContext());
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);
// Define the conversion target used for the test.
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
@ -211,12 +278,25 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
// Expect the type_producer/type_consumer operations to only operate on f64.
target.addDynamicallyLegalOp<TestTypeProducerOp>(
[](TestTypeProducerOp op) { return op.getType().isF64(); });
target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
return op.getOperand()->getType().isF64();
});
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
(void)applyPartialConversion(getModule(), target, patterns, &converter);
return;
}
// Handle a full conversion.
if (mode == ConversionMode::Full) {
(void)applyFullConversion(getModule(), target, patterns, &converter);
return;
}
// Otherwise, handle an analysis conversion.
assert(mode == ConversionMode::Analysis);
@ -244,6 +324,8 @@ static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
llvm::cl::values(
clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
"analysis", "Perform an analysis conversion"),
clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
"Perform a full conversion"),
clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
"partial", "Perform a partial conversion")));