forked from OSchip/llvm-project
[mlir][nfc] Add a func to compute numElements of a shape in Std -> LLVM.
For some reason the variable `cumulativeSizeInBytes` in `getCumulativeSizeInBytes` was actually storing number of elements. I decided to fix it and refactor the function a bit. Differential Revision: https://reviews.llvm.org/D89336
This commit is contained in:
parent
e79ca751fc
commit
323fd11df7
|
@ -470,6 +470,10 @@ protected:
|
|||
Value getSizeInBytes(Location loc, Type type,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Computes total number of elements for the given shape.
|
||||
Value getNumElements(Location loc, ArrayRef<Value> shape,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Computes total size in bytes of to store the given shape.
|
||||
Value getCumulativeSizeInBytes(Location loc, Type elementType,
|
||||
ArrayRef<Value> shape,
|
||||
|
|
|
@ -979,18 +979,23 @@ Value ConvertToLLVMPattern::getSizeInBytes(
|
|||
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getNumElements(
|
||||
Location loc, ArrayRef<Value> shape,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Compute the total number of memref elements.
|
||||
Value numElements =
|
||||
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
|
||||
for (unsigned i = 1, e = shape.size(); i < e; ++i)
|
||||
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
|
||||
return numElements;
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
|
||||
Location loc, Type elementType, ArrayRef<Value> shape,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Compute the total number of memref elements.
|
||||
Value cumulativeSizeInBytes =
|
||||
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
|
||||
for (unsigned i = 1, e = shape.size(); i < e; ++i)
|
||||
cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
|
||||
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, shape[i]});
|
||||
auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
|
||||
return rewriter.create<LLVM::MulOp>(
|
||||
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
|
||||
Value numElements = this->getNumElements(loc, shape, rewriter);
|
||||
Value elementSize = this->getSizeInBytes(loc, elementType, rewriter);
|
||||
return rewriter.create<LLVM::MulOp>(loc, numElements, elementSize);
|
||||
}
|
||||
|
||||
/// Creates and populates the memref descriptor struct given all its fields.
|
||||
|
|
Loading…
Reference in New Issue