diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 7318c0066922..754fb48bb26f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -67,21 +67,26 @@ public: /// Array type utilities. LLVMType getArrayElementType(); unsigned getArrayNumElements(); + bool isArrayTy(); /// Vector type utilities. LLVMType getVectorElementType(); + bool isVectorTy(); /// Function type utilities. LLVMType getFunctionParamType(unsigned argIdx); unsigned getFunctionNumParams(); LLVMType getFunctionResultType(); + bool isFunctionTy(); /// Pointer type utilities. LLVMType getPointerTo(unsigned addrSpace = 0); LLVMType getPointerElementTy(); + bool isPointerTy(); /// Struct type utilities. LLVMType getStructElementType(unsigned i); + bool isStructTy(); /// Utilities used to generate floating point types. static LLVMType getDoubleTy(LLVMDialect *dialect); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3e4815a5f32b..3697f5d50f55 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -137,6 +137,7 @@ public: ArrayAttr getAffineMapArrayAttr(ArrayRef values); ArrayAttr getI32ArrayAttr(ArrayRef values); ArrayAttr getI64ArrayAttr(ArrayRef values); + ArrayAttr getIndexArrayAttr(ArrayRef values); ArrayAttr getF32ArrayAttr(ArrayRef values); ArrayAttr getF64ArrayAttr(ArrayRef values); ArrayAttr getStrArrayAttr(ArrayRef values); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index e33da63f6b79..5e9c8787b673 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -346,58 +346,156 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), - this->getIntegerArrayAttr(rewriter, i))); + rewriter.getIndexArrayAttr(i))); } rewriter.replaceOp(op, results); return this->matchSuccess(); } }; +// Express `linearIndex` in terms of coordinates of `basis`. +// Returns the empty vector when linearIndex is out of the range [0, P] where +// P is the product of all the basis coordinates. +// +// Prerequisites: +// Basis is an array of nonnegative integers (signed type inherited from +// vector shape type). +static SmallVector getCoordinates(ArrayRef basis, + unsigned linearIndex) { + SmallVector res; + res.reserve(basis.size()); + for (unsigned basisElement : llvm::reverse(basis)) { + res.push_back(linearIndex % basisElement); + linearIndex = linearIndex / basisElement; + } + if (linearIndex > 0) + return {}; + std::reverse(res.begin(), res.end()); + return res; +} + +// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect +// Ops for binary ops with one result. This supports higher-dimensional vector +// types. +template +struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + using Super = BinaryOpLLVMOpLowering; + + // Convert the type of the result to an LLVM type, pass operands as is, + // preserve attributes. + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + static_assert( + std::is_base_of::Impl, SourceOp>::value, + "expected binary op"); + static_assert( + std::is_base_of, SourceOp>::value, + "expected single result op"); + static_assert(std::is_base_of, + SourceOp>::value, + "expected single result op"); + + auto loc = op->getLoc(); + auto llvmArrayTy = operands[0]->getType().cast(); + + if (!llvmArrayTy.isArrayTy()) { + auto newOp = rewriter.create( + op->getLoc(), operands[0]->getType(), operands, op->getAttrs()); + rewriter.replaceOp(op, newOp.getResult()); + return this->matchSuccess(); + } + + // Unroll iterated array type until we hit a non-array type. + auto llvmTy = llvmArrayTy; + SmallVector arraySizes; + while (llvmTy.isArrayTy()) { + arraySizes.push_back(llvmTy.getArrayNumElements()); + llvmTy = llvmTy.getArrayElementType(); + } + assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type"); + auto llvmVectorTy = llvmTy; + + // Iteratively extract a position coordinates with basis `arraySize` from a + // `linearIndex` that is incremented at each step. This terminates when + // `linearIndex` exceeds the range specified by `arraySize`. + // This has the effect of fully unrolling the dimensions of the n-D array + // type, getting to the underlying vector element. + Value *desc = rewriter.create(loc, llvmArrayTy); + unsigned ub = 1; + for (auto s : arraySizes) + ub *= s; + for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { + auto coords = getCoordinates(arraySizes, linearIndex); + // Linear index is out of bounds, we are done. + if (coords.empty()) + break; + + auto position = rewriter.getIndexArrayAttr(coords); + + // For this unrolled `position` corresponding to the `linearIndex`^th + // element, extract operand vectors + Value *extractedLHS = rewriter.create( + loc, llvmVectorTy, operands[0], position); + Value *extractedRHS = rewriter.create( + loc, llvmVectorTy, operands[1], position); + Value *newVal = rewriter.create( + loc, llvmVectorTy, ArrayRef{extractedLHS, extractedRHS}, + op->getAttrs()); + desc = rewriter.create(loc, llvmArrayTy, desc, + newVal, position); + } + rewriter.replaceOp(op, desc); + return this->matchSuccess(); + } +}; + // Specific lowerings. // FIXME: this should be tablegen'ed. -struct AddIOpLowering : public OneToOneLLVMOpLowering { +struct AddIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct SubIOpLowering : public OneToOneLLVMOpLowering { +struct SubIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct MulIOpLowering : public OneToOneLLVMOpLowering { +struct MulIOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct DivISOpLowering : public OneToOneLLVMOpLowering { +struct DivISOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct DivIUOpLowering : public OneToOneLLVMOpLowering { +struct DivIUOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct RemISOpLowering : public OneToOneLLVMOpLowering { +struct RemISOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct RemIUOpLowering : public OneToOneLLVMOpLowering { +struct RemIUOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct AndOpLowering : public OneToOneLLVMOpLowering { +struct AndOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct OrOpLowering : public OneToOneLLVMOpLowering { +struct OrOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct XOrOpLowering : public OneToOneLLVMOpLowering { +struct XOrOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct AddFOpLowering : public OneToOneLLVMOpLowering { +struct AddFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct SubFOpLowering : public OneToOneLLVMOpLowering { +struct SubFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct MulFOpLowering : public OneToOneLLVMOpLowering { +struct MulFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct DivFOpLowering : public OneToOneLLVMOpLowering { +struct DivFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; -struct RemFOpLowering : public OneToOneLLVMOpLowering { +struct RemFOpLowering : public BinaryOpLLVMOpLowering { using Super::Super; }; struct SelectOpLowering @@ -516,14 +614,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern { memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, allocated, - getIntegerArrayAttr(rewriter, 0)); + rewriter.getIndexArrayAttr(0)); // Store dynamically allocated sizes in the descriptor. Dynamic sizes are // passed in as operands. for (auto indexedSize : llvm::enumerate(operands)) { memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, indexedSize.value(), - getIntegerArrayAttr(rewriter, 1 + indexedSize.index())); + rewriter.getIndexArrayAttr(1 + indexedSize.index())); } // Return the final value of the descriptor. @@ -553,7 +651,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { } auto type = transformed.memref()->getType().cast(); - auto hasStaticShape = type.getUnderlyingType()->isPointerTy(); + auto hasStaticShape = type.isPointerTy(); Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0); Value *bufferPtr = extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(), @@ -603,7 +701,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { // Otherwise target type is dynamic memref, so create a proper descriptor. newDescriptor = rewriter.create( op->getLoc(), structType, newDescriptor, buffer, - getIntegerArrayAttr(rewriter, 0)); + rewriter.getIndexArrayAttr(0)); // Fill in the dynamic sizes of the new descriptor. If the size was // dynamic, copy it from the old descriptor. If the size was static, insert @@ -626,11 +724,11 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { ? rewriter.create( op->getLoc(), getIndexType(), transformed.source(), // NB: dynamic memref - getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++)) + rewriter.getIndexArrayAttr(sourceDynamicDimIdx++)) : createIndexConstant(rewriter, op->getLoc(), sourceSize); newDescriptor = rewriter.create( op->getLoc(), structType, newDescriptor, size, - getIntegerArrayAttr(rewriter, targetDynamicDimIdx++)); + rewriter.getIndexArrayAttr(targetDynamicDimIdx++)); } assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() && "source dynamic dimensions were not processed"); @@ -673,7 +771,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { } rewriter.replaceOpWithNewOp( op, getIndexType(), transformed.memrefOrTensor(), - getIntegerArrayAttr(rewriter, position)); + rewriter.getIndexArrayAttr(position)); } else { rewriter.replaceOp( op, createIndexConstant(rewriter, op->getLoc(), shape[index])); @@ -739,7 +837,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { if (s == -1) { Value *size = rewriter.create( loc, this->getIndexType(), memRefDescriptor, - this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++)); + rewriter.getIndexArrayAttr(dynamicSizeIdx++)); sizes.push_back(size); } else { sizes.push_back(this->createIndexConstant(rewriter, loc, s)); @@ -751,8 +849,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes); Value *dataPtr = rewriter.create( - loc, elementTypePtr, memRefDescriptor, - this->getIntegerArrayAttr(rewriter, 0)); + loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0)); return rewriter.create(loc, elementTypePtr, ArrayRef{dataPtr, subscript}, ArrayRef{}); @@ -970,7 +1067,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( op->getLoc(), packedType, packed, operands[i], - getIntegerArrayAttr(rewriter, i)); + rewriter.getIndexArrayAttr(i)); } rewriter.replaceOpWithNewOp( op, llvm::makeArrayRef(packed), llvm::ArrayRef(), diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 906cf3443474..7a2d4f45211a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1281,11 +1281,13 @@ LLVMType LLVMType::getArrayElementType() { unsigned LLVMType::getArrayNumElements() { return getUnderlyingType()->getArrayNumElements(); } +bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); } /// Vector type utilities. LLVMType LLVMType::getVectorElementType() { return get(getContext(), getUnderlyingType()->getVectorElementType()); } +bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } /// Function type utilities. LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { @@ -1299,6 +1301,7 @@ LLVMType LLVMType::getFunctionResultType() { getContext(), llvm::cast(getUnderlyingType())->getReturnType()); } +bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); } /// Pointer type utilities. LLVMType LLVMType::getPointerTo(unsigned addrSpace) { @@ -1310,11 +1313,13 @@ LLVMType LLVMType::getPointerTo(unsigned addrSpace) { LLVMType LLVMType::getPointerElementTy() { return get(getContext(), getUnderlyingType()->getPointerElementType()); } +bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } /// Struct type utilities. LLVMType LLVMType::getStructElementType(unsigned i) { return get(getContext(), getUnderlyingType()->getStructElementType(i)); } +bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); } /// Utilities used to generate floating point types. LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 2ade7b9f28a4..067ff7af6443 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -218,6 +218,15 @@ ArrayAttr Builder::getI64ArrayAttr(ArrayRef values) { return getArrayAttr(attrs); } +ArrayAttr Builder::getIndexArrayAttr(ArrayRef values) { + auto attrs = functional::map( + [this](int64_t v) -> Attribute { + return getIntegerAttr(IndexType::get(getContext()), v); + }, + values); + return getArrayAttr(attrs); +} + ArrayAttr Builder::getF32ArrayAttr(ArrayRef values) { auto attrs = functional::map( [this](float v) -> Attribute { return getF32FloatAttr(v); }, values); diff --git a/mlir/test/LLVMIR/convert-to-llvmir.mlir b/mlir/test/LLVMIR/convert-to-llvmir.mlir index 65818b0b02bc..e4c4c61ed732 100644 --- a/mlir/test/LLVMIR/convert-to-llvmir.mlir +++ b/mlir/test/LLVMIR/convert-to-llvmir.mlir @@ -510,6 +510,31 @@ func @fcmp(f32, f32) -> () { %12 = cmpf "ule", %arg0, %arg1 : f32 %13 = cmpf "une", %arg0, %arg1 : f32 %14 = cmpf "uno", %arg0, %arg1 : f32 - - return + + return +} + +// CHECK-LABEL: @vec_bin +func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> { + %0 = addf %arg0, %arg0 : vector<2x2x2xf32> + return %0 : vector<2x2x2xf32> + +// CHECK-NEXT: llvm.undef : !llvm<"[2 x [2 x <2 x float>]]"> + +// This block appears 2x2 times +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK-NEXT: llvm.fadd %{{.*}} : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.insertvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> + +// We check the proper indexing of extract/insert in the remaining 3 positions. +// CHECK: llvm.extractvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.insertvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.extractvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.insertvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.extractvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.insertvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> + +// And we're done +// CHECK-NEXT: return }