[mlir] Remove the use of FilterTypes for template metaprogramming

This technique results in an explosion in compile time, resulting from a
huge number of std::tuple/concat instatiations. This technique is replaced
by simpler metaprogramming and results in a signficant reduction in
compile time. A local debug/asan build saw a 4x speed up in the processing
of ArithmeticOps.h.inc, and given the nature of this change every dialect
should see similar reductions in compile time.

Differential Revision: https://reviews.llvm.org/D123360
This commit is contained in:
River Riddle 2022-04-07 21:36:40 -07:00
parent 04f3a224bc
commit 31c88660ab
2 changed files with 90 additions and 109 deletions

View File

@ -1540,36 +1540,24 @@ foldTrait(Operation *op, ArrayRef<Attribute> operands,
// fail to fold this trait.
return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
}
template <typename Trait>
static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
LogicalResult>
foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
return failure();
}
/// The internal implementation of `foldTraits` below that returns the result of
/// folding a set of trait types `Ts` that implement a `foldTrait` method.
/// Given a tuple type containing a set of traits, return the result of folding
/// the given operation.
template <typename... Ts>
static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results,
std::tuple<Ts...> *) {
static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
bool anyFolded = false;
(void)std::initializer_list<int>{
(anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
return success(anyFolded);
}
/// Given a tuple type containing a set of traits that contain a `foldTrait`
/// method, return the result of folding the given operation.
template <typename TraitTupleT>
static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
foldTraits(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
}
/// A variant of the method above that is specialized when there are no traits
/// that contain a `foldTrait` method.
template <typename TraitTupleT>
static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
foldTraits(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return failure();
}
//===----------------------------------------------------------------------===//
// Trait Verification
@ -1587,44 +1575,51 @@ template <typename T>
using detect_has_verify_region_trait =
llvm::is_detected<has_verify_region_trait, T>;
/// The internal implementation of `verifyTraits` below that returns the result
/// of verifying the current operation with all of the provided trait types
/// `Ts`.
/// Verify the given trait if it provides a verifier.
template <typename T>
std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
verifyTrait(Operation *op) {
return T::verifyTrait(op);
}
template <typename T>
inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
verifyTrait(Operation *) {
return success();
}
/// Given a set of traits, return the result of verifying the given operation.
template <typename... Ts>
static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
LogicalResult verifyTraits(Operation *op) {
LogicalResult result = success();
(void)std::initializer_list<int>{
(result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
(result = succeeded(result) ? verifyTrait<Ts>(op) : failure(), 0)...};
return result;
}
/// Given a tuple type containing a set of traits that contain a
/// `verifyTrait` method, return the result of verifying the given operation.
template <typename TraitTupleT>
static LogicalResult verifyTraits(Operation *op) {
return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
/// Verify the given trait if it provides a region verifier.
template <typename T>
std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
verifyRegionTrait(Operation *op) {
return T::verifyRegionTrait(op);
}
template <typename T>
inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
LogicalResult>
verifyRegionTrait(Operation *) {
return success();
}
/// The internal implementation of `verifyRegionTraits` below that returns the
/// result of verifying the current operation with all of the provided trait
/// types `Ts`.
/// Given a set of traits, return the result of verifying the regions of the
/// given operation.
template <typename... Ts>
static LogicalResult verifyRegionTraitsImpl(Operation *op,
std::tuple<Ts...> *) {
LogicalResult verifyRegionTraits(Operation *op) {
(void)op;
LogicalResult result = success();
(void)std::initializer_list<int>{
(result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
(result = succeeded(result) ? verifyRegionTrait<Ts>(op) : failure(),
0)...};
return result;
}
/// Given a tuple type containing a set of traits that contain a
/// `verifyTrait` method, return the result of verifying the given operation.
template <typename TraitTupleT>
static LogicalResult verifyRegionTraits(Operation *op) {
return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
}
} // namespace op_definition_impl
//===----------------------------------------------------------------------===//
@ -1733,18 +1728,6 @@ private:
decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
template <typename T>
using detect_has_print = llvm::is_detected<has_print, T>;
/// A tuple type containing the traits that have a `foldTrait` function.
using FoldableTraitsTupleT = typename detail::FilterTypes<
op_definition_impl::detect_has_any_fold_trait,
Traits<ConcreteType>...>::type;
/// A tuple type containing the traits that have a verify function.
using VerifiableTraitsTupleT =
typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
Traits<ConcreteType>...>::type;
/// A tuple type containing the region traits that have a verify function.
using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
op_definition_impl::detect_has_verify_region_trait,
Traits<ConcreteType>...>::type;
/// Returns an interface map containing the interfaces registered to this
/// operation.
@ -1794,8 +1777,8 @@ private:
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// In this case, we only need to fold the traits of the operation.
return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
results);
return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
op, operands, results);
};
}
/// Return the result of folding a single result operation that defines a
@ -1809,7 +1792,7 @@ private:
// If the fold failed or was in-place, try to fold the traits of the
// operation.
if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
op, operands, results)))
return success();
return success(static_cast<bool>(result));
@ -1826,7 +1809,7 @@ private:
// If the fold failed or was in-place, try to fold the traits of the
// operation.
if (failed(result) || results.empty()) {
if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
op, operands, results)))
return success();
}
@ -1879,7 +1862,7 @@ private:
static_assert(hasNoDataMembers(),
"Op class shouldn't define new data members");
return failure(
failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
failed(cast<ConcreteType>(op).verify()));
}
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
@ -1889,9 +1872,10 @@ private:
static LogicalResult verifyRegionInvariants(Operation *op) {
static_assert(hasNoDataMembers(),
"Op class shouldn't define new data members");
return failure(failed(op_definition_impl::verifyRegionTraits<
VerifiableRegionTraitsTupleT>(op)) ||
failed(cast<ConcreteType>(op).verifyRegions()));
return failure(
failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
op)) ||
failed(cast<ConcreteType>(op).verifyRegions()));
}
static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);

View File

@ -125,23 +125,17 @@ private:
// InterfaceMap
//===----------------------------------------------------------------------===//
/// Utility to filter a given sequence of types base upon a predicate.
template <bool>
struct FilterTypeT {
template <class E>
using type = std::tuple<E>;
};
template <>
struct FilterTypeT<false> {
template <class E>
using type = std::tuple<>;
};
template <template <class> class Pred, class... Es>
struct FilterTypes {
using type = decltype(std::tuple_cat(
std::declval<
typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
};
/// Template utility that computes the number of elements within `T` that
/// satisfy the given predicate.
template <template <class> class Pred, size_t N, typename... Ts>
struct count_if_t_impl : public std::integral_constant<size_t, N> {};
template <template <class> class Pred, size_t N, typename T, typename... Us>
struct count_if_t_impl<Pred, N, T, Us...>
: public std::integral_constant<
size_t,
count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {};
template <template <class> class Pred, typename... Ts>
using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
namespace {
/// Type trait indicating whether all template arguments are
@ -171,8 +165,7 @@ class InterfaceMap {
template <typename T>
using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
template <typename... Types>
using num_interface_types = typename std::tuple_size<
typename FilterTypes<detect_get_interface_id, Types...>::type>;
using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
public:
InterfaceMap(InterfaceMap &&) = default;
@ -192,20 +185,17 @@ public:
/// types, not all of the types need to be interfaces. The provided types that
/// do not represent interfaces are not added to the interface map.
template <typename... Types>
static std::enable_if_t<num_interface_types<Types...>::value != 0,
InterfaceMap>
get() {
// Filter the provided types for those that are interfaces.
using FilteredTupleType =
typename FilterTypes<detect_get_interface_id, Types...>::type;
return getImpl((FilteredTupleType *)nullptr);
}
static InterfaceMap get() {
// TODO: Use constexpr if here in C++17.
constexpr size_t numInterfaces = num_interface_types_t<Types...>::value;
if (numInterfaces == 0)
return InterfaceMap();
template <typename... Types>
static std::enable_if_t<num_interface_types<Types...>::value == 0,
InterfaceMap>
get() {
return InterfaceMap();
std::array<std::pair<TypeID, void *>, numInterfaces> elements;
std::pair<TypeID, void *> *elementIt = elements.data();
(void)std::initializer_list<int>{
0, (addModelAndUpdateIterator<Types>(elementIt), 0)...};
return InterfaceMap(elements);
}
/// Returns an instance of the concept object for the given interface if it
@ -235,23 +225,30 @@ public:
}
private:
InterfaceMap() = default;
/// Assign the interface model of the type to the given opaque element
/// iterator and increment it.
template <typename T>
static inline std::enable_if_t<detect_get_interface_id<T>::value>
addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) {
*elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT)))
typename T::ModelT()};
++elementIt;
}
/// Overload when `T` isn't an interface.
template <typename T>
static inline std::enable_if_t<!detect_get_interface_id<T>::value>
addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {}
/// Insert the given set of interface models into the interface map.
void insert(ArrayRef<std::pair<TypeID, void *>> elements);
/// Compare two TypeID instances by comparing the underlying pointer.
static bool compare(TypeID lhs, TypeID rhs) {
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
}
InterfaceMap() = default;
void insert(ArrayRef<std::pair<TypeID, void *>> elements);
template <typename... Ts>
static InterfaceMap getImpl(std::tuple<Ts...> *) {
std::pair<TypeID, void *> elements[] = {std::make_pair(
Ts::getInterfaceID(),
new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...};
return InterfaceMap(elements);
}
/// Returns an instance of the concept object for the given interface id if it
/// was registered to this map, null otherwise.
void *lookup(TypeID id) const {