Emit LLVM IR equivalent of sizeof when lowering alloc operations

Originally, the lowering of `alloc` operations has been computing the number of
bytes to allocate when lowering based on the properties of MLIR type. This does
not take into account type legalization that happens when compiling LLVM IR
down to target assembly. This legalization can widen the type, potentially
leading to out-of-bounds accesses to `alloc`ed data due to mismatches between
address computation that takes the widening into account and allocation that
does not. Use the LLVM IR's equivalent of `sizeof` to compute the number of
bytes to be allocated:
  %0 = getelementptr %type* null, %indexType 0
  %1 = ptrtoint %type* %0 to %indexType
adapted from
http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt

PiperOrigin-RevId: 274159900
This commit is contained in:
Alex Zinenko 2019-10-11 06:22:40 -07:00 committed by A. Unique TensorFlower
parent 71b82bcbf6
commit 8c2ea32072
3 changed files with 41 additions and 26 deletions

View File

@ -658,21 +658,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
op->getLoc(), getIndexType(),
ArrayRef<Value *>{cumulativeSize, sizes[i]});
// Compute the total amount of bytes to allocate.
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto elementType = type.getElementType();
assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
"invalid memref element type");
uint64_t elementSize = 0;
if (auto vectorType = elementType.dyn_cast<VectorType>())
elementSize = vectorType.getNumElements() *
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
auto convertedPtrType =
lowering.convertType(elementType).cast<LLVM::LLVMType>().getPointerTo();
auto nullPtr =
rewriter.create<LLVM::NullOp>(op->getLoc(), convertedPtrType);
auto one = createIndexConstant(rewriter, op->getLoc(), 1);
auto gep = rewriter.create<LLVM::GEPOp>(op->getLoc(), convertedPtrType,
ArrayRef<Value *>{nullPtr, one});
auto elementSize =
rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), gep);
cumulativeSize = rewriter.create<LLVM::MulOp>(
op->getLoc(), getIndexType(),
ArrayRef<Value *>{
cumulativeSize,
createIndexConstant(rewriter, op->getLoc(), elementSize)});
ArrayRef<Value *>{cumulativeSize, elementSize});
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();

View File

@ -22,8 +22,11 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, i64 }"> {
func @zero_d_alloc() -> memref<f32> {
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64 }">
@ -50,8 +53,11 @@ func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
// CHECK-NEXT: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: llvm.mul %[[M]], %[[c42]] : !llvm.i64
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64
// CHECK-NEXT: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %{{.*}} : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
@ -87,8 +93,11 @@ func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
// CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {
func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64
// CHECK-NEXT: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %{{.*}} : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
@ -118,13 +127,16 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
// CHECK-LABEL: func @static_alloc() -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {
func @static_alloc() -> memref<32x18xf32> {
// CHECK-NEXT: %0 = llvm.mlir.constant(32 : index) : !llvm.i64
// CHECK-NEXT: %1 = llvm.mlir.constant(18 : index) : !llvm.i64
// CHECK-NEXT: %2 = llvm.mul %0, %1 : !llvm.i64
// CHECK-NEXT: %3 = llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK-NEXT: %4 = llvm.mul %2, %3 : !llvm.i64
// CHECK-NEXT: %5 = llvm.call @malloc(%4) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %6 = llvm.bitcast %5 : !llvm<"i8*"> to !llvm<"float*">
// CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64
// CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64
// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm<"float*"> to !llvm.i64
// CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64
// CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm<"i8*"> to !llvm<"float*">
%0 = alloc() : memref<32x18xf32>
return %0 : memref<32x18xf32>
}

View File

@ -6,7 +6,7 @@ def multiply_transpose(a, b) {
}
# CHECK: define void @main() {
# CHECK: %1 = call i8* @malloc(i64 48)
# CHECK: %1 = call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i64 1) to i64), i64 6))
def main() {
var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
var b<2, 3> = [1, 2, 3, 4, 5, 6];