forked from OSchip/llvm-project
Avoid overflow when lowering linalg.slice
linalg.subview used to lower to a slice with a bounded range resulting in correct bounded accesses. However linalg.slice could still index out of bounds. This CL moves the bounding to linalg.slice. LLVM select and cmp ops gain a more idiomatic builder. PiperOrigin-RevId: 264897125
This commit is contained in:
parent
140b28ec12
commit
6f1d4bb8df
|
@ -164,6 +164,13 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
|
|||
let llvmBuilder = [{
|
||||
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState *result, ICmpPredicate predicate, Value *lhs, "
|
||||
"Value *rhs", [{
|
||||
LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
|
||||
build(b, result, LLVMType::getInt1Ty(dialect),
|
||||
b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
|
||||
}]>];
|
||||
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
|
||||
let printer = [{ printICmpOp(p, *this); }];
|
||||
}
|
||||
|
@ -386,6 +393,11 @@ def LLVM_SelectOp
|
|||
LLVM_Type:$falseValue)>,
|
||||
LLVM_Builder<
|
||||
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState *result, Value *condition, Value *lhs, "
|
||||
"Value *rhs", [{
|
||||
build(b, result, lhs->getType(), condition, lhs, rhs);
|
||||
}]>];
|
||||
let parser = [{ return parseSelectOp(parser, result); }];
|
||||
let printer = [{ printSelectOp(p, *this); }];
|
||||
}
|
||||
|
@ -550,5 +562,4 @@ def LLVM_fmuladd : LLVM_Op<"intr.fmuladd", [NoSideEffect]>,
|
|||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // LLVMIR_OPS
|
||||
|
|
|
@ -415,7 +415,8 @@ public:
|
|||
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
||||
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
||||
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
||||
/// and stride corresponding to the
|
||||
/// and stride corresponding to the region of memory within the bounds of
|
||||
/// the parent view.
|
||||
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
||||
/// The linalg.slice op is replaced by the alloca'ed pointer.
|
||||
class SliceOpConversion : public LLVMOpLowering {
|
||||
|
@ -446,6 +447,8 @@ public:
|
|||
auto ib = rewriter.getInsertionBlock();
|
||||
rewriter.setInsertionPointToStart(
|
||||
&op->getParentOfType<FuncOp>().getBlocks().front());
|
||||
Value *zero =
|
||||
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
Value *one =
|
||||
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
// Alloca with proper alignment.
|
||||
|
@ -470,12 +473,10 @@ public:
|
|||
Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
|
||||
for (int i = 0, e = viewType.getRank(); i < e; ++i) {
|
||||
Value *indexing = adaptor.indexings()[i];
|
||||
Value *min =
|
||||
sliceOp.indexing(i)->getType().isa<RangeType>()
|
||||
? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
|
||||
: indexing;
|
||||
Value *product = mul(min, strides[i]);
|
||||
baseOffset = add(baseOffset, product);
|
||||
Value *min = indexing;
|
||||
if (sliceOp.indexing(i)->getType().isa<RangeType>())
|
||||
min = extractvalue(int64Ty, indexing, pos(0));
|
||||
baseOffset = add(baseOffset, mul(min, strides[i]));
|
||||
}
|
||||
desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
|
||||
|
||||
|
@ -485,13 +486,21 @@ public:
|
|||
for (auto en : llvm::enumerate(sliceOp.indexings())) {
|
||||
Value *indexing = en.value();
|
||||
if (indexing->getType().isa<RangeType>()) {
|
||||
int i = en.index();
|
||||
Value *rangeDescriptor = adaptor.indexings()[i];
|
||||
int rank = en.index();
|
||||
Value *rangeDescriptor = adaptor.indexings()[rank];
|
||||
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
||||
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
||||
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
|
||||
Value *baseSize =
|
||||
extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank}));
|
||||
// Bound upper by base view upper bound.
|
||||
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
|
||||
baseSize);
|
||||
Value *size = sub(max, min);
|
||||
Value *stride = mul(strides[i], step);
|
||||
// Bound lower by zero.
|
||||
size =
|
||||
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
|
||||
Value *stride = mul(strides[rank], step);
|
||||
desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
|
||||
desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
|
||||
++numNewDims;
|
||||
|
@ -703,16 +712,8 @@ static void lowerLinalgSubViewOps(FuncOp &f) {
|
|||
ScopedContext scope(b, op.getLoc());
|
||||
auto *view = op.getView();
|
||||
SmallVector<Value *, 8> ranges;
|
||||
for (auto en : llvm::enumerate(op.getRanges())) {
|
||||
using edsc::op::operator<;
|
||||
using linalg::intrinsics::dim;
|
||||
unsigned rank = en.index();
|
||||
auto sliceRange = en.value();
|
||||
auto size = dim(view, rank);
|
||||
ValueHandle ub(sliceRange.max);
|
||||
auto max = edsc::intrinsics::select(size < ub, size, ub);
|
||||
ranges.push_back(range(sliceRange.min, max, sliceRange.step));
|
||||
}
|
||||
for (auto sliceRange : op.getRanges())
|
||||
ranges.push_back(range(sliceRange.min, sliceRange.max, sliceRange.step));
|
||||
op.replaceAllUsesWith(slice(view, ranges));
|
||||
op.erase();
|
||||
});
|
||||
|
|
|
@ -96,18 +96,30 @@ func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
|
|||
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
|
||||
// 3rd load from slice_op
|
||||
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
|
||||
// insert data ptr
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// insert offset
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// get size[0] from parent view
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// compute size[0] bounded by parent view's size[0]
|
||||
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// bound below by 0
|
||||
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// compute stride[0] using bounded size
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// insert size and stride
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
|
||||
|
@ -131,21 +143,29 @@ func @subview(%arg0: !linalg.view<?x?xf32>) {
|
|||
return
|
||||
}
|
||||
// CHECK-LABEL: func @subview
|
||||
// CHECK: llvm.constant(0 : index) : !llvm.i64
|
||||
//
|
||||
// Subview lowers to range + slice op
|
||||
// CHECK: llvm.alloca %{{.*}} x !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
|
||||
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">
|
||||
//
|
||||
// Select occurs in slice op lowering
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
//
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// CHECK: llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.icmp "slt" %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
|
||||
func @view_with_range_and_index(%arg0: !linalg.view<?x?xf64>) {
|
||||
%c0 = constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue