[mlir] Add structural conversion to async dialect lowering.

Lowering of async dialect uses a fixed type converter and therefore does not support lowering non-standard types.

This revision adds a structural conversion so that non-standard types in `!async.value`s can be lowered to LLVM before lowering the async dialect itself.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D94404
This commit is contained in:
Christian Sigg 2021-01-11 20:34:44 +01:00
parent 1027a22ccd
commit 195728c75a
2 changed files with 81 additions and 0 deletions

View File

@ -13,13 +13,29 @@
namespace mlir { namespace mlir {
class ConversionTarget;
class ModuleOp; class ModuleOp;
template <typename T> template <typename T>
class OperationPass; class OperationPass;
class MLIRContext;
class OwningRewritePatternList;
class TypeConverter;
/// Create a pass to convert Async operations to the LLVM dialect. /// Create a pass to convert Async operations to the LLVM dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass(); std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
/// Populates patterns for async structural type conversions.
///
/// A "structural" type conversion is one where the underlying ops are
/// completely agnostic to the actual types involved and simply need to update
/// their types. An example of this is async.execute -- the async.execute op and
/// the corresponding async.yield ops need to update their types accordingly to
/// the TypeConverter, but otherwise don't care what type conversions are
/// happening.
void populateAsyncStructuralTypeConversionsAndLegality(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns, ConversionTarget &target);
} // namespace mlir } // namespace mlir
#endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H #endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H

View File

@ -1132,6 +1132,71 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
} }
} // namespace } // namespace
namespace {
class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ExecuteOp newOp =
cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
// Set operands and update block argument and result types.
newOp->setOperands(operands);
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
return failure();
for (auto result : newOp.getResults())
result.setType(typeConverter->convertType(result.getType()));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
// Dummy pattern to trigger the appropriate type conversion / materialization.
class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
return success();
}
};
// Dummy pattern to trigger the appropriate type conversion / materialization.
class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
return success();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() { std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
return std::make_unique<ConvertAsyncToLLVMPass>(); return std::make_unique<ConvertAsyncToLLVMPass>();
} }
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
MLIRContext *context, TypeConverter &typeConverter,
OwningRewritePatternList &patterns, ConversionTarget &target) {
typeConverter.addConversion([&](TokenType type) { return type; });
typeConverter.addConversion([&](ValueType type) {
return ValueType::get(typeConverter.convertType(type.getValueType()));
});
patterns
.insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
typeConverter, context);
target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
}