forked from OSchip/llvm-project
Add linalg.range_intersect conversion to LLVM.
This CL adds lowering for linalg.range_intersect into LLVM by computing: * new_min <- max (range1.min, range2.min) * new_max <- min (range1.max, range2.max) * new_step <- range1.step * range2.step -- PiperOrigin-RevId: 248571810
This commit is contained in:
parent
a4317d1a59
commit
13dbad87f6
|
@ -60,6 +60,8 @@ using call = OperationBuilder<mlir::LLVM::CallOp>;
|
|||
using gep = ValueBuilder<mlir::LLVM::GEPOp>;
|
||||
using llvm_load = ValueBuilder<LLVM::LoadOp>;
|
||||
using llvm_store = OperationBuilder<LLVM::StoreOp>;
|
||||
using llvm_select = ValueBuilder<LLVM::SelectOp>;
|
||||
using icmp = ValueBuilder<LLVM::ICmpOp>;
|
||||
|
||||
template <typename T>
|
||||
static llvm::Type *getPtrToElementType(T containerType,
|
||||
|
@ -145,8 +147,7 @@ static Type convertLinalgType(Type t, LLVMLowering &lowering) {
|
|||
|
||||
// Create an array attribute containing integer attributes with values provided
|
||||
// in `position`.
|
||||
static ArrayAttr makePositionAttr(FuncBuilder &builder,
|
||||
ArrayRef<int> position) {
|
||||
static ArrayAttr positionAttr(FuncBuilder &builder, ArrayRef<int> position) {
|
||||
SmallVector<Attribute, 4> attrs;
|
||||
attrs.reserve(position.size());
|
||||
for (auto p : position)
|
||||
|
@ -203,9 +204,9 @@ public:
|
|||
allocated = bitcast(elementPtrType, allocated);
|
||||
Value *desc = undef(bufferDescriptorType);
|
||||
desc = insertvalue(bufferDescriptorType, desc, allocated,
|
||||
makePositionAttr(rewriter, 0));
|
||||
positionAttr(rewriter, 0));
|
||||
desc = insertvalue(bufferDescriptorType, desc, size,
|
||||
makePositionAttr(rewriter, 1));
|
||||
positionAttr(rewriter, 1));
|
||||
return {desc};
|
||||
}
|
||||
};
|
||||
|
@ -239,9 +240,8 @@ public:
|
|||
|
||||
// Emit MLIR for buffer_dealloc.
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
Value *casted =
|
||||
bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
|
||||
makePositionAttr(rewriter, 0)));
|
||||
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
|
||||
positionAttr(rewriter, 0)));
|
||||
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
|
||||
|
||||
return {};
|
||||
|
@ -258,7 +258,7 @@ public:
|
|||
FuncBuilder &rewriter) const override {
|
||||
auto int64Ty = lowering.convertType(operands[0]->getType());
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
return {extractvalue(int64Ty, operands[0], makePositionAttr(rewriter, 1))};
|
||||
return {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -275,7 +275,7 @@ public:
|
|||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
return {extractvalue(
|
||||
indexTy, operands[0],
|
||||
makePositionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
|
||||
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -299,7 +299,7 @@ public:
|
|||
getPtrToElementType(loadOp.getViewType(), lowering));
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
auto pos = [&rewriter](ArrayRef<int> values) {
|
||||
return makePositionAttr(rewriter, values);
|
||||
return positionAttr(rewriter, values);
|
||||
};
|
||||
|
||||
// Linearize subscripts as:
|
||||
|
@ -349,11 +349,57 @@ public:
|
|||
// Fill in an aggregate value of the descriptor.
|
||||
Value *desc = undef(rangeDescriptorTy);
|
||||
desc = insertvalue(rangeDescriptorTy, desc, operands[0],
|
||||
makePositionAttr(rewriter, 0));
|
||||
positionAttr(rewriter, 0));
|
||||
desc = insertvalue(rangeDescriptorTy, desc, operands[1],
|
||||
makePositionAttr(rewriter, 1));
|
||||
positionAttr(rewriter, 1));
|
||||
desc = insertvalue(rangeDescriptorTy, desc, operands[2],
|
||||
makePositionAttr(rewriter, 2));
|
||||
positionAttr(rewriter, 2));
|
||||
|
||||
return {desc};
|
||||
}
|
||||
};
|
||||
|
||||
// RangeIntersectOp creates a new range descriptor.
|
||||
class RangeIntersectOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit RangeIntersectOpConversion(MLIRContext *context,
|
||||
LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
auto rangeIntersectOp = cast<RangeIntersectOp>(op);
|
||||
auto rangeDescriptorTy =
|
||||
convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
auto int1Ty = lowering.convertType(rewriter.getIntegerType(1));
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
auto min1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 0));
|
||||
auto min2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 0));
|
||||
auto max1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1));
|
||||
auto max2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 1));
|
||||
auto step1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 2));
|
||||
auto step2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 2));
|
||||
|
||||
// Fill in an aggregate value of the descriptor.
|
||||
auto SLE =
|
||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(CmpIPredicate::SLE));
|
||||
auto SGE =
|
||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(CmpIPredicate::SGE));
|
||||
Value *desc = undef(rangeDescriptorTy);
|
||||
desc = insertvalue(
|
||||
rangeDescriptorTy, desc,
|
||||
llvm_select(int64Ty, icmp(int1Ty, SGE, min1, min2), min1, min2),
|
||||
positionAttr(rewriter, 0));
|
||||
desc = insertvalue(
|
||||
rangeDescriptorTy, desc,
|
||||
llvm_select(int64Ty, icmp(int1Ty, SLE, max1, max2), max1, max2),
|
||||
positionAttr(rewriter, 1));
|
||||
// TODO(ntv): this assumes both steps are one for now. Enforce and extend.
|
||||
desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2),
|
||||
positionAttr(rewriter, 2));
|
||||
|
||||
return {desc};
|
||||
}
|
||||
|
@ -374,7 +420,7 @@ public:
|
|||
// Helper function to create an integer array attribute out of a list of
|
||||
// values.
|
||||
auto pos = [&rewriter](ArrayRef<int> values) {
|
||||
return makePositionAttr(rewriter, values);
|
||||
return positionAttr(rewriter, values);
|
||||
};
|
||||
// Helper function to obtain the ptr of the given `view`.
|
||||
auto getViewPtr = [pos, &rewriter, this](ViewType type,
|
||||
|
@ -471,7 +517,7 @@ public:
|
|||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
|
||||
auto pos = [&rewriter](ArrayRef<int> values) {
|
||||
return makePositionAttr(rewriter, values);
|
||||
return positionAttr(rewriter, values);
|
||||
};
|
||||
|
||||
// First operand to `view` is the buffer descriptor.
|
||||
|
@ -545,10 +591,10 @@ protected:
|
|||
return ConversionListBuilder<
|
||||
BufferAllocOpConversion, BufferDeallocOpConversion,
|
||||
BufferSizeOpConversion, DimOpConversion, DotOpConversion,
|
||||
LoadOpConversion, RangeOpConversion, SliceOpConversion,
|
||||
StoreOpConversion, ViewOpConversion>::build(&converterStorage,
|
||||
llvmDialect->getContext(),
|
||||
*this);
|
||||
LoadOpConversion, RangeOpConversion, RangeIntersectOpConversion,
|
||||
SliceOpConversion, StoreOpConversion,
|
||||
ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
|
||||
*this);
|
||||
}
|
||||
|
||||
Type convertAdditionalType(Type t) override {
|
||||
|
|
|
@ -80,3 +80,25 @@ func @dim(%arg0: !linalg.view<?x?xf32>) {
|
|||
}
|
||||
// CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
|
||||
// CHECK: %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
||||
func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range {
|
||||
%0 = linalg.range_intersect %arg0, %arg1 : !linalg.range
|
||||
return %0 : !linalg.range
|
||||
}
|
||||
// CHECK-LABEL: func @range_intersect(%arg0: !llvm<"{ i64, i64, i64 }">, %arg1: !llvm<"{ i64, i64, i64 }">) -> !llvm<"{ i64, i64, i64 }"> {
|
||||
// CHECK: %0 = llvm.extractvalue %arg0[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %1 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %2 = llvm.extractvalue %arg0[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %3 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %4 = llvm.extractvalue %arg0[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %5 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %6 = llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %7 = llvm.icmp "sge" %0, %1 : !llvm.i64
|
||||
// CHECK: %8 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i64
|
||||
// CHECK: %9 = llvm.insertvalue %8, %6[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %10 = llvm.icmp "sle" %2, %3 : !llvm.i64
|
||||
// CHECK: %11 = llvm.select %10, %2, %3 : !llvm.i1, !llvm.i64
|
||||
// CHECK: %12 = llvm.insertvalue %11, %9[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %13 = llvm.mul %4, %5 : !llvm.i64
|
||||
// CHECK: %14 = llvm.insertvalue %13, %12[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.return %14 : !llvm<"{ i64, i64, i64 }">
|
||||
|
|
Loading…
Reference in New Issue