forked from OSchip/llvm-project
[mlir][Vector] Mostly-NFC - Restructure options for lowering to LLVM Matrix Intrinsics
Summary: This revision restructures the calling of vector transforms to make it more flexible to ask for lowering through LLVM matrix intrinsics. This also makes sure we bail out in degenerate cases (i.e. 1) in which LLVM complains about not being able to scalarize. Differential Revision: https://reviews.llvm.org/D76266
This commit is contained in:
parent
7ca473a27b
commit
2fae7878d5
|
@ -26,7 +26,7 @@ void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||
OwningRewritePatternList &patterns);
|
||||
|
||||
/// Create a pass to convert vector operations to the LLVMIR dialect.
|
||||
OpPassBase<ModuleOp> *createLowerVectorToLLVMPass();
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> createConvertVectorToLLVMPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -24,6 +24,13 @@ class MLIRContext;
|
|||
class OwningRewritePatternList;
|
||||
namespace vector {
|
||||
|
||||
/// Structure to control the behavior of vector transform patterns.
|
||||
struct VectorTransformsOptions {
|
||||
/// Let vector.contract lower to vector.matrix_multiply and LLVM matrix
|
||||
/// intrinsics.
|
||||
bool lowerToLLVMMatrixIntrinsics = false;
|
||||
};
|
||||
|
||||
/// Collect a set of vector-to-vector canonicalization patterns.
|
||||
void populateVectorToVectorCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||
|
@ -50,8 +57,9 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
|
|||
/// OuterproductOpLowering
|
||||
/// These transformation express higher level vector ops in terms of more
|
||||
/// elementary extraction, insertion, reduction, product, and broadcast ops.
|
||||
void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context);
|
||||
void populateVectorContractLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context,
|
||||
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
|
||||
|
||||
/// Returns the integer type required for subscripts in the vector dialect.
|
||||
IntegerType getVectorSubscriptType(Builder &builder);
|
||||
|
|
|
@ -562,6 +562,7 @@ void ConvertLinalgToLLVMPass::runOnModule() {
|
|||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
|
||||
/*emitCWrappers=*/true);
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns);
|
||||
populateLinalgToStandardConversionPatterns(patterns, &getContext());
|
||||
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
|
||||
|
|
|
@ -1150,8 +1150,8 @@ void LowerVectorToLLVMPass::runOnModule() {
|
|||
}
|
||||
}
|
||||
|
||||
OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
|
||||
return new LowerVectorToLLVMPass();
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
|
||||
return std::make_unique<LowerVectorToLLVMPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LowerVectorToLLVMPass>
|
||||
|
|
|
@ -42,13 +42,6 @@ using namespace mlir;
|
|||
using llvm::dbgs;
|
||||
using mlir::functional::zipMap;
|
||||
|
||||
static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
|
||||
|
||||
static llvm::cl::opt<bool> lowerToLLVMMatrixIntrinsics(
|
||||
"vector-lower-matrix-intrinsics",
|
||||
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
||||
llvm::cl::init(false), llvm::cl::cat(clOptionsCategory));
|
||||
|
||||
/// Given a shape with sizes greater than 0 along all dimensions,
|
||||
/// returns the distance, in number of elements, between a slice in a dimension
|
||||
/// and the next slice in the same dimension.
|
||||
|
@ -936,6 +929,11 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
|
|||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
||||
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
|
||||
MLIRContext *context)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformsOptions(vectorTransformsOptions) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(ajcbik): implement masks
|
||||
|
@ -946,33 +944,41 @@ public:
|
|||
// a new pattern.
|
||||
// TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix
|
||||
// intrinsics, use that.
|
||||
if (lowerToLLVMMatrixIntrinsics &&
|
||||
if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics &&
|
||||
isColumnMajorMatmul(op.indexing_maps())) {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
Type flattenedLHSType =
|
||||
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
|
||||
Type flattenedRHSType =
|
||||
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
|
||||
auto lhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedLHSType, op.lhs());
|
||||
auto rhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedRHSType, op.rhs());
|
||||
|
||||
unsigned lhsRows = op.getLhsType().getShape()[0];
|
||||
unsigned lhsColumns = op.getLhsType().getShape()[1];
|
||||
unsigned rhsColumns = op.getRhsType().getShape()[1];
|
||||
Value mul = rewriter.create<vector::MatmulOp>(
|
||||
op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
|
||||
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
|
||||
op.acc().getType(), mul);
|
||||
Type elementType = op.getLhsType().getElementType();
|
||||
assert(elementType.isIntOrFloat());
|
||||
if (elementType.isa<IntegerType>())
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
|
||||
return matchSuccess();
|
||||
|
||||
// In cases where matrices are degenerate, scalarization issues occur in
|
||||
// the backend. Avoid all LLVM scalarization issues for now.
|
||||
// For more details, see: https://bugs.llvm.org/show_bug.cgi?id=45227 and
|
||||
// https://bugs.llvm.org/show_bug.cgi?id=45229
|
||||
// TODO(ntv, fhahn): Relax once above bugs are fixed.
|
||||
if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) {
|
||||
Type flattenedLHSType =
|
||||
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
|
||||
Type flattenedRHSType =
|
||||
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
|
||||
auto lhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedLHSType, op.lhs());
|
||||
auto rhs = rewriter.create<vector::ShapeCastOp>(
|
||||
op.getLoc(), flattenedRHSType, op.rhs());
|
||||
|
||||
Value mul = rewriter.create<vector::MatmulOp>(
|
||||
op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
|
||||
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
|
||||
op.acc().getType(), mul);
|
||||
Type elementType = op.getLhsType().getElementType();
|
||||
assert(elementType.isIntOrFloat());
|
||||
if (elementType.isa<IntegerType>())
|
||||
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
|
||||
return matchSuccess();
|
||||
}
|
||||
}
|
||||
|
||||
// Find first batch dimension in LHS/RHS, and lower when found.
|
||||
|
@ -1255,6 +1261,8 @@ private:
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
vector::VectorTransformsOptions vectorTransformsOptions;
|
||||
};
|
||||
|
||||
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
|
||||
|
@ -1342,8 +1350,10 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
|
|||
}
|
||||
|
||||
void mlir::vector::populateVectorContractLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<ContractionOpLowering, ShapeCastOp2DDownCastRewritePattern,
|
||||
OwningRewritePatternList &patterns, MLIRContext *context,
|
||||
VectorTransformsOptions parameters) {
|
||||
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern, OuterProductOpLowering>(
|
||||
context);
|
||||
patterns.insert<ContractionOpLowering>(parameters, context);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-contraction-conversion -vector-lower-matrix-intrinsics | FileCheck %s --check-prefix=MATRIX
|
||||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
|
||||
|
||||
#dotp_accesses = [
|
||||
affine_map<(i) -> (i)>,
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
|
||||
#include "TestVectorTransformPatterns.h.inc"
|
||||
|
@ -44,9 +43,20 @@ struct TestVectorSlicesConversion
|
|||
|
||||
struct TestVectorContractionConversion
|
||||
: public FunctionPass<TestVectorContractionConversion> {
|
||||
TestVectorContractionConversion() = default;
|
||||
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
|
||||
}
|
||||
|
||||
Option<bool> lowerToLLVMMatrixIntrinsics{
|
||||
*this, "vector-lower-matrix-intrinsics",
|
||||
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
populateVectorContractLoweringPatterns(patterns, &getContext());
|
||||
VectorTransformsOptions options{
|
||||
/*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics};
|
||||
populateVectorContractLoweringPatterns(patterns, &getContext(), options);
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue