[mlir][nfc] Add a func to compute numElements of a shape in Std -> LLVM.

For some reason the variable `cumulativeSizeInBytes` in
`getCumulativeSizeInBytes` was actually storing number of elements. I decided
to fix it and refactor the function a bit.

Differential Revision: https://reviews.llvm.org/D89336
This commit is contained in:
Alexander Belyaev 2020-10-13 21:31:40 +02:00
parent e79ca751fc
commit 323fd11df7
2 changed files with 18 additions and 9 deletions

View File

@ -470,6 +470,10 @@ protected:
Value getSizeInBytes(Location loc, Type type,
ConversionPatternRewriter &rewriter) const;
/// Computes total number of elements for the given shape.
Value getNumElements(Location loc, ArrayRef<Value> shape,
ConversionPatternRewriter &rewriter) const;
/// Computes total size in bytes of to store the given shape.
Value getCumulativeSizeInBytes(Location loc, Type elementType,
ArrayRef<Value> shape,

View File

@ -979,18 +979,23 @@ Value ConvertToLLVMPattern::getSizeInBytes(
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
}
Value ConvertToLLVMPattern::getNumElements(
Location loc, ArrayRef<Value> shape,
ConversionPatternRewriter &rewriter) const {
// Compute the total number of memref elements.
Value numElements =
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
for (unsigned i = 1, e = shape.size(); i < e; ++i)
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
return numElements;
}
Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
Location loc, Type elementType, ArrayRef<Value> shape,
ConversionPatternRewriter &rewriter) const {
// Compute the total number of memref elements.
Value cumulativeSizeInBytes =
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
for (unsigned i = 1, e = shape.size(); i < e; ++i)
cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, shape[i]});
auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
return rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
Value numElements = this->getNumElements(loc, shape, rewriter);
Value elementSize = this->getSizeInBytes(loc, elementType, rewriter);
return rewriter.create<LLVM::MulOp>(loc, numElements, elementSize);
}
/// Creates and populates the memref descriptor struct given all its fields.