From 323fd11df7718e68c37f9220a8e1056bb56778cf Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 13 Oct 2020 21:31:40 +0200 Subject: [PATCH] [mlir][nfc] Add a func to compute numElements of a shape in Std -> LLVM. For some reason the variable `cumulativeSizeInBytes` in `getCumulativeSizeInBytes` was actually storing number of elements. I decided to fix it and refactor the function a bit. Differential Revision: https://reviews.llvm.org/D89336 --- .../StandardToLLVM/ConvertStandardToLLVM.h | 4 ++++ .../StandardToLLVM/StandardToLLVM.cpp | 23 +++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 645f4cd26581..36734f809175 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -470,6 +470,10 @@ protected: Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const; + /// Computes total number of elements for the given shape. + Value getNumElements(Location loc, ArrayRef shape, + ConversionPatternRewriter &rewriter) const; + /// Computes total size in bytes of to store the given shape. Value getCumulativeSizeInBytes(Location loc, Type elementType, ArrayRef shape, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index e042fc3d1c4e..3fe60f5e88d4 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -979,18 +979,23 @@ Value ConvertToLLVMPattern::getSizeInBytes( return rewriter.create(loc, getIndexType(), gep); } +Value ConvertToLLVMPattern::getNumElements( + Location loc, ArrayRef shape, + ConversionPatternRewriter &rewriter) const { + // Compute the total number of memref elements. + Value numElements = + shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); + for (unsigned i = 1, e = shape.size(); i < e; ++i) + numElements = rewriter.create(loc, numElements, shape[i]); + return numElements; +} + Value ConvertToLLVMPattern::getCumulativeSizeInBytes( Location loc, Type elementType, ArrayRef shape, ConversionPatternRewriter &rewriter) const { - // Compute the total number of memref elements. - Value cumulativeSizeInBytes = - shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); - for (unsigned i = 1, e = shape.size(); i < e; ++i) - cumulativeSizeInBytes = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, shape[i]}); - auto elementSize = this->getSizeInBytes(loc, elementType, rewriter); - return rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, elementSize}); + Value numElements = this->getNumElements(loc, shape, rewriter); + Value elementSize = this->getSizeInBytes(loc, elementType, rewriter); + return rewriter.create(loc, numElements, elementSize); } /// Creates and populates the memref descriptor struct given all its fields.