forked from OSchip/llvm-project
Add support for multi-level value mapping to DialectConversion.
When performing A->B->C conversion, an operation may still refer to an operand of A. This makes it necessary to unmap through multiple levels of replacement for a specific value. PiperOrigin-RevId: 269367859
This commit is contained in:
parent
6934a337f0
commit
9619ba10d4
|
@ -32,6 +32,40 @@ using namespace mlir::detail;
|
|||
|
||||
#define DEBUG_TYPE "dialect-conversion"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-Level Value Mapper
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class wraps a BlockAndValueMapping to provide recursive lookup
|
||||
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
|
||||
struct ConversionValueMapping {
|
||||
/// Lookup a mapped value within the map. If a mapping for the provided value
|
||||
/// does not exist then return the provided value.
|
||||
Value *lookupOrDefault(Value *from) const;
|
||||
|
||||
/// Map a value to the one provided.
|
||||
void map(Value *oldVal, Value *newVal) { mapping.map(oldVal, newVal); }
|
||||
|
||||
/// Drop the last mapping for the given value.
|
||||
void erase(Value *value) { mapping.erase(value); }
|
||||
|
||||
private:
|
||||
/// Current value mappings.
|
||||
BlockAndValueMapping mapping;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Lookup a mapped value within the map. If a mapping for the provided value
|
||||
/// does not exist then return the provided value.
|
||||
Value *ConversionValueMapping::lookupOrDefault(Value *from) const {
|
||||
// If this value had a valid mapping, unmap that value as well in the case
|
||||
// that it was also replaced.
|
||||
while (auto *mappedValue = mapping.lookupOrNull(from))
|
||||
from = mappedValue;
|
||||
return from;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArgConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -62,19 +96,19 @@ struct ArgConverter {
|
|||
bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
|
||||
|
||||
/// Attempt to convert the signature of the given block.
|
||||
LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
|
||||
LogicalResult convertSignature(Block *block, ConversionValueMapping &mapping);
|
||||
|
||||
/// Apply the given signature conversion on the given block.
|
||||
void applySignatureConversion(
|
||||
Block *block, TypeConverter::SignatureConversion &signatureConversion,
|
||||
BlockAndValueMapping &mapping);
|
||||
ConversionValueMapping &mapping);
|
||||
|
||||
/// Convert the given block argument given the provided set of new argument
|
||||
/// values that are to replace it. This function returns the operation used
|
||||
/// to perform the conversion.
|
||||
Operation *convertArgument(BlockArgument *origArg,
|
||||
ArrayRef<Value *> newValues,
|
||||
BlockAndValueMapping &mapping);
|
||||
ConversionValueMapping &mapping);
|
||||
|
||||
/// A utility function used to create a conversion cast operation with the
|
||||
/// given input and result types.
|
||||
|
@ -195,7 +229,7 @@ void ArgConverter::applyRewrites() {
|
|||
|
||||
/// Converts the signature of the given entry block.
|
||||
LogicalResult ArgConverter::convertSignature(Block *block,
|
||||
BlockAndValueMapping &mapping) {
|
||||
ConversionValueMapping &mapping) {
|
||||
if (auto conversion = typeConverter->convertBlockSignature(block))
|
||||
return applySignatureConversion(block, *conversion, mapping), success();
|
||||
return failure();
|
||||
|
@ -204,7 +238,7 @@ LogicalResult ArgConverter::convertSignature(Block *block,
|
|||
/// Apply the given signature conversion on the given block.
|
||||
void ArgConverter::applySignatureConversion(
|
||||
Block *block, TypeConverter::SignatureConversion &signatureConversion,
|
||||
BlockAndValueMapping &mapping) {
|
||||
ConversionValueMapping &mapping) {
|
||||
unsigned origArgCount = block->getNumArguments();
|
||||
auto convertedTypes = signatureConversion.getConvertedTypes();
|
||||
if (origArgCount == 0 && convertedTypes.empty())
|
||||
|
@ -236,7 +270,7 @@ void ArgConverter::applySignatureConversion(
|
|||
/// to perform the conversion.
|
||||
Operation *ArgConverter::convertArgument(BlockArgument *origArg,
|
||||
ArrayRef<Value *> newValues,
|
||||
BlockAndValueMapping &mapping) {
|
||||
ConversionValueMapping &mapping) {
|
||||
// Handle the cases of 1->0 or 1->1 mappings.
|
||||
if (newValues.size() < 2) {
|
||||
// Create a temporary producer for the argument during the conversion
|
||||
|
@ -394,7 +428,7 @@ struct ConversionPatternRewriterImpl {
|
|||
|
||||
// Mapping between replaced values that differ in type. This happens when
|
||||
// replacing a value with one of a different type.
|
||||
BlockAndValueMapping mapping;
|
||||
ConversionValueMapping mapping;
|
||||
|
||||
/// Utility used to convert block arguments.
|
||||
ArgConverter argConverter;
|
||||
|
@ -440,6 +474,7 @@ void ConversionPatternRewriterImpl::undoBlockActions(
|
|||
case BlockActionKind::Split: {
|
||||
action.originalBlock->getOperations().splice(
|
||||
action.originalBlock->end(), action.block->getOperations());
|
||||
action.block->dropAllUses();
|
||||
action.block->erase();
|
||||
break;
|
||||
}
|
||||
|
@ -465,10 +500,8 @@ void ConversionPatternRewriterImpl::discardRewrites() {
|
|||
undoBlockActions();
|
||||
|
||||
// Remove any newly created ops.
|
||||
for (auto *op : createdOps) {
|
||||
op->dropAllDefinedValueUses();
|
||||
for (auto *op : llvm::reverse(createdOps))
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::applyRewrites() {
|
||||
|
@ -574,6 +607,8 @@ ConversionPatternRewriter::~ConversionPatternRewriter() {}
|
|||
void ConversionPatternRewriter::replaceOp(
|
||||
Operation *op, ArrayRef<Value *> newValues,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName()
|
||||
<< "\n");
|
||||
impl->replaceOp(op, newValues, valuesToRemoveIfDead);
|
||||
}
|
||||
|
||||
|
@ -609,6 +644,7 @@ void ConversionPatternRewriter::inlineRegionBefore(Region ®ion,
|
|||
/// PatternRewriter hook for creating a new operation.
|
||||
Operation *
|
||||
ConversionPatternRewriter::createOperation(const OperationState &state) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "** Creating operation : " << state.name << "\n");
|
||||
auto *result = OpBuilder::createOperation(state);
|
||||
impl->createdOps.push_back(result);
|
||||
return result;
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
// RUN: mlir-opt -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
|
||||
// expected-remark@-2 {{op 'module' is legalizable}}
|
||||
// expected-remark@-3 {{op 'module_terminator' is legalizable}}
|
||||
|
||||
// expected-remark@+1 {{op 'func' is legalizable}}
|
||||
func @test(%arg0: f32) {
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
// RUN: mlir-opt -test-legalize-patterns -test-legalize-mode=full %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @multi_level_mapping
|
||||
func @multi_level_mapping() {
|
||||
// CHECK: "test.type_producer"() : () -> f64
|
||||
// CHECK: "test.type_consumer"(%{{.*}}) : (f64) -> ()
|
||||
%result = "test.type_producer"() : () -> i32
|
||||
"test.type_consumer"(%result) : (i32) -> ()
|
||||
"test.return"() : () -> ()
|
||||
}
|
|
@ -760,6 +760,10 @@ def TestCastOp : TEST_Op<"cast">,
|
|||
Arguments<(ins Variadic<AnyType>:$inputs)>, Results<(outs AnyType:$res)>;
|
||||
def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
|
||||
Arguments<(ins Variadic<AnyType>:$inputs)>;
|
||||
def TestTypeProducerOp : TEST_Op<"type_producer">,
|
||||
Results<(outs AnyType:$output)>;
|
||||
def TestTypeConsumerOp : TEST_Op<"type_consumer">,
|
||||
Arguments<(ins AnyType:$input)>;
|
||||
def TestValidOp : TEST_Op<"valid", [Terminator]>,
|
||||
Arguments<(ins Variadic<AnyType>:$inputs)>;
|
||||
|
||||
|
|
|
@ -56,6 +56,9 @@ static mlir::PassRegistration<TestPatternDriver>
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Region-Block Rewrite Testing
|
||||
|
||||
/// This pattern is a simple pattern that inlines the first region of a given
|
||||
/// operation into the parent region.
|
||||
struct TestRegionRewriteBlockMovement : public ConversionPattern {
|
||||
|
@ -99,6 +102,10 @@ struct TestRegionRewriteUndo : public RewritePattern {
|
|||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type-Conversion Rewrite Testing
|
||||
|
||||
/// This pattern simply erases the given operation.
|
||||
struct TestDropOp : public ConversionPattern {
|
||||
TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
|
||||
|
@ -145,6 +152,62 @@ struct TestSplitReturnType : public ConversionPattern {
|
|||
return matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-Level Type-Conversion Rewrite Testing
|
||||
struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
|
||||
TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_producer", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// If the type is I32, change the type to F32.
|
||||
if (!(*op->result_type_begin()).isInteger(32))
|
||||
return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
|
||||
TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_producer", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// If the type is F32, change the type to F64.
|
||||
if (!(*op->result_type_begin()).isF32())
|
||||
return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
|
||||
TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_producer", 10, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Always convert to B16, even though it is not a legal type. This tests
|
||||
// that values are unmapped correctly.
|
||||
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
struct TestUpdateConsumerType : public ConversionPattern {
|
||||
TestUpdateConsumerType(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_consumer", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Verify that the the incoming operand has been successfully remapped to
|
||||
// F64.
|
||||
if (!operands[0]->getType().isF64())
|
||||
return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -185,7 +248,7 @@ struct TestTypeConverter : public TypeConverter {
|
|||
struct TestLegalizePatternDriver
|
||||
: public ModulePass<TestLegalizePatternDriver> {
|
||||
/// The mode of conversion to use with the driver.
|
||||
enum class ConversionMode { Analysis, Partial };
|
||||
enum class ConversionMode { Analysis, Full, Partial };
|
||||
|
||||
TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
|
||||
|
||||
|
@ -193,14 +256,18 @@ struct TestLegalizePatternDriver
|
|||
TestTypeConverter converter;
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
|
||||
&getContext());
|
||||
patterns
|
||||
.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
|
||||
TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType,
|
||||
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType>(
|
||||
&getContext());
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
|
||||
converter);
|
||||
|
||||
// Define the conversion target used for the test.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
|
||||
target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
|
||||
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
|
||||
|
@ -211,12 +278,25 @@ struct TestLegalizePatternDriver
|
|||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||
|
||||
// Expect the type_producer/type_consumer operations to only operate on f64.
|
||||
target.addDynamicallyLegalOp<TestTypeProducerOp>(
|
||||
[](TestTypeProducerOp op) { return op.getType().isF64(); });
|
||||
target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
|
||||
return op.getOperand()->getType().isF64();
|
||||
});
|
||||
|
||||
// Handle a partial conversion.
|
||||
if (mode == ConversionMode::Partial) {
|
||||
(void)applyPartialConversion(getModule(), target, patterns, &converter);
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle a full conversion.
|
||||
if (mode == ConversionMode::Full) {
|
||||
(void)applyFullConversion(getModule(), target, patterns, &converter);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, handle an analysis conversion.
|
||||
assert(mode == ConversionMode::Analysis);
|
||||
|
||||
|
@ -244,6 +324,8 @@ static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
|
|||
llvm::cl::values(
|
||||
clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
|
||||
"analysis", "Perform an analysis conversion"),
|
||||
clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
|
||||
"Perform a full conversion"),
|
||||
clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
|
||||
"partial", "Perform a partial conversion")));
|
||||
|
||||
|
|
Loading…
Reference in New Issue