forked from OSchip/llvm-project
Avoid overflow when lowering linalg.slice
linalg.subview used to lower to a slice with a bounded range resulting in correct bounded accesses. However linalg.slice could still index out of bounds. This CL moves the bounding to linalg.slice. LLVM select and cmp ops gain a more idiomatic builder. PiperOrigin-RevId: 264897125
This commit is contained in:
parent
140b28ec12
commit
6f1d4bb8df
|
@ -164,6 +164,13 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
|
||||||
let llvmBuilder = [{
|
let llvmBuilder = [{
|
||||||
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
|
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
|
||||||
}];
|
}];
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *b, OperationState *result, ICmpPredicate predicate, Value *lhs, "
|
||||||
|
"Value *rhs", [{
|
||||||
|
LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
|
||||||
|
build(b, result, LLVMType::getInt1Ty(dialect),
|
||||||
|
b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
|
||||||
|
}]>];
|
||||||
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
|
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
|
||||||
let printer = [{ printICmpOp(p, *this); }];
|
let printer = [{ printICmpOp(p, *this); }];
|
||||||
}
|
}
|
||||||
|
@ -386,6 +393,11 @@ def LLVM_SelectOp
|
||||||
LLVM_Type:$falseValue)>,
|
LLVM_Type:$falseValue)>,
|
||||||
LLVM_Builder<
|
LLVM_Builder<
|
||||||
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
|
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *b, OperationState *result, Value *condition, Value *lhs, "
|
||||||
|
"Value *rhs", [{
|
||||||
|
build(b, result, lhs->getType(), condition, lhs, rhs);
|
||||||
|
}]>];
|
||||||
let parser = [{ return parseSelectOp(parser, result); }];
|
let parser = [{ return parseSelectOp(parser, result); }];
|
||||||
let printer = [{ printSelectOp(p, *this); }];
|
let printer = [{ printSelectOp(p, *this); }];
|
||||||
}
|
}
|
||||||
|
@ -550,5 +562,4 @@ def LLVM_fmuladd : LLVM_Op<"intr.fmuladd", [NoSideEffect]>,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // LLVMIR_OPS
|
#endif // LLVMIR_OPS
|
||||||
|
|
|
@ -415,7 +415,8 @@ public:
|
||||||
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
||||||
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
||||||
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
||||||
/// and stride corresponding to the
|
/// and stride corresponding to the region of memory within the bounds of
|
||||||
|
/// the parent view.
|
||||||
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
||||||
/// The linalg.slice op is replaced by the alloca'ed pointer.
|
/// The linalg.slice op is replaced by the alloca'ed pointer.
|
||||||
class SliceOpConversion : public LLVMOpLowering {
|
class SliceOpConversion : public LLVMOpLowering {
|
||||||
|
@ -446,6 +447,8 @@ public:
|
||||||
auto ib = rewriter.getInsertionBlock();
|
auto ib = rewriter.getInsertionBlock();
|
||||||
rewriter.setInsertionPointToStart(
|
rewriter.setInsertionPointToStart(
|
||||||
&op->getParentOfType<FuncOp>().getBlocks().front());
|
&op->getParentOfType<FuncOp>().getBlocks().front());
|
||||||
|
Value *zero =
|
||||||
|
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||||
Value *one =
|
Value *one =
|
||||||
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||||
// Alloca with proper alignment.
|
// Alloca with proper alignment.
|
||||||
|
@ -470,12 +473,10 @@ public:
|
||||||
Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
|
Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
|
||||||
for (int i = 0, e = viewType.getRank(); i < e; ++i) {
|
for (int i = 0, e = viewType.getRank(); i < e; ++i) {
|
||||||
Value *indexing = adaptor.indexings()[i];
|
Value *indexing = adaptor.indexings()[i];
|
||||||
Value *min =
|
Value *min = indexing;
|
||||||
sliceOp.indexing(i)->getType().isa<RangeType>()
|
if (sliceOp.indexing(i)->getType().isa<RangeType>())
|
||||||
? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
|
min = extractvalue(int64Ty, indexing, pos(0));
|
||||||
: indexing;
|
baseOffset = add(baseOffset, mul(min, strides[i]));
|
||||||
Value *product = mul(min, strides[i]);
|
|
||||||
baseOffset = add(baseOffset, product);
|
|
||||||
}
|
}
|
||||||
desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
|
desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
|
||||||
|
|
||||||
|
@ -485,13 +486,21 @@ public:
|
||||||
for (auto en : llvm::enumerate(sliceOp.indexings())) {
|
for (auto en : llvm::enumerate(sliceOp.indexings())) {
|
||||||
Value *indexing = en.value();
|
Value *indexing = en.value();
|
||||||
if (indexing->getType().isa<RangeType>()) {
|
if (indexing->getType().isa<RangeType>()) {
|
||||||
int i = en.index();
|
int rank = en.index();
|
||||||
Value *rangeDescriptor = adaptor.indexings()[i];
|
Value *rangeDescriptor = adaptor.indexings()[rank];
|
||||||
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
||||||
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
||||||
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
|
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
|
||||||
|
Value *baseSize =
|
||||||
|
extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank}));
|
||||||
|
// Bound upper by base view upper bound.
|
||||||
|
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
|
||||||
|
baseSize);
|
||||||
Value *size = sub(max, min);
|
Value *size = sub(max, min);
|
||||||
Value *stride = mul(strides[i], step);
|
// Bound lower by zero.
|
||||||
|
size =
|
||||||
|
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
|
||||||
|
Value *stride = mul(strides[rank], step);
|
||||||
desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
|
desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
|
||||||
desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
|
desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
|
||||||
++numNewDims;
|
++numNewDims;
|
||||||
|
@ -703,16 +712,8 @@ static void lowerLinalgSubViewOps(FuncOp &f) {
|
||||||
ScopedContext scope(b, op.getLoc());
|
ScopedContext scope(b, op.getLoc());
|
||||||
auto *view = op.getView();
|
auto *view = op.getView();
|
||||||
SmallVector<Value *, 8> ranges;
|
SmallVector<Value *, 8> ranges;
|
||||||
for (auto en : llvm::enumerate(op.getRanges())) {
|
for (auto sliceRange : op.getRanges())
|
||||||
using edsc::op::operator<;
|
ranges.push_back(range(sliceRange.min, sliceRange.max, sliceRange.step));
|
||||||
using linalg::intrinsics::dim;
|
|
||||||
unsigned rank = en.index();
|
|
||||||
auto sliceRange = en.value();
|
|
||||||
auto size = dim(view, rank);
|
|
||||||
ValueHandle ub(sliceRange.max);
|
|
||||||
auto max = edsc::intrinsics::select(size < ub, size, ub);
|
|
||||||
ranges.push_back(range(sliceRange.min, max, sliceRange.step));
|
|
||||||
}
|
|
||||||
op.replaceAllUsesWith(slice(view, ranges));
|
op.replaceAllUsesWith(slice(view, ranges));
|
||||||
op.erase();
|
op.erase();
|
||||||
});
|
});
|
||||||
|
|
|
@ -96,18 +96,30 @@ func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
|
||||||
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
|
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
|
||||||
// 3rd load from slice_op
|
// 3rd load from slice_op
|
||||||
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
|
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
|
||||||
|
// insert data ptr
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
// CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
// CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
// CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
|
// insert offset
|
||||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||||
|
// get size[0] from parent view
|
||||||
|
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
|
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||||
|
// compute size[0] bounded by parent view's size[0]
|
||||||
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
|
// bound below by 0
|
||||||
|
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
|
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||||
|
// compute stride[0] using bounded size
|
||||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
|
// insert size and stride
|
||||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
|
||||||
|
@ -131,21 +143,29 @@ func @subview(%arg0: !linalg.view<?x?xf32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @subview
|
// CHECK-LABEL: func @subview
|
||||||
// CHECK: llvm.constant(0 : index) : !llvm.i64
|
//
|
||||||
|
// Subview lowers to range + slice op
|
||||||
|
// CHECK: llvm.alloca %{{.*}} x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
|
||||||
|
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||||
|
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||||
|
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
|
||||||
|
//
|
||||||
|
// Select occurs in slice op lowering
|
||||||
// CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
// CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||||
// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||||
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
|
//
|
||||||
// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||||
// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||||
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||||
|
|
||||||
func @view_with_range_and_index(%arg0: !linalg.view<?x?xf64>) {
|
func @view_with_range_and_index(%arg0: !linalg.view<?x?xf64>) {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
|
|
Loading…
Reference in New Issue