From 6f1d4bb8dfde5023aad26319e14d7e051dfc4d95 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 22 Aug 2019 12:46:30 -0700 Subject: [PATCH] Avoid overflow when lowering linalg.slice linalg.subview used to lower to a slice with a bounded range resulting in correct bounded accesses. However linalg.slice could still index out of bounds. This CL moves the bounding to linalg.slice. LLVM select and cmp ops gain a more idiomatic builder. PiperOrigin-RevId: 264897125 --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 13 +++++- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 41 +++++++++-------- mlir/test/Linalg/llvm.mlir | 46 +++++++++++++------ 3 files changed, 66 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 10533cc72dec..fcba2b7bc95d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -164,6 +164,13 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>, let llvmBuilder = [{ $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; + let builders = [OpBuilder< + "Builder *b, OperationState *result, ICmpPredicate predicate, Value *lhs, " + "Value *rhs", [{ + LLVMDialect *dialect = &lhs->getType().cast().getDialect(); + build(b, result, LLVMType::getInt1Ty(dialect), + b->getI64IntegerAttr(static_cast(predicate)), lhs, rhs); + }]>]; let parser = [{ return parseCmpOp(parser, result); }]; let printer = [{ printICmpOp(p, *this); }]; } @@ -386,6 +393,11 @@ def LLVM_SelectOp LLVM_Type:$falseValue)>, LLVM_Builder< "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { + let builders = [OpBuilder< + "Builder *b, OperationState *result, Value *condition, Value *lhs, " + "Value *rhs", [{ + build(b, result, lhs->getType(), condition, lhs, rhs); + }]>]; let parser = [{ return parseSelectOp(parser, result); }]; let printer = [{ printSelectOp(p, *this); }]; } @@ -550,5 +562,4 @@ def LLVM_fmuladd : LLVM_Op<"intr.fmuladd", [NoSideEffect]>, }]; } - #endif // LLVMIR_OPS diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 1e8f07674afc..b6e0430bee36 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -415,7 +415,8 @@ public: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size -/// and stride corresponding to the +/// and stride corresponding to the region of memory within the bounds of +/// the parent view. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The linalg.slice op is replaced by the alloca'ed pointer. class SliceOpConversion : public LLVMOpLowering { @@ -446,6 +447,8 @@ public: auto ib = rewriter.getInsertionBlock(); rewriter.setInsertionPointToStart( &op->getParentOfType().getBlocks().front()); + Value *zero = + constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); Value *one = constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); // Alloca with proper alignment. @@ -470,12 +473,10 @@ public: Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView)); for (int i = 0, e = viewType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; - Value *min = - sliceOp.indexing(i)->getType().isa() - ? static_cast(extractvalue(int64Ty, indexing, pos(0))) - : indexing; - Value *product = mul(min, strides[i]); - baseOffset = add(baseOffset, product); + Value *min = indexing; + if (sliceOp.indexing(i)->getType().isa()) + min = extractvalue(int64Ty, indexing, pos(0)); + baseOffset = add(baseOffset, mul(min, strides[i])); } desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); @@ -485,13 +486,21 @@ public: for (auto en : llvm::enumerate(sliceOp.indexings())) { Value *indexing = en.value(); if (indexing->getType().isa()) { - int i = en.index(); - Value *rangeDescriptor = adaptor.indexings()[i]; + int rank = en.index(); + Value *rangeDescriptor = adaptor.indexings()[rank]; Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value *baseSize = + extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank})); + // Bound upper by base view upper bound. + max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, + baseSize); Value *size = sub(max, min); - Value *stride = mul(strides[i], step); + // Bound lower by zero. + size = + llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); + Value *stride = mul(strides[rank], step); desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims})); desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims})); ++numNewDims; @@ -703,16 +712,8 @@ static void lowerLinalgSubViewOps(FuncOp &f) { ScopedContext scope(b, op.getLoc()); auto *view = op.getView(); SmallVector ranges; - for (auto en : llvm::enumerate(op.getRanges())) { - using edsc::op::operator<; - using linalg::intrinsics::dim; - unsigned rank = en.index(); - auto sliceRange = en.value(); - auto size = dim(view, rank); - ValueHandle ub(sliceRange.max); - auto max = edsc::intrinsics::select(size < ub, size, ub); - ranges.push_back(range(sliceRange.min, max, sliceRange.step)); - } + for (auto sliceRange : op.getRanges()) + ranges.push_back(range(sliceRange.min, sliceRange.max, sliceRange.step)); op.replaceAllUsesWith(slice(view, ranges)); op.erase(); }); diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index ceb140dc5787..ea3d9d0ca8c3 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -96,18 +96,30 @@ func @slice(%arg0: !linalg.buffer, %arg1: !linalg.range) { // CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> // 3rd load from slice_op // CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> +// insert data ptr // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64 +// insert offset // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> +// get size[0] from parent view +// CHECK-NEXT: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 +// compute size[0] bounded by parent view's size[0] // CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64 +// bound below by 0 +// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 +// compute stride[0] using bounded size // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 +// insert size and stride // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> @@ -131,21 +143,29 @@ func @subview(%arg0: !linalg.view) { return } // CHECK-LABEL: func @subview -// CHECK: llvm.constant(0 : index) : !llvm.i64 +// +// Subview lowers to range + slice op +// CHECK: llvm.alloca %{{.*}} x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*"> +// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }"> +// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }"> +// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*"> +// +// Select occurs in slice op lowering // CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 -// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> +// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 +// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 +// // CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 -// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> +// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 +// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 +// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 +// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 func @view_with_range_and_index(%arg0: !linalg.view) { %c0 = constant 0 : index