From f55ac5c07643efa28a5bb621b08c0e5dc2f97f84 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <ntv@google.com>
Date: Tue, 20 Aug 2019 01:59:58 -0700
Subject: [PATCH] Add support for LLVM lowering of binary ops on n-D vector
 types

This CL allows binary operations on n-D vector types to be lowered to LLVMIR by performing an (n-1)-D extractvalue, 1-D vector operation and an (n-1)-D insertvalue.

PiperOrigin-RevId: 264339118
---
 .../include/mlir/Dialect/LLVMIR/LLVMDialect.h |   5 +
 mlir/include/mlir/IR/Builders.h               |   1 +
 .../StandardToLLVM/ConvertStandardToLLVM.cpp  | 151 ++++++++++++++----
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |   5 +
 mlir/lib/IR/Builders.cpp                      |   9 ++
 mlir/test/LLVMIR/convert-to-llvmir.mlir       |  29 +++-
 6 files changed, 171 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 7318c0066922..754fb48bb26f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -67,21 +67,26 @@ public:
   /// Array type utilities.
   LLVMType getArrayElementType();
   unsigned getArrayNumElements();
+  bool isArrayTy();
 
   /// Vector type utilities.
   LLVMType getVectorElementType();
+  bool isVectorTy();
 
   /// Function type utilities.
   LLVMType getFunctionParamType(unsigned argIdx);
   unsigned getFunctionNumParams();
   LLVMType getFunctionResultType();
+  bool isFunctionTy();
 
   /// Pointer type utilities.
   LLVMType getPointerTo(unsigned addrSpace = 0);
   LLVMType getPointerElementTy();
+  bool isPointerTy();
 
   /// Struct type utilities.
   LLVMType getStructElementType(unsigned i);
+  bool isStructTy();
 
   /// Utilities used to generate floating point types.
   static LLVMType getDoubleTy(LLVMDialect *dialect);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3e4815a5f32b..3697f5d50f55 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -137,6 +137,7 @@ public:
   ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
   ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
   ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
+  ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values);
   ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
   ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
   ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index e33da63f6b79..5e9c8787b673 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -346,58 +346,156 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
       auto type = this->lowering.convertType(op->getResult(i)->getType());
       results.push_back(rewriter.create<LLVM::ExtractValueOp>(
           op->getLoc(), type, newOp.getOperation()->getResult(0),
-          this->getIntegerArrayAttr(rewriter, i)));
+          rewriter.getIndexArrayAttr(i)));
     }
     rewriter.replaceOp(op, results);
     return this->matchSuccess();
   }
 };
 
+// Express `linearIndex` in terms of coordinates of `basis`.
+// Returns the empty vector when linearIndex is out of the range [0, P] where
+// P is the product of all the basis coordinates.
+//
+// Prerequisites:
+//   Basis is an array of nonnegative integers (signed type inherited from
+//   vector shape type).
+static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
+                                              unsigned linearIndex) {
+  SmallVector<int64_t, 4> res;
+  res.reserve(basis.size());
+  for (unsigned basisElement : llvm::reverse(basis)) {
+    res.push_back(linearIndex % basisElement);
+    linearIndex = linearIndex / basisElement;
+  }
+  if (linearIndex > 0)
+    return {};
+  std::reverse(res.begin(), res.end());
+  return res;
+}
+
+// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
+// Ops for binary ops with one result. This supports higher-dimensional vector
+// types.
+template <typename SourceOp, typename TargetOp>
+struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+  using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
+  using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
+
+  // Convert the type of the result to an LLVM type, pass operands as is,
+  // preserve attributes.
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    static_assert(
+        std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
+        "expected binary op");
+    static_assert(
+        std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
+        "expected single result op");
+    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+                                  SourceOp>::value,
+                  "expected single result op");
+
+    auto loc = op->getLoc();
+    auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>();
+
+    if (!llvmArrayTy.isArrayTy()) {
+      auto newOp = rewriter.create<TargetOp>(
+          op->getLoc(), operands[0]->getType(), operands, op->getAttrs());
+      rewriter.replaceOp(op, newOp.getResult());
+      return this->matchSuccess();
+    }
+
+    // Unroll iterated array type until we hit a non-array type.
+    auto llvmTy = llvmArrayTy;
+    SmallVector<int64_t, 4> arraySizes;
+    while (llvmTy.isArrayTy()) {
+      arraySizes.push_back(llvmTy.getArrayNumElements());
+      llvmTy = llvmTy.getArrayElementType();
+    }
+    assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
+    auto llvmVectorTy = llvmTy;
+
+    // Iteratively extract a position coordinates with basis `arraySize` from a
+    // `linearIndex` that is incremented at each step. This terminates when
+    // `linearIndex` exceeds the range specified by `arraySize`.
+    // This has the effect of fully unrolling the dimensions of the n-D array
+    // type, getting to the underlying vector element.
+    Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+    unsigned ub = 1;
+    for (auto s : arraySizes)
+      ub *= s;
+    for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
+      auto coords = getCoordinates(arraySizes, linearIndex);
+      // Linear index is out of bounds, we are done.
+      if (coords.empty())
+        break;
+
+      auto position = rewriter.getIndexArrayAttr(coords);
+
+      // For this unrolled `position` corresponding to the `linearIndex`^th
+      // element, extract operand vectors
+      Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
+          loc, llvmVectorTy, operands[0], position);
+      Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
+          loc, llvmVectorTy, operands[1], position);
+      Value *newVal = rewriter.create<TargetOp>(
+          loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
+          op->getAttrs());
+      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
+                                                  newVal, position);
+    }
+    rewriter.replaceOp(op, desc);
+    return this->matchSuccess();
+  }
+};
+
 // Specific lowerings.
 // FIXME: this should be tablegen'ed.
-struct AddIOpLowering : public OneToOneLLVMOpLowering<AddIOp, LLVM::AddOp> {
+struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
   using Super::Super;
 };
-struct SubIOpLowering : public OneToOneLLVMOpLowering<SubIOp, LLVM::SubOp> {
+struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
   using Super::Super;
 };
-struct MulIOpLowering : public OneToOneLLVMOpLowering<MulIOp, LLVM::MulOp> {
+struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
   using Super::Super;
 };
-struct DivISOpLowering : public OneToOneLLVMOpLowering<DivISOp, LLVM::SDivOp> {
+struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> {
   using Super::Super;
 };
-struct DivIUOpLowering : public OneToOneLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
+struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
   using Super::Super;
 };
-struct RemISOpLowering : public OneToOneLLVMOpLowering<RemISOp, LLVM::SRemOp> {
+struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> {
   using Super::Super;
 };
-struct RemIUOpLowering : public OneToOneLLVMOpLowering<RemIUOp, LLVM::URemOp> {
+struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> {
   using Super::Super;
 };
-struct AndOpLowering : public OneToOneLLVMOpLowering<AndOp, LLVM::AndOp> {
+struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
   using Super::Super;
 };
-struct OrOpLowering : public OneToOneLLVMOpLowering<OrOp, LLVM::OrOp> {
+struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
   using Super::Super;
 };
-struct XOrOpLowering : public OneToOneLLVMOpLowering<XOrOp, LLVM::XOrOp> {
+struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
   using Super::Super;
 };
-struct AddFOpLowering : public OneToOneLLVMOpLowering<AddFOp, LLVM::FAddOp> {
+struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
   using Super::Super;
 };
-struct SubFOpLowering : public OneToOneLLVMOpLowering<SubFOp, LLVM::FSubOp> {
+struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
   using Super::Super;
 };
-struct MulFOpLowering : public OneToOneLLVMOpLowering<MulFOp, LLVM::FMulOp> {
+struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
   using Super::Super;
 };
-struct DivFOpLowering : public OneToOneLLVMOpLowering<DivFOp, LLVM::FDivOp> {
+struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
   using Super::Super;
 };
-struct RemFOpLowering : public OneToOneLLVMOpLowering<RemFOp, LLVM::FRemOp> {
+struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
   using Super::Super;
 };
 struct SelectOpLowering
@@ -516,14 +614,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
 
     memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
         op->getLoc(), structType, memRefDescriptor, allocated,
-        getIntegerArrayAttr(rewriter, 0));
+        rewriter.getIndexArrayAttr(0));
 
     // Store dynamically allocated sizes in the descriptor.  Dynamic sizes are
     // passed in as operands.
     for (auto indexedSize : llvm::enumerate(operands)) {
       memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
           op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
-          getIntegerArrayAttr(rewriter, 1 + indexedSize.index()));
+          rewriter.getIndexArrayAttr(1 + indexedSize.index()));
     }
 
     // Return the final value of the descriptor.
@@ -553,7 +651,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
     }
 
     auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
-    auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
+    auto hasStaticShape = type.isPointerTy();
     Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
     Value *bufferPtr =
         extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
@@ -603,7 +701,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
     // Otherwise target type is dynamic memref, so create a proper descriptor.
     newDescriptor = rewriter.create<LLVM::InsertValueOp>(
         op->getLoc(), structType, newDescriptor, buffer,
-        getIntegerArrayAttr(rewriter, 0));
+        rewriter.getIndexArrayAttr(0));
 
     // Fill in the dynamic sizes of the new descriptor.  If the size was
     // dynamic, copy it from the old descriptor.  If the size was static, insert
@@ -626,11 +724,11 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
               ? rewriter.create<LLVM::ExtractValueOp>(
                     op->getLoc(), getIndexType(),
                     transformed.source(), // NB: dynamic memref
-                    getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
+                    rewriter.getIndexArrayAttr(sourceDynamicDimIdx++))
               : createIndexConstant(rewriter, op->getLoc(), sourceSize);
       newDescriptor = rewriter.create<LLVM::InsertValueOp>(
           op->getLoc(), structType, newDescriptor, size,
-          getIntegerArrayAttr(rewriter, targetDynamicDimIdx++));
+          rewriter.getIndexArrayAttr(targetDynamicDimIdx++));
     }
     assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
            "source dynamic dimensions were not processed");
@@ -673,7 +771,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
       }
       rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
           op, getIndexType(), transformed.memrefOrTensor(),
-          getIntegerArrayAttr(rewriter, position));
+          rewriter.getIndexArrayAttr(position));
     } else {
       rewriter.replaceOp(
           op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
@@ -739,7 +837,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
       if (s == -1) {
         Value *size = rewriter.create<LLVM::ExtractValueOp>(
             loc, this->getIndexType(), memRefDescriptor,
-            this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++));
+            rewriter.getIndexArrayAttr(dynamicSizeIdx++));
         sizes.push_back(size);
       } else {
         sizes.push_back(this->createIndexConstant(rewriter, loc, s));
@@ -751,8 +849,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
     Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
 
     Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
-        loc, elementTypePtr, memRefDescriptor,
-        this->getIntegerArrayAttr(rewriter, 0));
+        loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0));
     return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
                                         ArrayRef<Value *>{dataPtr, subscript},
                                         ArrayRef<NamedAttribute>{});
@@ -970,7 +1067,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
     for (unsigned i = 0; i < numArguments; ++i) {
       packed = rewriter.create<LLVM::InsertValueOp>(
           op->getLoc(), packedType, packed, operands[i],
-          getIntegerArrayAttr(rewriter, i));
+          rewriter.getIndexArrayAttr(i));
     }
     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
         op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 906cf3443474..7a2d4f45211a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1281,11 +1281,13 @@ LLVMType LLVMType::getArrayElementType() {
 unsigned LLVMType::getArrayNumElements() {
   return getUnderlyingType()->getArrayNumElements();
 }
+bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
 
 /// Vector type utilities.
 LLVMType LLVMType::getVectorElementType() {
   return get(getContext(), getUnderlyingType()->getVectorElementType());
 }
+bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
 
 /// Function type utilities.
 LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
@@ -1299,6 +1301,7 @@ LLVMType LLVMType::getFunctionResultType() {
       getContext(),
       llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
 }
+bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); }
 
 /// Pointer type utilities.
 LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
@@ -1310,11 +1313,13 @@ LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
 LLVMType LLVMType::getPointerElementTy() {
   return get(getContext(), getUnderlyingType()->getPointerElementType());
 }
+bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
 
 /// Struct type utilities.
 LLVMType LLVMType::getStructElementType(unsigned i) {
   return get(getContext(), getUnderlyingType()->getStructElementType(i));
 }
+bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); }
 
 /// Utilities used to generate floating point types.
 LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2ade7b9f28a4..067ff7af6443 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -218,6 +218,15 @@ ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
   return getArrayAttr(attrs);
 }
 
+ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
+  auto attrs = functional::map(
+      [this](int64_t v) -> Attribute {
+        return getIntegerAttr(IndexType::get(getContext()), v);
+      },
+      values);
+  return getArrayAttr(attrs);
+}
+
 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
   auto attrs = functional::map(
       [this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
diff --git a/mlir/test/LLVMIR/convert-to-llvmir.mlir b/mlir/test/LLVMIR/convert-to-llvmir.mlir
index 65818b0b02bc..e4c4c61ed732 100644
--- a/mlir/test/LLVMIR/convert-to-llvmir.mlir
+++ b/mlir/test/LLVMIR/convert-to-llvmir.mlir
@@ -510,6 +510,31 @@ func @fcmp(f32, f32) -> () {
   %12 = cmpf "ule", %arg0, %arg1 : f32
   %13 = cmpf "une", %arg0, %arg1 : f32
   %14 = cmpf "uno", %arg0, %arg1 : f32
-  
-  return 
+
+  return
+}
+
+// CHECK-LABEL: @vec_bin
+func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = addf %arg0, %arg0 : vector<2x2x2xf32>
+  return %0 : vector<2x2x2xf32>
+
+//  CHECK-NEXT: llvm.undef : !llvm<"[2 x [2 x <2 x float>]]">
+
+// This block appears 2x2 times
+//  CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+//  CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+//  CHECK-NEXT: llvm.fadd %{{.*}} : !llvm<"<2 x float>">
+//  CHECK-NEXT: llvm.insertvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+
+// We check the proper indexing of extract/insert in the remaining 3 positions.
+//       CHECK: llvm.extractvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+//       CHECK: llvm.insertvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+//       CHECK: llvm.extractvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+//       CHECK: llvm.insertvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+//       CHECK: llvm.extractvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+//       CHECK: llvm.insertvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
+
+// And we're done
+//   CHECK-NEXT: return
 }