[mlir][X86Vector] Add specialized vector.transpose lowering patterns for AVX2

This revision adds an implementation of 2-D vector.transpose for 4x8 and 8x8 for
AVX2 and surfaces it to the Linalg level of control.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D113347
This commit is contained in:
Nicolas Vasilache 2021-11-10 13:19:41 +00:00
parent 703ded8dda
commit 34ff857350
17 changed files with 556 additions and 156 deletions

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Bufferize.h"
@ -993,6 +994,12 @@ struct LinalgVectorLoweringOptions {
transposeLowering = val;
return *this;
}
/// Enable AVX2-specific lowerings.
bool avx2Lowering = false;
LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) {
avx2Lowering = val;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
@ -1009,6 +1016,13 @@ struct LinalgVectorLoweringOptions {
vectorTransformOptions = options;
return *this;
}
/// Configure specialized vector lowerings.
x86vector::avx2::LoweringOptions avx2LoweringOptions;
LinalgVectorLoweringOptions &
setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) {
avx2LoweringOptions = options;
return *this;
}
};
//===----------------------------------------------------------------------===//

View File

@ -9,13 +9,126 @@
#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
#include "mlir/IR/Value.h"
namespace mlir {
class ImplicitLocOpBuilder;
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
namespace x86vector {
/// Helper class to factor out the creation and extraction of masks from nibs.
struct MaskHelper {
/// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
/// Meant to be used with instructions such as mm256ShufflePs.
template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
static char shuffle() {
static_assert(b01 <= 0x03, "overflow");
static_assert(b23 <= 0x03, "overflow");
static_assert(b45 <= 0x03, "overflow");
static_assert(b67 <= 0x03, "overflow");
return (b67 << 6) + (b45 << 4) + (b23 << 2) + b01;
}
/// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
static void extractShuffle(char mask, char &b01, char &b23, char &b45,
char &b67) {
b67 = (mask & (0x03 << 6)) >> 6;
b45 = (mask & (0x03 << 4)) >> 4;
b23 = (mask & (0x03 << 2)) >> 2;
b01 = mask & 0x03;
}
/// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
/// Meant to be used with instructions such as mm256Permute2f128Ps.
template <unsigned b47, unsigned b03>
static char permute() {
static_assert(b03 <= 0x0f, "overflow");
static_assert(b47 <= 0x0f, "overflow");
return (b47 << 4) + b03;
}
/// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
static void extractPermute(char mask, char &b03, char &b47) {
b47 = (mask & (0x0f << 4)) >> 4;
b03 = mask & 0x0f;
}
};
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
/// - clang/test/CodeGen/X86/avx-builtins.c
/// - clang/test/CodeGen/X86/avx2-builtins.c
/// - clang/test/CodeGen/X86/avx-shuffle-builtins.c
/// as well as the Intel Intrinsics Guide
/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
/// make it easier to just implement known good lowerings.
/// All intrinsics correspond 1-1 to the Intel definition.
//===----------------------------------------------------------------------===//
namespace avx2 {
/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
/// a a b b a a b b
/// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
/// 0:127 | 128:255
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, char mask);
// imm[0:1] out of imm[0:3] is:
// 0 1 2 3
// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
// 0 1 2 3
// imm[0:1] out of imm[4:7].
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
char mask);
/// 4x8xf32-specific AVX2 transpose lowering.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
/// 8x8xf32-specific AVX2 transpose lowering.
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
/// Structure to control the behavior of specialized AVX2 transpose lowering.
struct TransposeLoweringOptions {
bool lower4x8xf32_ = false;
TransposeLoweringOptions &lower4x8xf32(bool lower = true) {
lower4x8xf32_ = lower;
return *this;
}
bool lower8x8xf32_ = false;
TransposeLoweringOptions &lower8x8xf32(bool lower = true) {
lower8x8xf32_ = lower;
return *this;
}
};
/// Options for controlling specialized AVX2 lowerings.
struct LoweringOptions {
/// Configure specialized vector lowerings.
TransposeLoweringOptions transposeOptions;
LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) {
transposeOptions = options;
return *this;
}
};
/// Insert specialized transpose lowering patterns.
void populateSpecializedTransposeLoweringPatterns(
RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
int benefit = 10);
} // namespace avx2
} // namespace x86vector
/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
/// intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns(

View File

@ -51,5 +51,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRTransforms
MLIRTransformUtils
MLIRVector
MLIRX86VectorTransforms
MLIRVectorToSCF
)

View File

@ -334,6 +334,9 @@ struct LinalgStrategyLowerVectorsPass
if (options.transposeLowering) {
vector::populateVectorTransposeLoweringPatterns(
patterns, options.vectorTransformOptions);
if (options.avx2Lowering)
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, options.avx2LoweringOptions, /*benefit=*/10);
}
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

View File

@ -0,0 +1,208 @@
//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements vector.transpose rewrites as AVX patterns for particular
// sizes of interest.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86vector::avx2;
Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1,
Value v2) {
return b.create<vector::ShuffleOp>(
v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
}
Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1,
Value v2) {
return b.create<vector::ShuffleOp>(
v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
}
/// a a b b a a b b
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
/// 0:127 | 128:255
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1,
Value v2, char mask) {
char b01, b23, b45, b67;
MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
}
// imm[0:1] out of imm[0:3] is:
// 0 1 2 3
// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
// 0 1 2 3
// imm[0:1] out of imm[4:7].
Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b,
Value v1, Value v2,
char mask) {
SmallVector<int64_t> shuffleMask;
auto appendToMask = [&](char control) {
if (control == 0)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
else if (control == 1)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
else if (control == 2)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
else if (control == 3)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
else
llvm_unreachable("control > 3 : overflow");
};
char b03, b47;
MaskHelper::extractPermute(mask, b03, b47);
appendToMask(b03);
appendToMask(b47);
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
}
/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
#ifndef NDEBUG
assert(vs.size() == 4 && "expects 4 vectors");
assert(llvm::all_of(ValueRange{vs}.getTypes(),
[&](Type t) { return t == vt; }) &&
"expects all types to be vector<8xf32>");
#endif
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>());
vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>());
vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>());
vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>());
}
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
(void)vt;
assert(vs.size() == 8 && "expects 8 vectors");
assert(llvm::all_of(ValueRange{vs}.getTypes(),
[&](Type t) { return t == vt; }) &&
"expects all types to be vector<8xf32>");
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>());
Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>());
Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>());
Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>());
vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());
vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>());
vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>());
vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>());
vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>());
vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>());
}
/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
/// depending on the `TransposeLoweringOptions`.
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
int benefit)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
loweringOptions(loweringOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType srcType = op.getVectorType();
if (srcType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Not a 2-D transpose");
SmallVector<int64_t, 4> transp;
for (auto attr : op.transp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (transp[0] != 1 && transp[1] != 0)
return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation");
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
auto applyRewrite = [&]() {
ImplicitLocOpBuilder ib(loc, rewriter);
SmallVector<Value> vs;
for (int64_t i = 0; i < m; ++i)
vs.push_back(ib.create<vector::ExtractOp>(op.vector(), i));
if (m == 4)
transpose4x8xf32(ib, vs);
if (m == 8)
transpose8x8xf32(ib, vs);
auto flattenedType =
VectorType::get({n * m}, op.getVectorType().getElementType());
auto transposedType =
VectorType::get({n, m}, op.getVectorType().getElementType());
Value res = ib.create<arith::ConstantOp>(
op.getVectorType(), ib.getZeroAttr(op.getVectorType()));
// The transposed form is still 4x8 and needs to be reinterpreted as 8x4
// via shape_casts.
for (int64_t i = 0; i < m; ++i)
res = ib.create<vector::InsertOp>(vs[i], res, i);
if (m == 4) {
res = ib.create<vector::ShapeCastOp>(flattenedType, res);
res = ib.create<vector::ShapeCastOp>(transposedType, res);
}
rewriter.replaceOp(op, res);
return success();
};
if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
return applyRewrite();
if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
return applyRewrite();
return failure();
}
private:
LoweringOptions loweringOptions;
};
void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
RewritePatternSet &patterns, LoweringOptions options, int benefit) {
patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
}

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
DEPENDS
@ -10,4 +11,5 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMIR
MLIRVector
)

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s
#matvec_accesses = [
affine_map<(i, j) -> (i, j)>,

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
#dotp_accesses = [
affine_map<(i) -> (i)>,
@ -149,8 +149,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
// ... bunch of extract insert to transpose B into Bt
// CHECK: %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32>
// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
@ -399,28 +398,6 @@ func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) ->
return %0: vector<16xi32>
}
// CHECK-LABEL: func @transpose23
// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
// CHECK: return %[[T11]] : vector<3x2xf32>
func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
return %0 : vector<3x2xf32>
}
// CHECK-LABEL: func @nop_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>

View File

@ -1,65 +0,0 @@
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s
// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
//
// TODO: having ShapeCastOp2DDownCastRewritePattern and
// ShapeCastOp2DUpCastRewritePattern too early in the greedy rewriting
// patterns misses opportunities to fold shape casts!
// No shape cast folding expected.
//
// CHECK-LABEL: func @transpose44_44(
// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
//
func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
%0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
return %0 : vector<4x4xf32>
}
// Folds preceding shape cast as expected,
// no following shape cast folding expected.
//
// FIXME: PR49590 - shape_cast not stable.
//
// CHECK-LABEL: func @transpose16_44(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
// HECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
//
func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
return %1 : vector<4x4xf32>
}
// No preceding shape cast folding expected,
// but FAILS to fold following cast.
//
// CHECK-LABEL: func @transpose44_16(
// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
%0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
%1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
return %1 : vector<16xf32>
}
// Folds preceding shape cast as expected,
// but FAILS to fold following cast.
//
// FIXME: PR49590 - shape_cast not stable.
//
// CHECK-LABEL: func @transpose16_16(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
//
func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
%2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
return %2 : vector<16xf32>
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
// CHECK-LABEL: func @maskedload0(
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>

View File

@ -0,0 +1,101 @@
// RUN: mlir-opt %s -test-vector-transpose-lowering=eltwise=1 | FileCheck %s --check-prefix=ELTWISE
// RUN: mlir-opt %s -test-vector-transpose-lowering=shuffle=1 | FileCheck %s --check-prefix=SHUFFLE
// RUN: mlir-opt %s -test-vector-transpose-lowering=flat=1 | FileCheck %s --check-prefix=FLAT
// RUN: mlir-opt %s -test-vector-transpose-lowering=avx2=1 | FileCheck %s --check-prefix=AVX2
// ELTWISE-LABEL: func @transpose23
// ELTWISE-SAME: %[[A:.*]]: vector<2x3xf32>
// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
// ELTWISE: return %[[T11]] : vector<3x2xf32>
func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
return %0 : vector<3x2xf32>
}
// SHUFFLE-LABEL: func @transpose
// FLAT-LABEL: func @transpose(
func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
// SHUFFLE: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
// 0 4
// 0 1 2 3 1 5
// 4 5 6 7 -> 2 6
// 3 7
// SHUFFLE-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
// SHUFFLE-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
// FLAT: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32>
// FLAT: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32>
// FLAT: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
return %0 : vector<4x2xf32>
}
// AVX2-LABEL: func @transpose4x8
func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
// AVX2: vector.extract {{.*}}[0]
// AVX2-NEXT: vector.extract {{.*}}[1]
// AVX2-NEXT: vector.extract {{.*}}[2]
// AVX2-NEXT: vector.extract {{.*}}[3]
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.insert {{.*}}[0]
// AVX2-NEXT: vector.insert {{.*}}[1]
// AVX2-NEXT: vector.insert {{.*}}[2]
// AVX2-NEXT: vector.insert {{.*}}[3]
// AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
// AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
return %0 : vector<8x4xf32>
}
// AVX2-LABEL: func @transpose8x8
func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
// AVX2: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
// AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32>
return %0 : vector<8x8xf32>
}

View File

@ -1,14 +0,0 @@
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s
// CHECK-LABEL: func @transpose
func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
// CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32>
// 0 4
// 0 1 2 3 1 5
// 4 5 6 7 -> 2 6
// 3 7
// CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32>
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
return %0 : vector<4x2xf32>
}

View File

@ -1,4 +1,4 @@
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -11,27 +11,31 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::vector;
namespace {
struct TestVectorToVectorConversion
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
TestVectorToVectorConversion() = default;
TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
struct TestVectorToVectorLowering
: public PassWrapper<TestVectorToVectorLowering, FunctionPass> {
TestVectorToVectorLowering() = default;
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {}
StringRef getArgument() const final {
return "test-vector-to-vector-conversion";
return "test-vector-to-vector-lowering";
}
StringRef getDescription() const final {
return "Test conversion patterns between ops in the vector dialect";
return "Test lowering patterns between ops in the vector dialect";
}
void getDependentDialects(DialectRegistry &registry) const override {
@ -95,31 +99,22 @@ private:
}
};
struct TestVectorContractionConversion
: public PassWrapper<TestVectorContractionConversion, FunctionPass> {
struct TestVectorContractionLowering
: public PassWrapper<TestVectorContractionLowering, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-contraction-conversion";
return "test-vector-contraction-lowering";
}
StringRef getDescription() const final {
return "Test conversion patterns that lower contract ops in the vector "
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorContractionConversion() = default;
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
}
TestVectorContractionLowering() = default;
TestVectorContractionLowering(const TestVectorContractionLowering &pass) {}
Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
Option<bool> lowerToFlatTranspose{
*this, "vector-flat-transpose",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToShuffleTranspose{
*this, "vector-shuffle-transpose",
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@ -165,31 +160,91 @@ struct TestVectorContractionConversion
contractLowering = VectorContractLowering::Matmul;
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransposeLowering transposeLowering =
VectorTransposeLowering::EltWise;
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
if (lowerToShuffleTranspose)
transposeLowering = VectorTransposeLowering::Shuffle;
VectorTransformsOptions options{
contractLowering, vectorMultiReductionLowering, transposeLowering};
VectorTransformsOptions options{contractLowering,
vectorMultiReductionLowering,
VectorTransposeLowering()};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
if (!lowerToShuffleTranspose)
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, options);
populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransposeLowering
: public PassWrapper<TestVectorTransposeLowering, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-transpose-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorTransposeLowering() = default;
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {}
Option<bool> lowerToEltwise{
*this, "eltwise",
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
llvm::cl::init(false)};
Option<bool> lowerToFlatTranspose{
*this, "flat",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToShuffleTranspose{
*this, "shuffle",
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
llvm::cl::init(false)};
Option<bool> lowerToAvx2{
*this, "avx2",
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
llvm::cl::init(false)};
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
// Test on one pattern in isolation.
// Explicitly disable shape_cast lowering.
LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
.enableVectorTransposeLowering()
.enableShapeCastLowering(false);
if (lowerToEltwise) {
options = options.setVectorTransformsOptions(
VectorTransformsOptions().setVectorTransposeLowering(
VectorTransposeLowering::EltWise));
}
if (lowerToFlatTranspose) {
options = options.setVectorTransformsOptions(
VectorTransformsOptions().setVectorTransposeLowering(
VectorTransposeLowering::Flat));
}
if (lowerToShuffleTranspose) {
options = options.setVectorTransformsOptions(
VectorTransformsOptions().setVectorTransposeLowering(
VectorTransposeLowering::Shuffle));
}
if (lowerToAvx2) {
options = options.enableAVX2Lowering().setAVX2LoweringOptions(
x86vector::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::TransposeLoweringOptions()
.lower4x8xf32()
.lower8x8xf32()));
}
OpPassManager dynamicPM("builtin.func");
dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
if (failed(runPipeline(dynamicPM, getFunction())))
return signalPassFailure();
}
};
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to unroll contract ops in the vector "
return "Test lowering patterns to unroll contract ops in the vector "
"dialect";
}
TestVectorUnrollingPatterns() = default;
@ -248,7 +303,7 @@ struct TestVectorDistributePatterns
return "test-vector-distribute-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to distribute vector ops in the vector "
return "Test lowering patterns to distribute vector ops in the vector "
"dialect";
}
TestVectorDistributePatterns() = default;
@ -302,7 +357,7 @@ struct TestVectorToLoopPatterns
: public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
StringRef getArgument() const final { return "test-vector-to-forloop"; }
StringRef getDescription() const final {
return "Test conversion patterns to break up a vector op into a for loop";
return "Test lowering patterns to break up a vector op into a for loop";
}
TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
@ -365,7 +420,7 @@ struct TestVectorTransferUnrollingPatterns
return "test-vector-transfer-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to unroll transfer ops in the vector "
return "Test lowering patterns to unroll transfer ops in the vector "
"dialect";
}
void runOnFunction() override {
@ -391,7 +446,7 @@ struct TestVectorTransferFullPartialSplitPatterns
return "test-vector-transfer-full-partial-split";
}
StringRef getDescription() const final {
return "Test conversion patterns to split "
return "Test lowering patterns to split "
"transfer ops via scf.if + linalg ops";
}
TestVectorTransferFullPartialSplitPatterns() = default;
@ -439,7 +494,7 @@ struct TestVectorTransferLoweringPatterns
return "test-vector-transfer-lowering-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to lower transfer ops to other vector ops";
return "Test lowering patterns to lower transfer ops to other vector ops";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
@ -462,7 +517,7 @@ struct TestVectorMultiReductionLoweringPatterns
return "test-vector-multi-reduction-lowering-patterns";
}
StringRef getDescription() const final {
return "Test conversion patterns to lower vector.multi_reduction to other "
return "Test lowering patterns to lower vector.multi_reduction to other "
"vector ops";
}
Option<bool> useOuterReductions{
@ -495,7 +550,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
StringRef getDescription() const final {
return "Test conversion patterns that reducedes the rank of the vector "
return "Test lowering patterns that reducedes the rank of the vector "
"transfer memory and vector operands.";
}
@ -527,10 +582,12 @@ struct TestVectorReduceToContractPatternsPatterns
namespace mlir {
namespace test {
void registerTestVectorConversions() {
PassRegistration<TestVectorToVectorConversion>();
void registerTestVectorLowerings() {
PassRegistration<TestVectorToVectorLowering>();
PassRegistration<TestVectorContractionConversion>();
PassRegistration<TestVectorContractionLowering>();
PassRegistration<TestVectorTransposeLowering>();
PassRegistration<TestVectorUnrollingPatterns>();

View File

@ -107,7 +107,7 @@ void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
void registerTestSliceAnalysisPass();
void registerTestVectorConversions();
void registerTestVectorLowerings();
} // namespace test
} // namespace mlir
@ -197,7 +197,7 @@ void registerTestPasses() {
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestVectorConversions();
mlir::test::registerTestVectorLowerings();
}
#endif

View File

@ -1458,6 +1458,7 @@ cc_library(
":LLVMCommonConversion",
":LLVMDialect",
":StandardOps",
":VectorOps",
":X86Vector",
"//llvm:Core",
"//llvm:Support",
@ -6401,6 +6402,7 @@ cc_library(
":TransformUtils",
":VectorOps",
":VectorToSCF",
":X86VectorTransforms",
"//llvm:Support",
],
)

View File

@ -484,6 +484,7 @@ cc_library(
"//mlir:Affine",
"//mlir:Analysis",
"//mlir:LinalgOps",
"//mlir:LinalgTransforms",
"//mlir:MemRefDialect",
"//mlir:Pass",
"//mlir:SCFDialect",