Add getMemRefType() accessors to LoadOp/StoreOp.

- There are several places where we are casting the type of the memref obtained
  from the load/store op to a memref type, and this will become even more
  common (some upcoming CLs this week). Add a getMemRefType and use it at
  several places where the cast was being used.

PiperOrigin-RevId: 219164326
This commit is contained in:
Uday Bondhugula 2018-10-29 11:39:55 -07:00 committed by jpienaar
parent 582b0761c6
commit bdfd6193b8
3 changed files with 9 additions and 3 deletions

View File

@ -460,6 +460,9 @@ public:
SSAValue *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); }
void setMemRef(SSAValue *value) { setOperand(0, value); }
MemRefType *getMemRefType() const {
return cast<MemRefType>(getMemRef()->getType());
}
llvm::iterator_range<Operation::operand_iterator> getIndices() {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
@ -580,6 +583,9 @@ public:
SSAValue *getMemRef() { return getOperand(1); }
const SSAValue *getMemRef() const { return getOperand(1); }
void setMemRef(SSAValue *value) { setOperand(1, value); }
MemRefType *getMemRefType() const {
return cast<MemRefType>(getMemRef()->getType());
}
llvm::iterator_range<Operation::operand_iterator> getIndices() {
return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};

View File

@ -157,7 +157,7 @@ static bool isAccessInvariant(MLValue *input, MemRefType *memRefType,
template <typename LoadOrStoreOpPointer>
static bool isContiguousAccess(MLValue *input, LoadOrStoreOpPointer memoryOp) {
auto indicesAsOperandIterators = memoryOp->getIndices();
auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
auto *memRefType = memoryOp->getMemRefType();
SmallVector<MLValue *, 4> indices;
for (auto *it : indicesAsOperandIterators) {
indices.push_back(cast<MLValue>(it));

View File

@ -753,7 +753,7 @@ void LoadOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRef()->getType();
*p << " : " << *getMemRefType();
}
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
@ -928,7 +928,7 @@ void StoreOp::print(OpAsmPrinter *p) const {
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
*p << " : " << *getMemRef()->getType();
*p << " : " << *getMemRefType();
}
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {