forked from OSchip/llvm-project
[mlir] Optimize usage of llvm::mapped_iterator
mapped_iterator is a useful abstraction for applying a map function over an existing iterator, but our current usage ends up allocating storage/making indirect calls even with the map function is a known function, which is horribly inefficient. This commit refactors the usage of mapped_iterator to avoid this, and allows for directly referencing the map function when dereferencing. Fixes PR52319 Differential Revision: https://reviews.llvm.org/D113511
This commit is contained in:
parent
919ca9fc04
commit
6de6131f02
|
@ -307,6 +307,32 @@ auto map_range(ContainerTy &&C, FuncTy F) {
|
|||
return make_range(map_iterator(C.begin(), F), map_iterator(C.end(), F));
|
||||
}
|
||||
|
||||
/// A base type of mapped iterator, that is useful for building derived
|
||||
/// iterators that do not need/want to store the map function (as in
|
||||
/// mapped_iterator). These iterators must simply provide a `mapElement` method
|
||||
/// that defines how to map a value of the iterator to the provided reference
|
||||
/// type.
|
||||
template <typename DerivedT, typename ItTy, typename ReferenceTy>
|
||||
class mapped_iterator_base
|
||||
: public iterator_adaptor_base<
|
||||
DerivedT, ItTy,
|
||||
typename std::iterator_traits<ItTy>::iterator_category,
|
||||
std::remove_reference_t<ReferenceTy>,
|
||||
typename std::iterator_traits<ItTy>::difference_type,
|
||||
std::remove_reference_t<ReferenceTy> *, ReferenceTy> {
|
||||
public:
|
||||
using BaseT = mapped_iterator_base<DerivedT, ItTy, ReferenceTy>;
|
||||
|
||||
mapped_iterator_base(ItTy U)
|
||||
: mapped_iterator_base::iterator_adaptor_base(std::move(U)) {}
|
||||
|
||||
ItTy getCurrent() { return this->I; }
|
||||
|
||||
ReferenceTy operator*() const {
|
||||
return static_cast<const DerivedT &>(*this).mapElement(*this->I);
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper to determine if type T has a member called rbegin().
|
||||
template <typename Ty> class has_rbegin_impl {
|
||||
using yes = char[1];
|
||||
|
|
|
@ -47,4 +47,67 @@ TEST(MappedIteratorTest, FunctionPreservesReferences) {
|
|||
EXPECT_EQ(M[1], 42) << "assignment should have modified M";
|
||||
}
|
||||
|
||||
TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnDereference) {
|
||||
struct CustomMapIterator
|
||||
: public llvm::mapped_iterator_base<CustomMapIterator,
|
||||
std::vector<int>::iterator, int> {
|
||||
using BaseT::BaseT;
|
||||
|
||||
/// Map the element to the iterator result type.
|
||||
int mapElement(int X) const { return X + 1; }
|
||||
};
|
||||
|
||||
std::vector<int> V({0});
|
||||
|
||||
CustomMapIterator I(V.begin());
|
||||
|
||||
EXPECT_EQ(*I, 1) << "should have applied function in dereference";
|
||||
}
|
||||
|
||||
TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnArrow) {
|
||||
struct S {
|
||||
int Z = 0;
|
||||
};
|
||||
struct CustomMapIterator
|
||||
: public llvm::mapped_iterator_base<CustomMapIterator,
|
||||
std::vector<int>::iterator, S &> {
|
||||
CustomMapIterator(std::vector<int>::iterator it, S *P) : BaseT(it), P(P) {}
|
||||
|
||||
/// Map the element to the iterator result type.
|
||||
S &mapElement(int X) const { return *(P + X); }
|
||||
|
||||
S *P;
|
||||
};
|
||||
|
||||
std::vector<int> V({0});
|
||||
S Y;
|
||||
|
||||
CustomMapIterator I(V.begin(), &Y);
|
||||
|
||||
I->Z = 42;
|
||||
|
||||
EXPECT_EQ(Y.Z, 42) << "should have applied function during arrow";
|
||||
}
|
||||
|
||||
TEST(MappedIteratorTest, CustomIteratorFunctionPreservesReferences) {
|
||||
struct CustomMapIterator
|
||||
: public llvm::mapped_iterator_base<CustomMapIterator,
|
||||
std::vector<int>::iterator, int &> {
|
||||
CustomMapIterator(std::vector<int>::iterator it, std::map<int, int> &M)
|
||||
: BaseT(it), M(M) {}
|
||||
|
||||
/// Map the element to the iterator result type.
|
||||
int &mapElement(int X) const { return M[X]; }
|
||||
|
||||
std::map<int, int> &M;
|
||||
};
|
||||
std::vector<int> V({1});
|
||||
std::map<int, int> M({{1, 1}});
|
||||
|
||||
auto I = CustomMapIterator(V.begin(), M);
|
||||
*I = 42;
|
||||
|
||||
EXPECT_EQ(M[1], 42) << "assignment should have modified M";
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
|
|
@ -323,24 +323,46 @@ public:
|
|||
|
||||
/// Iterator for walking over APFloat values.
|
||||
class FloatElementIterator final
|
||||
: public llvm::mapped_iterator<IntElementIterator,
|
||||
std::function<APFloat(const APInt &)>> {
|
||||
: public llvm::mapped_iterator_base<FloatElementIterator,
|
||||
IntElementIterator, APFloat> {
|
||||
public:
|
||||
/// Map the element to the iterator result type.
|
||||
APFloat mapElement(const APInt &value) const {
|
||||
return APFloat(*smt, value);
|
||||
}
|
||||
|
||||
private:
|
||||
friend DenseElementsAttr;
|
||||
|
||||
/// Initializes the float element iterator to the specified iterator.
|
||||
FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
|
||||
FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it)
|
||||
: BaseT(it), smt(&smt) {}
|
||||
|
||||
/// The float semantics to use when constructing the APFloat.
|
||||
const llvm::fltSemantics *smt;
|
||||
};
|
||||
|
||||
/// Iterator for walking over complex APFloat values.
|
||||
class ComplexFloatElementIterator final
|
||||
: public llvm::mapped_iterator<
|
||||
ComplexIntElementIterator,
|
||||
std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
|
||||
: public llvm::mapped_iterator_base<ComplexFloatElementIterator,
|
||||
ComplexIntElementIterator,
|
||||
std::complex<APFloat>> {
|
||||
public:
|
||||
/// Map the element to the iterator result type.
|
||||
std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
|
||||
return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
|
||||
}
|
||||
|
||||
private:
|
||||
friend DenseElementsAttr;
|
||||
|
||||
/// Initializes the float element iterator to the specified iterator.
|
||||
ComplexFloatElementIterator(const llvm::fltSemantics &smt,
|
||||
ComplexIntElementIterator it);
|
||||
ComplexIntElementIterator it)
|
||||
: BaseT(it), smt(&smt) {}
|
||||
|
||||
/// The float semantics to use when constructing the APFloat.
|
||||
const llvm::fltSemantics *smt;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -478,24 +500,27 @@ public:
|
|||
typename std::enable_if<std::is_base_of<Attribute, T>::value &&
|
||||
!std::is_same<Attribute, T>::value>::type;
|
||||
template <typename T>
|
||||
using DerivedAttributeElementIterator =
|
||||
llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
|
||||
struct DerivedAttributeElementIterator
|
||||
: public llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
|
||||
AttributeElementIterator, T> {
|
||||
using DerivedAttributeElementIterator::BaseT::BaseT;
|
||||
|
||||
/// Map the element to the iterator result type.
|
||||
T mapElement(Attribute attr) const { return attr.cast<T>(); }
|
||||
};
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return {Attribute::getType(),
|
||||
llvm::map_range(getValues<Attribute>(),
|
||||
static_cast<T (*)(Attribute)>(castFn))};
|
||||
using DerivedIterT = DerivedAttributeElementIterator<T>;
|
||||
return {Attribute::getType(), DerivedIterT(value_begin<Attribute>()),
|
||||
DerivedIterT(value_end<Attribute>())};
|
||||
}
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
DerivedAttributeElementIterator<T> value_begin() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
|
||||
return {value_begin<Attribute>()};
|
||||
}
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
DerivedAttributeElementIterator<T> value_end() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
|
||||
return {value_end<Attribute>()};
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of bool. The element type of
|
||||
|
|
|
@ -155,20 +155,6 @@ inline raw_ostream &operator<<(raw_ostream &os, const DiagnosticArgument &arg) {
|
|||
class Diagnostic {
|
||||
using NoteVector = std::vector<std::unique_ptr<Diagnostic>>;
|
||||
|
||||
/// This class implements a wrapper iterator around NoteVector::iterator to
|
||||
/// implicitly dereference the unique_ptr.
|
||||
template <typename IteratorTy, typename NotePtrTy = decltype(*IteratorTy()),
|
||||
typename ResultTy = decltype(**IteratorTy())>
|
||||
class NoteIteratorImpl
|
||||
: public llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)> {
|
||||
static ResultTy &unwrap(NotePtrTy note) { return *note; }
|
||||
|
||||
public:
|
||||
NoteIteratorImpl(IteratorTy it)
|
||||
: llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)>(it,
|
||||
&unwrap) {}
|
||||
};
|
||||
|
||||
public:
|
||||
Diagnostic(Location loc, DiagnosticSeverity severity)
|
||||
: loc(loc), severity(severity) {}
|
||||
|
@ -262,15 +248,16 @@ public:
|
|||
/// diagnostic. Notes may not be attached to other notes.
|
||||
Diagnostic &attachNote(Optional<Location> noteLoc = llvm::None);
|
||||
|
||||
using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
|
||||
using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
|
||||
using note_iterator = llvm::pointee_iterator<NoteVector::iterator>;
|
||||
using const_note_iterator =
|
||||
llvm::pointee_iterator<NoteVector::const_iterator>;
|
||||
|
||||
/// Returns the notes held by this diagnostic.
|
||||
iterator_range<note_iterator> getNotes() {
|
||||
return {notes.begin(), notes.end()};
|
||||
return llvm::make_pointee_range(notes);
|
||||
}
|
||||
iterator_range<const_note_iterator> getNotes() const {
|
||||
return {notes.begin(), notes.end()};
|
||||
return llvm::make_pointee_range(notes);
|
||||
}
|
||||
|
||||
/// Allow a diagnostic to be converted to 'failure'.
|
||||
|
|
|
@ -111,20 +111,16 @@ protected:
|
|||
/// An iterator class that iterates the held interface objects of the given
|
||||
/// derived interface type.
|
||||
template <typename InterfaceT>
|
||||
class iterator : public llvm::mapped_iterator<
|
||||
InterfaceVectorT::const_iterator,
|
||||
const InterfaceT &(*)(const DialectInterface *)> {
|
||||
static const InterfaceT &remapIt(const DialectInterface *interface) {
|
||||
struct iterator
|
||||
: public llvm::mapped_iterator_base<iterator<InterfaceT>,
|
||||
InterfaceVectorT::const_iterator,
|
||||
const InterfaceT &> {
|
||||
using iterator::BaseT::BaseT;
|
||||
|
||||
/// Map the element to the iterator result type.
|
||||
const InterfaceT &mapElement(const DialectInterface *interface) const {
|
||||
return *static_cast<const InterfaceT *>(interface);
|
||||
}
|
||||
|
||||
iterator(InterfaceVectorT::const_iterator it)
|
||||
: llvm::mapped_iterator<
|
||||
InterfaceVectorT::const_iterator,
|
||||
const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
|
||||
|
||||
/// Allow access to the constructor.
|
||||
friend DialectInterfaceCollectionBase;
|
||||
};
|
||||
|
||||
/// Iterator access to the held interfaces.
|
||||
|
|
|
@ -124,16 +124,13 @@ private:
|
|||
/// This class implements iteration on the types of a given range of values.
|
||||
template <typename ValueIteratorT>
|
||||
class ValueTypeIterator final
|
||||
: public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
|
||||
static Type unwrap(Value value) { return value.getType(); }
|
||||
|
||||
: public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
|
||||
ValueIteratorT, Type> {
|
||||
public:
|
||||
/// Provide a const dereference method.
|
||||
Type operator*() const { return unwrap(*this->I); }
|
||||
using ValueTypeIterator::BaseT::BaseT;
|
||||
|
||||
/// Initializes the type iterator to the specified value iterator.
|
||||
ValueTypeIterator(ValueIteratorT it)
|
||||
: llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
|
||||
/// Map the element to the iterator result type.
|
||||
Type mapElement(Value value) const { return value.getType(); }
|
||||
};
|
||||
|
||||
/// This class implements iteration on the types of a given range of values.
|
||||
|
|
|
@ -66,36 +66,33 @@ LogicalResult verifyCompatibleShapes(TypeRange types);
|
|||
|
||||
/// Dimensions are compatible if all non-dynamic dims are equal.
|
||||
LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Iterators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// An iterator for the element types of an op's operands of shaped types.
|
||||
class OperandElementTypeIterator final
|
||||
: public llvm::mapped_iterator<Operation::operand_iterator,
|
||||
Type (*)(Value)> {
|
||||
: public llvm::mapped_iterator_base<OperandElementTypeIterator,
|
||||
Operation::operand_iterator, Type> {
|
||||
public:
|
||||
/// Initializes the result element type iterator to the specified operand
|
||||
/// iterator.
|
||||
explicit OperandElementTypeIterator(Operation::operand_iterator it);
|
||||
using BaseT::BaseT;
|
||||
|
||||
private:
|
||||
static Type unwrap(Value value);
|
||||
/// Map the element to the iterator result type.
|
||||
Type mapElement(Value value) const;
|
||||
};
|
||||
|
||||
using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>;
|
||||
|
||||
// An iterator for the tensor element types of an op's results of shaped types.
|
||||
class ResultElementTypeIterator final
|
||||
: public llvm::mapped_iterator<Operation::result_iterator,
|
||||
Type (*)(Value)> {
|
||||
: public llvm::mapped_iterator_base<ResultElementTypeIterator,
|
||||
Operation::result_iterator, Type> {
|
||||
public:
|
||||
/// Initializes the result element type iterator to the specified result
|
||||
/// iterator.
|
||||
explicit ResultElementTypeIterator(Operation::result_iterator it);
|
||||
using BaseT::BaseT;
|
||||
|
||||
private:
|
||||
static Type unwrap(Value value);
|
||||
/// Map the element to the iterator result type.
|
||||
Type mapElement(Value value) const;
|
||||
};
|
||||
|
||||
using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>;
|
||||
|
|
|
@ -281,15 +281,16 @@ protected:
|
|||
/// a specific use iterator.
|
||||
template <typename UseIteratorT, typename OperandType>
|
||||
class ValueUserIterator final
|
||||
: public llvm::mapped_iterator<UseIteratorT,
|
||||
Operation *(*)(OperandType &)> {
|
||||
static Operation *unwrap(OperandType &value) { return value.getOwner(); }
|
||||
|
||||
: public llvm::mapped_iterator_base<
|
||||
ValueUserIterator<UseIteratorT, OperandType>, UseIteratorT,
|
||||
Operation *> {
|
||||
public:
|
||||
/// Initializes the user iterator to the specified use iterator.
|
||||
ValueUserIterator(UseIteratorT it)
|
||||
: llvm::mapped_iterator<UseIteratorT, Operation *(*)(OperandType &)>(
|
||||
it, &unwrap) {}
|
||||
using ValueUserIterator::BaseT::BaseT;
|
||||
|
||||
/// Map the element to the iterator result type.
|
||||
Operation *mapElement(OperandType &value) const { return value.getOwner(); }
|
||||
|
||||
/// Provide access to the underlying operation.
|
||||
Operation *operator->() { return **this; }
|
||||
};
|
||||
|
||||
|
|
|
@ -667,27 +667,6 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
|
|||
readBits(getData(), offset + storageWidth, bitWidth)};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FloatElementIterator
|
||||
|
||||
DenseElementsAttr::FloatElementIterator::FloatElementIterator(
|
||||
const llvm::fltSemantics &smt, IntElementIterator it)
|
||||
: llvm::mapped_iterator<IntElementIterator,
|
||||
std::function<APFloat(const APInt &)>>(
|
||||
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ComplexFloatElementIterator
|
||||
|
||||
DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
|
||||
const llvm::fltSemantics &smt, ComplexIntElementIterator it)
|
||||
: llvm::mapped_iterator<
|
||||
ComplexIntElementIterator,
|
||||
std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
|
||||
it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
|
||||
return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
|
||||
}) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -151,20 +151,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
|
|||
return success();
|
||||
}
|
||||
|
||||
OperandElementTypeIterator::OperandElementTypeIterator(
|
||||
Operation::operand_iterator it)
|
||||
: llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
|
||||
it, &unwrap) {}
|
||||
|
||||
Type OperandElementTypeIterator::unwrap(Value value) {
|
||||
Type OperandElementTypeIterator::mapElement(Value value) const {
|
||||
return value.getType().cast<ShapedType>().getElementType();
|
||||
}
|
||||
|
||||
ResultElementTypeIterator::ResultElementTypeIterator(
|
||||
Operation::result_iterator it)
|
||||
: llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>(
|
||||
it, &unwrap) {}
|
||||
|
||||
Type ResultElementTypeIterator::unwrap(Value value) {
|
||||
Type ResultElementTypeIterator::mapElement(Value value) const {
|
||||
return value.getType().cast<ShapedType>().getElementType();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue