[mlir] Allow to use constant lambda as callbacks for `TypeConverter`

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95787
This commit is contained in:
Vladislav Vinogradov 2021-02-02 18:26:31 +00:00 committed by Mehdi Amini
parent d8c373815d
commit 7cc7998497
2 changed files with 18 additions and 19 deletions

View File

@ -103,8 +103,8 @@ public:
/// conversion function to perform the conversion.
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<0>>
void addConversion(FnT &&callback) {
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
@ -124,8 +124,8 @@ public:
///
/// This method registers a materialization that will be called when
/// converting an illegal block argument type, to a legal type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@ -133,16 +133,16 @@ public:
/// This method registers a materialization that will be called when
/// converting a legal type to an illegal source type. This is used when
/// conversions to an illegal type must persist beyond the main conversion.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting type from an illegal, or source, type to a legal type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));

View File

@ -492,8 +492,17 @@ struct TestTypeConverter : public TypeConverter {
TestTypeConverter() {
addConversion(convertType);
addArgumentMaterialization(materializeCast);
addArgumentMaterialization(materializeOneToOneCast);
addSourceMaterialization(materializeCast);
/// Materialize the cast for one-to-one conversion from i64 to f64.
const auto materializeOneToOneCast =
[](OpBuilder &builder, IntegerType resultType, ValueRange inputs,
Location loc) -> Optional<Value> {
if (resultType.getWidth() == 42 && inputs.size() == 1)
return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
return llvm::None;
};
addArgumentMaterialization(materializeOneToOneCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@ -532,16 +541,6 @@ struct TestTypeConverter : public TypeConverter {
return inputs[0];
return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
}
/// Materialize the cast for one-to-one conversion from i64 to f64.
static Optional<Value> materializeOneToOneCast(OpBuilder &builder,
IntegerType resultType,
ValueRange inputs,
Location loc) {
if (resultType.getWidth() == 42 && inputs.size() == 1)
return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
return llvm::None;
}
};
struct TestLegalizePatternDriver