Avoid overflow when lowering linalg.slice

linalg.subview used to lower to a slice with a bounded range resulting in correct bounded accesses. However linalg.slice could still index out of bounds. This CL moves the bounding to linalg.slice.

LLVM select and cmp ops gain a more idiomatic builder.

PiperOrigin-RevId: 264897125
This commit is contained in:
Nicolas Vasilache 2019-08-22 12:46:30 -07:00 committed by A. Unique TensorFlower
parent 140b28ec12
commit 6f1d4bb8df
3 changed files with 66 additions and 34 deletions

View File

@ -164,6 +164,13 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
let llvmBuilder = [{ let llvmBuilder = [{
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}]; }];
let builders = [OpBuilder<
"Builder *b, OperationState *result, ICmpPredicate predicate, Value *lhs, "
"Value *rhs", [{
LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
build(b, result, LLVMType::getInt1Ty(dialect),
b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
}]>];
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }]; let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
let printer = [{ printICmpOp(p, *this); }]; let printer = [{ printICmpOp(p, *this); }];
} }
@ -386,6 +393,11 @@ def LLVM_SelectOp
LLVM_Type:$falseValue)>, LLVM_Type:$falseValue)>,
LLVM_Builder< LLVM_Builder<
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
let builders = [OpBuilder<
"Builder *b, OperationState *result, Value *condition, Value *lhs, "
"Value *rhs", [{
build(b, result, lhs->getType(), condition, lhs, rhs);
}]>];
let parser = [{ return parseSelectOp(parser, result); }]; let parser = [{ return parseSelectOp(parser, result); }];
let printer = [{ printSelectOp(p, *this); }]; let printer = [{ printSelectOp(p, *this); }];
} }
@ -550,5 +562,4 @@ def LLVM_fmuladd : LLVM_Op<"intr.fmuladd", [NoSideEffect]>,
}]; }];
} }
#endif // LLVMIR_OPS #endif // LLVMIR_OPS

View File

@ -415,7 +415,8 @@ public:
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
/// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 2. A load of the ViewDescriptor from the pointer allocated in 1.
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
/// and stride corresponding to the /// and stride corresponding to the region of memory within the bounds of
/// the parent view.
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
/// The linalg.slice op is replaced by the alloca'ed pointer. /// The linalg.slice op is replaced by the alloca'ed pointer.
class SliceOpConversion : public LLVMOpLowering { class SliceOpConversion : public LLVMOpLowering {
@ -446,6 +447,8 @@ public:
auto ib = rewriter.getInsertionBlock(); auto ib = rewriter.getInsertionBlock();
rewriter.setInsertionPointToStart( rewriter.setInsertionPointToStart(
&op->getParentOfType<FuncOp>().getBlocks().front()); &op->getParentOfType<FuncOp>().getBlocks().front());
Value *zero =
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
Value *one = Value *one =
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
// Alloca with proper alignment. // Alloca with proper alignment.
@ -470,12 +473,10 @@ public:
Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView)); Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
for (int i = 0, e = viewType.getRank(); i < e; ++i) { for (int i = 0, e = viewType.getRank(); i < e; ++i) {
Value *indexing = adaptor.indexings()[i]; Value *indexing = adaptor.indexings()[i];
Value *min = Value *min = indexing;
sliceOp.indexing(i)->getType().isa<RangeType>() if (sliceOp.indexing(i)->getType().isa<RangeType>())
? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0))) min = extractvalue(int64Ty, indexing, pos(0));
: indexing; baseOffset = add(baseOffset, mul(min, strides[i]));
Value *product = mul(min, strides[i]);
baseOffset = add(baseOffset, product);
} }
desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
@ -485,13 +486,21 @@ public:
for (auto en : llvm::enumerate(sliceOp.indexings())) { for (auto en : llvm::enumerate(sliceOp.indexings())) {
Value *indexing = en.value(); Value *indexing = en.value();
if (indexing->getType().isa<RangeType>()) { if (indexing->getType().isa<RangeType>()) {
int i = en.index(); int rank = en.index();
Value *rangeDescriptor = adaptor.indexings()[i]; Value *rangeDescriptor = adaptor.indexings()[rank];
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
Value *baseSize =
extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank}));
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
Value *size = sub(max, min); Value *size = sub(max, min);
Value *stride = mul(strides[i], step); // Bound lower by zero.
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
Value *stride = mul(strides[rank], step);
desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims})); desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims})); desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
++numNewDims; ++numNewDims;
@ -703,16 +712,8 @@ static void lowerLinalgSubViewOps(FuncOp &f) {
ScopedContext scope(b, op.getLoc()); ScopedContext scope(b, op.getLoc());
auto *view = op.getView(); auto *view = op.getView();
SmallVector<Value *, 8> ranges; SmallVector<Value *, 8> ranges;
for (auto en : llvm::enumerate(op.getRanges())) { for (auto sliceRange : op.getRanges())
using edsc::op::operator<; ranges.push_back(range(sliceRange.min, sliceRange.max, sliceRange.step));
using linalg::intrinsics::dim;
unsigned rank = en.index();
auto sliceRange = en.value();
auto size = dim(view, rank);
ValueHandle ub(sliceRange.max);
auto max = edsc::intrinsics::select(size < ub, size, ub);
ranges.push_back(range(sliceRange.min, max, sliceRange.step));
}
op.replaceAllUsesWith(slice(view, ranges)); op.replaceAllUsesWith(slice(view, ranges));
op.erase(); op.erase();
}); });

View File

@ -96,18 +96,30 @@ func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> // CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
// 3rd load from slice_op // 3rd load from slice_op
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*"> // CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
// insert data ptr
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
// insert offset
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
// get size[0] from parent view
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
// compute size[0] bounded by parent view's size[0]
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
// bound below by 0
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
// compute stride[0] using bounded size
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// insert size and stride
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
@ -131,21 +143,29 @@ func @subview(%arg0: !linalg.view<?x?xf32>) {
return return
} }
// CHECK-LABEL: func @subview // CHECK-LABEL: func @subview
// CHECK: llvm.constant(0 : index) : !llvm.i64 //
// Subview lowers to range + slice op
// CHECK: llvm.alloca %{{.*}} x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
//
// Select occurs in slice op lowering
// CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 // CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
//
// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 // CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
func @view_with_range_and_index(%arg0: !linalg.view<?x?xf64>) { func @view_with_range_and_index(%arg0: !linalg.view<?x?xf64>) {
%c0 = constant 0 : index %c0 = constant 0 : index