[mlir] Optimize usage of llvm::mapped_iterator

mapped_iterator is a useful abstraction for applying a
map function over an existing iterator, but our current
usage ends up allocating storage/making indirect calls
even with the map function is a known function, which
is horribly inefficient. This commit refactors the usage
of mapped_iterator to avoid this, and allows for directly
referencing the map function when dereferencing.

Fixes PR52319

Differential Revision: https://reviews.llvm.org/D113511
This commit is contained in:
River Riddle 2021-11-11 03:26:10 +00:00
parent 919ca9fc04
commit 6de6131f02
10 changed files with 171 additions and 110 deletions

View File

@ -307,6 +307,32 @@ auto map_range(ContainerTy &&C, FuncTy F) {
return make_range(map_iterator(C.begin(), F), map_iterator(C.end(), F));
}
/// A base type of mapped iterator, that is useful for building derived
/// iterators that do not need/want to store the map function (as in
/// mapped_iterator). These iterators must simply provide a `mapElement` method
/// that defines how to map a value of the iterator to the provided reference
/// type.
template <typename DerivedT, typename ItTy, typename ReferenceTy>
class mapped_iterator_base
: public iterator_adaptor_base<
DerivedT, ItTy,
typename std::iterator_traits<ItTy>::iterator_category,
std::remove_reference_t<ReferenceTy>,
typename std::iterator_traits<ItTy>::difference_type,
std::remove_reference_t<ReferenceTy> *, ReferenceTy> {
public:
using BaseT = mapped_iterator_base<DerivedT, ItTy, ReferenceTy>;
mapped_iterator_base(ItTy U)
: mapped_iterator_base::iterator_adaptor_base(std::move(U)) {}
ItTy getCurrent() { return this->I; }
ReferenceTy operator*() const {
return static_cast<const DerivedT &>(*this).mapElement(*this->I);
}
};
/// Helper to determine if type T has a member called rbegin().
template <typename Ty> class has_rbegin_impl {
using yes = char[1];

View File

@ -47,4 +47,67 @@ TEST(MappedIteratorTest, FunctionPreservesReferences) {
EXPECT_EQ(M[1], 42) << "assignment should have modified M";
}
TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnDereference) {
struct CustomMapIterator
: public llvm::mapped_iterator_base<CustomMapIterator,
std::vector<int>::iterator, int> {
using BaseT::BaseT;
/// Map the element to the iterator result type.
int mapElement(int X) const { return X + 1; }
};
std::vector<int> V({0});
CustomMapIterator I(V.begin());
EXPECT_EQ(*I, 1) << "should have applied function in dereference";
}
TEST(MappedIteratorTest, CustomIteratorApplyFunctionOnArrow) {
struct S {
int Z = 0;
};
struct CustomMapIterator
: public llvm::mapped_iterator_base<CustomMapIterator,
std::vector<int>::iterator, S &> {
CustomMapIterator(std::vector<int>::iterator it, S *P) : BaseT(it), P(P) {}
/// Map the element to the iterator result type.
S &mapElement(int X) const { return *(P + X); }
S *P;
};
std::vector<int> V({0});
S Y;
CustomMapIterator I(V.begin(), &Y);
I->Z = 42;
EXPECT_EQ(Y.Z, 42) << "should have applied function during arrow";
}
TEST(MappedIteratorTest, CustomIteratorFunctionPreservesReferences) {
struct CustomMapIterator
: public llvm::mapped_iterator_base<CustomMapIterator,
std::vector<int>::iterator, int &> {
CustomMapIterator(std::vector<int>::iterator it, std::map<int, int> &M)
: BaseT(it), M(M) {}
/// Map the element to the iterator result type.
int &mapElement(int X) const { return M[X]; }
std::map<int, int> &M;
};
std::vector<int> V({1});
std::map<int, int> M({{1, 1}});
auto I = CustomMapIterator(V.begin(), M);
*I = 42;
EXPECT_EQ(M[1], 42) << "assignment should have modified M";
}
} // anonymous namespace

View File

@ -323,24 +323,46 @@ public:
/// Iterator for walking over APFloat values.
class FloatElementIterator final
: public llvm::mapped_iterator<IntElementIterator,
std::function<APFloat(const APInt &)>> {
: public llvm::mapped_iterator_base<FloatElementIterator,
IntElementIterator, APFloat> {
public:
/// Map the element to the iterator result type.
APFloat mapElement(const APInt &value) const {
return APFloat(*smt, value);
}
private:
friend DenseElementsAttr;
/// Initializes the float element iterator to the specified iterator.
FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it);
FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it)
: BaseT(it), smt(&smt) {}
/// The float semantics to use when constructing the APFloat.
const llvm::fltSemantics *smt;
};
/// Iterator for walking over complex APFloat values.
class ComplexFloatElementIterator final
: public llvm::mapped_iterator<
ComplexIntElementIterator,
std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
: public llvm::mapped_iterator_base<ComplexFloatElementIterator,
ComplexIntElementIterator,
std::complex<APFloat>> {
public:
/// Map the element to the iterator result type.
std::complex<APFloat> mapElement(const std::complex<APInt> &value) const {
return {APFloat(*smt, value.real()), APFloat(*smt, value.imag())};
}
private:
friend DenseElementsAttr;
/// Initializes the float element iterator to the specified iterator.
ComplexFloatElementIterator(const llvm::fltSemantics &smt,
ComplexIntElementIterator it);
ComplexIntElementIterator it)
: BaseT(it), smt(&smt) {}
/// The float semantics to use when constructing the APFloat.
const llvm::fltSemantics *smt;
};
//===--------------------------------------------------------------------===//
@ -478,24 +500,27 @@ public:
typename std::enable_if<std::is_base_of<Attribute, T>::value &&
!std::is_same<Attribute, T>::value>::type;
template <typename T>
using DerivedAttributeElementIterator =
llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
struct DerivedAttributeElementIterator
: public llvm::mapped_iterator_base<DerivedAttributeElementIterator<T>,
AttributeElementIterator, T> {
using DerivedAttributeElementIterator::BaseT::BaseT;
/// Map the element to the iterator result type.
T mapElement(Attribute attr) const { return attr.cast<T>(); }
};
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return {Attribute::getType(),
llvm::map_range(getValues<Attribute>(),
static_cast<T (*)(Attribute)>(castFn))};
using DerivedIterT = DerivedAttributeElementIterator<T>;
return {Attribute::getType(), DerivedIterT(value_begin<Attribute>()),
DerivedIterT(value_end<Attribute>())};
}
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
DerivedAttributeElementIterator<T> value_begin() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
return {value_begin<Attribute>()};
}
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
DerivedAttributeElementIterator<T> value_end() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
return {value_end<Attribute>()};
}
/// Return the held element values as a range of bool. The element type of

View File

@ -155,20 +155,6 @@ inline raw_ostream &operator<<(raw_ostream &os, const DiagnosticArgument &arg) {
class Diagnostic {
using NoteVector = std::vector<std::unique_ptr<Diagnostic>>;
/// This class implements a wrapper iterator around NoteVector::iterator to
/// implicitly dereference the unique_ptr.
template <typename IteratorTy, typename NotePtrTy = decltype(*IteratorTy()),
typename ResultTy = decltype(**IteratorTy())>
class NoteIteratorImpl
: public llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)> {
static ResultTy &unwrap(NotePtrTy note) { return *note; }
public:
NoteIteratorImpl(IteratorTy it)
: llvm::mapped_iterator<IteratorTy, ResultTy (*)(NotePtrTy)>(it,
&unwrap) {}
};
public:
Diagnostic(Location loc, DiagnosticSeverity severity)
: loc(loc), severity(severity) {}
@ -262,15 +248,16 @@ public:
/// diagnostic. Notes may not be attached to other notes.
Diagnostic &attachNote(Optional<Location> noteLoc = llvm::None);
using note_iterator = NoteIteratorImpl<NoteVector::iterator>;
using const_note_iterator = NoteIteratorImpl<NoteVector::const_iterator>;
using note_iterator = llvm::pointee_iterator<NoteVector::iterator>;
using const_note_iterator =
llvm::pointee_iterator<NoteVector::const_iterator>;
/// Returns the notes held by this diagnostic.
iterator_range<note_iterator> getNotes() {
return {notes.begin(), notes.end()};
return llvm::make_pointee_range(notes);
}
iterator_range<const_note_iterator> getNotes() const {
return {notes.begin(), notes.end()};
return llvm::make_pointee_range(notes);
}
/// Allow a diagnostic to be converted to 'failure'.

View File

@ -111,20 +111,16 @@ protected:
/// An iterator class that iterates the held interface objects of the given
/// derived interface type.
template <typename InterfaceT>
class iterator : public llvm::mapped_iterator<
InterfaceVectorT::const_iterator,
const InterfaceT &(*)(const DialectInterface *)> {
static const InterfaceT &remapIt(const DialectInterface *interface) {
struct iterator
: public llvm::mapped_iterator_base<iterator<InterfaceT>,
InterfaceVectorT::const_iterator,
const InterfaceT &> {
using iterator::BaseT::BaseT;
/// Map the element to the iterator result type.
const InterfaceT &mapElement(const DialectInterface *interface) const {
return *static_cast<const InterfaceT *>(interface);
}
iterator(InterfaceVectorT::const_iterator it)
: llvm::mapped_iterator<
InterfaceVectorT::const_iterator,
const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
/// Allow access to the constructor.
friend DialectInterfaceCollectionBase;
};
/// Iterator access to the held interfaces.

View File

@ -124,16 +124,13 @@ private:
/// This class implements iteration on the types of a given range of values.
template <typename ValueIteratorT>
class ValueTypeIterator final
: public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
static Type unwrap(Value value) { return value.getType(); }
: public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
ValueIteratorT, Type> {
public:
/// Provide a const dereference method.
Type operator*() const { return unwrap(*this->I); }
using ValueTypeIterator::BaseT::BaseT;
/// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)
: llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
/// Map the element to the iterator result type.
Type mapElement(Value value) const { return value.getType(); }
};
/// This class implements iteration on the types of a given range of values.

View File

@ -66,36 +66,33 @@ LogicalResult verifyCompatibleShapes(TypeRange types);
/// Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims);
//===----------------------------------------------------------------------===//
// Utility Iterators
//===----------------------------------------------------------------------===//
// An iterator for the element types of an op's operands of shaped types.
class OperandElementTypeIterator final
: public llvm::mapped_iterator<Operation::operand_iterator,
Type (*)(Value)> {
: public llvm::mapped_iterator_base<OperandElementTypeIterator,
Operation::operand_iterator, Type> {
public:
/// Initializes the result element type iterator to the specified operand
/// iterator.
explicit OperandElementTypeIterator(Operation::operand_iterator it);
using BaseT::BaseT;
private:
static Type unwrap(Value value);
/// Map the element to the iterator result type.
Type mapElement(Value value) const;
};
using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>;
// An iterator for the tensor element types of an op's results of shaped types.
class ResultElementTypeIterator final
: public llvm::mapped_iterator<Operation::result_iterator,
Type (*)(Value)> {
: public llvm::mapped_iterator_base<ResultElementTypeIterator,
Operation::result_iterator, Type> {
public:
/// Initializes the result element type iterator to the specified result
/// iterator.
explicit ResultElementTypeIterator(Operation::result_iterator it);
using BaseT::BaseT;
private:
static Type unwrap(Value value);
/// Map the element to the iterator result type.
Type mapElement(Value value) const;
};
using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>;

View File

@ -281,15 +281,16 @@ protected:
/// a specific use iterator.
template <typename UseIteratorT, typename OperandType>
class ValueUserIterator final
: public llvm::mapped_iterator<UseIteratorT,
Operation *(*)(OperandType &)> {
static Operation *unwrap(OperandType &value) { return value.getOwner(); }
: public llvm::mapped_iterator_base<
ValueUserIterator<UseIteratorT, OperandType>, UseIteratorT,
Operation *> {
public:
/// Initializes the user iterator to the specified use iterator.
ValueUserIterator(UseIteratorT it)
: llvm::mapped_iterator<UseIteratorT, Operation *(*)(OperandType &)>(
it, &unwrap) {}
using ValueUserIterator::BaseT::BaseT;
/// Map the element to the iterator result type.
Operation *mapElement(OperandType &value) const { return value.getOwner(); }
/// Provide access to the underlying operation.
Operation *operator->() { return **this; }
};

View File

@ -667,27 +667,6 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
readBits(getData(), offset + storageWidth, bitWidth)};
}
//===----------------------------------------------------------------------===//
// FloatElementIterator
DenseElementsAttr::FloatElementIterator::FloatElementIterator(
const llvm::fltSemantics &smt, IntElementIterator it)
: llvm::mapped_iterator<IntElementIterator,
std::function<APFloat(const APInt &)>>(
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
//===----------------------------------------------------------------------===//
// ComplexFloatElementIterator
DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
const llvm::fltSemantics &smt, ComplexIntElementIterator it)
: llvm::mapped_iterator<
ComplexIntElementIterator,
std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
}) {}
//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//

View File

@ -151,20 +151,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
return success();
}
OperandElementTypeIterator::OperandElementTypeIterator(
Operation::operand_iterator it)
: llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
it, &unwrap) {}
Type OperandElementTypeIterator::unwrap(Value value) {
Type OperandElementTypeIterator::mapElement(Value value) const {
return value.getType().cast<ShapedType>().getElementType();
}
ResultElementTypeIterator::ResultElementTypeIterator(
Operation::result_iterator it)
: llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>(
it, &unwrap) {}
Type ResultElementTypeIterator::unwrap(Value value) {
Type ResultElementTypeIterator::mapElement(Value value) const {
return value.getType().cast<ShapedType>().getElementType();
}