Refactor the various operand/result/type iterators to use indexed_accessor_range.

This has several benefits:
* The implementation is much cleaner and more efficient.
* The ranges now have support for many useful operations: operator[], slice, drop_front, size, etc.
* Value ranges can now directly query a range for their types via 'getTypes()': e.g:
   void foo(Operation::operand_range operands) {
     auto operandTypes = operands.getTypes();
   }

PiperOrigin-RevId: 284834912
This commit is contained in:
River Riddle 2019-12-10 13:20:50 -08:00 committed by A. Unique TensorFlower
parent b19fed5415
commit 9ed22ae5b8
10 changed files with 279 additions and 263 deletions

View File

@ -68,14 +68,12 @@ class SuccessorRange final
: public detail::indexed_accessor_range_base<SuccessorRange, BlockOperand *,
Block *, Block *, Block *> {
public:
using detail::indexed_accessor_range_base<
SuccessorRange, BlockOperand *, Block *, Block *,
Block *>::indexed_accessor_range_base;
using RangeBaseT::RangeBaseT;
SuccessorRange(Block *block);
private:
/// See `detail::indexed_accessor_range_base` for details.
static BlockOperand *offset_object(BlockOperand *object, ptrdiff_t index) {
static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) {
return object + index;
}
/// See `detail::indexed_accessor_range_base` for details.
@ -83,9 +81,8 @@ private:
return object[index].get();
}
/// Allow access to `offset_object` and `dereference_iterator`.
friend detail::indexed_accessor_range_base<SuccessorRange, BlockOperand *,
Block *, Block *, Block *>;
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
} // end namespace mlir

View File

@ -29,15 +29,6 @@
#include "llvm/ADT/Twine.h"
namespace mlir {
class BlockAndValueMapping;
class Location;
class MLIRContext;
class OperandIterator;
class OperandTypeIterator;
struct OperationState;
class ResultIterator;
class ResultTypeIterator;
/// Terminator operations can have Block operands to represent successors.
using BlockOperand = IROperandImpl<Block>;
@ -230,14 +221,14 @@ public:
}
// Support operand iteration.
using operand_iterator = OperandIterator;
using operand_range = llvm::iterator_range<operand_iterator>;
using operand_range = OperandRange;
using operand_iterator = operand_range::iterator;
operand_iterator operand_begin();
operand_iterator operand_end();
operand_iterator operand_begin() { return getOperands().begin(); }
operand_iterator operand_end() { return getOperands().end(); }
/// Returns an iterator on the underlying Value's (Value *).
operand_range getOperands();
operand_range getOperands() { return operand_range(this); }
/// Erase the operand at position `idx`.
void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
@ -249,11 +240,11 @@ public:
OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
// Support operand type iteration.
using operand_type_iterator = OperandTypeIterator;
using operand_type_range = llvm::iterator_range<operand_type_iterator>;
operand_type_iterator operand_type_begin();
operand_type_iterator operand_type_end();
operand_type_range getOperandTypes();
using operand_type_iterator = operand_range::type_iterator;
using operand_type_range = iterator_range<operand_type_iterator>;
operand_type_iterator operand_type_begin() { return operand_begin(); }
operand_type_iterator operand_type_end() { return operand_end(); }
operand_type_range getOperandTypes() { return getOperands().getTypes(); }
//===--------------------------------------------------------------------===//
// Results
@ -266,14 +257,13 @@ public:
Value *getResult(unsigned idx) { return &getOpResult(idx); }
// Support result iteration.
using result_iterator = ResultIterator;
using result_range = llvm::iterator_range<result_iterator>;
/// Support result iteration.
using result_range = ResultRange;
using result_iterator = result_range::iterator;
result_iterator result_begin();
result_iterator result_end();
result_range getResults();
result_iterator result_begin() { return getResults().begin(); }
result_iterator result_end() { return getResults().end(); }
result_range getResults() { return result_range(this); }
MutableArrayRef<OpResult> getOpResults() {
return {getTrailingObjects<OpResult>(), numResults};
@ -281,12 +271,12 @@ public:
OpResult &getOpResult(unsigned idx) { return getOpResults()[idx]; }
// Support result type iteration.
using result_type_iterator = ResultTypeIterator;
using result_type_range = llvm::iterator_range<result_type_iterator>;
result_type_iterator result_type_begin();
result_type_iterator result_type_end();
result_type_range getResultTypes();
/// Support result type iteration.
using result_type_iterator = result_range::type_iterator;
using result_type_range = iterator_range<result_type_iterator>;
result_type_iterator result_type_begin() { return result_begin(); }
result_type_iterator result_type_end() { return result_end(); }
result_type_range getResultTypes() { return getResults().getTypes(); }
//===--------------------------------------------------------------------===//
// Attributes
@ -657,91 +647,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Operation &op) {
return os;
}
/// This class implements the const/non-const operand iterators for the
/// Operation class in terms of getOperand(idx).
class OperandIterator final
: public indexed_accessor_iterator<OperandIterator, Operation *, Value *,
Value *, Value *> {
public:
/// Initializes the operand iterator to the specified operand index.
OperandIterator(Operation *object, unsigned index)
: indexed_accessor_iterator<OperandIterator, Operation *, Value *,
Value *, Value *>(object, index) {}
Value *operator*() const { return this->base->getOperand(this->index); }
};
/// This class implements the operand type iterators for the Operation
/// class in terms of operand_iterator->getType().
class OperandTypeIterator final
: public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
static Type unwrap(Value *value) { return value->getType(); }
public:
using reference = Type;
/// Provide a const deference method.
Type operator*() const { return unwrap(*I); }
/// Initializes the operand type iterator to the specified operand iterator.
OperandTypeIterator(OperandIterator it)
: llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
}
};
// Implement the inline operand iterator methods.
inline auto Operation::operand_begin() -> operand_iterator {
return operand_iterator(this, 0);
}
inline auto Operation::operand_end() -> operand_iterator {
return operand_iterator(this, getNumOperands());
}
inline auto Operation::getOperands() -> operand_range {
return {operand_begin(), operand_end()};
}
inline auto Operation::operand_type_begin() -> operand_type_iterator {
return operand_type_iterator(operand_begin());
}
inline auto Operation::operand_type_end() -> operand_type_iterator {
return operand_type_iterator(operand_end());
}
inline auto Operation::getOperandTypes() -> operand_type_range {
return {operand_type_begin(), operand_type_end()};
}
/// This class implements the result iterators for the Operation class
/// in terms of getResult(idx).
class ResultIterator final
: public indexed_accessor_iterator<ResultIterator, Operation *, Value *,
Value *, Value *> {
public:
/// Initializes the result iterator to the specified index.
ResultIterator(Operation *base, unsigned index)
: indexed_accessor_iterator<ResultIterator, Operation *, Value *, Value *,
Value *>(base, index) {}
Value *operator*() const { return this->base->getResult(this->index); }
};
/// This class implements the result type iterators for the Operation
/// class in terms of result_iterator->getType().
class ResultTypeIterator final
: public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
static Type unwrap(Value *value) { return value->getType(); }
public:
using reference = Type;
/// Initializes the result type iterator to the specified result iterator.
ResultTypeIterator(ResultIterator it)
: llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
};
/// This class implements use iterator for the Operation. This iterates over all
/// uses of all results of an Operation.
class UseIterator final
@ -768,75 +673,6 @@ private:
/// The use of the result.
Value::use_iterator use;
};
// Implement the inline result iterator methods.
inline auto Operation::result_begin() -> result_iterator {
return result_iterator(this, 0);
}
inline auto Operation::result_end() -> result_iterator {
return result_iterator(this, getNumResults());
}
inline auto Operation::getResults() -> llvm::iterator_range<result_iterator> {
return {result_begin(), result_end()};
}
inline auto Operation::result_type_begin() -> result_type_iterator {
return result_type_iterator(result_begin());
}
inline auto Operation::result_type_end() -> result_type_iterator {
return result_type_iterator(result_end());
}
inline auto Operation::getResultTypes() -> result_type_range {
return {result_type_begin(), result_type_end()};
}
/// This class provides an abstraction over the different types of ranges over
/// Value*s. In many cases, this prevents the need to explicitly materialize a
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class ValueRange
: public detail::indexed_accessor_range_base<
ValueRange,
llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
Value *, Value *> {
/// The type representing the owner of this range. This is either a list of
/// values, operands, or results.
using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
public:
using detail::indexed_accessor_range_base<
ValueRange, OwnerT, Value *, Value *,
Value *>::indexed_accessor_range_base;
template <typename Arg,
typename = typename std::enable_if_t<
std::is_constructible<ArrayRef<Value *>, Arg>::value &&
!std::is_convertible<Arg, Value *>::value>>
ValueRange(Arg &&arg)
: ValueRange(ArrayRef<Value *>(std::forward<Arg>(arg))) {}
ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {}
ValueRange(const std::initializer_list<Value *> &values)
: ValueRange(ArrayRef<Value *>(values)) {}
ValueRange(ArrayRef<Value *> values = llvm::None);
ValueRange(iterator_range<OperandIterator> values);
ValueRange(iterator_range<ResultIterator> values);
private:
/// See `detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
/// See `detail::indexed_accessor_range_base` for details.
static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
/// Allow access to `offset_base` and `dereference_iterator`.
friend detail::indexed_accessor_range_base<ValueRange, OwnerT, Value *,
Value *, Value *>;
};
} // end namespace mlir
namespace llvm {

View File

@ -60,6 +60,10 @@ template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
class OwningRewritePatternList;
//===----------------------------------------------------------------------===//
// AbstractOperation
//===----------------------------------------------------------------------===//
enum class OperationProperty {
/// This bit is set for an operation if it is a commutative operation: that
/// is a binary operator (two inputs) where "a op b" and "b op a" produce the
@ -201,6 +205,10 @@ private:
bool (&hasRawTrait)(ClassID *traitID);
};
//===----------------------------------------------------------------------===//
// OperationName
//===----------------------------------------------------------------------===//
class OperationName {
public:
using RepresentationUnion =
@ -251,6 +259,10 @@ inline llvm::hash_code hash_value(OperationName arg) {
return llvm::hash_value(arg.getAsOpaquePointer());
}
//===----------------------------------------------------------------------===//
// OperationState
//===----------------------------------------------------------------------===//
/// This represents an operation in an abstracted form, suitable for use with
/// the builder APIs. This object is a large and heavy weight object meant to
/// be used as a temporary object on the stack. It is generally unwise to put
@ -322,6 +334,10 @@ public:
MLIRContext *getContext() { return location->getContext(); }
};
//===----------------------------------------------------------------------===//
// OperandStorage
//===----------------------------------------------------------------------===//
namespace detail {
/// A utility class holding the information necessary to dynamically resize
/// operands.
@ -445,6 +461,10 @@ private:
};
} // end namespace detail
//===----------------------------------------------------------------------===//
// OpPrintingFlags
//===----------------------------------------------------------------------===//
/// Set of flags used to control the behavior of the various IR print methods
/// (e.g. Operation::Print).
class OpPrintingFlags {
@ -504,6 +524,138 @@ private:
bool printLocalScope : 1;
};
//===----------------------------------------------------------------------===//
// Operation Value-Iterators
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// ValueTypeRange
/// This class implements iteration on the types of a given range of values.
template <typename ValueIteratorT>
class ValueTypeIterator final
: public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value *)> {
static Type unwrap(Value *value) { return value->getType(); }
public:
using reference = Type;
/// Provide a const dereference method.
Type operator*() const { return unwrap(*this->I); }
/// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)
: llvm::mapped_iterator<ValueIteratorT, Type (*)(Value *)>(it, &unwrap) {}
};
//===----------------------------------------------------------------------===//
// OperandRange
/// This class implements the operand iterators for the Operation class.
class OperandRange final
: public detail::indexed_accessor_range_base<OperandRange, OpOperand *,
Value *, Value *, Value *> {
public:
using RangeBaseT::RangeBaseT;
OperandRange(Operation *op);
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
private:
/// See `detail::indexed_accessor_range_base` for details.
static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
return object + index;
}
/// See `detail::indexed_accessor_range_base` for details.
static Value *dereference_iterator(OpOperand *object, ptrdiff_t index) {
return object[index].get();
}
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// ResultRange
/// This class implements the result iterators for the Operation class.
class ResultRange final
: public detail::indexed_accessor_range_base<ResultRange, OpResult *,
Value *, Value *, Value *> {
public:
using RangeBaseT::RangeBaseT;
ResultRange(Operation *op);
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
private:
/// See `detail::indexed_accessor_range_base` for details.
static OpResult *offset_base(OpResult *object, ptrdiff_t index) {
return object + index;
}
/// See `detail::indexed_accessor_range_base` for details.
static Value *dereference_iterator(OpResult *object, ptrdiff_t index) {
return &object[index];
}
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// ValueRange
/// This class provides an abstraction over the different types of ranges over
/// Value*s. In many cases, this prevents the need to explicitly materialize a
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class ValueRange final
: public detail::indexed_accessor_range_base<
ValueRange,
llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
Value *, Value *> {
public:
using RangeBaseT::RangeBaseT;
template <typename Arg,
typename = typename std::enable_if_t<
std::is_constructible<ArrayRef<Value *>, Arg>::value &&
!std::is_convertible<Arg, Value *>::value>>
ValueRange(Arg &&arg)
: ValueRange(ArrayRef<Value *>(std::forward<Arg>(arg))) {}
ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {}
ValueRange(const std::initializer_list<Value *> &values)
: ValueRange(ArrayRef<Value *>(values)) {}
ValueRange(iterator_range<OperandRange::iterator> values)
: ValueRange(OperandRange(values)) {}
ValueRange(iterator_range<ResultRange::iterator> values)
: ValueRange(ResultRange(values)) {}
ValueRange(ArrayRef<Value *> values = llvm::None);
ValueRange(OperandRange values);
ValueRange(ResultRange values);
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
private:
/// The type representing the owner of this range. This is either a list of
/// values, operands, or results.
using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
/// See `detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
/// See `detail::indexed_accessor_range_base` for details.
static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
} // end namespace mlir
namespace llvm {

View File

@ -175,9 +175,7 @@ class RegionRange
using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>;
public:
using detail::indexed_accessor_range_base<
RegionRange, OwnerT, Region *, Region *,
Region *>::indexed_accessor_range_base;
using RangeBaseT::RangeBaseT;
RegionRange(MutableArrayRef<Region> regions = llvm::None);
@ -196,8 +194,7 @@ private:
static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
/// Allow access to `offset_base` and `dereference_iterator`.
friend detail::indexed_accessor_range_base<RegionRange, OwnerT, Region *,
Region *, Region *>;
friend RangeBaseT;
};
} // end namespace mlir

View File

@ -65,13 +65,14 @@ LogicalResult verifyCompatibleShape(Type type1, Type type2);
// An iterator for the element types of an op's operands of shaped types.
class OperandElementTypeIterator final
: public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
: public llvm::mapped_iterator<Operation::operand_iterator,
Type (*)(Value *)> {
public:
using reference = Type;
/// Initializes the result element type iterator to the specified operand
/// iterator.
explicit OperandElementTypeIterator(OperandIterator it);
explicit OperandElementTypeIterator(Operation::operand_iterator it);
private:
static Type unwrap(Value *value);
@ -82,13 +83,14 @@ using OperandElementTypeRange =
// An iterator for the tensor element types of an op's results of shaped types.
class ResultElementTypeIterator final
: public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
: public llvm::mapped_iterator<Operation::result_iterator,
Type (*)(Value *)> {
public:
using reference = Type;
/// Initializes the result element type iterator to the specified result
/// iterator.
explicit ResultElementTypeIterator(ResultIterator it);
explicit ResultElementTypeIterator(Operation::result_iterator it);
private:
static Type unwrap(Value *value);

View File

@ -204,6 +204,9 @@ template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_range_base {
public:
using RangeBaseT =
indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>;
/// An iterator element of this range.
class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
PointerT, ReferenceT> {
@ -223,11 +226,17 @@ public:
ReferenceT>;
};
indexed_accessor_range_base(iterator begin, iterator end)
: base(DerivedT::offset_base(begin.getBase(), begin.getIndex())),
count(end.getIndex() - begin.getIndex()) {}
indexed_accessor_range_base(const iterator_range<iterator> &range)
: indexed_accessor_range_base(range.begin(), range.end()) {}
iterator begin() const { return iterator(base, 0); }
iterator end() const { return iterator(base, count); }
ReferenceT operator[](unsigned index) const {
assert(index < size() && "invalid index for value range");
return *std::next(begin(), index);
return DerivedT::dereference_iterator(base, index);
}
/// Return the size of this range.
@ -237,22 +246,35 @@ public:
bool empty() const { return size() == 0; }
/// Drop the first N elements, and keep M elements.
DerivedT slice(unsigned n, unsigned m) const {
DerivedT slice(size_t n, size_t m) const {
assert(n + m <= size() && "invalid size specifiers");
return DerivedT(DerivedT::offset_base(base, n), m);
}
/// Drop the first n elements.
DerivedT drop_front(unsigned n = 1) const {
DerivedT drop_front(size_t n = 1) const {
assert(size() >= n && "Dropping more elements than exist");
return slice(n, size() - n);
}
/// Drop the last n elements.
DerivedT drop_back(unsigned n = 1) const {
DerivedT drop_back(size_t n = 1) const {
assert(size() >= n && "Dropping more elements than exist");
return DerivedT(base, size() - n);
}
/// Take the first n elements.
DerivedT take_front(size_t n = 1) const {
return n < size() ? drop_back(size() - n)
: static_cast<const DerivedT &>(*this);
}
/// Allow conversion to SmallVector if necessary.
/// TODO(riverriddle) Remove this when SmallVector accepts different range
/// types in its constructor.
template <typename SVT, unsigned N> operator SmallVector<SVT, N>() const {
return {begin(), end()};
}
protected:
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
: base(base), count(count) {}

View File

@ -2819,7 +2819,7 @@ public:
return matchFailure();
}
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
for (auto size : enumerate(subViewOp.sizes())) {
for (auto size : llvm::enumerate(subViewOp.sizes())) {
auto defOp = size.value()->getDefiningOp();
assert(defOp);
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
@ -2865,7 +2865,7 @@ public:
}
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
for (auto stride : enumerate(subViewOp.strides())) {
for (auto stride : llvm::enumerate(subViewOp.strides())) {
auto defOp = stride.value()->getDefiningOp();
assert(defOp);
assert(baseStrides[stride.index()] > 0);
@ -2916,7 +2916,7 @@ public:
}
auto staticOffset = baseOffset;
for (auto offset : enumerate(subViewOp.offsets())) {
for (auto offset : llvm::enumerate(subViewOp.offsets())) {
auto defOp = offset.value()->getDefiningOp();
assert(defOp);
assert(baseStrides[offset.index()] > 0);

View File

@ -597,9 +597,8 @@ void Operation::setSuccessor(Block *block, unsigned index) {
}
auto Operation::getNonSuccessorOperands() -> operand_range {
return {operand_iterator(this, 0),
operand_iterator(this, hasSuccessors() ? getSuccessorOperandIndex(0)
: getNumOperands())};
return getOperands().take_front(hasSuccessors() ? getSuccessorOperandIndex(0)
: getNumOperands());
}
/// Get the index of the first operand of the successor at the provided
@ -635,9 +634,7 @@ Operation::decomposeSuccessorOperandIndex(unsigned operandIndex) {
auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
unsigned succOperandIndex = getSuccessorOperandIndex(index);
return {operand_iterator(this, succOperandIndex),
operand_iterator(this,
succOperandIndex + getNumSuccessorOperands(index))};
return getOperands().slice(succOperandIndex, getNumSuccessorOperands(index));
}
/// Attempt to fold this operation using the Op's registered foldHook.
@ -745,48 +742,6 @@ Operation *Operation::clone() {
return clone(mapper);
}
//===----------------------------------------------------------------------===//
// ValueRange
//===----------------------------------------------------------------------===//
ValueRange::ValueRange(ArrayRef<Value *> values)
: ValueRange(values.data(), values.size()) {}
ValueRange::ValueRange(llvm::iterator_range<OperandIterator> values)
: ValueRange(nullptr, llvm::size(values)) {
if (!empty()) {
auto begin = values.begin();
base = &begin.getBase()->getOpOperand(begin.getIndex());
}
}
ValueRange::ValueRange(llvm::iterator_range<ResultIterator> values)
: ValueRange(nullptr, llvm::size(values)) {
if (!empty()) {
auto begin = values.begin();
base = &begin.getBase()->getOpResult(begin.getIndex());
}
}
/// See `detail::indexed_accessor_range_base` for details.
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
return operand + index;
if (OpResult *result = owner.dyn_cast<OpResult *>())
return result + index;
return owner.get<Value *const *>() + index;
}
/// See `detail::indexed_accessor_range_base` for details.
Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
// Operands access the held value via 'get'.
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
return operand[index].get();
// An OpResult is a value, so we can return it directly.
if (OpResult *result = owner.dyn_cast<OpResult *>())
return &result[index];
// Otherwise, this is a raw value array so just index directly.
return owner.get<Value *const *>()[index];
}
//===----------------------------------------------------------------------===//
// OpState trait class.
//===----------------------------------------------------------------------===//
@ -979,7 +934,7 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
auto elementType = getElementTypeOrSelf(op->getResult(0));
// Verify result element type matches first result's element type.
for (auto result : drop_begin(op->getResults(), 1)) {
for (auto result : llvm::drop_begin(op->getResults(), 1)) {
if (getElementTypeOrSelf(result) != elementType)
return op->emitOpError(
"requires the same element type for all operands and results");
@ -1210,7 +1165,7 @@ Value *impl::foldCastOp(Operation *op) {
}
//===----------------------------------------------------------------------===//
// CastOp implementation
// Misc. utils
//===----------------------------------------------------------------------===//
/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
@ -1230,6 +1185,10 @@ void impl::ensureRegionTerminator(
block.push_back(buildTerminatorOp());
}
//===----------------------------------------------------------------------===//
// UseIterator
//===----------------------------------------------------------------------===//
UseIterator::UseIterator(Operation *op, bool end)
: op(op), res(end ? op->result_end() : op->result_begin()) {
// Only initialize current use if there are results/can be uses.

View File

@ -144,3 +144,50 @@ void detail::OperandStorage::grow(ResizableStorage &resizeUtil,
operand.~OpOperand();
resizeUtil.setDynamicStorage(newStorage);
}
//===----------------------------------------------------------------------===//
// Operation Value-Iterators
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// OperandRange
OperandRange::OperandRange(Operation *op)
: OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
//===----------------------------------------------------------------------===//
// ResultRange
ResultRange::ResultRange(Operation *op)
: ResultRange(op->getOpResults().data(), op->getNumResults()) {}
//===----------------------------------------------------------------------===//
// ValueRange
ValueRange::ValueRange(ArrayRef<Value *> values)
: ValueRange(values.data(), values.size()) {}
ValueRange::ValueRange(OperandRange values)
: ValueRange(values.begin().getBase(), values.size()) {}
ValueRange::ValueRange(ResultRange values)
: ValueRange(values.begin().getBase(), values.size()) {}
/// See `detail::indexed_accessor_range_base` for details.
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
return operand + index;
if (OpResult *result = owner.dyn_cast<OpResult *>())
return result + index;
return owner.get<Value *const *>() + index;
}
/// See `detail::indexed_accessor_range_base` for details.
Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
// Operands access the held value via 'get'.
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
return operand[index].get();
// An OpResult is a value, so we can return it directly.
if (OpResult *result = owner.dyn_cast<OpResult *>())
return &result[index];
// Otherwise, this is a raw value array so just index directly.
return owner.get<Value *const *>()[index];
}

View File

@ -92,15 +92,19 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
return success();
}
OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
: llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
OperandElementTypeIterator::OperandElementTypeIterator(
Operation::operand_iterator it)
: llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value *)>(
it, &unwrap) {}
Type OperandElementTypeIterator::unwrap(Value *value) {
return value->getType().cast<ShapedType>().getElementType();
}
ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it)
: llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
ResultElementTypeIterator::ResultElementTypeIterator(
Operation::result_iterator it)
: llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value *)>(
it, &unwrap) {}
Type ResultElementTypeIterator::unwrap(Value *value) {
return value->getType().cast<ShapedType>().getElementType();