forked from OSchip/llvm-project
Emit LLVM IR equivalent of sizeof when lowering alloc operations
Originally, the lowering of `alloc` operations has been computing the number of bytes to allocate when lowering based on the properties of MLIR type. This does not take into account type legalization that happens when compiling LLVM IR down to target assembly. This legalization can widen the type, potentially leading to out-of-bounds accesses to `alloc`ed data due to mismatches between address computation that takes the widening into account and allocation that does not. Use the LLVM IR's equivalent of `sizeof` to compute the number of bytes to be allocated: %0 = getelementptr %type* null, %indexType 0 %1 = ptrtoint %type* %0 to %indexType adapted from http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt PiperOrigin-RevId: 274159900
This commit is contained in:
parent
71b82bcbf6
commit
8c2ea32072
|
@ -658,21 +658,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
op->getLoc(), getIndexType(),
|
||||
ArrayRef<Value *>{cumulativeSize, sizes[i]});
|
||||
|
||||
// Compute the total amount of bytes to allocate.
|
||||
// Compute the size of an individual element. This emits the MLIR equivalent
|
||||
// of the following sizeof(...) implementation in LLVM IR:
|
||||
// %0 = getelementptr %elementType* null, %indexType 1
|
||||
// %1 = ptrtoint %elementType* %0 to %indexType
|
||||
// which is a common pattern of getting the size of a type in bytes.
|
||||
auto elementType = type.getElementType();
|
||||
assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
|
||||
"invalid memref element type");
|
||||
uint64_t elementSize = 0;
|
||||
if (auto vectorType = elementType.dyn_cast<VectorType>())
|
||||
elementSize = vectorType.getNumElements() *
|
||||
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
|
||||
else
|
||||
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||
auto convertedPtrType =
|
||||
lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo();
|
||||
auto nullPtr =
|
||||
rewriter.create<LLVM::NullOp>(op->getLoc(), convertedPtrType);
|
||||
auto one = createIndexConstant(rewriter, op->getLoc(), 1);
|
||||
auto gep = rewriter.create<LLVM::GEPOp>(op->getLoc(), convertedPtrType,
|
||||
ArrayRef<Value *>{nullPtr, one});
|
||||
auto elementSize =
|
||||
rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), gep);
|
||||
cumulativeSize = rewriter.create<LLVM::MulOp>(
|
||||
op->getLoc(), getIndexType(),
|
||||
ArrayRef<Value *>{
|
||||
cumulativeSize,
|
||||
createIndexConstant(rewriter, op->getLoc(), elementSize)});
|
||||
ArrayRef<Value *>{cumulativeSize, elementSize});
|
||||
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
|
|
|
@ -22,8 +22,11 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
|
|||
// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, i64 }"> {
|
||||
func @zero_d_alloc() -> memref<f32> {
|
||||
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
|
||||
// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64 }">
|
||||
|
@ -50,8 +53,11 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
|
|||
// CHECK-NEXT: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.mul %[[M]], %[[c42]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
|
||||
// CHECK-NEXT: llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
|
||||
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
|
||||
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
|
||||
|
@ -87,8 +93,11 @@ func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
|
|||
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
|
||||
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
|
||||
// CHECK-NEXT: llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
|
||||
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
|
||||
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
@ -118,13 +127,16 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
|
|||
|
||||
// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {
|
||||
func @static_alloc() -> memref<32x18xf32> {
|
||||
// CHECK-NEXT: %0 = llvm.mlir.constant(32 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %1 = llvm.mlir.constant(18 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %2 = llvm.mul %0, %1 : !llvm.i64
|
||||
// CHECK-NEXT: %3 = llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %4 = llvm.mul %2, %3 : !llvm.i64
|
||||
// CHECK-NEXT: %5 = llvm.call @malloc(%4) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: %6 = llvm.bitcast %5 : !llvm<"i8*"> to !llvm<"float*">
|
||||
// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64
|
||||
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
|
||||
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
|
||||
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
|
||||
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
|
||||
%0 = alloc() : memref<32x18xf32>
|
||||
return %0 : memref<32x18xf32>
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ def multiply_transpose(a, b) {
|
|||
}
|
||||
|
||||
# CHECK: define void @main() {
|
||||
# CHECK: %1 = call i8* @malloc(i64 48)
|
||||
# CHECK: %1 = call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i64 1) to i64), i64 6))
|
||||
def main() {
|
||||
var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
|
|
Loading…
Reference in New Issue