forked from OSchip/llvm-project
[mlir] Fix indexed_accessor_range to properly forward the derived class.
Summary: This fixes the return value of helper methods on the base range class. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D72127
This commit is contained in:
parent
21309eafde
commit
0d9ca98c1a
|
@ -598,8 +598,8 @@ public:
|
|||
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
|
||||
|
||||
private:
|
||||
/// See `detail::indexed_accessor_range_base` for details.
|
||||
static OpResult dereference_iterator(Operation *op, ptrdiff_t index);
|
||||
/// See `indexed_accessor_range` for details.
|
||||
static OpResult dereference(Operation *op, ptrdiff_t index);
|
||||
|
||||
/// Allow access to `dereference_iterator`.
|
||||
friend indexed_accessor_range<ResultRange, Operation *, OpResult, OpResult,
|
||||
|
|
|
@ -222,6 +222,8 @@ public:
|
|||
count(end.getIndex() - begin.getIndex()) {}
|
||||
indexed_accessor_range_base(const iterator_range<iterator> &range)
|
||||
: indexed_accessor_range_base(range.begin(), range.end()) {}
|
||||
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
|
||||
: base(base), count(count) {}
|
||||
|
||||
iterator begin() const { return iterator(base, 0); }
|
||||
iterator end() const { return iterator(base, count); }
|
||||
|
@ -267,8 +269,6 @@ public:
|
|||
}
|
||||
|
||||
protected:
|
||||
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
|
||||
: base(base), count(count) {}
|
||||
indexed_accessor_range_base(const indexed_accessor_range_base &) = default;
|
||||
indexed_accessor_range_base(indexed_accessor_range_base &&) = default;
|
||||
indexed_accessor_range_base &
|
||||
|
@ -286,18 +286,20 @@ protected:
|
|||
/// bases that are offsetable should derive from indexed_accessor_range_base
|
||||
/// instead. Derived range classes are expected to implement the following
|
||||
/// static method:
|
||||
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
|
||||
/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index)
|
||||
/// - Derefence an iterator pointing to a parent base at the given index.
|
||||
template <typename DerivedT, typename BaseT, typename T,
|
||||
typename PointerT = T *, typename ReferenceT = T &>
|
||||
class indexed_accessor_range
|
||||
: public detail::indexed_accessor_range_base<
|
||||
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
|
||||
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
|
||||
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
|
||||
public:
|
||||
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
|
||||
: detail::indexed_accessor_range_base<
|
||||
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
|
||||
std::make_pair(base, startIndex), count) {}
|
||||
using detail::indexed_accessor_range_base<
|
||||
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
|
||||
std::pair<BaseT, ptrdiff_t>, T, PointerT,
|
||||
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT,
|
||||
ReferenceT>::indexed_accessor_range_base;
|
||||
|
||||
/// Returns the current base of the range.
|
||||
|
@ -306,14 +308,6 @@ public:
|
|||
/// Returns the current start index of the range.
|
||||
ptrdiff_t getStartIndex() const { return this->base.second; }
|
||||
|
||||
protected:
|
||||
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
|
||||
: detail::indexed_accessor_range_base<
|
||||
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
|
||||
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
|
||||
std::make_pair(base, startIndex), count) {}
|
||||
|
||||
private:
|
||||
/// See `detail::indexed_accessor_range_base` for details.
|
||||
static std::pair<BaseT, ptrdiff_t>
|
||||
offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
|
||||
|
@ -325,13 +319,8 @@ private:
|
|||
static ReferenceT
|
||||
dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
|
||||
ptrdiff_t index) {
|
||||
return DerivedT::dereference_iterator(base.first, base.second + index);
|
||||
return DerivedT::dereference(base.first, base.second + index);
|
||||
}
|
||||
|
||||
/// Allow access to `offset_base` and `dereference_iterator`.
|
||||
friend detail::indexed_accessor_range_base<
|
||||
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
|
||||
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>;
|
||||
};
|
||||
|
||||
/// Given a container of pairs, return a range over the second elements.
|
||||
|
|
|
@ -152,8 +152,8 @@ OperandRange::OperandRange(Operation *op)
|
|||
ResultRange::ResultRange(Operation *op)
|
||||
: ResultRange(op, /*startIndex=*/0, op->getNumResults()) {}
|
||||
|
||||
/// See `detail::indexed_accessor_range_base` for details.
|
||||
OpResult ResultRange::dereference_iterator(Operation *op, ptrdiff_t index) {
|
||||
/// See `indexed_accessor_range` for details.
|
||||
OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) {
|
||||
return op->getResult(index);
|
||||
}
|
||||
|
||||
|
|
|
@ -10,4 +10,5 @@ add_subdirectory(Dialect)
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(SDBM)
|
||||
add_subdirectory(Support)
|
||||
add_subdirectory(TableGen)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
add_mlir_unittest(MLIRSupportTests
|
||||
IndexedAccessorTest.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRSupportTests
|
||||
PRIVATE MLIRSupport)
|
|
@ -0,0 +1,49 @@
|
|||
//===- IndexedAccessorTest.cpp - Indexed Accessor Tests -------------------===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
namespace {
|
||||
/// Simple indexed accessor range that wraps an array.
|
||||
template <typename T>
|
||||
struct ArrayIndexedAccessorRange
|
||||
: public indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *, T> {
|
||||
ArrayIndexedAccessorRange(T *data, ptrdiff_t start, ptrdiff_t numElements)
|
||||
: indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *, T>(
|
||||
data, start, numElements) {}
|
||||
using indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *,
|
||||
T>::indexed_accessor_range;
|
||||
|
||||
/// See `indexed_accessor_range` for details.
|
||||
static T &dereference(T *data, ptrdiff_t index) { return data[index]; }
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
template <typename T>
|
||||
static void compareData(ArrayIndexedAccessorRange<T> range,
|
||||
ArrayRef<T> referenceData) {
|
||||
ASSERT_TRUE(referenceData.size() == range.size());
|
||||
ASSERT_TRUE(std::equal(range.begin(), range.end(), referenceData.begin()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
TEST(AccessorRange, SliceTest) {
|
||||
int rawData[] = {0, 1, 2, 3, 4};
|
||||
ArrayRef<int> data = llvm::makeArrayRef(rawData);
|
||||
|
||||
ArrayIndexedAccessorRange<int> range(rawData, /*start=*/0, /*numElements=*/5);
|
||||
compareData(range, data);
|
||||
compareData(range.slice(2, 3), data.slice(2, 3));
|
||||
compareData(range.slice(0, 5), data.slice(0, 5));
|
||||
}
|
||||
} // end anonymous namespace
|
Loading…
Reference in New Issue