From f55ac5c07643efa28a5bb621b08c0e5dc2f97f84 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache <ntv@google.com> Date: Tue, 20 Aug 2019 01:59:58 -0700 Subject: [PATCH] Add support for LLVM lowering of binary ops on n-D vector types This CL allows binary operations on n-D vector types to be lowered to LLVMIR by performing an (n-1)-D extractvalue, 1-D vector operation and an (n-1)-D insertvalue. PiperOrigin-RevId: 264339118 --- .../include/mlir/Dialect/LLVMIR/LLVMDialect.h | 5 + mlir/include/mlir/IR/Builders.h | 1 + .../StandardToLLVM/ConvertStandardToLLVM.cpp | 151 ++++++++++++++---- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 5 + mlir/lib/IR/Builders.cpp | 9 ++ mlir/test/LLVMIR/convert-to-llvmir.mlir | 29 +++- 6 files changed, 171 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 7318c0066922..754fb48bb26f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -67,21 +67,26 @@ public: /// Array type utilities. LLVMType getArrayElementType(); unsigned getArrayNumElements(); + bool isArrayTy(); /// Vector type utilities. LLVMType getVectorElementType(); + bool isVectorTy(); /// Function type utilities. LLVMType getFunctionParamType(unsigned argIdx); unsigned getFunctionNumParams(); LLVMType getFunctionResultType(); + bool isFunctionTy(); /// Pointer type utilities. LLVMType getPointerTo(unsigned addrSpace = 0); LLVMType getPointerElementTy(); + bool isPointerTy(); /// Struct type utilities. LLVMType getStructElementType(unsigned i); + bool isStructTy(); /// Utilities used to generate floating point types. static LLVMType getDoubleTy(LLVMDialect *dialect); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3e4815a5f32b..3697f5d50f55 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -137,6 +137,7 @@ public: ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values); ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values); ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values); + ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values); ArrayAttr getF32ArrayAttr(ArrayRef<float> values); ArrayAttr getF64ArrayAttr(ArrayRef<double> values); ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index e33da63f6b79..5e9c8787b673 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -346,58 +346,156 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { auto type = this->lowering.convertType(op->getResult(i)->getType()); results.push_back(rewriter.create<LLVM::ExtractValueOp>( op->getLoc(), type, newOp.getOperation()->getResult(0), - this->getIntegerArrayAttr(rewriter, i))); + rewriter.getIndexArrayAttr(i))); } rewriter.replaceOp(op, results); return this->matchSuccess(); } }; +// Express `linearIndex` in terms of coordinates of `basis`. +// Returns the empty vector when linearIndex is out of the range [0, P] where +// P is the product of all the basis coordinates. +// +// Prerequisites: +// Basis is an array of nonnegative integers (signed type inherited from +// vector shape type). +static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis, + unsigned linearIndex) { + SmallVector<int64_t, 4> res; + res.reserve(basis.size()); + for (unsigned basisElement : llvm::reverse(basis)) { + res.push_back(linearIndex % basisElement); + linearIndex = linearIndex / basisElement; + } + if (linearIndex > 0) + return {}; + std::reverse(res.begin(), res.end()); + return res; +} + +// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect +// Ops for binary ops with one result. This supports higher-dimensional vector +// types. +template <typename SourceOp, typename TargetOp> +struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { + using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; + using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>; + + // Convert the type of the result to an LLVM type, pass operands as is, + // preserve attributes. + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + static_assert( + std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value, + "expected binary op"); + static_assert( + std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, + "expected single result op"); + static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, + SourceOp>::value, + "expected single result op"); + + auto loc = op->getLoc(); + auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>(); + + if (!llvmArrayTy.isArrayTy()) { + auto newOp = rewriter.create<TargetOp>( + op->getLoc(), operands[0]->getType(), operands, op->getAttrs()); + rewriter.replaceOp(op, newOp.getResult()); + return this->matchSuccess(); + } + + // Unroll iterated array type until we hit a non-array type. + auto llvmTy = llvmArrayTy; + SmallVector<int64_t, 4> arraySizes; + while (llvmTy.isArrayTy()) { + arraySizes.push_back(llvmTy.getArrayNumElements()); + llvmTy = llvmTy.getArrayElementType(); + } + assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type"); + auto llvmVectorTy = llvmTy; + + // Iteratively extract a position coordinates with basis `arraySize` from a + // `linearIndex` that is incremented at each step. This terminates when + // `linearIndex` exceeds the range specified by `arraySize`. + // This has the effect of fully unrolling the dimensions of the n-D array + // type, getting to the underlying vector element. + Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); + unsigned ub = 1; + for (auto s : arraySizes) + ub *= s; + for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { + auto coords = getCoordinates(arraySizes, linearIndex); + // Linear index is out of bounds, we are done. + if (coords.empty()) + break; + + auto position = rewriter.getIndexArrayAttr(coords); + + // For this unrolled `position` corresponding to the `linearIndex`^th + // element, extract operand vectors + Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>( + loc, llvmVectorTy, operands[0], position); + Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>( + loc, llvmVectorTy, operands[1], position); + Value *newVal = rewriter.create<TargetOp>( + loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS}, + op->getAttrs()); + desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, + newVal, position); + } + rewriter.replaceOp(op, desc); + return this->matchSuccess(); + } +}; + // Specific lowerings. // FIXME: this should be tablegen'ed. -struct AddIOpLowering : public OneToOneLLVMOpLowering<AddIOp, LLVM::AddOp> { +struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> { using Super::Super; }; -struct SubIOpLowering : public OneToOneLLVMOpLowering<SubIOp, LLVM::SubOp> { +struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> { using Super::Super; }; -struct MulIOpLowering : public OneToOneLLVMOpLowering<MulIOp, LLVM::MulOp> { +struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> { using Super::Super; }; -struct DivISOpLowering : public OneToOneLLVMOpLowering<DivISOp, LLVM::SDivOp> { +struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> { using Super::Super; }; -struct DivIUOpLowering : public OneToOneLLVMOpLowering<DivIUOp, LLVM::UDivOp> { +struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> { using Super::Super; }; -struct RemISOpLowering : public OneToOneLLVMOpLowering<RemISOp, LLVM::SRemOp> { +struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> { using Super::Super; }; -struct RemIUOpLowering : public OneToOneLLVMOpLowering<RemIUOp, LLVM::URemOp> { +struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> { using Super::Super; }; -struct AndOpLowering : public OneToOneLLVMOpLowering<AndOp, LLVM::AndOp> { +struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> { using Super::Super; }; -struct OrOpLowering : public OneToOneLLVMOpLowering<OrOp, LLVM::OrOp> { +struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> { using Super::Super; }; -struct XOrOpLowering : public OneToOneLLVMOpLowering<XOrOp, LLVM::XOrOp> { +struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> { using Super::Super; }; -struct AddFOpLowering : public OneToOneLLVMOpLowering<AddFOp, LLVM::FAddOp> { +struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> { using Super::Super; }; -struct SubFOpLowering : public OneToOneLLVMOpLowering<SubFOp, LLVM::FSubOp> { +struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> { using Super::Super; }; -struct MulFOpLowering : public OneToOneLLVMOpLowering<MulFOp, LLVM::FMulOp> { +struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> { using Super::Super; }; -struct DivFOpLowering : public OneToOneLLVMOpLowering<DivFOp, LLVM::FDivOp> { +struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> { using Super::Super; }; -struct RemFOpLowering : public OneToOneLLVMOpLowering<RemFOp, LLVM::FRemOp> { +struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> { using Super::Super; }; struct SelectOpLowering @@ -516,14 +614,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, memRefDescriptor, allocated, - getIntegerArrayAttr(rewriter, 0)); + rewriter.getIndexArrayAttr(0)); // Store dynamically allocated sizes in the descriptor. Dynamic sizes are // passed in as operands. for (auto indexedSize : llvm::enumerate(operands)) { memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, memRefDescriptor, indexedSize.value(), - getIntegerArrayAttr(rewriter, 1 + indexedSize.index())); + rewriter.getIndexArrayAttr(1 + indexedSize.index())); } // Return the final value of the descriptor. @@ -553,7 +651,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { } auto type = transformed.memref()->getType().cast<LLVM::LLVMType>(); - auto hasStaticShape = type.getUnderlyingType()->isPointerTy(); + auto hasStaticShape = type.isPointerTy(); Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0); Value *bufferPtr = extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(), @@ -603,7 +701,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { // Otherwise target type is dynamic memref, so create a proper descriptor. newDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, newDescriptor, buffer, - getIntegerArrayAttr(rewriter, 0)); + rewriter.getIndexArrayAttr(0)); // Fill in the dynamic sizes of the new descriptor. If the size was // dynamic, copy it from the old descriptor. If the size was static, insert @@ -626,11 +724,11 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { ? rewriter.create<LLVM::ExtractValueOp>( op->getLoc(), getIndexType(), transformed.source(), // NB: dynamic memref - getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++)) + rewriter.getIndexArrayAttr(sourceDynamicDimIdx++)) : createIndexConstant(rewriter, op->getLoc(), sourceSize); newDescriptor = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), structType, newDescriptor, size, - getIntegerArrayAttr(rewriter, targetDynamicDimIdx++)); + rewriter.getIndexArrayAttr(targetDynamicDimIdx++)); } assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() && "source dynamic dimensions were not processed"); @@ -673,7 +771,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> { } rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( op, getIndexType(), transformed.memrefOrTensor(), - getIntegerArrayAttr(rewriter, position)); + rewriter.getIndexArrayAttr(position)); } else { rewriter.replaceOp( op, createIndexConstant(rewriter, op->getLoc(), shape[index])); @@ -739,7 +837,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { if (s == -1) { Value *size = rewriter.create<LLVM::ExtractValueOp>( loc, this->getIndexType(), memRefDescriptor, - this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++)); + rewriter.getIndexArrayAttr(dynamicSizeIdx++)); sizes.push_back(size); } else { sizes.push_back(this->createIndexConstant(rewriter, loc, s)); @@ -751,8 +849,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes); Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>( - loc, elementTypePtr, memRefDescriptor, - this->getIntegerArrayAttr(rewriter, 0)); + loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0)); return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, ArrayRef<Value *>{dataPtr, subscript}, ArrayRef<NamedAttribute>{}); @@ -970,7 +1067,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), packedType, packed, operands[i], - getIntegerArrayAttr(rewriter, i)); + rewriter.getIndexArrayAttr(i)); } rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(), diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 906cf3443474..7a2d4f45211a 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1281,11 +1281,13 @@ LLVMType LLVMType::getArrayElementType() { unsigned LLVMType::getArrayNumElements() { return getUnderlyingType()->getArrayNumElements(); } +bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); } /// Vector type utilities. LLVMType LLVMType::getVectorElementType() { return get(getContext(), getUnderlyingType()->getVectorElementType()); } +bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } /// Function type utilities. LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { @@ -1299,6 +1301,7 @@ LLVMType LLVMType::getFunctionResultType() { getContext(), llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType()); } +bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); } /// Pointer type utilities. LLVMType LLVMType::getPointerTo(unsigned addrSpace) { @@ -1310,11 +1313,13 @@ LLVMType LLVMType::getPointerTo(unsigned addrSpace) { LLVMType LLVMType::getPointerElementTy() { return get(getContext(), getUnderlyingType()->getPointerElementType()); } +bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } /// Struct type utilities. LLVMType LLVMType::getStructElementType(unsigned i) { return get(getContext(), getUnderlyingType()->getStructElementType(i)); } +bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); } /// Utilities used to generate floating point types. LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 2ade7b9f28a4..067ff7af6443 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -218,6 +218,15 @@ ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) { return getArrayAttr(attrs); } +ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) { + auto attrs = functional::map( + [this](int64_t v) -> Attribute { + return getIntegerAttr(IndexType::get(getContext()), v); + }, + values); + return getArrayAttr(attrs); +} + ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) { auto attrs = functional::map( [this](float v) -> Attribute { return getF32FloatAttr(v); }, values); diff --git a/mlir/test/LLVMIR/convert-to-llvmir.mlir b/mlir/test/LLVMIR/convert-to-llvmir.mlir index 65818b0b02bc..e4c4c61ed732 100644 --- a/mlir/test/LLVMIR/convert-to-llvmir.mlir +++ b/mlir/test/LLVMIR/convert-to-llvmir.mlir @@ -510,6 +510,31 @@ func @fcmp(f32, f32) -> () { %12 = cmpf "ule", %arg0, %arg1 : f32 %13 = cmpf "une", %arg0, %arg1 : f32 %14 = cmpf "uno", %arg0, %arg1 : f32 - - return + + return +} + +// CHECK-LABEL: @vec_bin +func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> { + %0 = addf %arg0, %arg0 : vector<2x2x2xf32> + return %0 : vector<2x2x2xf32> + +// CHECK-NEXT: llvm.undef : !llvm<"[2 x [2 x <2 x float>]]"> + +// This block appears 2x2 times +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK-NEXT: llvm.fadd %{{.*}} : !llvm<"<2 x float>"> +// CHECK-NEXT: llvm.insertvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> + +// We check the proper indexing of extract/insert in the remaining 3 positions. +// CHECK: llvm.extractvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.insertvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.extractvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.insertvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.extractvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> +// CHECK: llvm.insertvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]"> + +// And we're done +// CHECK-NEXT: return }