forked from OSchip/llvm-project
Implement lowering of VectorTypeCastOp to LLVM
A VectorTypeCastOp can only be used to lower between statically sized contiguous memrefs of scalar and matching vector type. The sizes and strides are thus fully static and easy to determine. A relevant test is added. This is a step towards solving tensorflow/mlir#189. PiperOrigin-RevId: 275538981
This commit is contained in:
parent
02b3ea6038
commit
2823b68580
|
@ -82,6 +82,11 @@ public:
|
||||||
Value *promoteOneMemRefDescriptor(Location loc, Value *operand,
|
Value *promoteOneMemRefDescriptor(Location loc, Value *operand,
|
||||||
OpBuilder &builder);
|
OpBuilder &builder);
|
||||||
|
|
||||||
|
static constexpr unsigned kPtrPosInMemRefDescriptor = 0;
|
||||||
|
static constexpr unsigned kOffsetPosInMemRefDescriptor = 1;
|
||||||
|
static constexpr unsigned kSizePosInMemRefDescriptor = 2;
|
||||||
|
static constexpr unsigned kStridePosInMemRefDescriptor = 3;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// LLVM IR module used to parse/create types.
|
/// LLVM IR module used to parse/create types.
|
||||||
llvm::Module *module;
|
llvm::Module *module;
|
||||||
|
|
|
@ -156,10 +156,10 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
|
||||||
// int64_t sizes[Rank]; // omitted when rank == 0
|
// int64_t sizes[Rank]; // omitted when rank == 0
|
||||||
// int64_t strides[Rank]; // omitted when rank == 0
|
// int64_t strides[Rank]; // omitted when rank == 0
|
||||||
// };
|
// };
|
||||||
static unsigned kPtrPosInMemRefDescriptor = 0;
|
constexpr unsigned LLVMTypeConverter::kPtrPosInMemRefDescriptor;
|
||||||
static unsigned kOffsetPosInMemRefDescriptor = 1;
|
constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor;
|
||||||
static unsigned kSizePosInMemRefDescriptor = 2;
|
constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor;
|
||||||
static unsigned kStridePosInMemRefDescriptor = 3;
|
constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor;
|
||||||
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
SmallVector<int64_t, 4> strides;
|
SmallVector<int64_t, 4> strides;
|
||||||
|
@ -282,7 +282,8 @@ public:
|
||||||
Type elementTypePtr) {
|
Type elementTypePtr) {
|
||||||
return builder.create<LLVM::ExtractValueOp>(
|
return builder.create<LLVM::ExtractValueOp>(
|
||||||
loc, elementTypePtr, memref,
|
loc, elementTypePtr, memref,
|
||||||
builder.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
|
builder.getIndexArrayAttr(
|
||||||
|
LLVMTypeConverter::kPtrPosInMemRefDescriptor));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -763,11 +764,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||||
|
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||||
op->getLoc(), structType, memRefDescriptor, allocated,
|
op->getLoc(), structType, memRefDescriptor, allocated,
|
||||||
rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
|
rewriter.getIndexArrayAttr(
|
||||||
|
LLVMTypeConverter::kPtrPosInMemRefDescriptor));
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||||
op->getLoc(), structType, memRefDescriptor,
|
op->getLoc(), structType, memRefDescriptor,
|
||||||
createIndexConstant(rewriter, op->getLoc(), offset),
|
createIndexConstant(rewriter, op->getLoc(), offset),
|
||||||
rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor));
|
rewriter.getIndexArrayAttr(
|
||||||
|
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||||
|
|
||||||
if (type.getRank() == 0)
|
if (type.getRank() == 0)
|
||||||
// No size/stride descriptor in memref, return the descriptor value.
|
// No size/stride descriptor in memref, return the descriptor value.
|
||||||
|
@ -798,10 +801,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
||||||
int64_t index = indexedSize.index();
|
int64_t index = indexedSize.index();
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||||
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
|
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
|
||||||
rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
|
rewriter.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
|
||||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||||
op->getLoc(), structType, memRefDescriptor, strideValues[index],
|
op->getLoc(), structType, memRefDescriptor, strideValues[index],
|
||||||
rewriter.getI64ArrayAttr({kStridePosInMemRefDescriptor, index}));
|
rewriter.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the final value of the descriptor.
|
// Return the final value of the descriptor.
|
||||||
|
@ -896,7 +901,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
||||||
Type elementPtrType = type.getStructElementType(kPtrPosInMemRefDescriptor);
|
Type elementPtrType =
|
||||||
|
type.getStructElementType(LLVMTypeConverter::kPtrPosInMemRefDescriptor);
|
||||||
Value *bufferPtr = extractMemRefElementPtr(
|
Value *bufferPtr = extractMemRefElementPtr(
|
||||||
rewriter, op->getLoc(), transformed.memref(), elementPtrType);
|
rewriter, op->getLoc(), transformed.memref(), elementPtrType);
|
||||||
Value *casted = rewriter.create<LLVM::BitcastOp>(
|
Value *casted = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
@ -952,7 +958,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
||||||
if (ShapedType::isDynamic(shape[index]))
|
if (ShapedType::isDynamic(shape[index]))
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
||||||
op, getIndexType(), transformed.memrefOrTensor(),
|
op, getIndexType(), transformed.memrefOrTensor(),
|
||||||
rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index}));
|
rewriter.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
|
||||||
else
|
else
|
||||||
// Use constant for static size.
|
// Use constant for static size.
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(
|
||||||
|
@ -1015,7 +1022,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
||||||
offset == MemRefType::getDynamicStrideOrOffset()
|
offset == MemRefType::getDynamicStrideOrOffset()
|
||||||
? rewriter.create<LLVM::ExtractValueOp>(
|
? rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, indexTy, memRefDescriptor,
|
loc, indexTy, memRefDescriptor,
|
||||||
rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor))
|
rewriter.getIndexArrayAttr(
|
||||||
|
LLVMTypeConverter::kOffsetPosInMemRefDescriptor))
|
||||||
: this->createIndexConstant(rewriter, loc, offset);
|
: this->createIndexConstant(rewriter, loc, offset);
|
||||||
for (int i = 0, e = indices.size(); i < e; ++i) {
|
for (int i = 0, e = indices.size(); i < e; ++i) {
|
||||||
Value *stride;
|
Value *stride;
|
||||||
|
@ -1028,7 +1036,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
||||||
// Use dynamic stride.
|
// Use dynamic stride.
|
||||||
stride = rewriter.create<LLVM::ExtractValueOp>(
|
stride = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, indexTy, memRefDescriptor,
|
loc, indexTy, memRefDescriptor,
|
||||||
rewriter.getIndexArrayAttr({kStridePosInMemRefDescriptor, i}));
|
rewriter.getIndexArrayAttr(
|
||||||
|
{LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
|
||||||
}
|
}
|
||||||
Value *additionalOffset =
|
Value *additionalOffset =
|
||||||
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
|
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
|
||||||
|
|
|
@ -155,10 +155,112 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class VectorTypeCastOpConversion : public LLVMOpLowering {
|
||||||
|
public:
|
||||||
|
explicit VectorTypeCastOpConversion(MLIRContext *context,
|
||||||
|
LLVMTypeConverter &typeConverter)
|
||||||
|
: LLVMOpLowering(vector::VectorTypeCastOp::getOperationName(), context,
|
||||||
|
typeConverter) {}
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
vector::VectorTypeCastOp castOp = cast<vector::VectorTypeCastOp>(op);
|
||||||
|
MemRefType sourceMemRefType =
|
||||||
|
castOp.getOperand()->getType().cast<MemRefType>();
|
||||||
|
MemRefType targetMemRefType =
|
||||||
|
castOp.getResult()->getType().cast<MemRefType>();
|
||||||
|
|
||||||
|
// Only static shape casts supported atm.
|
||||||
|
if (!sourceMemRefType.hasStaticShape() ||
|
||||||
|
!targetMemRefType.hasStaticShape())
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
Value *sourceMemRef = operands[0];
|
||||||
|
auto llvmSourceDescriptorTy =
|
||||||
|
sourceMemRef->getType().dyn_cast<LLVM::LLVMType>();
|
||||||
|
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
|
||||||
|
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||||
|
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
Type llvmSourceElementTy = llvmSourceDescriptorTy.getStructElementType(
|
||||||
|
LLVMTypeConverter::kPtrPosInMemRefDescriptor);
|
||||||
|
Type llvmTargetElementTy = llvmTargetDescriptorTy.getStructElementType(
|
||||||
|
LLVMTypeConverter::kPtrPosInMemRefDescriptor);
|
||||||
|
|
||||||
|
int64_t offset;
|
||||||
|
SmallVector<int64_t, 4> strides;
|
||||||
|
auto successStrides =
|
||||||
|
getStridesAndOffset(targetMemRefType, strides, offset);
|
||||||
|
bool isContiguous = (strides.back() == 1);
|
||||||
|
if (isContiguous) {
|
||||||
|
auto sizes = targetMemRefType.getShape();
|
||||||
|
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
|
||||||
|
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
|
||||||
|
isContiguous = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Only contiguous tensors supported atm.
|
||||||
|
if (failed(successStrides) || !isContiguous)
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
|
||||||
|
|
||||||
|
// Create descriptor.
|
||||||
|
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmTargetDescriptorTy);
|
||||||
|
// Set ptr.
|
||||||
|
Value *ptr = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, llvmSourceElementTy, sourceMemRef,
|
||||||
|
rewriter.getIndexArrayAttr(
|
||||||
|
LLVMTypeConverter::kPtrPosInMemRefDescriptor));
|
||||||
|
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
||||||
|
desc = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
op->getLoc(), llvmTargetDescriptorTy, desc, ptr,
|
||||||
|
rewriter.getIndexArrayAttr(
|
||||||
|
LLVMTypeConverter::kPtrPosInMemRefDescriptor));
|
||||||
|
// Fill offset 0.
|
||||||
|
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
|
||||||
|
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
|
||||||
|
desc = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
op->getLoc(), llvmTargetDescriptorTy, desc, zero,
|
||||||
|
rewriter.getIndexArrayAttr(
|
||||||
|
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||||
|
// Fill size and stride descriptors in memref.
|
||||||
|
for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
|
||||||
|
int64_t index = indexedSize.index();
|
||||||
|
auto sizeAttr =
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
|
||||||
|
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
|
||||||
|
desc = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
op->getLoc(), llvmTargetDescriptorTy, desc, size,
|
||||||
|
rewriter.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kSizePosInMemRefDescriptor, index}));
|
||||||
|
auto strideAttr =
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
|
||||||
|
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
|
||||||
|
desc = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
op->getLoc(), llvmTargetDescriptorTy, desc, stride,
|
||||||
|
rewriter.getI64ArrayAttr(
|
||||||
|
{LLVMTypeConverter::kStridePosInMemRefDescriptor, index}));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, desc);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||||
void mlir::populateVectorToLLVMConversionPatterns(
|
void mlir::populateVectorToLLVMConversionPatterns(
|
||||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||||
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
|
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion,
|
||||||
|
VectorTypeCastOpConversion>(
|
||||||
converter.getDialect()->getContext(), converter);
|
converter.getDialect()->getContext(), converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,5 +292,5 @@ OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<LowerVectorToLLVMPass>
|
static PassRegistration<LowerVectorToLLVMPass>
|
||||||
pass("vector-lower-to-llvm-dialect",
|
pass("convert-vector-to-llvm",
|
||||||
"Lower the operations from the vector dialect into the LLVM dialect");
|
"Lower the operations from the vector dialect into the LLVM dialect");
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
|
// RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s
|
||||||
|
|
||||||
func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
|
func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
|
||||||
%2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
|
%2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
|
||||||
|
@ -47,3 +47,20 @@ func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 {
|
||||||
// CHECK: llvm.mlir.constant(0 : i32) : !llvm.i32
|
// CHECK: llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
// CHECK: llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>">
|
// CHECK: llvm.extractelement %{{.*}}, %{{.*}} : !llvm<"<16 x float>">
|
||||||
// CHECK: llvm.return %{{.*}} : !llvm.float
|
// CHECK: llvm.return %{{.*}} : !llvm.float
|
||||||
|
|
||||||
|
func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<1xvector<8x8x8xf32>> {
|
||||||
|
%0 = vector.type_cast %arg0: memref<8x8x8xf32>, memref<1xvector<8x8x8xf32>>
|
||||||
|
return %0 : memref<1xvector<8x8x8xf32>>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: vector_type_cast
|
||||||
|
// CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: %[[ptr:.*]] = llvm.extractvalue {{.*}}[0 : index] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
|
||||||
|
// CHECK: %[[bit:.*]] = llvm.bitcast %[[ptr]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*">
|
||||||
|
// CHECK: llvm.insertvalue %[[bit]], {{.*}}[0 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: llvm.mlir.constant(0 : index
|
||||||
|
// CHECK: llvm.insertvalue {{.*}}[1 : index] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: llvm.mlir.constant(1 : index
|
||||||
|
// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: llvm.mlir.constant(1 : index
|
||||||
|
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ [8 x [8 x <8 x float>]]*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue