Refactor DialectConversion to convert the signatures of blocks when they are moved.

Often we want to ensure that block arguments are converted before operations that use them. This refactors the current implementation to be cleaner/less frequent by triggering conversion when a set of blocks are moved/inlined; or when legalization is successful.

PiperOrigin-RevId: 263795005
This commit is contained in:
River Riddle 2019-08-16 10:16:09 -07:00 committed by A. Unique TensorFlower
parent f79fc1c181
commit 9c29273ddc
1 changed files with 41 additions and 33 deletions

View File

@ -499,11 +499,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
LogicalResult
ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
// Check to see if this block should not be converted:
// * The block is invalid, or there is no type converter.
// * There is no type converter.
// * The block has already been converted.
// * This is an entry block, these are converted explicitly via patterns.
if (!block || !argConverter.typeConverter ||
argConverter.hasBeenConverted(block) || block->isEntryBlock())
if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
block->isEntryBlock())
return success();
// Otherwise, try to convert the block signature.
@ -738,10 +738,6 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
LogicalResult
OperationLegalizer::legalize(Operation *op,
ConversionPatternRewriter &rewriter) {
// Make sure that the signature of the parent block has been converted.
if (failed(rewriter.getImpl().convertBlockSignature(op->getBlock())))
return failure();
LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
<< "\n");
@ -802,6 +798,24 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
return cleanupFailure();
}
// If the pattern moved any blocks, try to legalize their types. This ensures
// that the types of the block arguments are legal for the region they were
// moved into.
for (unsigned i = curState.numBlockActions,
e = rewriterImpl.blockActions.size();
i != e; ++i) {
auto &action = rewriterImpl.blockActions[i];
if (action.kind != ConversionPatternRewriterImpl::BlockActionKind::Move)
continue;
// Convert the block signature.
if (failed(rewriterImpl.convertBlockSignature(action.block))) {
LLVM_DEBUG(llvm::dbgs()
<< "-- FAIL: failed to convert types of moved block.\n");
return cleanupFailure();
}
}
// Recursively legalize each of the new operations.
for (unsigned i = curState.numCreatedOperations,
e = rewriterImpl.createdOps.size();
@ -958,9 +972,9 @@ enum OpConversionMode {
Analysis,
};
// This class converts operations using the given pattern matcher. If a
// TypeConverter object is provided, then the types of block arguments will be
// converted using the appropriate 'convertType' calls.
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
struct OperationConverter {
explicit OperationConverter(ConversionTarget &target,
const OwningRewritePatternList &patterns,
@ -981,8 +995,7 @@ private:
LogicalResult computeConversionSet(Region &region,
std::vector<Operation *> &toConvert);
/// Converts the type signatures of the blocks nested within 'op' that have
/// yet to be converted.
/// Converts the type signatures of the blocks nested within 'op'.
LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
Operation *op);
@ -1001,18 +1014,14 @@ private:
LogicalResult
OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
Operation *op) {
SmallVector<Region *, 8> worklist;
for (auto &region : op->getRegions())
worklist.push_back(&region);
// Check to see if type signatures need to be converted.
if (!rewriter.getImpl().argConverter.typeConverter)
return success();
while (!worklist.empty()) {
for (auto &block : *worklist.pop_back_val()) {
for (auto &region : op->getRegions()) {
for (auto &block : region)
if (failed(rewriter.getImpl().convertBlockSignature(&block)))
return failure();
for (auto &nestedOp : block)
for (auto &region : nestedOp.getRegions())
worklist.push_back(&region);
}
}
return success();
}
@ -1065,10 +1074,17 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
} else if (mode == OpConversionMode::Analysis) {
/// Analysis conversions don't fail if any operations fail to legalize, they
/// are only interested in the operations that were successfully legalized.
} else {
/// Analysis conversions don't fail if any operations fail to legalize,
/// they are only interested in the operations that were successfully
/// legalized.
if (mode == OpConversionMode::Analysis)
legalizableOps->insert(op);
// If legalization succeeded, convert the types any of the blocks within
// this operation.
if (failed(convertBlockSignatures(rewriter, op)))
return failure();
}
return success();
}
@ -1094,14 +1110,6 @@ OperationConverter::convertOperations(ArrayRef<Operation *> ops,
if (failed(convert(rewriter, op)))
return rewriter.getImpl().discardRewrites(), failure();
// If a type converter was provided, ensure that all blocks have had their
// signatures properly converted.
if (typeConverter) {
for (auto *op : ops)
if (failed(convertBlockSignatures(rewriter, op)))
return rewriter.getImpl().discardRewrites(), failure();
}
// Otherwise, the body conversion succeeded. Apply rewrites if this is not an
// analysis conversion.
if (mode == OpConversionMode::Analysis)