Push a bunch of 'consts' out of the *Op structure, in prep for removing

OpPointer.

PiperOrigin-RevId: 240044712
This commit is contained in:
Chris Lattner 2019-03-24 13:02:43 -07:00 committed by jpienaar
parent f26c7cd792
commit dd2b2ec542
7 changed files with 103 additions and 121 deletions

View File

@ -250,29 +250,27 @@ void extractForInductionVars(ArrayRef<OpPointer<AffineForOp>> forInsts,
/// that of the for instruction it refers to. /// that of the for instruction it refers to.
class AffineBound { class AffineBound {
public: public:
OpPointer<AffineForOp> getAffineForOp() const { return inst; } OpPointer<AffineForOp> getAffineForOp() { return inst; }
AffineMap getMap() const { return map; } AffineMap getMap() { return map; }
/// Returns an AffineValueMap representing this bound. /// Returns an AffineValueMap representing this bound.
AffineValueMap getAsAffineValueMap(); AffineValueMap getAsAffineValueMap();
unsigned getNumOperands() const { return opEnd - opStart; } unsigned getNumOperands() { return opEnd - opStart; }
Value *getOperand(unsigned idx) const { Value *getOperand(unsigned idx) {
return inst->getInstruction()->getOperand(opStart + idx); return inst->getInstruction()->getOperand(opStart + idx);
} }
using operand_iterator = AffineForOp::operand_iterator; using operand_iterator = AffineForOp::operand_iterator;
using operand_range = AffineForOp::operand_range; using operand_range = AffineForOp::operand_range;
operand_iterator operand_begin() const { operand_iterator operand_begin() {
return const_cast<Instruction *>(inst->getInstruction())->operand_begin() + return inst->getInstruction()->operand_begin() + opStart;
opStart;
} }
operand_iterator operand_end() const { operand_iterator operand_end() {
return const_cast<Instruction *>(inst->getInstruction())->operand_begin() + return inst->getInstruction()->operand_begin() + opEnd;
opEnd;
} }
operand_range getOperands() const { return {operand_begin(), operand_end()}; } operand_range getOperands() { return {operand_begin(), operand_end()}; }
private: private:
// 'for' instruction that contains this bound. // 'for' instruction that contains this bound.

View File

@ -62,32 +62,34 @@ public:
explicit OpPointer() : value(Instruction::getNull<OpType>().value) {} explicit OpPointer() : value(Instruction::getNull<OpType>().value) {}
explicit OpPointer(OpType value) : value(value) {} explicit OpPointer(OpType value) : value(value) {}
OpType &operator*() const { return const_cast<OpType &>(value); } OpType &operator*() { return value; }
OpType *operator->() const { return const_cast<OpType *>(&value); } OpType *operator->() { return &value; }
operator bool() const { return value.getInstruction(); } explicit operator bool() { return value.getInstruction(); }
bool operator==(OpPointer rhs) const { bool operator==(OpPointer rhs) {
return value.getInstruction() == rhs.value.getInstruction(); return value.getInstruction() == rhs.value.getInstruction();
} }
bool operator!=(OpPointer rhs) const { return !(*this == rhs); } bool operator!=(OpPointer rhs) { return !(*this == rhs); }
/// OpPointer can be implicitly converted to OpType*. /// OpPointer can be implicitly converted to OpType*.
/// Return `nullptr` if there is no associated Instruction*. /// Return `nullptr` if there is no associated Instruction*.
operator OpType *() const { operator OpType *() {
if (!value.getInstruction()) if (!value.getInstruction())
return nullptr; return nullptr;
return const_cast<OpType *>(&value); return &value;
} }
operator OpType() { return value; }
/// If the OpType operation includes the OneResult trait, then OpPointer can /// If the OpType operation includes the OneResult trait, then OpPointer can
/// be implicitly converted to an Value*. This yields the value of the /// be implicitly converted to an Value*. This yields the value of the
/// only result. /// only result.
template <typename SFINAE = OpType> template <typename SFINAE = OpType>
operator typename std::enable_if<IsSingleResult<SFINAE>::value, operator typename std::enable_if<IsSingleResult<SFINAE>::value,
Value *>::type() const { Value *>::type() {
return const_cast<Value *>(value.getResult()); return value.getResult();
} }
private: private:
@ -103,23 +105,22 @@ private:
class OpState { class OpState {
public: public:
/// Return the operation that this refers to. /// Return the operation that this refers to.
Instruction *getInstruction() const { return state; }
Instruction *getInstruction() { return state; } Instruction *getInstruction() { return state; }
/// Return the context this operation belongs to. /// Return the context this operation belongs to.
MLIRContext *getContext() { return getInstruction()->getContext(); } MLIRContext *getContext() { return getInstruction()->getContext(); }
/// The source location the operation was defined or derived from. /// The source location the operation was defined or derived from.
Location getLoc() const { return state->getLoc(); } Location getLoc() { return state->getLoc(); }
/// Return all of the attributes on this operation. /// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); } ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
/// Return an attribute with the specified name. /// Return an attribute with the specified name.
Attribute getAttr(StringRef name) const { return state->getAttr(name); } Attribute getAttr(StringRef name) { return state->getAttr(name); }
/// If the operation has an attribute of the specified type, return it. /// If the operation has an attribute of the specified type, return it.
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const { template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
return getAttr(name).dyn_cast_or_null<AttrClass>(); return getAttr(name).dyn_cast_or_null<AttrClass>();
} }
@ -133,7 +134,7 @@ public:
} }
/// Return true if there are no users of any results of this operation. /// Return true if there are no users of any results of this operation.
bool use_empty() const { return state->use_empty(); } bool use_empty() { return state->use_empty(); }
/// Remove this operation from its parent block and delete it. /// Remove this operation from its parent block and delete it.
void erase() { state->erase(); } void erase() { state->erase(); }
@ -142,19 +143,19 @@ public:
/// any diagnostic handlers that may be listening. This function always /// any diagnostic handlers that may be listening. This function always
/// returns true. NOTE: This may terminate the containing application, only /// returns true. NOTE: This may terminate the containing application, only
/// use when the IR is in an inconsistent state. /// use when the IR is in an inconsistent state.
bool emitError(const Twine &message) const; bool emitError(const Twine &message);
/// Emit an error with the op name prefixed, like "'dim' op " which is /// Emit an error with the op name prefixed, like "'dim' op " which is
/// convenient for verifiers. This always returns true. /// convenient for verifiers. This always returns true.
bool emitOpError(const Twine &message) const; bool emitOpError(const Twine &message);
/// Emit a warning about this operation, reporting up to any diagnostic /// Emit a warning about this operation, reporting up to any diagnostic
/// handlers that may be listening. /// handlers that may be listening.
void emitWarning(const Twine &message) const; void emitWarning(const Twine &message);
/// Emit a note about this operation, reporting up to any diagnostic /// Emit a note about this operation, reporting up to any diagnostic
/// handlers that may be listening. /// handlers that may be listening.
void emitNote(const Twine &message) const; void emitNote(const Twine &message);
// These are default implementations of customization hooks. // These are default implementations of customization hooks.
public: public:
@ -179,8 +180,7 @@ protected:
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here. /// so we can cast it away here.
explicit OpState(Instruction *state) explicit OpState(Instruction *state) : state(state) {}
: state(const_cast<Instruction *>(state)) {}
private: private:
Instruction *state; Instruction *state;
@ -364,9 +364,6 @@ protected:
auto *base = static_cast<OpState *>(concrete); auto *base = static_cast<OpState *>(concrete);
return base->getInstruction(); return base->getInstruction();
} }
Instruction *getInstruction() const {
return const_cast<TraitBase *>(this)->getInstruction();
}
/// Provide default implementations of trait hooks. This allows traits to /// Provide default implementations of trait hooks. This allows traits to
/// provide exactly the overrides they care about. /// provide exactly the overrides they care about.
@ -387,8 +384,8 @@ public:
private: private:
// Disable these. // Disable these.
void getOperand() const {} void getOperand() {}
void setOperand() const {} void setOperand() {}
}; };
/// This class provides the API for ops that are known to have exactly one /// This class provides the API for ops that are known to have exactly one
@ -396,7 +393,7 @@ private:
template <typename ConcreteType> template <typename ConcreteType>
class OneOperand : public TraitBase<ConcreteType, OneOperand> { class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public: public:
Value *getOperand() const { return this->getInstruction()->getOperand(0); } Value *getOperand() { return this->getInstruction()->getOperand(0); }
void setOperand(Value *value) { void setOperand(Value *value) {
this->getInstruction()->setOperand(0, value); this->getInstruction()->setOperand(0, value);
@ -417,7 +414,7 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> { class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
public: public:
Value *getOperand(unsigned i) const { Value *getOperand(unsigned i) {
return this->getInstruction()->getOperand(i); return this->getInstruction()->getOperand(i);
} }
@ -441,10 +438,11 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, AtLeastNOperands<N>::Impl> { class Impl : public TraitBase<ConcreteType, AtLeastNOperands<N>::Impl> {
public: public:
unsigned getNumOperands() const { unsigned getNumOperands() {
return this->getInstruction()->getNumOperands(); return this->getInstruction()->getNumOperands();
} }
Value *getOperand(unsigned i) const {
Value *getOperand(unsigned i) {
return this->getInstruction()->getOperand(i); return this->getInstruction()->getOperand(i);
} }
@ -452,7 +450,6 @@ public:
this->getInstruction()->setOperand(i, value); this->getInstruction()->setOperand(i, value);
} }
// Support non-const operand iteration.
using operand_iterator = Instruction::operand_iterator; using operand_iterator = Instruction::operand_iterator;
operand_iterator operand_begin() { operand_iterator operand_begin() {
return this->getInstruction()->operand_begin(); return this->getInstruction()->operand_begin();
@ -475,11 +472,9 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class VariadicOperands : public TraitBase<ConcreteType, VariadicOperands> { class VariadicOperands : public TraitBase<ConcreteType, VariadicOperands> {
public: public:
unsigned getNumOperands() const { unsigned getNumOperands() { return this->getInstruction()->getNumOperands(); }
return this->getInstruction()->getNumOperands();
}
Value *getOperand(unsigned i) const { Value *getOperand(unsigned i) {
return this->getInstruction()->getOperand(i); return this->getInstruction()->getOperand(i);
} }
@ -487,7 +482,7 @@ public:
this->getInstruction()->setOperand(i, value); this->getInstruction()->setOperand(i, value);
} }
// Support non-const operand iteration. // Support operand iteration.
using operand_iterator = Instruction::operand_iterator; using operand_iterator = Instruction::operand_iterator;
using operand_range = Instruction::operand_range; using operand_range = Instruction::operand_range;
operand_iterator operand_begin() { operand_iterator operand_begin() {
@ -514,9 +509,9 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> { class OneResult : public TraitBase<ConcreteType, OneResult> {
public: public:
Value *getResult() const { return this->getInstruction()->getResult(0); } Value *getResult() { return this->getInstruction()->getResult(0); }
Type getType() const { return getResult()->getType(); } Type getType() { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in /// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns /// the IR that uses 'this' to use the other value instead. When this returns
@ -540,11 +535,11 @@ public:
public: public:
static unsigned getNumResults() { return N; } static unsigned getNumResults() { return N; }
Value *getResult(unsigned i) const { Value *getResult(unsigned i) {
return this->getInstruction()->getResult(i); return this->getInstruction()->getResult(i);
} }
Type getType(unsigned i) const { return getResult(i)->getType(); } Type getType(unsigned i) { return getResult(i)->getType(); }
static bool verifyTrait(Instruction *op) { static bool verifyTrait(Instruction *op) {
return impl::verifyNResults(op, N); return impl::verifyNResults(op, N);
@ -562,11 +557,11 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, AtLeastNResults<N>::Impl> { class Impl : public TraitBase<ConcreteType, AtLeastNResults<N>::Impl> {
public: public:
Value *getResult(unsigned i) const { Value *getResult(unsigned i) {
return this->getInstruction()->getResult(i); return this->getInstruction()->getResult(i);
} }
Type getType(unsigned i) const { return getResult(i)->getType(); } Type getType(unsigned i) { return getResult(i)->getType(); }
static bool verifyTrait(Instruction *op) { static bool verifyTrait(Instruction *op) {
return impl::verifyAtLeastNResults(op, N); return impl::verifyAtLeastNResults(op, N);
@ -579,19 +574,15 @@ public:
template <typename ConcreteType> template <typename ConcreteType>
class VariadicResults : public TraitBase<ConcreteType, VariadicResults> { class VariadicResults : public TraitBase<ConcreteType, VariadicResults> {
public: public:
unsigned getNumResults() const { unsigned getNumResults() { return this->getInstruction()->getNumResults(); }
return this->getInstruction()->getNumResults();
}
Value *getResult(unsigned i) const { Value *getResult(unsigned i) { return this->getInstruction()->getResult(i); }
return this->getInstruction()->getResult(i);
}
void setResult(unsigned i, Value *value) { void setResult(unsigned i, Value *value) {
this->getInstruction()->setResult(i, value); this->getInstruction()->setResult(i, value);
} }
// Support non-const result iteration. // Support result iteration.
using result_iterator = Instruction::result_iterator; using result_iterator = Instruction::result_iterator;
result_iterator result_begin() { result_iterator result_begin() {
return this->getInstruction()->result_begin(); return this->getInstruction()->result_begin();
@ -714,14 +705,14 @@ public:
return impl::verifyIsTerminator(op); return impl::verifyIsTerminator(op);
} }
unsigned getNumSuccessors() const { unsigned getNumSuccessors() {
return this->getInstruction()->getNumSuccessors(); return this->getInstruction()->getNumSuccessors();
} }
unsigned getNumSuccessorOperands(unsigned index) const { unsigned getNumSuccessorOperands(unsigned index) {
return this->getInstruction()->getNumSuccessorOperands(index); return this->getInstruction()->getNumSuccessorOperands(index);
} }
Block *getSuccessor(unsigned index) const { Block *getSuccessor(unsigned index) {
return this->getInstruction()->getSuccessor(index); return this->getInstruction()->getSuccessor(index);
} }
@ -755,7 +746,11 @@ class Op : public OpState,
Traits<ConcreteType>...>::value> { Traits<ConcreteType>...>::value> {
public: public:
/// Return the operation that this refers to. /// Return the operation that this refers to.
Instruction *getInstruction() const { return OpState::getInstruction(); } Instruction *getInstruction() { return OpState::getInstruction(); }
// FIXME: Remove this, this is just a transition to allow using -> and staging
// patches.
ConcreteType *operator->() { return static_cast<ConcreteType *>(this); }
/// Return true if this "op class" can match against the specified operation. /// Return true if this "op class" can match against the specified operation.
/// This hook can be overridden with a more specific implementation in /// This hook can be overridden with a more specific implementation in

View File

@ -240,7 +240,7 @@ class CmpIOp
OpTrait::OneResult, OpTrait::ResultsAreBoolLike, OpTrait::OneResult, OpTrait::ResultsAreBoolLike,
OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> { OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> {
public: public:
CmpIPredicate getPredicate() const { CmpIPredicate getPredicate() {
return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName()) return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
.getInt(); .getInt();
} }
@ -298,7 +298,7 @@ public:
MLIRContext *context); MLIRContext *context);
// The condition operand is the first operand in the list. // The condition operand is the first operand in the list.
Value *getCondition() const { return getOperand(0); } Value *getCondition() { return getOperand(0); }
/// Return the destination if the condition is true. /// Return the destination if the condition is true.
Block *getTrueDest(); Block *getTrueDest();
@ -327,7 +327,7 @@ public:
return {true_operand_begin(), true_operand_end()}; return {true_operand_begin(), true_operand_end()};
} }
unsigned getNumTrueOperands() const; unsigned getNumTrueOperands();
/// Erase the operand at 'index' from the true operand list. /// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index); void eraseTrueOperand(unsigned index);
@ -337,9 +337,6 @@ public:
assert(idx < getNumFalseOperands()); assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx); return getOperand(getFalseDestOperandIndex() + idx);
} }
Value *getFalseOperand(unsigned idx) const {
return const_cast<CondBranchOp *>(this)->getFalseOperand(idx);
}
void setFalseOperand(unsigned idx, Value *value) { void setFalseOperand(unsigned idx, Value *value) {
assert(idx < getNumFalseOperands()); assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value); setOperand(getFalseDestOperandIndex() + idx, value);
@ -353,17 +350,17 @@ public:
return {false_operand_begin(), false_operand_end()}; return {false_operand_begin(), false_operand_end()};
} }
unsigned getNumFalseOperands() const; unsigned getNumFalseOperands();
/// Erase the operand at 'index' from the false operand list. /// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index); void eraseFalseOperand(unsigned index);
private: private:
/// Get the index of the first true destination operand. /// Get the index of the first true destination operand.
unsigned getTrueDestOperandIndex() const { return 1; } unsigned getTrueDestOperandIndex() { return 1; }
/// Get the index of the first false destination operand. /// Get the index of the first false destination operand.
unsigned getFalseDestOperandIndex() const { unsigned getFalseDestOperandIndex() {
return getTrueDestOperandIndex() + getNumTrueOperands(); return getTrueDestOperandIndex() + getNumTrueOperands();
} }
@ -388,7 +385,7 @@ public:
/// attribute's type. /// attribute's type.
static void build(Builder *builder, OperationState *result, Attribute value); static void build(Builder *builder, OperationState *result, Attribute value);
Attribute getValue() const { return getAttr("value"); } Attribute getValue() { return getAttr("value"); }
static StringRef getOperationName() { return "std.constant"; } static StringRef getOperationName() { return "std.constant"; }
@ -414,9 +411,7 @@ public:
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
const APFloat &value, FloatType type); const APFloat &value, FloatType type);
APFloat getValue() const { APFloat getValue() { return getAttrOfType<FloatAttr>("value").getValue(); }
return getAttrOfType<FloatAttr>("value").getValue();
}
static bool isClassFor(Instruction *op); static bool isClassFor(Instruction *op);
@ -441,9 +436,7 @@ public:
static void build(Builder *builder, OperationState *result, int64_t value, static void build(Builder *builder, OperationState *result, int64_t value,
Type type); Type type);
int64_t getValue() const { int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(Instruction *op); static bool isClassFor(Instruction *op);
@ -462,9 +455,7 @@ public:
/// Build a constant int op producing an index. /// Build a constant int op producing an index.
static void build(Builder *builder, OperationState *result, int64_t value); static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const { int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(Instruction *op); static bool isClassFor(Instruction *op);
@ -486,7 +477,7 @@ private:
class DeallocOp class DeallocOp
: public Op<DeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> { : public Op<DeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
public: public:
Value *getMemRef() const { return getOperand(); } Value *getMemRef() { return getOperand(); }
void setMemRef(Value *value) { setOperand(value); } void setMemRef(Value *value) { setOperand(value); }
static StringRef getOperationName() { return "std.dealloc"; } static StringRef getOperationName() { return "std.dealloc"; }
@ -519,7 +510,7 @@ public:
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context); Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
/// This returns the dimension number that the 'dim' is inspecting. /// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const { unsigned getIndex() {
return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue(); return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
} }
@ -583,9 +574,9 @@ public:
Value *elementsPerStride = nullptr); Value *elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation. // Returns the source MemRefType for this DMA operation.
Value *getSrcMemRef() const { return getOperand(0); } Value *getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType. // Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() const { unsigned getSrcMemRefRank() {
return getSrcMemRef()->getType().cast<MemRefType>().getRank(); return getSrcMemRef()->getType().cast<MemRefType>().getRank();
} }
// Returns the source memerf indices for this DMA operation. // Returns the source memerf indices for this DMA operation.
@ -595,15 +586,15 @@ public:
} }
// Returns the destination MemRefType for this DMA operations. // Returns the destination MemRefType for this DMA operations.
Value *getDstMemRef() const { return getOperand(1 + getSrcMemRefRank()); } Value *getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType. // Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() const { unsigned getDstMemRefRank() {
return getDstMemRef()->getType().cast<MemRefType>().getRank(); return getDstMemRef()->getType().cast<MemRefType>().getRank();
} }
unsigned getSrcMemorySpace() const { unsigned getSrcMemorySpace() {
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace(); return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
} }
unsigned getDstMemorySpace() const { unsigned getDstMemorySpace() {
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace(); return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
} }
@ -615,21 +606,21 @@ public:
} }
// Returns the number of elements being transferred by this DMA operation. // Returns the number of elements being transferred by this DMA operation.
Value *getNumElements() const { Value *getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
} }
// Returns the Tag MemRef for this DMA operation. // Returns the Tag MemRef for this DMA operation.
Value *getTagMemRef() const { Value *getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
} }
// Returns the rank (number of indices) of the tag MemRefType. // Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() const { unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank(); return getTagMemRef()->getType().cast<MemRefType>().getRank();
} }
// Returns the tag memref index for this DMA operation. // Returns the tag memref index for this DMA operation.
llvm::iterator_range<Instruction::operand_iterator> getTagIndices() const { llvm::iterator_range<Instruction::operand_iterator> getTagIndices() {
unsigned tagIndexStartPos = unsigned tagIndexStartPos =
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
return {getInstruction()->operand_begin() + tagIndexStartPos, return {getInstruction()->operand_begin() + tagIndexStartPos,
@ -638,12 +629,12 @@ public:
} }
/// Returns true if this is a DMA from a faster memory space to a slower one. /// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() const { bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace()); return (getSrcMemorySpace() < getDstMemorySpace());
} }
/// Returns true if this is a DMA from a slower memory space to a faster one. /// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() const { bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space. // Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace()); return (getDstMemorySpace() < getSrcMemorySpace());
} }
@ -651,7 +642,7 @@ public:
/// Given a DMA start operation, returns the operand position of either the /// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher /// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true. /// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() const { unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
} }
@ -664,7 +655,7 @@ public:
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context); MLIRContext *context);
bool isStrided() const { bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
1 + 1 + getTagMemRefRank(); 1 + 1 + getTagMemRefRank();
} }
@ -708,21 +699,21 @@ public:
static StringRef getOperationName() { return "std.dma_wait"; } static StringRef getOperationName() { return "std.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on. // Returns the Tag MemRef associated with the DMA operation being waited on.
Value *getTagMemRef() const { return getOperand(0); } Value *getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation. // Returns the tag memref index for this DMA operation.
llvm::iterator_range<Instruction::operand_iterator> getTagIndices() const { llvm::iterator_range<Instruction::operand_iterator> getTagIndices() {
return {getInstruction()->operand_begin() + 1, return {getInstruction()->operand_begin() + 1,
getInstruction()->operand_begin() + 1 + getTagMemRefRank()}; getInstruction()->operand_begin() + 1 + getTagMemRefRank()};
} }
// Returns the rank (number of indices) of the tag memref. // Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() const { unsigned getTagMemRefRank() {
return getTagMemRef()->getType().cast<MemRefType>().getRank(); return getTagMemRef()->getType().cast<MemRefType>().getRank();
} }
// Returns the number of elements transferred in the associated DMA operation. // Returns the number of elements transferred in the associated DMA operation.
Value *getNumElements() const { return getOperand(1 + getTagMemRefRank()); } Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
@ -752,7 +743,7 @@ public:
static void build(Builder *builder, OperationState *result, Value *aggregate, static void build(Builder *builder, OperationState *result, Value *aggregate,
ArrayRef<Value *> indices = {}); ArrayRef<Value *> indices = {});
Value *getAggregate() const { return getOperand(0); } Value *getAggregate() { return getOperand(0); }
llvm::iterator_range<Instruction::operand_iterator> getIndices() { llvm::iterator_range<Instruction::operand_iterator> getIndices() {
return {getInstruction()->operand_begin() + 1, return {getInstruction()->operand_begin() + 1,
@ -787,9 +778,9 @@ public:
static void build(Builder *builder, OperationState *result, Value *memref, static void build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<Value *> indices = {}); ArrayRef<Value *> indices = {});
Value *getMemRef() const { return getOperand(0); } Value *getMemRef() { return getOperand(0); }
void setMemRef(Value *value) { setOperand(0, value); } void setMemRef(Value *value) { setOperand(0, value); }
MemRefType getMemRefType() const { MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>(); return getMemRef()->getType().cast<MemRefType>();
} }
@ -831,9 +822,7 @@ public:
static StringRef getOperationName() { return "std.memref_cast"; } static StringRef getOperationName() { return "std.memref_cast"; }
/// The result of a memref_cast is always a memref. /// The result of a memref_cast is always a memref.
MemRefType getType() const { MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
return getResult()->getType().cast<MemRefType>();
}
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
@ -892,9 +881,9 @@ public:
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
bool verify(); bool verify();
Value *getCondition() const { return getOperand(0); } Value *getCondition() { return getOperand(0); }
Value *getTrueValue() const { return getOperand(1); } Value *getTrueValue() { return getOperand(1); }
Value *getFalseValue() const { return getOperand(2); } Value *getFalseValue() { return getOperand(2); }
Value *fold(); Value *fold();
@ -921,7 +910,7 @@ public:
Value *valueToStore, Value *memref, Value *valueToStore, Value *memref,
ArrayRef<Value *> indices = {}); ArrayRef<Value *> indices = {});
Value *getValueToStore() const { return getOperand(0); } Value *getValueToStore() { return getOperand(0); }
Value *getMemRef() { return getOperand(1); } Value *getMemRef() { return getOperand(1); }
void setMemRef(Value *value) { setOperand(1, value); } void setMemRef(Value *value) { setOperand(1, value); }

View File

@ -67,25 +67,25 @@ void OpState::print(OpAsmPrinter *p) { p->printGenericOp(getInstruction()); }
/// any diagnostic handlers that may be listening. NOTE: This may terminate /// any diagnostic handlers that may be listening. NOTE: This may terminate
/// the containing application, only use when the IR is in an inconsistent /// the containing application, only use when the IR is in an inconsistent
/// state. /// state.
bool OpState::emitError(const Twine &message) const { bool OpState::emitError(const Twine &message) {
return getInstruction()->emitError(message); return getInstruction()->emitError(message);
} }
/// Emit an error with the op name prefixed, like "'dim' op " which is /// Emit an error with the op name prefixed, like "'dim' op " which is
/// convenient for verifiers. /// convenient for verifiers.
bool OpState::emitOpError(const Twine &message) const { bool OpState::emitOpError(const Twine &message) {
return getInstruction()->emitOpError(message); return getInstruction()->emitOpError(message);
} }
/// Emit a warning about this operation, reporting up to any diagnostic /// Emit a warning about this operation, reporting up to any diagnostic
/// handlers that may be listening. /// handlers that may be listening.
void OpState::emitWarning(const Twine &message) const { void OpState::emitWarning(const Twine &message) {
getInstruction()->emitWarning(message); getInstruction()->emitWarning(message);
} }
/// Emit a note about this operation, reporting up to any diagnostic /// Emit a note about this operation, reporting up to any diagnostic
/// handlers that may be listening. /// handlers that may be listening.
void OpState::emitNote(const Twine &message) const { void OpState::emitNote(const Twine &message) {
getInstruction()->emitNote(message); getInstruction()->emitNote(message);
} }

View File

@ -898,7 +898,7 @@ Block *CondBranchOp::getFalseDest() {
return getInstruction()->getSuccessor(falseIndex); return getInstruction()->getSuccessor(falseIndex);
} }
unsigned CondBranchOp::getNumTrueOperands() const { unsigned CondBranchOp::getNumTrueOperands() {
return getInstruction()->getNumSuccessorOperands(trueIndex); return getInstruction()->getNumSuccessorOperands(trueIndex);
} }
@ -906,7 +906,7 @@ void CondBranchOp::eraseTrueOperand(unsigned index) {
getInstruction()->eraseSuccessorOperand(trueIndex, index); getInstruction()->eraseSuccessorOperand(trueIndex, index);
} }
unsigned CondBranchOp::getNumFalseOperands() const { unsigned CondBranchOp::getNumFalseOperands() {
return getInstruction()->getNumSuccessorOperands(falseIndex); return getInstruction()->getNumSuccessorOperands(falseIndex);
} }

View File

@ -135,7 +135,7 @@ LogicalResult mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
forOp->getLoc(), forOp->getConstantLowerBound()); forOp->getLoc(), forOp->getConstantLowerBound());
iv->replaceAllUsesWith(constOp); iv->replaceAllUsesWith(constOp);
} else { } else {
const AffineBound lb = forOp->getLowerBound(); AffineBound lb = forOp->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end()); SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
if (lb.getMap() == builder.getDimIdentityMap()) { if (lb.getMap() == builder.getDimIdentityMap()) {

View File

@ -944,7 +944,7 @@ vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch,
/// element type. /// element type.
/// If `type` is not a valid vector type or if the scalar constant is not a /// If `type` is not a valid vector type or if the scalar constant is not a
/// valid vector element type, returns nullptr. /// valid vector element type, returns nullptr.
static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, static Value *vectorizeConstant(Instruction *inst, ConstantOp constant,
Type type) { Type type) {
if (!type || !type.isa<VectorType>() || if (!type || !type.isa<VectorType>() ||
!VectorType::isValidElementType(constant.getType())) { !VectorType::isValidElementType(constant.getType())) {