Add lowering of vector dialect to LLVM dialect.

This CL is step 3/n towards building a simple, programmable and portable vector abstraction in MLIR that can go all the way down to generating assembly vector code via LLVM's opt and llc tools.

This CL adds support for converting MLIR n-D vector types to (n-1)-D arrays of 1-D LLVM vectors and a conversion VectorToLLVM that lowers the `vector.extractelement` and `vector.outerproduct` instructions to the proper mix of `llvm.vectorshuffle`, `llvm.extractelement` and `llvm.mulf`.

This has been independently verified to produce proper avx2 code.

Input:
```
func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
  %2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
  %3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
  return %3 : vector<8xf32>
}
```

Command:
```
mlir-opt vector-to-llvm.mlir -vector-lower-to-llvm-dialect --disable-pass-threading | mlir-opt -lower-to-cfg -lower-to-llvm | mlir-translate --mlir-to-llvmir | opt -O3 | llc -O3 -march=x86-64 -mcpu=haswell -mattr=fma,avx2
```

Output:
```
vec_1d:                                 # @vec_1d
# %bb.0:
        vbroadcastss    %xmm0, %ymm0
        vmulps  %ymm1, %ymm0, %ymm0
        retq
```
PiperOrigin-RevId: 262895929
This commit is contained in:
Nicolas Vasilache 2019-08-12 04:08:26 -07:00 committed by A. Unique TensorFlower
parent 5290e8c36d
commit 252ada4932
8 changed files with 302 additions and 14 deletions

View File

@ -39,11 +39,14 @@ object. For example, on x86-64 CPUs it converts to `!llvm.type<"i64">`.
### Vector Types
LLVM IR only supports *one-dimensional* vectors, unlike MLIR where vectors can
be multi-dimensional. MLIR vectors are converted to LLVM IR vectors of the same
size with element type converted using these conversion rules. Vector types
cannot be nested in either IR.
be multi-dimensional. Vector types cannot be nested in either IR. In the
one-dimensional case, MLIR vectors are converted to LLVM IR vectors of the same
size with element type converted using these conversion rules. In the
n-dimensional case, MLIR vectors are converted to (n-1)-dimensional array types
of one-dimensional vectors.
For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">`.
For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
`vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
### Memref Types

View File

@ -0,0 +1,26 @@
//===- VectorToLLVM.h - Pass converting vector to LLVM dialect --*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
#define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
namespace mlir {
class ModulePassBase;
ModulePassBase *createLowerVectorToLLVMPass();
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_

View File

@ -5,3 +5,4 @@ add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToSPIRV)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
add_subdirectory(VectorToLLVM)

View File

@ -145,18 +145,20 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
return LLVM::LLVMType::getStructTy(llvmDialect, types);
}
// Convert a 1D vector type to an LLVM vector type.
// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
// n > 1.
// For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
// `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
Type LLVMTypeConverter::convertVectorType(VectorType type) {
if (type.getRank() != 1) {
auto *mlirContext = llvmDialect->getContext();
emitError(UnknownLoc::get(mlirContext), "only 1D vectors are supported");
auto elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
}
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
return elementType
? LLVM::LLVMType::getVectorTy(elementType, type.getShape().front())
: Type();
auto vectorType =
LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
return vectorType;
}
// Dispatch based on the actual type. Return null type on error.

View File

@ -0,0 +1,15 @@
add_llvm_library(MLIRVectorToLLVM
VectorToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToLLVM
)
set(LIBS
MLIRLLVMIR
MLIRTransforms
LLVMCore
LLVMSupport
)
add_dependencies(MLIRVectorToLLVM ${LIBS})
target_link_libraries(MLIRVectorToLLVM ${LIBS})

View File

@ -0,0 +1,207 @@
//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/VectorOps/VectorOps.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
LLVMTypeConverter &lowering) {
return lowering.convertType(containerType.getElementType())
.template cast<LLVM::LLVMType>()
.getPointerTo();
}
// Create an array attribute containing integer attributes with values provided
// in `position`.
static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(position.size());
for (auto p : position)
attrs.push_back(builder.getI64IntegerAttr(p));
return builder.getArrayAttr(attrs);
}
class ExtractElementOpConversion : public LLVMOpLowering {
public:
explicit ExtractElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractOp.vector()->getType().cast<VectorType>();
auto resultType = extractOp.getResult()->getType();
auto llvmResultType = lowering.convertType(resultType);
auto positionArrayAttr = extractOp.position();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value *extracted =
rewriter
.create<LLVM::ExtractValueOp>(loc, llvmResultType,
adaptor.vector(), positionArrayAttr)
.getResult();
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
// Potential extraction of 1-D vector from struct.
auto *context = op->getContext();
Value *extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
auto indexType = rewriter.getIndexType();
if (positionAttrs.size() > 1) {
auto nDVectorType = vectorType;
auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(),
nDVectorType.getElementType());
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter
.create<LLVM::ExtractValueOp>(
loc, lowering.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs)
.getResult();
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
auto constant = rewriter
.create<LLVM::ConstantOp>(
loc, lowering.convertType(indexType), position)
.getResult();
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant)
.getResult();
rewriter.replaceOp(op, extracted);
return matchSuccess();
}
};
class OuterProductOpConversion : public LLVMOpLowering {
public:
explicit OuterProductOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
auto vt1 = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
auto vt2 = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
auto rankV1 = vt1.getUnderlyingType()->getVectorNumElements();
auto rankV2 = vt2.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
Value *desc =
rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType).getResult();
for (unsigned i = 0, e = rankV1; i < e; ++i) {
// Emit the following pattern:
// vec(a[i]) * b -> llvmStructOfVectType[i]
Value *a = adaptor.lhs(), *b = adaptor.rhs();
// shufflevector explicitly requires i32 /
auto attr = rewriter.getI32IntegerAttr(i);
SmallVector<Attribute, 4> broadcastAttr(rankV2, attr);
auto broadcastArrayAttr = ArrayAttr::get(broadcastAttr, ctx);
auto *broadcasted =
rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, broadcastArrayAttr)
.getResult();
auto *multiplied =
rewriter.create<LLVM::FMulOp>(loc, broadcasted, b).getResult();
desc = rewriter
.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, desc,
multiplied,
positionAttr(rewriter, i))
.getResult();
}
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
/// Populate the given list with patterns that convert from Vector to LLVM.
static void
populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>(
ctx, converter);
}
namespace {
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
void runOnModule();
};
} // namespace
void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateVectorToLLVMConversionPatterns(converter, patterns, &getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(
applyPartialConversion(getModule(), target, patterns, &converter))) {
signalPassFailure();
}
}
ModulePassBase *mlir::createLowerVectorToLLVMPass() {
return new LowerVectorToLLVMPass();
}
static PassRegistration<LowerVectorToLLVMPass>
pass("vector-lower-to-llvm-dialect",
"Lower the operations from the vector dialect into the LLVM dialect");

View File

@ -0,0 +1,33 @@
// RUN: mlir-opt %s -vector-lower-to-llvm-dialect | FileCheck %s
func @vec_1d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
%3 = vector.extractelement %2[0 : i32]: vector<4x8xf32>
return %3 : vector<8xf32>
}
// CHECK-LABEL: vec_1d
// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
// CHECK-5: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
// CHECK: llvm.extractvalue {{.*}}[0 : i32] : !llvm<"[4 x <8 x float>]">
// CHECK: llvm.return {{.*}} : !llvm<"<8 x float>">
func @vec_2d(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<4x8xf32> {
%2 = vector.outerproduct %arg0, %arg1 : vector<4xf32>, vector<8xf32>
return %2 : vector<4x8xf32>
}
// CHECK-LABEL: vec_2d
// CHECK: llvm.undef : !llvm<"[4 x <8 x float>]">
// CHECK-4: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<8 x float>">
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <8 x float>]">
// CHECK: llvm.return {{.*}} : !llvm<"[4 x <8 x float>]">
func @vec_3d(%arg0: vector<4x8x16xf32>) -> vector<8x16xf32> {
%0 = vector.extractelement %arg0[0 : i32]: vector<4x8x16xf32>
return %0 : vector<8x16xf32>
}
// CHECK-LABEL: vec_3d
// CHECK: llvm.extractvalue %{{.*}}[0 : i32] : !llvm<"[4 x [8 x <16 x float>]]">
// CHECK: llvm.return %{{.*}} : !llvm<"[8 x <16 x float>]">

View File

@ -43,6 +43,7 @@ set(LIBS
MLIRTestTransforms
MLIRSupport
MLIRVectorOps
MLIRVectorToLLVM
)
if(MLIR_CUDA_CONVERSIONS_ENABLED)
list(APPEND LIBS