forked from OSchip/llvm-project
Refactor the traversal of operations to Convert in DialectConversion.
This cl changes the way that operations/blocks to convert are collected/traversed so that parent region operations can be legalized before their bodies. Most RewritePatterns for region operations assume that the entry arguments to each region are yet to be converted. Given that the bodies are currently converted first, this makes it difficult to fit these patterns into the same run as one converting types. The operations/blocks to convert are now collected before any legalization has run, which simplifies the conversion logic itself, as legalization may insert new operations, move blocks, etc. PiperOrigin-RevId: 258170158
This commit is contained in:
parent
d2f1ed5137
commit
7d1e1e6721
|
@ -170,10 +170,13 @@ void ArgConverter::applyRewrites() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If mapping is from type to itself, replace the remaining uses and drop
|
// If mapping is 1-1, replace the remaining uses and drop the cast
|
||||||
// the cast operation.
|
// operation.
|
||||||
if (op->getNumOperands() == 1 &&
|
// FIXME(riverriddle) This should check that the result type and operand
|
||||||
op->getResult(0)->getType() == op->getOperand(0)->getType()) {
|
// type are the same, otherwise it should force a conversion to be
|
||||||
|
// materialized. This works around a current limitation with regards to
|
||||||
|
// region entry argument type conversion.
|
||||||
|
if (op->getNumOperands() == 1) {
|
||||||
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
|
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
|
||||||
op->destroy();
|
op->destroy();
|
||||||
continue;
|
continue;
|
||||||
|
@ -229,7 +232,7 @@ void ArgConverter::convertSignature(
|
||||||
LogicalResult ArgConverter::convertArguments(Block *block,
|
LogicalResult ArgConverter::convertArguments(Block *block,
|
||||||
BlockAndValueMapping &mapping) {
|
BlockAndValueMapping &mapping) {
|
||||||
unsigned origArgCount = block->getNumArguments();
|
unsigned origArgCount = block->getNumArguments();
|
||||||
if (origArgCount == 0)
|
if (origArgCount == 0 || argMapping.count(block))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Convert the types of each of the block arguments.
|
// Convert the types of each of the block arguments.
|
||||||
|
@ -440,10 +443,14 @@ struct DialectConversionRewriter final : public PatternRewriter {
|
||||||
for (auto ®ion : repl.op->getRegions())
|
for (auto ®ion : repl.op->getRegions())
|
||||||
argConverter.cancelPendingRewrites(region);
|
argConverter.cancelPendingRewrites(region);
|
||||||
}
|
}
|
||||||
|
|
||||||
repl.op->erase();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// In a second pass, erase all of the replaced operations in reverse. This
|
||||||
|
// allows processing nested operations before their parent region is
|
||||||
|
// destroyed.
|
||||||
|
for (auto &repl : llvm::reverse(replacements))
|
||||||
|
repl.op->erase();
|
||||||
|
|
||||||
argConverter.applyRewrites();
|
argConverter.applyRewrites();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -841,7 +848,7 @@ namespace {
|
||||||
// TypeConverter object is provided, then the types of block arguments will be
|
// TypeConverter object is provided, then the types of block arguments will be
|
||||||
// converted using the appropriate 'convertType' calls.
|
// converted using the appropriate 'convertType' calls.
|
||||||
struct FunctionConverter {
|
struct FunctionConverter {
|
||||||
explicit FunctionConverter(MLIRContext *ctx, ConversionTarget &target,
|
explicit FunctionConverter(ConversionTarget &target,
|
||||||
OwningRewritePatternList &patterns,
|
OwningRewritePatternList &patterns,
|
||||||
TypeConverter *conversion = nullptr)
|
TypeConverter *conversion = nullptr)
|
||||||
: typeConverter(conversion), opLegalizer(target, patterns) {}
|
: typeConverter(conversion), opLegalizer(target, patterns) {}
|
||||||
|
@ -853,22 +860,18 @@ struct FunctionConverter {
|
||||||
convertFunction(FuncOp f,
|
convertFunction(FuncOp f,
|
||||||
TypeConverter::SignatureConversion *signatureConversion);
|
TypeConverter::SignatureConversion *signatureConversion);
|
||||||
|
|
||||||
/// Converts the given region starting from the entry block and following the
|
private:
|
||||||
/// block successors. Returns failure on error, success otherwise. Prints
|
/// Converts a block or operation with the given rewriter.
|
||||||
/// error messages at `loc`.
|
LogicalResult convert(DialectConversionRewriter &rewriter,
|
||||||
LogicalResult convertRegion(DialectConversionRewriter &rewriter,
|
llvm::PointerUnion<Operation *, Block *> &ptr);
|
||||||
Region ®ion, bool convertEntryTypes = true);
|
|
||||||
|
|
||||||
/// Converts a block by traversing its operations sequentially, attempting to
|
/// Recursively collect all of the blocks, and operations, to convert from
|
||||||
/// match a pattern. If there is no match, recurses the operations regions if
|
/// within 'region'.
|
||||||
/// it has any.
|
LogicalResult computeConversionSet(
|
||||||
//
|
Region ®ion,
|
||||||
/// After converting operations, traverses the successor blocks unless they
|
std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
|
||||||
/// have been visited already as indicated in `visitedBlocks`.
|
|
||||||
LogicalResult convertBlock(DialectConversionRewriter &rewriter, Block *block,
|
|
||||||
DenseSet<Block *> &visitedBlocks);
|
|
||||||
|
|
||||||
/// Pointer to a specific dialect conversion info.
|
/// Pointer to the type converter.
|
||||||
TypeConverter *typeConverter;
|
TypeConverter *typeConverter;
|
||||||
|
|
||||||
/// The legalizer to use when converting operations.
|
/// The legalizer to use when converting operations.
|
||||||
|
@ -876,84 +879,56 @@ struct FunctionConverter {
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
LogicalResult
|
/// Recursively collect all of the blocks to convert from within 'region'.
|
||||||
FunctionConverter::convertBlock(DialectConversionRewriter &rewriter,
|
LogicalResult FunctionConverter::computeConversionSet(
|
||||||
Block *block,
|
Region ®ion,
|
||||||
DenseSet<Block *> &visitedBlocks) {
|
std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
|
||||||
// First, add the current block to the list of visited blocks.
|
if (region.empty())
|
||||||
visitedBlocks.insert(block);
|
|
||||||
|
|
||||||
if (block->empty())
|
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Preserve the successors before rewriting the operations.
|
// Traverse starting from the entry block.
|
||||||
SmallVector<Block *, 4> successors(block->getSuccessors());
|
SmallVector<Block *, 16> worklist(1, ®ion.front());
|
||||||
|
DenseSet<Block *> visitedBlocks;
|
||||||
|
visitedBlocks.insert(®ion.front());
|
||||||
|
while (!worklist.empty()) {
|
||||||
|
auto *block = worklist.pop_back_val();
|
||||||
|
|
||||||
// Iterate over ops and convert them. Since the conversion may split the
|
// We only need to process blocks if we are changing argument types.
|
||||||
// block, we eagerly take the pointer to the next operation in it. Splitting
|
if (typeConverter)
|
||||||
// moves the operations from one block to another, so this will keep
|
toConvert.emplace_back(block);
|
||||||
// considering the original list of operations independently of the block
|
|
||||||
// within which they are currently located. This relies on iplist node API
|
|
||||||
// to get the next node in the list witout knowing which list it is, iterators
|
|
||||||
// are unsuitable because block splitting invalidates all iterators following
|
|
||||||
// the current one. Any operation inserted by the conversion, independently of
|
|
||||||
// its parent block, will be recursively legalized independently of this
|
|
||||||
// function.
|
|
||||||
Operation *current = &block->front();
|
|
||||||
Operation *next = nullptr;
|
|
||||||
do {
|
|
||||||
next = current->getNextNode();
|
|
||||||
// Traverse any held regions.
|
|
||||||
for (auto ®ion : current->getRegions())
|
|
||||||
if (!region.empty() && failed(convertRegion(rewriter, region)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Legalize the current operation.
|
// Compute the conversion set of each of the nested operations.
|
||||||
(void)opLegalizer.legalize(current, rewriter);
|
for (auto &op : *block) {
|
||||||
} while ((current = next));
|
toConvert.emplace_back(&op);
|
||||||
|
for (auto ®ion : op.getRegions())
|
||||||
|
computeConversionSet(region, toConvert);
|
||||||
|
}
|
||||||
|
|
||||||
// Recurse to children that haven't been visited.
|
// Recurse to children that haven't been visited.
|
||||||
for (Block *succ : successors) {
|
for (Block *succ : block->getSuccessors())
|
||||||
if (visitedBlocks.count(succ))
|
if (visitedBlocks.insert(succ).second)
|
||||||
continue;
|
worklist.push_back(succ);
|
||||||
if (failed(convertBlock(rewriter, succ, visitedBlocks)))
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that all blocks in the region were visited.
|
||||||
|
if (llvm::any_of(llvm::drop_begin(region.getBlocks(), 1),
|
||||||
|
[&](Block &block) { return !visitedBlocks.count(&block); }))
|
||||||
|
return emitError(region.getLoc(), "unreachable blocks were not converted");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts a block or operation with the given rewriter.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
|
FunctionConverter::convert(DialectConversionRewriter &rewriter,
|
||||||
Region ®ion, bool convertEntryTypes) {
|
llvm::PointerUnion<Operation *, Block *> &ptr) {
|
||||||
assert(!region.empty() && "expected non-empty region");
|
// If this is a block, then convert the types of each of the arguments.
|
||||||
|
if (auto *block = ptr.dyn_cast<Block *>()) {
|
||||||
// Create the arguments of each of the blocks in the region. If a type
|
assert(typeConverter && "expected valid type converter");
|
||||||
// converter was not provided, then we don't need to change any of the block
|
return rewriter.argConverter.convertArguments(block, rewriter.mapping);
|
||||||
// types.
|
|
||||||
if (typeConverter) {
|
|
||||||
for (Block &block :
|
|
||||||
llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) {
|
|
||||||
if (failed(
|
|
||||||
rewriter.argConverter.convertArguments(&block, rewriter.mapping)))
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the number of blocks before conversion (new blocks may be added due
|
// Otherwise, this is an operation to legalize.
|
||||||
// to splits or moves, but the operations in them will be processed
|
(void)opLegalizer.legalize(ptr.get<Operation *>(), rewriter);
|
||||||
// elsewhere).
|
|
||||||
unsigned numBlocks = std::distance(region.begin(), region.end());
|
|
||||||
|
|
||||||
// Start a DFS-order traversal of the CFG to make sure defs are converted
|
|
||||||
// before uses in dominated blocks.
|
|
||||||
llvm::DenseSet<Block *> visitedBlocks;
|
|
||||||
if (failed(convertBlock(rewriter, ®ion.front(), visitedBlocks)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// If some blocks are not reachable through successor chains, they should have
|
|
||||||
// been removed by the DCE before this.
|
|
||||||
if (visitedBlocks.size() != numBlocks)
|
|
||||||
return emitError(region.getLoc(), "unreachable blocks were not converted");
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -971,12 +946,17 @@ LogicalResult FunctionConverter::convertFunction(
|
||||||
&f.getBody().front(), *signatureConversion, rewriter.mapping);
|
&f.getBody().front(), *signatureConversion, rewriter.mapping);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite the function body.
|
/// Compute the set of operations and blocks to convert.
|
||||||
if (failed(
|
std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
|
||||||
convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) {
|
if (failed(computeConversionSet(f.getBody(), toConvert)))
|
||||||
// Reset any of the generated rewrites.
|
|
||||||
rewriter.discardRewrites();
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// Convert each operation/block and discard rewrites on failure.
|
||||||
|
for (auto &it : toConvert) {
|
||||||
|
if (failed(convert(rewriter, it))) {
|
||||||
|
rewriter.discardRewrites();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise the body conversion succeeded, so apply all rewrites.
|
// Otherwise the body conversion succeeded, so apply all rewrites.
|
||||||
|
@ -1146,7 +1126,7 @@ LogicalResult mlir::applyConversionPatterns(
|
||||||
|
|
||||||
// Build the function converter.
|
// Build the function converter.
|
||||||
auto *ctx = fns.front().getContext();
|
auto *ctx = fns.front().getContext();
|
||||||
FunctionConverter funcConverter(ctx, target, patterns, &converter);
|
FunctionConverter funcConverter(target, patterns, &converter);
|
||||||
|
|
||||||
// Try to convert each of the functions within the module.
|
// Try to convert each of the functions within the module.
|
||||||
SmallVector<NamedAttributeList, 4> argAttrs;
|
SmallVector<NamedAttributeList, 4> argAttrs;
|
||||||
|
@ -1178,6 +1158,6 @@ LogicalResult
|
||||||
mlir::applyConversionPatterns(FuncOp fn, ConversionTarget &target,
|
mlir::applyConversionPatterns(FuncOp fn, ConversionTarget &target,
|
||||||
OwningRewritePatternList &&patterns) {
|
OwningRewritePatternList &&patterns) {
|
||||||
// Convert the body of this function.
|
// Convert the body of this function.
|
||||||
FunctionConverter converter(fn.getContext(), target, patterns);
|
FunctionConverter converter(target, patterns);
|
||||||
return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
|
return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue