forked from OSchip/llvm-project
[mlir][X86Vector] Add specialized vector.transpose lowering patterns for AVX2
This revision adds an implementation of 2-D vector.transpose for 4x8 and 8x8 for AVX2 and surfaces it to the Linalg level of control. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D113347
This commit is contained in:
parent
703ded8dda
commit
34ff857350
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
|
@ -993,6 +994,12 @@ struct LinalgVectorLoweringOptions {
|
|||
transposeLowering = val;
|
||||
return *this;
|
||||
}
|
||||
/// Enable AVX2-specific lowerings.
|
||||
bool avx2Lowering = false;
|
||||
LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) {
|
||||
avx2Lowering = val;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Configure the post staged-patterns late vector.transfer to scf
|
||||
/// conversion.
|
||||
|
@ -1009,6 +1016,13 @@ struct LinalgVectorLoweringOptions {
|
|||
vectorTransformOptions = options;
|
||||
return *this;
|
||||
}
|
||||
/// Configure specialized vector lowerings.
|
||||
x86vector::avx2::LoweringOptions avx2LoweringOptions;
|
||||
LinalgVectorLoweringOptions &
|
||||
setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) {
|
||||
avx2LoweringOptions = options;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -9,13 +9,126 @@
|
|||
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
|
||||
#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ImplicitLocOpBuilder;
|
||||
class LLVMConversionTarget;
|
||||
class LLVMTypeConverter;
|
||||
class RewritePatternSet;
|
||||
using OwningRewritePatternList = RewritePatternSet;
|
||||
|
||||
namespace x86vector {
|
||||
|
||||
/// Helper class to factor out the creation and extraction of masks from nibs.
|
||||
struct MaskHelper {
|
||||
/// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
|
||||
/// Meant to be used with instructions such as mm256ShufflePs.
|
||||
template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
|
||||
static char shuffle() {
|
||||
static_assert(b01 <= 0x03, "overflow");
|
||||
static_assert(b23 <= 0x03, "overflow");
|
||||
static_assert(b45 <= 0x03, "overflow");
|
||||
static_assert(b67 <= 0x03, "overflow");
|
||||
return (b67 << 6) + (b45 << 4) + (b23 << 2) + b01;
|
||||
}
|
||||
/// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
|
||||
static void extractShuffle(char mask, char &b01, char &b23, char &b45,
|
||||
char &b67) {
|
||||
b67 = (mask & (0x03 << 6)) >> 6;
|
||||
b45 = (mask & (0x03 << 4)) >> 4;
|
||||
b23 = (mask & (0x03 << 2)) >> 2;
|
||||
b01 = mask & 0x03;
|
||||
}
|
||||
/// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
|
||||
/// Meant to be used with instructions such as mm256Permute2f128Ps.
|
||||
template <unsigned b47, unsigned b03>
|
||||
static char permute() {
|
||||
static_assert(b03 <= 0x0f, "overflow");
|
||||
static_assert(b47 <= 0x0f, "overflow");
|
||||
return (b47 << 4) + b03;
|
||||
}
|
||||
/// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
|
||||
static void extractPermute(char mask, char &b03, char &b47) {
|
||||
b47 = (mask & (0x0f << 4)) >> 4;
|
||||
b03 = mask & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Helpers extracted from:
|
||||
/// - clang/lib/Headers/avxintrin.h
|
||||
/// - clang/test/CodeGen/X86/avx-builtins.c
|
||||
/// - clang/test/CodeGen/X86/avx2-builtins.c
|
||||
/// - clang/test/CodeGen/X86/avx-shuffle-builtins.c
|
||||
/// as well as the Intel Intrinsics Guide
|
||||
/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
|
||||
/// make it easier to just implement known good lowerings.
|
||||
/// All intrinsics correspond 1-1 to the Intel definition.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace avx2 {
|
||||
|
||||
/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
|
||||
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
|
||||
|
||||
/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
|
||||
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
|
||||
|
||||
/// a a b b a a b b
|
||||
/// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
|
||||
/// 0:127 | 128:255
|
||||
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
|
||||
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, char mask);
|
||||
|
||||
// imm[0:1] out of imm[0:3] is:
|
||||
// 0 1 2 3
|
||||
// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
|
||||
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
|
||||
// 0 1 2 3
|
||||
// imm[0:1] out of imm[4:7].
|
||||
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
|
||||
char mask);
|
||||
|
||||
/// 4x8xf32-specific AVX2 transpose lowering.
|
||||
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
|
||||
|
||||
/// 8x8xf32-specific AVX2 transpose lowering.
|
||||
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
|
||||
|
||||
/// Structure to control the behavior of specialized AVX2 transpose lowering.
|
||||
struct TransposeLoweringOptions {
|
||||
bool lower4x8xf32_ = false;
|
||||
TransposeLoweringOptions &lower4x8xf32(bool lower = true) {
|
||||
lower4x8xf32_ = lower;
|
||||
return *this;
|
||||
}
|
||||
bool lower8x8xf32_ = false;
|
||||
TransposeLoweringOptions &lower8x8xf32(bool lower = true) {
|
||||
lower8x8xf32_ = lower;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// Options for controlling specialized AVX2 lowerings.
|
||||
struct LoweringOptions {
|
||||
/// Configure specialized vector lowerings.
|
||||
TransposeLoweringOptions transposeOptions;
|
||||
LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) {
|
||||
transposeOptions = options;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// Insert specialized transpose lowering patterns.
|
||||
void populateSpecializedTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
|
||||
int benefit = 10);
|
||||
|
||||
} // namespace avx2
|
||||
} // namespace x86vector
|
||||
|
||||
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
|
||||
/// intrinsics.
|
||||
void populateX86VectorLegalizeForLLVMExportPatterns(
|
||||
|
|
|
@ -51,5 +51,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
MLIRVector
|
||||
MLIRX86VectorTransforms
|
||||
MLIRVectorToSCF
|
||||
)
|
||||
|
|
|
@ -334,6 +334,9 @@ struct LinalgStrategyLowerVectorsPass
|
|||
if (options.transposeLowering) {
|
||||
vector::populateVectorTransposeLoweringPatterns(
|
||||
patterns, options.vectorTransformOptions);
|
||||
if (options.avx2Lowering)
|
||||
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
|
||||
patterns, options.avx2LoweringOptions, /*benefit=*/10);
|
||||
}
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements vector.transpose rewrites as AVX patterns for particular
|
||||
// sizes of interest.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
using namespace mlir::x86vector;
|
||||
using namespace mlir::x86vector::avx2;
|
||||
|
||||
Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1,
|
||||
Value v2) {
|
||||
return b.create<vector::ShuffleOp>(
|
||||
v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
|
||||
}
|
||||
|
||||
Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1,
|
||||
Value v2) {
|
||||
return b.create<vector::ShuffleOp>(
|
||||
v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
|
||||
}
|
||||
/// a a b b a a b b
|
||||
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
|
||||
/// 0:127 | 128:255
|
||||
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
|
||||
Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
|
||||
Value v2, char mask) {
|
||||
char b01, b23, b45, b67;
|
||||
MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
|
||||
SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
|
||||
b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
|
||||
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
|
||||
}
|
||||
|
||||
// imm[0:1] out of imm[0:3] is:
|
||||
// 0 1 2 3
|
||||
// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
|
||||
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
|
||||
// 0 1 2 3
|
||||
// imm[0:1] out of imm[4:7].
|
||||
Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
|
||||
Value v1, Value v2,
|
||||
char mask) {
|
||||
SmallVector<int64_t> shuffleMask;
|
||||
auto appendToMask = [&](char control) {
|
||||
if (control == 0)
|
||||
llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
|
||||
else if (control == 1)
|
||||
llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
|
||||
else if (control == 2)
|
||||
llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
|
||||
else if (control == 3)
|
||||
llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
|
||||
else
|
||||
llvm_unreachable("control > 3 : overflow");
|
||||
};
|
||||
char b03, b47;
|
||||
MaskHelper::extractPermute(mask, b03, b47);
|
||||
appendToMask(b03);
|
||||
appendToMask(b47);
|
||||
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
|
||||
}
|
||||
|
||||
/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
|
||||
void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
|
||||
MutableArrayRef<Value> vs) {
|
||||
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
|
||||
#ifndef NDEBUG
|
||||
assert(vs.size() == 4 && "expects 4 vectors");
|
||||
assert(llvm::all_of(ValueRange{vs}.getTypes(),
|
||||
[&](Type t) { return t == vt; }) &&
|
||||
"expects all types to be vector<8xf32>");
|
||||
#endif
|
||||
|
||||
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
|
||||
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
|
||||
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
|
||||
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
|
||||
Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||
Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||
Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||
Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||
vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>());
|
||||
vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>());
|
||||
vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>());
|
||||
vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>());
|
||||
}
|
||||
|
||||
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
|
||||
void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
|
||||
MutableArrayRef<Value> vs) {
|
||||
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
|
||||
(void)vt;
|
||||
assert(vs.size() == 8 && "expects 8 vectors");
|
||||
assert(llvm::all_of(ValueRange{vs}.getTypes(),
|
||||
[&](Type t) { return t == vt; }) &&
|
||||
"expects all types to be vector<8xf32>");
|
||||
|
||||
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
|
||||
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
|
||||
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
|
||||
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
|
||||
Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
|
||||
Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
|
||||
Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
|
||||
Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
|
||||
Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||
Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||
Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||
Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||
Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||
Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||
Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||
Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||
vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
|
||||
vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
|
||||
vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());
|
||||
vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>());
|
||||
vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>());
|
||||
vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>());
|
||||
vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>());
|
||||
vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>());
|
||||
}
|
||||
|
||||
/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
|
||||
/// depending on the `TransposeLoweringOptions`.
|
||||
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
|
||||
|
||||
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
|
||||
int benefit)
|
||||
: OpRewritePattern<vector::TransposeOp>(context, benefit),
|
||||
loweringOptions(loweringOptions) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
VectorType srcType = op.getVectorType();
|
||||
if (srcType.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "Not a 2-D transpose");
|
||||
|
||||
SmallVector<int64_t, 4> transp;
|
||||
for (auto attr : op.transp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
if (transp[0] != 1 && transp[1] != 0)
|
||||
return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation");
|
||||
|
||||
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
|
||||
|
||||
auto applyRewrite = [&]() {
|
||||
ImplicitLocOpBuilder ib(loc, rewriter);
|
||||
SmallVector<Value> vs;
|
||||
for (int64_t i = 0; i < m; ++i)
|
||||
vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i));
|
||||
if (m == 4)
|
||||
transpose4x8xf32(ib, vs);
|
||||
if (m == 8)
|
||||
transpose8x8xf32(ib, vs);
|
||||
auto flattenedType =
|
||||
VectorType::get({n * m}, op.getVectorType().getElementType());
|
||||
auto transposedType =
|
||||
VectorType::get({n, m}, op.getVectorType().getElementType());
|
||||
Value res = ib.create<arith::ConstantOp>(
|
||||
op.getVectorType(), ib.getZeroAttr(op.getVectorType()));
|
||||
// The transposed form is still 4x8 and needs to be reinterpreted as 8x4
|
||||
// via shape_casts.
|
||||
for (int64_t i = 0; i < m; ++i)
|
||||
res = ib.create<vector::InsertOp>(vs[i], res, i);
|
||||
if (m == 4) {
|
||||
res = ib.create<vector::ShapeCastOp>(flattenedType, res);
|
||||
res = ib.create<vector::ShapeCastOp>(transposedType, res);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
};
|
||||
|
||||
if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
|
||||
return applyRewrite();
|
||||
if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
|
||||
return applyRewrite();
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
LoweringOptions loweringOptions;
|
||||
};
|
||||
|
||||
void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns, LoweringOptions options, int benefit) {
|
||||
patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_dialect_library(MLIRX86VectorTransforms
|
||||
AVXTranspose.cpp
|
||||
LegalizeForLLVMExport.cpp
|
||||
|
||||
DEPENDS
|
||||
|
@ -10,4 +11,5 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
|
|||
MLIRIR
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMIR
|
||||
MLIRVector
|
||||
)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s
|
||||
|
||||
#matvec_accesses = [
|
||||
affine_map<(i, j) -> (i, j)>,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
|
||||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
|
||||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
|
||||
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
|
||||
|
||||
#dotp_accesses = [
|
||||
affine_map<(i) -> (i)>,
|
||||
|
@ -149,8 +149,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
|
|||
// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
|
||||
// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
|
||||
// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
|
||||
// ... bunch of extract insert to transpose B into Bt
|
||||
// CHECK: %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32>
|
||||
// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
|
||||
// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
|
||||
|
@ -399,28 +398,6 @@ func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) ->
|
|||
return %0: vector<16xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transpose23
|
||||
// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
|
||||
// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
|
||||
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
|
||||
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
|
||||
// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
|
||||
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
|
||||
// CHECK: return %[[T11]] : vector<3x2xf32>
|
||||
|
||||
func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
|
||||
return %0 : vector<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @nop_shape_cast
|
||||
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
|
||||
// CHECK: return %[[A]] : vector<16xf32>
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s
|
||||
|
||||
// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
|
||||
//
|
||||
// TODO: having ShapeCastOp2DDownCastRewritePattern and
|
||||
// ShapeCastOp2DUpCastRewritePattern too early in the greedy rewriting
|
||||
// patterns misses opportunities to fold shape casts!
|
||||
|
||||
// No shape cast folding expected.
|
||||
//
|
||||
// CHECK-LABEL: func @transpose44_44(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
|
||||
// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
|
||||
//
|
||||
func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
|
||||
return %0 : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// Folds preceding shape cast as expected,
|
||||
// no following shape cast folding expected.
|
||||
//
|
||||
// FIXME: PR49590 - shape_cast not stable.
|
||||
//
|
||||
// CHECK-LABEL: func @transpose16_44(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
|
||||
// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
|
||||
// HECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
|
||||
//
|
||||
func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
|
||||
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
|
||||
%1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
|
||||
return %1 : vector<4x4xf32>
|
||||
}
|
||||
|
||||
// No preceding shape cast folding expected,
|
||||
// but FAILS to fold following cast.
|
||||
//
|
||||
// CHECK-LABEL: func @transpose44_16(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
|
||||
func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
|
||||
%1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
|
||||
return %1 : vector<16xf32>
|
||||
}
|
||||
|
||||
// Folds preceding shape cast as expected,
|
||||
// but FAILS to fold following cast.
|
||||
//
|
||||
// FIXME: PR49590 - shape_cast not stable.
|
||||
//
|
||||
// CHECK-LABEL: func @transpose16_16(
|
||||
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
|
||||
// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
|
||||
//
|
||||
func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
|
||||
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
|
||||
%1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
|
||||
%2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
|
||||
return %2 : vector<16xf32>
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @maskedload0(
|
||||
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
// RUN: mlir-opt %s -test-vector-transpose-lowering=eltwise=1 | FileCheck %s --check-prefix=ELTWISE
|
||||
// RUN: mlir-opt %s -test-vector-transpose-lowering=shuffle=1 | FileCheck %s --check-prefix=SHUFFLE
|
||||
// RUN: mlir-opt %s -test-vector-transpose-lowering=flat=1 | FileCheck %s --check-prefix=FLAT
|
||||
// RUN: mlir-opt %s -test-vector-transpose-lowering=avx2=1 | FileCheck %s --check-prefix=AVX2
|
||||
|
||||
// ELTWISE-LABEL: func @transpose23
|
||||
// ELTWISE-SAME: %[[A:.*]]: vector<2x3xf32>
|
||||
// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
|
||||
// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: return %[[T11]] : vector<3x2xf32>
|
||||
func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
|
||||
return %0 : vector<3x2xf32>
|
||||
}
|
||||
|
||||
// SHUFFLE-LABEL: func @transpose
|
||||
// FLAT-LABEL: func @transpose(
|
||||
func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
|
||||
// SHUFFLE: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
|
||||
// 0 4
|
||||
// 0 1 2 3 1 5
|
||||
// 4 5 6 7 -> 2 6
|
||||
// 3 7
|
||||
// SHUFFLE-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
|
||||
// SHUFFLE-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
|
||||
|
||||
// FLAT: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32>
|
||||
// FLAT: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32>
|
||||
// FLAT: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32>
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
return %0 : vector<4x2xf32>
|
||||
}
|
||||
|
||||
// AVX2-LABEL: func @transpose4x8
|
||||
func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
|
||||
// AVX2: vector.extract {{.*}}[0]
|
||||
// AVX2-NEXT: vector.extract {{.*}}[1]
|
||||
// AVX2-NEXT: vector.extract {{.*}}[2]
|
||||
// AVX2-NEXT: vector.extract {{.*}}[3]
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.insert {{.*}}[0]
|
||||
// AVX2-NEXT: vector.insert {{.*}}[1]
|
||||
// AVX2-NEXT: vector.insert {{.*}}[2]
|
||||
// AVX2-NEXT: vector.insert {{.*}}[3]
|
||||
// AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
|
||||
// AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
|
||||
return %0 : vector<8x4xf32>
|
||||
}
|
||||
|
||||
// AVX2-LABEL: func @transpose8x8
|
||||
func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
|
||||
// AVX2: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32>
|
||||
return %0 : vector<8x8xf32>
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
|
||||
// 0 4
|
||||
// 0 1 2 3 1 5
|
||||
// 4 5 6 7 -> 2 6
|
||||
// 3 7
|
||||
// CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
|
||||
%0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
|
||||
return %0 : vector<4x2xf32>
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
|
||||
//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -11,27 +11,31 @@
|
|||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestVectorToVectorConversion
|
||||
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
|
||||
TestVectorToVectorConversion() = default;
|
||||
TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
|
||||
struct TestVectorToVectorLowering
|
||||
: public PassWrapper<TestVectorToVectorLowering, FunctionPass> {
|
||||
TestVectorToVectorLowering() = default;
|
||||
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {}
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-to-vector-conversion";
|
||||
return "test-vector-to-vector-lowering";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns between ops in the vector dialect";
|
||||
return "Test lowering patterns between ops in the vector dialect";
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
|
@ -95,31 +99,22 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
struct TestVectorContractionConversion
|
||||
: public PassWrapper<TestVectorContractionConversion, FunctionPass> {
|
||||
struct TestVectorContractionLowering
|
||||
: public PassWrapper<TestVectorContractionLowering, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-contraction-conversion";
|
||||
return "test-vector-contraction-lowering";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns that lower contract ops in the vector "
|
||||
return "Test lowering patterns that lower contract ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
TestVectorContractionConversion() = default;
|
||||
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
|
||||
}
|
||||
TestVectorContractionLowering() = default;
|
||||
TestVectorContractionLowering(const TestVectorContractionLowering &pass) {}
|
||||
|
||||
Option<bool> lowerToFlatMatrix{
|
||||
*this, "vector-lower-matrix-intrinsics",
|
||||
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToFlatTranspose{
|
||||
*this, "vector-flat-transpose",
|
||||
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToShuffleTranspose{
|
||||
*this, "vector-shuffle-transpose",
|
||||
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToOuterProduct{
|
||||
*this, "vector-outerproduct",
|
||||
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
|
||||
|
@ -165,31 +160,91 @@ struct TestVectorContractionConversion
|
|||
contractLowering = VectorContractLowering::Matmul;
|
||||
VectorMultiReductionLowering vectorMultiReductionLowering =
|
||||
VectorMultiReductionLowering::InnerParallel;
|
||||
VectorTransposeLowering transposeLowering =
|
||||
VectorTransposeLowering::EltWise;
|
||||
if (lowerToFlatTranspose)
|
||||
transposeLowering = VectorTransposeLowering::Flat;
|
||||
if (lowerToShuffleTranspose)
|
||||
transposeLowering = VectorTransposeLowering::Shuffle;
|
||||
VectorTransformsOptions options{
|
||||
contractLowering, vectorMultiReductionLowering, transposeLowering};
|
||||
VectorTransformsOptions options{contractLowering,
|
||||
vectorMultiReductionLowering,
|
||||
VectorTransposeLowering()};
|
||||
populateVectorBroadcastLoweringPatterns(patterns);
|
||||
populateVectorContractLoweringPatterns(patterns, options);
|
||||
populateVectorMaskOpLoweringPatterns(patterns);
|
||||
if (!lowerToShuffleTranspose)
|
||||
populateVectorShapeCastLoweringPatterns(patterns);
|
||||
populateVectorTransposeLoweringPatterns(patterns, options);
|
||||
populateVectorShapeCastLoweringPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestVectorTransposeLowering
|
||||
: public PassWrapper<TestVectorTransposeLowering, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-transpose-lowering";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test lowering patterns that lower contract ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
TestVectorTransposeLowering() = default;
|
||||
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {}
|
||||
|
||||
Option<bool> lowerToEltwise{
|
||||
*this, "eltwise",
|
||||
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToFlatTranspose{
|
||||
*this, "flat",
|
||||
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToShuffleTranspose{
|
||||
*this, "shuffle",
|
||||
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> lowerToAvx2{
|
||||
*this, "avx2",
|
||||
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Test on one pattern in isolation.
|
||||
// Explicitly disable shape_cast lowering.
|
||||
LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
|
||||
.enableVectorTransposeLowering()
|
||||
.enableShapeCastLowering(false);
|
||||
if (lowerToEltwise) {
|
||||
options = options.setVectorTransformsOptions(
|
||||
VectorTransformsOptions().setVectorTransposeLowering(
|
||||
VectorTransposeLowering::EltWise));
|
||||
}
|
||||
if (lowerToFlatTranspose) {
|
||||
options = options.setVectorTransformsOptions(
|
||||
VectorTransformsOptions().setVectorTransposeLowering(
|
||||
VectorTransposeLowering::Flat));
|
||||
}
|
||||
if (lowerToShuffleTranspose) {
|
||||
options = options.setVectorTransformsOptions(
|
||||
VectorTransformsOptions().setVectorTransposeLowering(
|
||||
VectorTransposeLowering::Shuffle));
|
||||
}
|
||||
if (lowerToAvx2) {
|
||||
options = options.enableAVX2Lowering().setAVX2LoweringOptions(
|
||||
x86vector::avx2::LoweringOptions().setTransposeOptions(
|
||||
x86vector::avx2::TransposeLoweringOptions()
|
||||
.lower4x8xf32()
|
||||
.lower8x8xf32()));
|
||||
}
|
||||
|
||||
OpPassManager dynamicPM("builtin.func");
|
||||
dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
|
||||
if (failed(runPipeline(dynamicPM, getFunction())))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestVectorUnrollingPatterns
|
||||
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-unrolling-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to unroll contract ops in the vector "
|
||||
return "Test lowering patterns to unroll contract ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
TestVectorUnrollingPatterns() = default;
|
||||
|
@ -248,7 +303,7 @@ struct TestVectorDistributePatterns
|
|||
return "test-vector-distribute-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to distribute vector ops in the vector "
|
||||
return "Test lowering patterns to distribute vector ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
TestVectorDistributePatterns() = default;
|
||||
|
@ -302,7 +357,7 @@ struct TestVectorToLoopPatterns
|
|||
: public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-vector-to-forloop"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to break up a vector op into a for loop";
|
||||
return "Test lowering patterns to break up a vector op into a for loop";
|
||||
}
|
||||
TestVectorToLoopPatterns() = default;
|
||||
TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
|
||||
|
@ -365,7 +420,7 @@ struct TestVectorTransferUnrollingPatterns
|
|||
return "test-vector-transfer-unrolling-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to unroll transfer ops in the vector "
|
||||
return "Test lowering patterns to unroll transfer ops in the vector "
|
||||
"dialect";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
|
@ -391,7 +446,7 @@ struct TestVectorTransferFullPartialSplitPatterns
|
|||
return "test-vector-transfer-full-partial-split";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to split "
|
||||
return "Test lowering patterns to split "
|
||||
"transfer ops via scf.if + linalg ops";
|
||||
}
|
||||
TestVectorTransferFullPartialSplitPatterns() = default;
|
||||
|
@ -439,7 +494,7 @@ struct TestVectorTransferLoweringPatterns
|
|||
return "test-vector-transfer-lowering-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to lower transfer ops to other vector ops";
|
||||
return "Test lowering patterns to lower transfer ops to other vector ops";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
|
@ -462,7 +517,7 @@ struct TestVectorMultiReductionLoweringPatterns
|
|||
return "test-vector-multi-reduction-lowering-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns to lower vector.multi_reduction to other "
|
||||
return "Test lowering patterns to lower vector.multi_reduction to other "
|
||||
"vector ops";
|
||||
}
|
||||
Option<bool> useOuterReductions{
|
||||
|
@ -495,7 +550,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
|
|||
}
|
||||
|
||||
StringRef getDescription() const final {
|
||||
return "Test conversion patterns that reducedes the rank of the vector "
|
||||
return "Test lowering patterns that reducedes the rank of the vector "
|
||||
"transfer memory and vector operands.";
|
||||
}
|
||||
|
||||
|
@ -527,10 +582,12 @@ struct TestVectorReduceToContractPatternsPatterns
|
|||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestVectorConversions() {
|
||||
PassRegistration<TestVectorToVectorConversion>();
|
||||
void registerTestVectorLowerings() {
|
||||
PassRegistration<TestVectorToVectorLowering>();
|
||||
|
||||
PassRegistration<TestVectorContractionConversion>();
|
||||
PassRegistration<TestVectorContractionLowering>();
|
||||
|
||||
PassRegistration<TestVectorTransposeLowering>();
|
||||
|
||||
PassRegistration<TestVectorUnrollingPatterns>();
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ void registerTestPreparationPassWithAllowedMemrefResults();
|
|||
void registerTestRecursiveTypesPass();
|
||||
void registerTestSCFUtilsPass();
|
||||
void registerTestSliceAnalysisPass();
|
||||
void registerTestVectorConversions();
|
||||
void registerTestVectorLowerings();
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -197,7 +197,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestRecursiveTypesPass();
|
||||
mlir::test::registerTestSCFUtilsPass();
|
||||
mlir::test::registerTestSliceAnalysisPass();
|
||||
mlir::test::registerTestVectorConversions();
|
||||
mlir::test::registerTestVectorLowerings();
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -1458,6 +1458,7 @@ cc_library(
|
|||
":LLVMCommonConversion",
|
||||
":LLVMDialect",
|
||||
":StandardOps",
|
||||
":VectorOps",
|
||||
":X86Vector",
|
||||
"//llvm:Core",
|
||||
"//llvm:Support",
|
||||
|
@ -6401,6 +6402,7 @@ cc_library(
|
|||
":TransformUtils",
|
||||
":VectorOps",
|
||||
":VectorToSCF",
|
||||
":X86VectorTransforms",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -484,6 +484,7 @@ cc_library(
|
|||
"//mlir:Affine",
|
||||
"//mlir:Analysis",
|
||||
"//mlir:LinalgOps",
|
||||
"//mlir:LinalgTransforms",
|
||||
"//mlir:MemRefDialect",
|
||||
"//mlir:Pass",
|
||||
"//mlir:SCFDialect",
|
||||
|
|
Loading…
Reference in New Issue