Static cast size_t -> int64_t instead of vice versa for equals comparisons

These were just introduced by a previous CL moving MemRef getRank to return int64_t. size_t could be smaller than 64 bits and in equals comparisons, signed vs unsigned doesn't matter. In these cases, we know right now that the particular int64_t is not larger than max size_t (because it currently comes directly from a size() call), the alternative cast plus equals comparison is always safe, so we might as well do it that way and no longer require reasoning deeper into the callstack.

    We are already assuming that size() calls fit into int64_t in a number of other cases like the aforementioned getRank() (since exabytes of RAM are rare). If we want to avoid this assumption we will have to come up with a principled way to do it throughout.

--

PiperOrigin-RevId: 250980297
This commit is contained in:
Geoffrey Martin-Noble 2019-05-31 16:41:21 -07:00 committed by Mehdi Amini
parent e7b337acf8
commit c912981bbd
3 changed files with 14 additions and 13 deletions

View File

@ -40,7 +40,7 @@ void linalg::ViewOp::build(Builder *b, OperationState *result, Value *memRef,
ArrayRef<Value *> indexings) {
MemRefType memRefType = memRef->getType().cast<MemRefType>();
result->addOperands({memRef});
assert(indexings.size() == static_cast<size_t>(memRefType.getRank()) &&
assert(static_cast<int64_t>(indexings.size()) == memRefType.getRank() &&
"unexpected number of indexings (must match the memref rank)");
result->addOperands(indexings);
@ -107,7 +107,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
if (!memRefType)
return parser->emitError(parser->getNameLoc(),
"memRef type expected for first type");
if (indexingsInfo.size() != static_cast<size_t>(memRefType.getRank()))
if (static_cast<int64_t>(indexingsInfo.size()) != memRefType.getRank())
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(memRefType.getRank()) +
" indexings");
@ -116,7 +116,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), "view type expected");
ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back();
if (indexingTypes.size() != static_cast<size_t>(memRefType.getRank()))
if (static_cast<int64_t>(indexingTypes.size()) != memRefType.getRank())
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(memRefType.getRank()) +
" indexing types");

View File

@ -338,7 +338,8 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
auto newMemRefType = MemRefType::get(
newShapeConstants, memrefType.getElementType(),
memrefType.getAffineMaps(), memrefType.getMemorySpace());
assert(newOperands.size() == newMemRefType.getNumDynamicDims());
assert(static_cast<int64_t>(newOperands.size()) ==
newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
auto newAlloc =
@ -1459,15 +1460,15 @@ ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
}
// Check that source/destination index list size matches associated rank.
if (srcIndexInfos.size() !=
static_cast<size_t>(types[0].cast<MemRefType>().getRank()) ||
dstIndexInfos.size() !=
static_cast<size_t>(types[1].cast<MemRefType>().getRank()))
if (static_cast<int64_t>(srcIndexInfos.size()) !=
types[0].cast<MemRefType>().getRank() ||
static_cast<int64_t>(dstIndexInfos.size()) !=
types[1].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"memref rank not equal to indices count");
if (tagIndexInfos.size() !=
static_cast<size_t>(types[2].cast<MemRefType>().getRank()))
if (static_cast<int64_t>(tagIndexInfos.size()) !=
types[2].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
@ -1546,8 +1547,8 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(),
"expected tag to be of memref type");
if (tagIndexInfos.size() !=
static_cast<size_t>(type.cast<MemRefType>().getRank()))
if (static_cast<int64_t>(tagIndexInfos.size()) !=
type.cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");

View File

@ -148,7 +148,7 @@ ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
// Extract optional paddingValue.
// At this point, indexInfo may contain the optional paddingValue, pop it out.
if (indexInfo.size() != static_cast<size_t>(memrefType.getRank()))
if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(memrefType.getRank()) +
" indices to the memref");