[llvm][STLExtras] Add various type_trait utilities currently present in MLIR

This revision moves several type_trait utilities from MLIR into LLVM. Namely, this revision adds:
is_detected - This matches the experimental std::is_detected
is_invocable - This matches the c++17 std::is_invocable
function_traits - A utility traits class for getting the argument and result types of a callable type

Differential Revision: https://reviews.llvm.org/D78059
This commit is contained in:
River Riddle 2020-04-14 14:52:52 -07:00
parent f52ec5d5c0
commit 8cbe371c28
11 changed files with 196 additions and 108 deletions

View File

@ -75,6 +75,79 @@ template <typename T> struct make_const_ref {
typename std::add_const<T>::type>::type;
};
/// Utilities for detecting if a given trait holds for some set of arguments
/// 'Args'. For example, the given trait could be used to detect if a given type
/// has a copy assignment operator:
/// template<class T>
/// using has_copy_assign_t = decltype(std::declval<T&>()
/// = std::declval<const T&>());
/// bool fooHasCopyAssign = is_detected<has_copy_assign_t, FooClass>::value;
namespace detail {
template <typename...> using void_t = void;
template <class, template <class...> class Op, class... Args> struct detector {
using value_t = std::false_type;
};
template <template <class...> class Op, class... Args>
struct detector<void_t<Op<Args...>>, Op, Args...> {
using value_t = std::true_type;
};
} // end namespace detail
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<void, Op, Args...>::value_t;
/// Check if a Callable type can be invoked with the given set of arg types.
namespace detail {
template <typename Callable, typename... Args>
using is_invocable =
decltype(std::declval<Callable &>()(std::declval<Args>()...));
} // namespace detail
template <typename Callable, typename... Args>
using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
/// This class provides various trait information about a callable object.
/// * To access the number of arguments: Traits::num_args
/// * To access the type of an argument: Traits::arg_t<i>
/// * To access the type of the result: Traits::result_t
template <typename T, bool isClass = std::is_class<T>::value>
struct function_traits : public function_traits<decltype(&T::operator())> {};
/// Overload for class function types.
template <typename ClassType, typename ReturnType, typename... Args>
struct function_traits<ReturnType (ClassType::*)(Args...) const, false> {
/// The number of arguments to this function.
enum { num_args = sizeof...(Args) };
/// The result type of this function.
using result_t = ReturnType;
/// The type of an argument to this function.
template <size_t i>
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
/// Overload for class function types.
template <typename ClassType, typename ReturnType, typename... Args>
struct function_traits<ReturnType (ClassType::*)(Args...), false>
: function_traits<ReturnType (ClassType::*)(Args...) const> {};
/// Overload for non-class function types.
template <typename ReturnType, typename... Args>
struct function_traits<ReturnType (*)(Args...), false> {
/// The number of arguments to this function.
enum { num_args = sizeof...(Args) };
/// The result type of this function.
using result_t = ReturnType;
/// The type of an argument to this function.
template <size_t i>
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
/// Overload for non-class function type references.
template <typename ReturnType, typename... Args>
struct function_traits<ReturnType (&)(Args...), false>
: public function_traits<ReturnType (*)(Args...)> {};
//===----------------------------------------------------------------------===//
// Extra additions to <functional>
//===----------------------------------------------------------------------===//

View File

@ -73,6 +73,7 @@ add_llvm_unittest(ADTTests
TinyPtrVectorTest.cpp
TripleTest.cpp
TwineTest.cpp
TypeTraitsTest.cpp
WaymarkingTest.cpp
)

View File

@ -0,0 +1,80 @@
//===- TypeTraitsTest.cpp - type_traits unit tests ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/STLExtras.h"
#include "gtest/gtest.h"
using namespace llvm;
//===----------------------------------------------------------------------===//
// function_traits
//===----------------------------------------------------------------------===//
namespace {
/// Check a callable type of the form `bool(const int &)`.
template <typename CallableT> struct CheckFunctionTraits {
static_assert(
std::is_same<typename function_traits<CallableT>::result_t, bool>::value,
"expected result_t to be `bool`");
static_assert(
std::is_same<typename function_traits<CallableT>::template arg_t<0>,
const int &>::value,
"expected arg_t<0> to be `const int &`");
static_assert(function_traits<CallableT>::num_args == 1,
"expected num_args to be 1");
};
/// Test function pointers.
using FuncType = bool (*)(const int &);
struct CheckFunctionPointer : CheckFunctionTraits<FuncType> {};
static bool func(const int &v);
struct CheckFunctionPointer2 : CheckFunctionTraits<decltype(&func)> {};
/// Test method pointers.
struct Foo {
bool func(const int &v);
};
struct CheckMethodPointer : CheckFunctionTraits<decltype(&Foo::func)> {};
/// Test lambda references.
auto lambdaFunc = [](const int &v) -> bool { return true; };
struct CheckLambda : CheckFunctionTraits<decltype(lambdaFunc)> {};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// is_detected
//===----------------------------------------------------------------------===//
namespace {
struct HasFooMethod {
void foo() {}
};
struct NoFooMethod {};
template <class T> using has_foo_method_t = decltype(std::declval<T &>().foo());
static_assert(is_detected<has_foo_method_t, HasFooMethod>::value,
"expected foo method to be detected");
static_assert(!is_detected<has_foo_method_t, NoFooMethod>::value,
"expected no foo method to be detected");
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// is_invocable
//===----------------------------------------------------------------------===//
static void invocable_fn(int) {}
static_assert(is_invocable<decltype(invocable_fn), int>::value,
"expected function to be invocable");
static_assert(!is_invocable<decltype(invocable_fn), void *>::value,
"expected function not to be invocable");
static_assert(!is_invocable<decltype(invocable_fn), int, int>::value,
"expected function not to be invocable");

View File

@ -46,7 +46,7 @@ public:
/// Note: This inference rules for this overload are very simple: strip
/// pointers and references.
template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
using Traits = FunctionTraits<std::decay_t<CallableT>>;
using Traits = llvm::function_traits<std::decay_t<CallableT>>;
using CaseT = std::remove_cv_t<std::remove_pointer_t<
std::remove_reference_t<typename Traits::template arg_t<0>>>>;
@ -64,20 +64,22 @@ protected:
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
/// selected if `value` already has a suitable dyn_cast method.
template <typename CastT, typename ValueT>
static auto castValue(
ValueT value,
static auto
castValue(ValueT value,
typename std::enable_if_t<
is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
nullptr) {
return value.template dyn_cast<CastT>();
}
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
/// selected if llvm::dyn_cast should be used.
template <typename CastT, typename ValueT>
static auto castValue(
ValueT value,
static auto
castValue(ValueT value,
typename std::enable_if_t<
!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
!llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
nullptr) {
return dyn_cast<CastT>(value);
}

View File

@ -140,8 +140,9 @@ using has_operation_or_value_matcher_t =
/// Statically switch to a Value matcher.
template <typename MatcherClass>
typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
MatcherClass, Value>::value,
typename std::enable_if_t<
llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
Value>::value,
bool>
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
return matcher.match(op->getOperand(idx));
@ -149,8 +150,9 @@ matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
/// Statically switch to an Operation matcher.
template <typename MatcherClass>
typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
MatcherClass, Operation *>::value,
typename std::enable_if_t<
llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
Operation *>::value,
bool>
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
if (auto defOp = op->getOperand(idx).getDefiningOp())

View File

@ -1298,16 +1298,16 @@ private:
/// If 'T' is the same interface as 'interfaceID' return the concept
/// instance.
template <typename T>
static typename std::enable_if<is_detected<has_get_interface_id, T>::value,
void *>::type
static typename std::enable_if<
llvm::is_detected<has_get_interface_id, T>::value, void *>::type
lookup(TypeID interfaceID) {
return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
}
/// 'T' is known to not be an interface, return nullptr.
template <typename T>
static typename std::enable_if<!is_detected<has_get_interface_id, T>::value,
void *>::type
static typename std::enable_if<
!llvm::is_detected<has_get_interface_id, T>::value, void *>::type
lookup(TypeID) {
return nullptr;
}

View File

@ -71,13 +71,13 @@ using has_is_invalidated = decltype(std::declval<T &>().isInvalidated(
/// Implementation of 'isInvalidated' if the analysis provides a definition.
template <typename AnalysisT>
std::enable_if_t<is_detected<has_is_invalidated, AnalysisT>::value, bool>
std::enable_if_t<llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
return analysis.isInvalidated(pa);
}
/// Default implementation of 'isInvalidated'.
template <typename AnalysisT>
std::enable_if_t<!is_detected<has_is_invalidated, AnalysisT>::value, bool>
std::enable_if_t<!llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
return !pa.isPreserved<AnalysisT>();
}

View File

@ -88,37 +88,6 @@ inline void interleaveComma(const Container &c, raw_ostream &os) {
interleaveComma(c, os, [&](const T &a) { os << a; });
}
/// Utilities for detecting if a given trait holds for some set of arguments
/// 'Args'. For example, the given trait could be used to detect if a given type
/// has a copy assignment operator:
/// template<class T>
/// using has_copy_assign_t = decltype(std::declval<T&>()
/// = std::declval<const T&>());
/// bool fooHasCopyAssign = is_detected<has_copy_assign_t, FooClass>::value;
namespace detail {
template <typename...> using void_t = void;
template <class, template <class...> class Op, class... Args> struct detector {
using value_t = std::false_type;
};
template <template <class...> class Op, class... Args>
struct detector<void_t<Op<Args...>>, Op, Args...> {
using value_t = std::true_type;
};
} // end namespace detail
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<void, Op, Args...>::value_t;
/// Check if a Callable type can be invoked with the given set of arg types.
namespace detail {
template <typename Callable, typename... Args>
using is_invocable =
decltype(std::declval<Callable &>()(std::declval<Args>()...));
} // namespace detail
template <typename Callable, typename... Args>
using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
//===----------------------------------------------------------------------===//
// Extra additions to <iterator>
//===----------------------------------------------------------------------===//
@ -356,47 +325,6 @@ template <typename ContainerTy> bool has_single_element(ContainerTy &&c) {
return it != e && std::next(it) == e;
}
//===----------------------------------------------------------------------===//
// Extra additions to <type_traits>
//===----------------------------------------------------------------------===//
/// This class provides various trait information about a callable object.
/// * To access the number of arguments: Traits::num_args
/// * To access the type of an argument: Traits::arg_t<i>
/// * To access the type of the result: Traits::result_t<i>
template <typename T, bool isClass = std::is_class<T>::value>
struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
/// Overload for class function types.
template <typename ClassType, typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType (ClassType::*)(Args...) const, false> {
/// The number of arguments to this function.
enum { num_args = sizeof...(Args) };
/// The result type of this function.
using result_t = ReturnType;
/// The type of an argument to this function.
template <size_t i>
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
/// Overload for non-class function types.
template <typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType (*)(Args...), false> {
/// The number of arguments to this function.
enum { num_args = sizeof...(Args) };
/// The result type of this function.
using result_t = ReturnType;
/// The type of an argument to this function.
template <size_t i>
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
/// Overload for non-class function type references.
template <typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType (&)(Args...), false>
: public FunctionTraits<ReturnType (*)(Args...)> {};
} // end namespace mlir
#endif // MLIR_SUPPORT_STLEXTRAS_H

View File

@ -215,7 +215,7 @@ private:
/// 'ImplTy::getKey' function for the provided arguments.
template <typename ImplTy, typename... Args>
static typename std::enable_if<
is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
typename ImplTy::KeyTy>::type
getKey(Args &&... args) {
return ImplTy::getKey(args...);
@ -224,7 +224,7 @@ private:
/// the 'ImplTy::KeyTy' with the provided arguments.
template <typename ImplTy, typename... Args>
static typename std::enable_if<
!is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
!llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
typename ImplTy::KeyTy>::type
getKey(Args &&... args) {
return typename ImplTy::KeyTy(args...);
@ -238,7 +238,7 @@ private:
/// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<
is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
getHash(unsigned kind, const DerivedKey &derivedKey) {
return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
@ -246,8 +246,8 @@ private:
/// If there is no 'ImplTy::hashKey' default to using the
/// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<
!is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
getHash(unsigned kind, const DerivedKey &derivedKey) {
return llvm::hash_combine(

View File

@ -108,7 +108,7 @@ public:
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
typename T = typename FunctionTraits<FnT>::template arg_t<0>>
typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
void addConversion(FnT &&callback) {
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
@ -172,7 +172,7 @@ private:
/// different callback forms, that all compose into a single version.
/// With callback of form: `Optional<Type>(T)`
template <typename T, typename FnT>
std::enable_if_t<is_invocable<FnT, T>::value, ConversionCallbackFn>
std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results) {
@ -187,7 +187,7 @@ private:
}
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)`
template <typename T, typename FnT>
std::enable_if_t<!is_invocable<FnT, T>::value, ConversionCallbackFn>
std::enable_if_t<!llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
Type type,
@ -482,7 +482,8 @@ public:
addDynamicallyLegalOp<OpT2, OpTs...>(callback);
}
template <typename OpT, class Callable>
typename std::enable_if<!is_invocable<Callable, Operation *>::value>::type
typename std::enable_if<
!llvm::is_invocable<Callable, Operation *>::value>::type
addDynamicallyLegalOp(Callable &&callback) {
addDynamicallyLegalOp<OpT>(
[=](Operation *op) { return callback(cast<OpT>(op)); });
@ -514,7 +515,8 @@ public:
markOpRecursivelyLegal<OpT2, OpTs...>(callback);
}
template <typename OpT, class Callable>
typename std::enable_if<!is_invocable<Callable, Operation *>::value>::type
typename std::enable_if<
!llvm::is_invocable<Callable, Operation *>::value>::type
markOpRecursivelyLegal(Callable &&callback) {
markOpRecursivelyLegal<OpT>(
[=](Operation *op) { return callback(cast<OpT>(op)); });

View File

@ -477,8 +477,8 @@ struct SymbolScope {
/// 'walkSymbolUses'.
template <typename CallbackT,
typename std::enable_if_t<!std::is_same<
typename FunctionTraits<CallbackT>::result_t, void>::value> * =
nullptr>
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
Optional<WalkResult> walk(CallbackT cback) {
if (Region *region = limit.dyn_cast<Region *>())
return walkSymbolUses(*region, cback);
@ -488,8 +488,8 @@ struct SymbolScope {
/// void(SymbolTable::SymbolUse use)
template <typename CallbackT,
typename std::enable_if_t<std::is_same<
typename FunctionTraits<CallbackT>::result_t, void>::value> * =
nullptr>
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
Optional<WalkResult> walk(CallbackT cback) {
return walk([=](SymbolTable::SymbolUse use, ArrayRef<int>) {
return cback(use), WalkResult::advance();