Refactor DialectConversion to support different conversion modes.

Users generally want several different modes of conversion. This cl refactors DialectConversion to provide two:
* Partial (applyPartialConversion)
  - This mode allows for illegal operations to exist in the IR, and does not fail if an operation fails to be legalized.

* Full (applyFullConversion)
  - This mode fails if any operation is not properly legalized to the conversion target. This allows for ensuring that the IR after a conversion only contains operations legal for the target.

PiperOrigin-RevId: 258412243
This commit is contained in:
River Riddle 2019-07-16 11:57:45 -07:00 committed by Mehdi Amini
parent ffc0217bc7
commit 2b9855b5b4
11 changed files with 232 additions and 95 deletions

View File

@ -421,8 +421,7 @@ LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
ConversionTarget target(*module.getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
return applyConversionPatterns(module, target, converter,
std::move(patterns));
return applyFullConversion(module, target, converter, std::move(patterns));
}
namespace {

View File

@ -160,8 +160,8 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
ConversionTarget target(*module.getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyConversionPatterns(module, target, converter,
std::move(patterns))))
if (failed(
applyFullConversion(module, target, converter, std::move(patterns))))
return failure();
return success();

View File

@ -132,8 +132,8 @@ struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
OwningRewritePatternList patterns;
RewriteListBuilder<MulOpConversion>::build(patterns, &getContext());
if (failed(applyConversionPatterns(getFunction(), target,
std::move(patterns)))) {
if (failed(applyPartialConversion(getFunction(), target,
std::move(patterns)))) {
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n");
signalPassFailure();
}

View File

@ -356,8 +356,8 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
LLVM::LLVMDialect, StandardOpsDialect>();
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
if (failed(applyConversionPatterns(getModule(), target, typeConverter,
std::move(toyPatterns)))) {
if (failed(applyPartialConversion(getModule(), target, typeConverter,
std::move(toyPatterns)))) {
emitError(UnknownLoc::get(getModule().getContext()),
"Error lowering Toy\n");
signalPassFailure();

View File

@ -349,31 +349,61 @@ private:
};
//===----------------------------------------------------------------------===//
// Conversion Application
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
/// Convert the given module with the provided conversion patterns and type
/// conversion object. This function returns failure if a type conversion
/// failed.
LLVM_NODISCARD LogicalResult applyConversionPatterns(
/// Apply a partial conversion on the given operations, and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
/// returns failure if there are unreachable blocks in any of the regions nested
/// within 'ops'.
LLVM_NODISCARD LogicalResult
applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
OwningRewritePatternList &&patterns);
LLVM_NODISCARD LogicalResult
applyPartialConversion(Operation *op, ConversionTarget &target,
OwningRewritePatternList &&patterns);
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
/// fails, or if there are unreachable blocks in any of the regions nested
/// within 'ops'.
LLVM_NODISCARD LogicalResult
applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
OwningRewritePatternList &&patterns);
LLVM_NODISCARD LogicalResult
applyFullConversion(Operation *op, ConversionTarget &target,
OwningRewritePatternList &&patterns);
//===----------------------------------------------------------------------===//
// Op + Type Conversion Entry Points
//===----------------------------------------------------------------------===//
/// Apply a partial conversion on the function operations within the given
/// module. This method returns failure if a type conversion was encountered.
LLVM_NODISCARD LogicalResult applyPartialConversion(
ModuleOp module, ConversionTarget &target, TypeConverter &converter,
OwningRewritePatternList &&patterns);
/// Convert the given functions with the provided conversion patterns. This
/// function returns failure if a type conversion failed.
LLVM_NODISCARD
LogicalResult applyConversionPatterns(MutableArrayRef<FuncOp> fns,
ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns);
/// Apply a partial conversion on the given function operations. This method
/// returns failure if a type conversion was encountered.
LLVM_NODISCARD LogicalResult applyPartialConversion(
MutableArrayRef<FuncOp> fns, ConversionTarget &target,
TypeConverter &converter, OwningRewritePatternList &&patterns);
/// Convert the given function with the provided conversion patterns. This will
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LLVM_NODISCARD
LogicalResult applyConversionPatterns(FuncOp fn, ConversionTarget &target,
OwningRewritePatternList &&patterns);
/// Apply a full conversion on the function operations within the given
/// module. This method returns failure if a type conversion was encountered, or
/// if the conversion of any operations failed.
LLVM_NODISCARD LogicalResult applyFullConversion(
ModuleOp module, ConversionTarget &target, TypeConverter &converter,
OwningRewritePatternList &&patterns);
/// Apply a partial conversion on the given function operations. This method
/// returns failure if a type conversion was encountered, or if the conversion
/// of any operation failed.
LLVM_NODISCARD LogicalResult applyFullConversion(
MutableArrayRef<FuncOp> fns, ConversionTarget &target,
TypeConverter &converter, OwningRewritePatternList &&patterns);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

View File

@ -282,7 +282,7 @@ void ControlFlowToCFGPass::runOnFunction() {
ConversionTarget target(getContext());
target.addLegalDialect<StandardOpsDialect>();
if (failed(
applyConversionPatterns(getFunction(), target, std::move(patterns))))
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}

View File

@ -1064,13 +1064,13 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyConversionPatterns(m, target, *typeConverter,
std::move(patterns))))
if (failed(applyPartialConversion(m, target, *typeConverter,
std::move(patterns))))
signalPassFailure();
}
// Callback for creating a list of patterns. It is called every time in
// runOnModule since applyConversionPatterns consumes the list.
// runOnModule since applyPartialConversion consumes the list.
LLVMPatternListFiller patternListFiller;
// Callback for creating an instance of type converter. The converter

View File

@ -763,8 +763,8 @@ void LowerLinalgToLLVMPass::runOnModule() {
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyConversionPatterns(module, target, converter,
std::move(patterns)))) {
if (failed(applyPartialConversion(module, target, converter,
std::move(patterns)))) {
signalPassFailure();
}

View File

@ -841,30 +841,48 @@ void OperationLegalizer::computeLegalizationGraphBenefit() {
}
//===----------------------------------------------------------------------===//
// FunctionConverter
// OperationConverter
//===----------------------------------------------------------------------===//
namespace {
// This class converts a single function using the given pattern matcher. If a
enum OpConversionMode {
// In this mode, the conversion will ignore failed conversions to allow
// illegal operations to co-exist in the IR.
Partial,
// In this mode, all operations must be legal for the given target for the
// conversion to succeeed.
Full,
};
// This class converts operations using the given pattern matcher. If a
// TypeConverter object is provided, then the types of block arguments will be
// converted using the appropriate 'convertType' calls.
struct FunctionConverter {
explicit FunctionConverter(ConversionTarget &target,
OwningRewritePatternList &patterns,
TypeConverter *conversion = nullptr)
: typeConverter(conversion), opLegalizer(target, patterns) {}
struct OperationConverter {
explicit OperationConverter(ConversionTarget &target,
OwningRewritePatternList &patterns,
OpConversionMode mode,
TypeConverter *conversion = nullptr)
: typeConverter(conversion), opLegalizer(target, patterns), mode(mode) {}
/// Converts the given function to the conversion target. Returns failure on
/// error, success otherwise. If 'signatureConversion' is provided, the
/// arguments of the entry block are updated accordingly.
/// error, success otherwise.
LogicalResult
convertFunction(FuncOp f,
TypeConverter::SignatureConversion *signatureConversion);
TypeConverter::SignatureConversion &signatureConversion);
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts a block or operation with the given rewriter.
LogicalResult convert(DialectConversionRewriter &rewriter,
llvm::PointerUnion<Operation *, Block *> &ptr);
/// Converts a set of blocks/operations with the given rewriter.
LogicalResult
convert(DialectConversionRewriter &rewriter,
std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert);
/// Recursively collect all of the blocks, and operations, to convert from
/// within 'region'.
LogicalResult computeConversionSet(
@ -876,11 +894,14 @@ private:
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
};
} // end anonymous namespace
/// Recursively collect all of the blocks to convert from within 'region'.
LogicalResult FunctionConverter::computeConversionSet(
LogicalResult OperationConverter::computeConversionSet(
Region &region,
std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
if (region.empty())
@ -919,38 +940,30 @@ LogicalResult FunctionConverter::computeConversionSet(
/// Converts a block or operation with the given rewriter.
LogicalResult
FunctionConverter::convert(DialectConversionRewriter &rewriter,
llvm::PointerUnion<Operation *, Block *> &ptr) {
OperationConverter::convert(DialectConversionRewriter &rewriter,
llvm::PointerUnion<Operation *, Block *> &ptr) {
// If this is a block, then convert the types of each of the arguments.
if (auto *block = ptr.dyn_cast<Block *>()) {
assert(typeConverter && "expected valid type converter");
return rewriter.argConverter.convertArguments(block, rewriter.mapping);
}
// Otherwise, this is an operation to legalize.
(void)opLegalizer.legalize(ptr.get<Operation *>(), rewriter);
// Otherwise, legalize the given operation.
auto *op = ptr.get<Operation *>();
auto result = opLegalizer.legalize(op, rewriter);
// Failed conversions are only important if this is a full conversion.
if (mode == OpConversionMode::Full && failed(result))
return op->emitError() << "failed to legalize operation '" << op->getName()
<< "'";
// In any other case, illegal operations are allowed to remain in the IR.
return success();
}
LogicalResult FunctionConverter::convertFunction(
FuncOp f, TypeConverter::SignatureConversion *signatureConversion) {
// If this is an external function, there is nothing else to do.
if (f.isExternal())
return success();
DialectConversionRewriter rewriter(f.getContext(), typeConverter);
// Update the signature of the entry block.
if (signatureConversion) {
rewriter.argConverter.convertSignature(
&f.getBody().front(), *signatureConversion, rewriter.mapping);
}
/// Compute the set of operations and blocks to convert.
std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
if (failed(computeConversionSet(f.getBody(), toConvert)))
return failure();
LogicalResult OperationConverter::convert(
DialectConversionRewriter &rewriter,
std::vector<llvm::PointerUnion<Operation *, Block *>> &toConvert) {
// Convert each operation/block and discard rewrites on failure.
for (auto &it : toConvert) {
if (failed(convert(rewriter, it))) {
@ -964,6 +977,43 @@ LogicalResult FunctionConverter::convertFunction(
return success();
}
LogicalResult OperationConverter::convertFunction(
FuncOp f, TypeConverter::SignatureConversion &signatureConversion) {
// If this is an external function, there is nothing else to do.
if (f.isExternal())
return success();
// Update the signature of the entry block.
DialectConversionRewriter rewriter(f.getContext(), typeConverter);
rewriter.argConverter.convertSignature(&f.getBody().front(),
signatureConversion, rewriter.mapping);
// Compute the set of operations and blocks to convert.
std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
if (failed(computeConversionSet(f.getBody(), toConvert)))
return failure();
return convert(rewriter, toConvert);
}
/// Converts the given top-level operation to the conversion target.
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
/// Compute the set of operations and blocks to convert.
std::vector<llvm::PointerUnion<Operation *, Block *>> toConvert;
for (auto *op : ops) {
toConvert.emplace_back(op);
for (auto &region : op->getRegions())
if (failed(computeConversionSet(region, toConvert)))
return failure();
}
// Rewrite the blocks and operations.
DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter);
return convert(rewriter, toConvert);
}
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
@ -1102,34 +1152,59 @@ auto ConversionTarget::getOpAction(OperationName op) const
}
//===----------------------------------------------------------------------===//
// Conversion Application
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
/// Convert the given module with the provided conversion patterns and type
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
/// Apply a partial conversion on the given operations, and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize.
LogicalResult
mlir::applyConversionPatterns(ModuleOp module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
return applyConversionPatterns(allFunctions, target, converter,
std::move(patterns));
mlir::applyPartialConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
OwningRewritePatternList &&patterns) {
OperationConverter converter(target, patterns, OpConversionMode::Partial);
return converter.convertOperations(ops);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
return applyPartialConversion(llvm::makeArrayRef(op), target,
std::move(patterns));
}
/// Convert the given functions with the provided conversion patterns.
LogicalResult mlir::applyConversionPatterns(
MutableArrayRef<FuncOp> fns, ConversionTarget &target,
TypeConverter &converter, OwningRewritePatternList &&patterns) {
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method will return failure if the conversion of any
/// operation fails.
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
OwningRewritePatternList &&patterns) {
OperationConverter converter(target, patterns, OpConversionMode::Full);
return converter.convertOperations(ops);
}
LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
return applyFullConversion(llvm::makeArrayRef(op), target,
std::move(patterns));
}
//===----------------------------------------------------------------------===//
// Op + Type Conversion Entry Points
//===----------------------------------------------------------------------===//
static LogicalResult applyConversion(MutableArrayRef<FuncOp> fns,
ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns,
OpConversionMode mode) {
if (fns.empty())
return success();
// Build the function converter.
auto *ctx = fns.front().getContext();
FunctionConverter funcConverter(target, patterns, &converter);
OperationConverter funcConverter(target, patterns, mode, &converter);
// Try to convert each of the functions within the module.
SmallVector<NamedAttributeList, 4> argAttrs;
auto *ctx = fns.front().getContext();
for (auto func : fns) {
argAttrs.clear();
func.getAllArgAttrs(argAttrs);
@ -1144,20 +1219,53 @@ LogicalResult mlir::applyConversionPatterns(
func.setAllArgAttrs(conversion->getConvertedArgAttrs());
// Convert the body of this function.
if (failed(funcConverter.convertFunction(func, &*conversion)))
if (failed(funcConverter.convertFunction(func, *conversion)))
return failure();
}
return success();
}
/// Convert the given function with the provided conversion patterns. This will
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
/// Apply a partial conversion on the function operations within the given
/// module. This method returns failure if a type conversion was encountered.
LogicalResult
mlir::applyConversionPatterns(FuncOp fn, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
FunctionConverter converter(target, patterns);
return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
mlir::applyPartialConversion(ModuleOp module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
return applyPartialConversion(allFunctions, target, converter,
std::move(patterns));
}
/// Apply a partial conversion on the given function operations. This method
/// returns failure if a type conversion was encountered.
LogicalResult
mlir::applyPartialConversion(MutableArrayRef<FuncOp> fns,
ConversionTarget &target, TypeConverter &converter,
OwningRewritePatternList &&patterns) {
return applyConversion(fns, target, converter, std::move(patterns),
OpConversionMode::Partial);
}
/// Apply a full conversion on the function operations within the given module.
/// This method returns failure if a type conversion was encountered, or if the
/// conversion of any operations failed.
LogicalResult mlir::applyFullConversion(ModuleOp module,
ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
SmallVector<FuncOp, 32> allFunctions(module.getOps<FuncOp>());
return applyFullConversion(allFunctions, target, converter,
std::move(patterns));
}
/// Apply a full conversion on the given function operations. This method
/// returns failure if a type conversion was encountered, or if the conversion
/// of any operation failed.
LogicalResult mlir::applyFullConversion(MutableArrayRef<FuncOp> fns,
ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
return applyConversion(fns, target, converter, std::move(patterns),
OpConversionMode::Full);
}

View File

@ -521,8 +521,8 @@ class LowerAffinePass : public FunctionPass<LowerAffinePass> {
populateAffineToStdConversionPatterns(patterns, &getContext());
ConversionTarget target(getContext());
target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
if (failed(applyConversionPatterns(getFunction(), target,
std::move(patterns))))
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};

View File

@ -177,8 +177,8 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
TestConversionTarget target(getContext());
if (failed(applyConversionPatterns(getModule(), target, converter,
std::move(patterns))))
if (failed(applyPartialConversion(getModule(), target, converter,
std::move(patterns))))
signalPassFailure();
}
};