forked from OSchip/llvm-project
[mlir] Add structural conversion to async dialect lowering.
Lowering of async dialect uses a fixed type converter and therefore does not support lowering non-standard types. This revision adds a structural conversion so that non-standard types in `!async.value`s can be lowered to LLVM before lowering the async dialect itself. Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D94404
This commit is contained in:
parent
1027a22ccd
commit
195728c75a
|
@ -13,13 +13,29 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
class ConversionTarget;
|
||||||
class ModuleOp;
|
class ModuleOp;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class OperationPass;
|
class OperationPass;
|
||||||
|
class MLIRContext;
|
||||||
|
class OwningRewritePatternList;
|
||||||
|
class TypeConverter;
|
||||||
|
|
||||||
/// Create a pass to convert Async operations to the LLVM dialect.
|
/// Create a pass to convert Async operations to the LLVM dialect.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
|
||||||
|
|
||||||
|
/// Populates patterns for async structural type conversions.
|
||||||
|
///
|
||||||
|
/// A "structural" type conversion is one where the underlying ops are
|
||||||
|
/// completely agnostic to the actual types involved and simply need to update
|
||||||
|
/// their types. An example of this is async.execute -- the async.execute op and
|
||||||
|
/// the corresponding async.yield ops need to update their types accordingly to
|
||||||
|
/// the TypeConverter, but otherwise don't care what type conversions are
|
||||||
|
/// happening.
|
||||||
|
void populateAsyncStructuralTypeConversionsAndLegality(
|
||||||
|
MLIRContext *context, TypeConverter &typeConverter,
|
||||||
|
OwningRewritePatternList &patterns, ConversionTarget &target);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H
|
#endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H
|
||||||
|
|
|
@ -1132,6 +1132,71 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ExecuteOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
ExecuteOp newOp =
|
||||||
|
cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||||
|
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
|
||||||
|
newOp.getRegion().end());
|
||||||
|
|
||||||
|
// Set operands and update block argument and result types.
|
||||||
|
newOp->setOperands(operands);
|
||||||
|
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
|
||||||
|
return failure();
|
||||||
|
for (auto result : newOp.getResults())
|
||||||
|
result.setType(typeConverter->convertType(result.getType()));
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, newOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Dummy pattern to trigger the appropriate type conversion / materialization.
|
||||||
|
class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AwaitOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<AwaitOp>(op, operands.front());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Dummy pattern to trigger the appropriate type conversion / materialization.
|
||||||
|
class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<async::YieldOp>(op, operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
|
||||||
return std::make_unique<ConvertAsyncToLLVMPass>();
|
return std::make_unique<ConvertAsyncToLLVMPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
|
||||||
|
MLIRContext *context, TypeConverter &typeConverter,
|
||||||
|
OwningRewritePatternList &patterns, ConversionTarget &target) {
|
||||||
|
typeConverter.addConversion([&](TokenType type) { return type; });
|
||||||
|
typeConverter.addConversion([&](ValueType type) {
|
||||||
|
return ValueType::get(typeConverter.convertType(type.getValueType()));
|
||||||
|
});
|
||||||
|
|
||||||
|
patterns
|
||||||
|
.insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
|
||||||
|
typeConverter, context);
|
||||||
|
|
||||||
|
target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
|
||||||
|
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue