forked from OSchip/llvm-project
[mlir] Remove the use of FilterTypes for template metaprogramming
This technique results in an explosion in compile time, resulting from a huge number of std::tuple/concat instatiations. This technique is replaced by simpler metaprogramming and results in a signficant reduction in compile time. A local debug/asan build saw a 4x speed up in the processing of ArithmeticOps.h.inc, and given the nature of this change every dialect should see similar reductions in compile time. Differential Revision: https://reviews.llvm.org/D123360
This commit is contained in:
parent
04f3a224bc
commit
31c88660ab
|
@ -1540,36 +1540,24 @@ foldTrait(Operation *op, ArrayRef<Attribute> operands,
|
||||||
// fail to fold this trait.
|
// fail to fold this trait.
|
||||||
return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
|
return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
|
||||||
}
|
}
|
||||||
|
template <typename Trait>
|
||||||
|
static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
|
||||||
|
LogicalResult>
|
||||||
|
foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
/// The internal implementation of `foldTraits` below that returns the result of
|
/// Given a tuple type containing a set of traits, return the result of folding
|
||||||
/// folding a set of trait types `Ts` that implement a `foldTrait` method.
|
/// the given operation.
|
||||||
template <typename... Ts>
|
template <typename... Ts>
|
||||||
static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
|
static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<OpFoldResult> &results,
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
std::tuple<Ts...> *) {
|
|
||||||
bool anyFolded = false;
|
bool anyFolded = false;
|
||||||
(void)std::initializer_list<int>{
|
(void)std::initializer_list<int>{
|
||||||
(anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
|
(anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
|
||||||
return success(anyFolded);
|
return success(anyFolded);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a tuple type containing a set of traits that contain a `foldTrait`
|
|
||||||
/// method, return the result of folding the given operation.
|
|
||||||
template <typename TraitTupleT>
|
|
||||||
static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
|
|
||||||
foldTraits(Operation *op, ArrayRef<Attribute> operands,
|
|
||||||
SmallVectorImpl<OpFoldResult> &results) {
|
|
||||||
return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
|
|
||||||
}
|
|
||||||
/// A variant of the method above that is specialized when there are no traits
|
|
||||||
/// that contain a `foldTrait` method.
|
|
||||||
template <typename TraitTupleT>
|
|
||||||
static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
|
|
||||||
foldTraits(Operation *op, ArrayRef<Attribute> operands,
|
|
||||||
SmallVectorImpl<OpFoldResult> &results) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Trait Verification
|
// Trait Verification
|
||||||
|
|
||||||
|
@ -1587,44 +1575,51 @@ template <typename T>
|
||||||
using detect_has_verify_region_trait =
|
using detect_has_verify_region_trait =
|
||||||
llvm::is_detected<has_verify_region_trait, T>;
|
llvm::is_detected<has_verify_region_trait, T>;
|
||||||
|
|
||||||
/// The internal implementation of `verifyTraits` below that returns the result
|
/// Verify the given trait if it provides a verifier.
|
||||||
/// of verifying the current operation with all of the provided trait types
|
template <typename T>
|
||||||
/// `Ts`.
|
std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
|
||||||
|
verifyTrait(Operation *op) {
|
||||||
|
return T::verifyTrait(op);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
|
||||||
|
verifyTrait(Operation *) {
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given a set of traits, return the result of verifying the given operation.
|
||||||
template <typename... Ts>
|
template <typename... Ts>
|
||||||
static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
|
LogicalResult verifyTraits(Operation *op) {
|
||||||
LogicalResult result = success();
|
LogicalResult result = success();
|
||||||
(void)std::initializer_list<int>{
|
(void)std::initializer_list<int>{
|
||||||
(result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
|
(result = succeeded(result) ? verifyTrait<Ts>(op) : failure(), 0)...};
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a tuple type containing a set of traits that contain a
|
/// Verify the given trait if it provides a region verifier.
|
||||||
/// `verifyTrait` method, return the result of verifying the given operation.
|
template <typename T>
|
||||||
template <typename TraitTupleT>
|
std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
|
||||||
static LogicalResult verifyTraits(Operation *op) {
|
verifyRegionTrait(Operation *op) {
|
||||||
return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
|
return T::verifyRegionTrait(op);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
|
||||||
|
LogicalResult>
|
||||||
|
verifyRegionTrait(Operation *) {
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The internal implementation of `verifyRegionTraits` below that returns the
|
/// Given a set of traits, return the result of verifying the regions of the
|
||||||
/// result of verifying the current operation with all of the provided trait
|
/// given operation.
|
||||||
/// types `Ts`.
|
|
||||||
template <typename... Ts>
|
template <typename... Ts>
|
||||||
static LogicalResult verifyRegionTraitsImpl(Operation *op,
|
LogicalResult verifyRegionTraits(Operation *op) {
|
||||||
std::tuple<Ts...> *) {
|
|
||||||
(void)op;
|
(void)op;
|
||||||
LogicalResult result = success();
|
LogicalResult result = success();
|
||||||
(void)std::initializer_list<int>{
|
(void)std::initializer_list<int>{
|
||||||
(result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
|
(result = succeeded(result) ? verifyRegionTrait<Ts>(op) : failure(),
|
||||||
0)...};
|
0)...};
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a tuple type containing a set of traits that contain a
|
|
||||||
/// `verifyTrait` method, return the result of verifying the given operation.
|
|
||||||
template <typename TraitTupleT>
|
|
||||||
static LogicalResult verifyRegionTraits(Operation *op) {
|
|
||||||
return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
|
|
||||||
}
|
|
||||||
} // namespace op_definition_impl
|
} // namespace op_definition_impl
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1733,18 +1728,6 @@ private:
|
||||||
decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
|
decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using detect_has_print = llvm::is_detected<has_print, T>;
|
using detect_has_print = llvm::is_detected<has_print, T>;
|
||||||
/// A tuple type containing the traits that have a `foldTrait` function.
|
|
||||||
using FoldableTraitsTupleT = typename detail::FilterTypes<
|
|
||||||
op_definition_impl::detect_has_any_fold_trait,
|
|
||||||
Traits<ConcreteType>...>::type;
|
|
||||||
/// A tuple type containing the traits that have a verify function.
|
|
||||||
using VerifiableTraitsTupleT =
|
|
||||||
typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
|
|
||||||
Traits<ConcreteType>...>::type;
|
|
||||||
/// A tuple type containing the region traits that have a verify function.
|
|
||||||
using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
|
|
||||||
op_definition_impl::detect_has_verify_region_trait,
|
|
||||||
Traits<ConcreteType>...>::type;
|
|
||||||
|
|
||||||
/// Returns an interface map containing the interfaces registered to this
|
/// Returns an interface map containing the interfaces registered to this
|
||||||
/// operation.
|
/// operation.
|
||||||
|
@ -1794,8 +1777,8 @@ private:
|
||||||
return [](Operation *op, ArrayRef<Attribute> operands,
|
return [](Operation *op, ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<OpFoldResult> &results) {
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
// In this case, we only need to fold the traits of the operation.
|
// In this case, we only need to fold the traits of the operation.
|
||||||
return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
|
return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
|
||||||
results);
|
op, operands, results);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
/// Return the result of folding a single result operation that defines a
|
/// Return the result of folding a single result operation that defines a
|
||||||
|
@ -1809,7 +1792,7 @@ private:
|
||||||
// If the fold failed or was in-place, try to fold the traits of the
|
// If the fold failed or was in-place, try to fold the traits of the
|
||||||
// operation.
|
// operation.
|
||||||
if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
|
if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
|
||||||
if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
|
if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
|
||||||
op, operands, results)))
|
op, operands, results)))
|
||||||
return success();
|
return success();
|
||||||
return success(static_cast<bool>(result));
|
return success(static_cast<bool>(result));
|
||||||
|
@ -1826,7 +1809,7 @@ private:
|
||||||
// If the fold failed or was in-place, try to fold the traits of the
|
// If the fold failed or was in-place, try to fold the traits of the
|
||||||
// operation.
|
// operation.
|
||||||
if (failed(result) || results.empty()) {
|
if (failed(result) || results.empty()) {
|
||||||
if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
|
if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
|
||||||
op, operands, results)))
|
op, operands, results)))
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1879,7 +1862,7 @@ private:
|
||||||
static_assert(hasNoDataMembers(),
|
static_assert(hasNoDataMembers(),
|
||||||
"Op class shouldn't define new data members");
|
"Op class shouldn't define new data members");
|
||||||
return failure(
|
return failure(
|
||||||
failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
|
failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
|
||||||
failed(cast<ConcreteType>(op).verify()));
|
failed(cast<ConcreteType>(op).verify()));
|
||||||
}
|
}
|
||||||
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
|
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
|
||||||
|
@ -1889,9 +1872,10 @@ private:
|
||||||
static LogicalResult verifyRegionInvariants(Operation *op) {
|
static LogicalResult verifyRegionInvariants(Operation *op) {
|
||||||
static_assert(hasNoDataMembers(),
|
static_assert(hasNoDataMembers(),
|
||||||
"Op class shouldn't define new data members");
|
"Op class shouldn't define new data members");
|
||||||
return failure(failed(op_definition_impl::verifyRegionTraits<
|
return failure(
|
||||||
VerifiableRegionTraitsTupleT>(op)) ||
|
failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
|
||||||
failed(cast<ConcreteType>(op).verifyRegions()));
|
op)) ||
|
||||||
|
failed(cast<ConcreteType>(op).verifyRegions()));
|
||||||
}
|
}
|
||||||
static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
|
static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
|
||||||
return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
|
return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
|
||||||
|
|
|
@ -125,23 +125,17 @@ private:
|
||||||
// InterfaceMap
|
// InterfaceMap
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Utility to filter a given sequence of types base upon a predicate.
|
/// Template utility that computes the number of elements within `T` that
|
||||||
template <bool>
|
/// satisfy the given predicate.
|
||||||
struct FilterTypeT {
|
template <template <class> class Pred, size_t N, typename... Ts>
|
||||||
template <class E>
|
struct count_if_t_impl : public std::integral_constant<size_t, N> {};
|
||||||
using type = std::tuple<E>;
|
template <template <class> class Pred, size_t N, typename T, typename... Us>
|
||||||
};
|
struct count_if_t_impl<Pred, N, T, Us...>
|
||||||
template <>
|
: public std::integral_constant<
|
||||||
struct FilterTypeT<false> {
|
size_t,
|
||||||
template <class E>
|
count_if_t_impl<Pred, N + (Pred<T>::value ? 1 : 0), Us...>::value> {};
|
||||||
using type = std::tuple<>;
|
template <template <class> class Pred, typename... Ts>
|
||||||
};
|
using count_if_t = count_if_t_impl<Pred, 0, Ts...>;
|
||||||
template <template <class> class Pred, class... Es>
|
|
||||||
struct FilterTypes {
|
|
||||||
using type = decltype(std::tuple_cat(
|
|
||||||
std::declval<
|
|
||||||
typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// Type trait indicating whether all template arguments are
|
/// Type trait indicating whether all template arguments are
|
||||||
|
@ -171,8 +165,7 @@ class InterfaceMap {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
|
using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
|
||||||
template <typename... Types>
|
template <typename... Types>
|
||||||
using num_interface_types = typename std::tuple_size<
|
using num_interface_types_t = count_if_t<detect_get_interface_id, Types...>;
|
||||||
typename FilterTypes<detect_get_interface_id, Types...>::type>;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
InterfaceMap(InterfaceMap &&) = default;
|
InterfaceMap(InterfaceMap &&) = default;
|
||||||
|
@ -192,20 +185,17 @@ public:
|
||||||
/// types, not all of the types need to be interfaces. The provided types that
|
/// types, not all of the types need to be interfaces. The provided types that
|
||||||
/// do not represent interfaces are not added to the interface map.
|
/// do not represent interfaces are not added to the interface map.
|
||||||
template <typename... Types>
|
template <typename... Types>
|
||||||
static std::enable_if_t<num_interface_types<Types...>::value != 0,
|
static InterfaceMap get() {
|
||||||
InterfaceMap>
|
// TODO: Use constexpr if here in C++17.
|
||||||
get() {
|
constexpr size_t numInterfaces = num_interface_types_t<Types...>::value;
|
||||||
// Filter the provided types for those that are interfaces.
|
if (numInterfaces == 0)
|
||||||
using FilteredTupleType =
|
return InterfaceMap();
|
||||||
typename FilterTypes<detect_get_interface_id, Types...>::type;
|
|
||||||
return getImpl((FilteredTupleType *)nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Types>
|
std::array<std::pair<TypeID, void *>, numInterfaces> elements;
|
||||||
static std::enable_if_t<num_interface_types<Types...>::value == 0,
|
std::pair<TypeID, void *> *elementIt = elements.data();
|
||||||
InterfaceMap>
|
(void)std::initializer_list<int>{
|
||||||
get() {
|
0, (addModelAndUpdateIterator<Types>(elementIt), 0)...};
|
||||||
return InterfaceMap();
|
return InterfaceMap(elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an instance of the concept object for the given interface if it
|
/// Returns an instance of the concept object for the given interface if it
|
||||||
|
@ -235,23 +225,30 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
InterfaceMap() = default;
|
||||||
|
|
||||||
|
/// Assign the interface model of the type to the given opaque element
|
||||||
|
/// iterator and increment it.
|
||||||
|
template <typename T>
|
||||||
|
static inline std::enable_if_t<detect_get_interface_id<T>::value>
|
||||||
|
addModelAndUpdateIterator(std::pair<TypeID, void *> *&elementIt) {
|
||||||
|
*elementIt = {T::getInterfaceID(), new (malloc(sizeof(typename T::ModelT)))
|
||||||
|
typename T::ModelT()};
|
||||||
|
++elementIt;
|
||||||
|
}
|
||||||
|
/// Overload when `T` isn't an interface.
|
||||||
|
template <typename T>
|
||||||
|
static inline std::enable_if_t<!detect_get_interface_id<T>::value>
|
||||||
|
addModelAndUpdateIterator(std::pair<TypeID, void *> *&) {}
|
||||||
|
|
||||||
|
/// Insert the given set of interface models into the interface map.
|
||||||
|
void insert(ArrayRef<std::pair<TypeID, void *>> elements);
|
||||||
|
|
||||||
/// Compare two TypeID instances by comparing the underlying pointer.
|
/// Compare two TypeID instances by comparing the underlying pointer.
|
||||||
static bool compare(TypeID lhs, TypeID rhs) {
|
static bool compare(TypeID lhs, TypeID rhs) {
|
||||||
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
|
||||||
}
|
}
|
||||||
|
|
||||||
InterfaceMap() = default;
|
|
||||||
|
|
||||||
void insert(ArrayRef<std::pair<TypeID, void *>> elements);
|
|
||||||
|
|
||||||
template <typename... Ts>
|
|
||||||
static InterfaceMap getImpl(std::tuple<Ts...> *) {
|
|
||||||
std::pair<TypeID, void *> elements[] = {std::make_pair(
|
|
||||||
Ts::getInterfaceID(),
|
|
||||||
new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...};
|
|
||||||
return InterfaceMap(elements);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an instance of the concept object for the given interface id if it
|
/// Returns an instance of the concept object for the given interface id if it
|
||||||
/// was registered to this map, null otherwise.
|
/// was registered to this map, null otherwise.
|
||||||
void *lookup(TypeID id) const {
|
void *lookup(TypeID id) const {
|
||||||
|
|
Loading…
Reference in New Issue