Add linalg.range_intersect conversion to LLVM.

This CL adds lowering for linalg.range_intersect into LLVM by computing:
      * new_min <- max (range1.min, range2.min)
      * new_max <- min (range1.max, range2.max)
      * new_step <- range1.step * range2.step

--

PiperOrigin-RevId: 248571810
This commit is contained in:
Nicolas Vasilache 2019-05-16 11:57:36 -07:00 committed by Mehdi Amini
parent a4317d1a59
commit 13dbad87f6
2 changed files with 87 additions and 19 deletions

View File

@ -60,6 +60,8 @@ using call = OperationBuilder<mlir::LLVM::CallOp>;
using gep = ValueBuilder<mlir::LLVM::GEPOp>;
using llvm_load = ValueBuilder<LLVM::LoadOp>;
using llvm_store = OperationBuilder<LLVM::StoreOp>;
using llvm_select = ValueBuilder<LLVM::SelectOp>;
using icmp = ValueBuilder<LLVM::ICmpOp>;
template <typename T>
static llvm::Type *getPtrToElementType(T containerType,
@ -145,8 +147,7 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
// Create an array attribute containing integer attributes with values provided
// in `position`.
static ArrayAttr makePositionAttr(FuncBuilder &builder,
ArrayRef<int> position) {
static ArrayAttr positionAttr(FuncBuilder &builder, ArrayRef<int> position) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(position.size());
for (auto p : position)
@ -203,9 +204,9 @@ public:
allocated = bitcast(elementPtrType, allocated);
Value *desc = undef(bufferDescriptorType);
desc = insertvalue(bufferDescriptorType, desc, allocated,
makePositionAttr(rewriter, 0));
positionAttr(rewriter, 0));
desc = insertvalue(bufferDescriptorType, desc, size,
makePositionAttr(rewriter, 1));
positionAttr(rewriter, 1));
return {desc};
}
};
@ -239,9 +240,8 @@ public:
// Emit MLIR for buffer_dealloc.
edsc::ScopedContext context(rewriter, op->getLoc());
Value *casted =
bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
makePositionAttr(rewriter, 0)));
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
positionAttr(rewriter, 0)));
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
return {};
@ -258,7 +258,7 @@ public:
FuncBuilder &rewriter) const override {
auto int64Ty = lowering.convertType(operands[0]->getType());
edsc::ScopedContext context(rewriter, op->getLoc());
return {extractvalue(int64Ty, operands[0], makePositionAttr(rewriter, 1))};
return {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))};
}
};
@ -275,7 +275,7 @@ public:
edsc::ScopedContext context(rewriter, op->getLoc());
return {extractvalue(
indexTy, operands[0],
makePositionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
}
};
@ -299,7 +299,7 @@ public:
getPtrToElementType(loadOp.getViewType(), lowering));
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
return makePositionAttr(rewriter, values);
return positionAttr(rewriter, values);
};
// Linearize subscripts as:
@ -349,11 +349,57 @@ public:
// Fill in an aggregate value of the descriptor.
Value *desc = undef(rangeDescriptorTy);
desc = insertvalue(rangeDescriptorTy, desc, operands[0],
makePositionAttr(rewriter, 0));
positionAttr(rewriter, 0));
desc = insertvalue(rangeDescriptorTy, desc, operands[1],
makePositionAttr(rewriter, 1));
positionAttr(rewriter, 1));
desc = insertvalue(rangeDescriptorTy, desc, operands[2],
makePositionAttr(rewriter, 2));
positionAttr(rewriter, 2));
return {desc};
}
};
// RangeIntersectOp creates a new range descriptor.
class RangeIntersectOpConversion : public LLVMOpLowering {
public:
explicit RangeIntersectOpConversion(MLIRContext *context,
LLVMLowering &lowering_)
: LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto rangeIntersectOp = cast<RangeIntersectOp>(op);
auto rangeDescriptorTy =
convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto int1Ty = lowering.convertType(rewriter.getIntegerType(1));
edsc::ScopedContext context(rewriter, op->getLoc());
auto min1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 0));
auto min2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 0));
auto max1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1));
auto max2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 1));
auto step1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 2));
auto step2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 2));
// Fill in an aggregate value of the descriptor.
auto SLE =
rewriter.getI64IntegerAttr(static_cast<int64_t>(CmpIPredicate::SLE));
auto SGE =
rewriter.getI64IntegerAttr(static_cast<int64_t>(CmpIPredicate::SGE));
Value *desc = undef(rangeDescriptorTy);
desc = insertvalue(
rangeDescriptorTy, desc,
llvm_select(int64Ty, icmp(int1Ty, SGE, min1, min2), min1, min2),
positionAttr(rewriter, 0));
desc = insertvalue(
rangeDescriptorTy, desc,
llvm_select(int64Ty, icmp(int1Ty, SLE, max1, max2), max1, max2),
positionAttr(rewriter, 1));
// TODO(ntv): this assumes both steps are one for now. Enforce and extend.
desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2),
positionAttr(rewriter, 2));
return {desc};
}
@ -374,7 +420,7 @@ public:
// Helper function to create an integer array attribute out of a list of
// values.
auto pos = [&rewriter](ArrayRef<int> values) {
return makePositionAttr(rewriter, values);
return positionAttr(rewriter, values);
};
// Helper function to obtain the ptr of the given `view`.
auto getViewPtr = [pos, &rewriter, this](ViewType type,
@ -471,7 +517,7 @@ public:
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
auto pos = [&rewriter](ArrayRef<int> values) {
return makePositionAttr(rewriter, values);
return positionAttr(rewriter, values);
};
// First operand to `view` is the buffer descriptor.
@ -545,10 +591,10 @@ protected:
return ConversionListBuilder<
BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion, DotOpConversion,
LoadOpConversion, RangeOpConversion, SliceOpConversion,
StoreOpConversion, ViewOpConversion>::build(&converterStorage,
llvmDialect->getContext(),
*this);
LoadOpConversion, RangeOpConversion, RangeIntersectOpConversion,
SliceOpConversion, StoreOpConversion,
ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
*this);
}
Type convertAdditionalType(Type t) override {

View File

@ -80,3 +80,25 @@ func @dim(%arg0: !linalg.view<?x?xf32>) {
}
// CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
// CHECK: %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range {
%0 = linalg.range_intersect %arg0, %arg1 : !linalg.range
return %0 : !linalg.range
}
// CHECK-LABEL: func @range_intersect(%arg0: !llvm<"{ i64, i64, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">) -> !llvm<"{ i64, i64, i64 }"> {
// CHECK: %0 = llvm.extractvalue %arg0[0] : !llvm<"{ i64, i64, i64 }">
// CHECK: %1 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
// CHECK: %2 = llvm.extractvalue %arg0[1] : !llvm<"{ i64, i64, i64 }">
// CHECK: %3 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
// CHECK: %4 = llvm.extractvalue %arg0[2] : !llvm<"{ i64, i64, i64 }">
// CHECK: %5 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
// CHECK: %6 = llvm.undef : !llvm<"{ i64, i64, i64 }">
// CHECK: %7 = llvm.icmp "sge" %0, %1 : !llvm.i64
// CHECK: %8 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i64
// CHECK: %9 = llvm.insertvalue %8, %6[0] : !llvm<"{ i64, i64, i64 }">
// CHECK: %10 = llvm.icmp "sle" %2, %3 : !llvm.i64
// CHECK: %11 = llvm.select %10, %2, %3 : !llvm.i1, !llvm.i64
// CHECK: %12 = llvm.insertvalue %11, %9[1] : !llvm<"{ i64, i64, i64 }">
// CHECK: %13 = llvm.mul %4, %5 : !llvm.i64
// CHECK: %14 = llvm.insertvalue %13, %12[2] : !llvm<"{ i64, i64, i64 }">
// CHECK: llvm.return %14 : !llvm<"{ i64, i64, i64 }">