forked from OSchip/llvm-project
Apply all necessary tilings and unrollings to get a micro-kernel
This is the first patch to apply the BLIS matmul optimization pattern on matmul kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf). BLIS implements gemm as three nested loops around a macro-kernel, plus two packing routines. The macro-kernel is implemented in terms of two additional loops around a micro-kernel. The micro-kernel is a loop around a rank-1 (i.e., outer product) update. In this change we create the BLIS micro-kernel by applying a combination of tiling and unrolling. In subsequent changes we will add the extraction of the BLIS macro-kernel and implement the packing transformation. Contributed-by: Roman Gareev <gareevroman@gmail.com> Reviewed-by: Tobias Grosser <tobias@grosser.es> Differential Revision: http://reviews.llvm.org/D21140 llvm-svn: 273397
This commit is contained in:
parent
50b80359c0
commit
42402c9e89
|
@ -13,6 +13,7 @@
|
|||
#define POLLY_SCHEDULE_OPTIMIZER_H
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||
#include "isl/ctx.h"
|
||||
|
||||
struct isl_schedule;
|
||||
|
@ -37,9 +38,11 @@ public:
|
|||
///
|
||||
/// @param Schedule The schedule object the transformations will be applied
|
||||
/// to.
|
||||
/// @param TTI Target Transform Info.
|
||||
/// @returns The transformed schedule.
|
||||
static __isl_give isl_schedule *
|
||||
optimizeSchedule(__isl_take isl_schedule *Schedule);
|
||||
optimizeSchedule(__isl_take isl_schedule *Schedule,
|
||||
const llvm::TargetTransformInfo *TTI = nullptr);
|
||||
|
||||
/// @brief Apply schedule tree transformations.
|
||||
///
|
||||
|
@ -51,9 +54,11 @@ public:
|
|||
/// - Prevectorization
|
||||
///
|
||||
/// @param Node The schedule object post-transformations will be applied to.
|
||||
/// @param TTI Target Transform Info.
|
||||
/// @returns The transformed schedule.
|
||||
static __isl_give isl_schedule_node *
|
||||
optimizeScheduleNode(__isl_take isl_schedule_node *Node);
|
||||
optimizeScheduleNode(__isl_take isl_schedule_node *Node,
|
||||
const llvm::TargetTransformInfo *TTI = nullptr);
|
||||
|
||||
/// @brief Decide if the @p NewSchedule is profitable for @p S.
|
||||
///
|
||||
|
@ -100,6 +105,32 @@ private:
|
|||
applyRegisterTiling(__isl_take isl_schedule_node *Node,
|
||||
llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
|
||||
|
||||
/// @brief Apply the BLIS matmul optimization pattern
|
||||
///
|
||||
/// Apply the BLIS matmul optimization pattern
|
||||
/// (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
|
||||
/// BLIS implements gemm as three nested loops around a macro-kernel,
|
||||
/// plus two packing routines. The macro-kernel is implemented in terms
|
||||
/// of two additional loops around a micro-kernel. The micro-kernel
|
||||
/// is a loop around a rank-1 (i.e., outer product) update.
|
||||
///
|
||||
/// We create the BLIS micro-kernel by applying a combination of tiling
|
||||
/// and unrolling. In subsequent changes we will add the extraction
|
||||
/// of the BLIS macro-kernel and implement the packing transformation.
|
||||
///
|
||||
/// It is assumed that the Node is successfully checked
|
||||
/// by ScheduleTreeOptimizer::isMatrMultPattern. Consequently
|
||||
/// in case of matmul kernels the application of optimizeMatMulPattern
|
||||
/// can lead to close-to-peak performance. Maybe it can be generalized
|
||||
/// to effectively optimize the whole class of successfully checked
|
||||
/// statements.
|
||||
///
|
||||
/// @param Node the node that contains a band to be optimized.
|
||||
/// @return Modified isl_schedule_node.
|
||||
static __isl_give isl_schedule_node *
|
||||
optimizeMatMulPattern(__isl_take isl_schedule_node *Node,
|
||||
const llvm::TargetTransformInfo *TTI);
|
||||
|
||||
/// @brief Check if this node is a band node we want to tile.
|
||||
///
|
||||
/// We look for innermost band nodes where individual dimensions are marked as
|
||||
|
|
|
@ -53,6 +53,7 @@
|
|||
#include "polly/Options.h"
|
||||
#include "polly/ScopInfo.h"
|
||||
#include "polly/Support/GICHelper.h"
|
||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "isl/aff.h"
|
||||
#include "isl/band.h"
|
||||
|
@ -119,6 +120,20 @@ static cl::opt<bool> FirstLevelTiling("polly-tiling",
|
|||
cl::init(true), cl::ZeroOrMore,
|
||||
cl::cat(PollyCategory));
|
||||
|
||||
static cl::opt<int> LatencyVectorFma(
|
||||
"polly-target-latency-vector-fma",
|
||||
cl::desc("The minimal number of cycles between issuing two "
|
||||
"dependent consecutive vector fused multiply-add "
|
||||
"instructions."),
|
||||
cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
|
||||
|
||||
static cl::opt<int> ThrougputVectorFma(
|
||||
"polly-target-througput-vector-fma",
|
||||
cl::desc("A throughput of the processor floating-point arithmetic units "
|
||||
"expressed in the number of vector fused multiply-add "
|
||||
"instructions per clock cycle."),
|
||||
cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory));
|
||||
|
||||
static cl::opt<int> FirstLevelDefaultTileSize(
|
||||
"polly-default-tile-size",
|
||||
cl::desc("The default tile size (if not enough were provided by"
|
||||
|
@ -478,6 +493,23 @@ static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
|
|||
return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
|
||||
}
|
||||
|
||||
__isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern(
|
||||
__isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) {
|
||||
assert(TTI && "The target transform info should be provided.");
|
||||
// Get a micro-kernel.
|
||||
// Nvec - Number of double-precision floating-point numbers that can be hold
|
||||
// by a vector register. Use 2 by default.
|
||||
auto Nvec = TTI->getRegisterBitWidth(true) / 64;
|
||||
if (Nvec == 0)
|
||||
Nvec = 2;
|
||||
int Nr =
|
||||
ceil(sqrt(Nvec * LatencyVectorFma * ThrougputVectorFma) / Nvec) * Nvec;
|
||||
int Mr = ceil(Nvec * LatencyVectorFma * ThrougputVectorFma / Nr);
|
||||
std::vector<int> MicroKernelParams{Mr, Nr};
|
||||
Node = applyRegisterTiling(Node, MicroKernelParams, 1);
|
||||
return Node;
|
||||
}
|
||||
|
||||
bool ScheduleTreeOptimizer::isMatrMultPattern(
|
||||
__isl_keep isl_schedule_node *Node) {
|
||||
auto *PartialSchedule =
|
||||
|
@ -508,16 +540,21 @@ ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
|
|||
if (!isTileableBandNode(Node))
|
||||
return Node;
|
||||
|
||||
if (PMBasedOpts && isMatrMultPattern(Node))
|
||||
if (PMBasedOpts && User && isMatrMultPattern(Node)) {
|
||||
DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
|
||||
const llvm::TargetTransformInfo *TTI;
|
||||
TTI = static_cast<const llvm::TargetTransformInfo *>(User);
|
||||
Node = optimizeMatMulPattern(Node, TTI);
|
||||
}
|
||||
|
||||
return standardBandOpts(Node, User);
|
||||
}
|
||||
|
||||
__isl_give isl_schedule *
|
||||
ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
|
||||
ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule,
|
||||
const llvm::TargetTransformInfo *TTI) {
|
||||
isl_schedule_node *Root = isl_schedule_get_root(Schedule);
|
||||
Root = optimizeScheduleNode(Root);
|
||||
Root = optimizeScheduleNode(Root, TTI);
|
||||
isl_schedule_free(Schedule);
|
||||
auto S = isl_schedule_node_get_schedule(Root);
|
||||
isl_schedule_node_free(Root);
|
||||
|
@ -525,8 +562,9 @@ ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
|
|||
}
|
||||
|
||||
__isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeScheduleNode(
|
||||
__isl_take isl_schedule_node *Node) {
|
||||
Node = isl_schedule_node_map_descendant_bottom_up(Node, optimizeBand, NULL);
|
||||
__isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) {
|
||||
Node = isl_schedule_node_map_descendant_bottom_up(
|
||||
Node, optimizeBand, const_cast<void *>(static_cast<const void *>(TTI)));
|
||||
return Node;
|
||||
}
|
||||
|
||||
|
@ -714,7 +752,10 @@ bool IslScheduleOptimizer::runOnScop(Scop &S) {
|
|||
isl_printer_free(P);
|
||||
});
|
||||
|
||||
isl_schedule *NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule);
|
||||
Function &F = S.getFunction();
|
||||
auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
|
||||
isl_schedule *NewSchedule =
|
||||
ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
|
||||
isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
|
||||
|
||||
if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
|
||||
|
@ -752,6 +793,7 @@ void IslScheduleOptimizer::printScop(raw_ostream &OS, Scop &) const {
|
|||
void IslScheduleOptimizer::getAnalysisUsage(AnalysisUsage &AU) const {
|
||||
ScopPass::getAnalysisUsage(AU);
|
||||
AU.addRequired<DependenceInfo>();
|
||||
AU.addRequired<TargetTransformInfoWrapperPass>();
|
||||
}
|
||||
|
||||
Pass *polly::createIslScheduleOptimizerPass() {
|
||||
|
@ -762,5 +804,6 @@ INITIALIZE_PASS_BEGIN(IslScheduleOptimizer, "polly-opt-isl",
|
|||
"Polly - Optimize schedule of SCoP", false, false);
|
||||
INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
|
||||
INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass);
|
||||
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass);
|
||||
INITIALIZE_PASS_END(IslScheduleOptimizer, "polly-opt-isl",
|
||||
"Polly - Optimize schedule of SCoP", false, false)
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -polly-target-througput-vector-fma=1 -polly-target-latency-vector-fma=8 -analyze -polly-ast < %s 2>&1 | FileCheck %s
|
||||
;
|
||||
; /* C := alpha*A*B + beta*C */
|
||||
; for (i = 0; i < _PB_NI; i++)
|
||||
; for (j = 0; j < _PB_NJ; j++)
|
||||
; {
|
||||
; C[i][j] *= beta;
|
||||
; for (k = 0; k < _PB_NK; ++k)
|
||||
; C[i][j] += alpha * A[i][k] * B[k][j];
|
||||
; }
|
||||
;
|
||||
; CHECK: {
|
||||
; CHECK: // 1st level tiling - Tiles
|
||||
; CHECK: for (int c0 = 0; c0 <= 32; c0 += 1)
|
||||
; CHECK: for (int c1 = 0; c1 <= 32; c1 += 1) {
|
||||
; CHECK: // 1st level tiling - Points
|
||||
; CHECK: for (int c2 = 0; c2 <= 31; c2 += 1)
|
||||
; CHECK: for (int c3 = 0; c3 <= 31; c3 += 1)
|
||||
; CHECK: Stmt_bb14(32 * c0 + c2, 32 * c1 + c3);
|
||||
; CHECK: }
|
||||
; CHECK: // Register tiling - Tiles
|
||||
; CHECK: for (int c0 = 0; c0 <= 263; c0 += 1)
|
||||
; CHECK: for (int c1 = 0; c1 <= 131; c1 += 1)
|
||||
; CHECK: for (int c2 = 0; c2 <= 1023; c2 += 1) {
|
||||
; CHECK: // Register tiling - Points
|
||||
; CHECK: // 1st level tiling - Tiles
|
||||
; CHECK: // 1st level tiling - Points
|
||||
; CHECK: {
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 2, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 3, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 4, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 5, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 6, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0, 8 * c1 + 7, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 2, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 3, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 4, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 5, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 6, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 1, 8 * c1 + 7, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 2, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 3, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 4, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 5, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 6, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 2, 8 * c1 + 7, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 1, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 2, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 3, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 4, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 5, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 6, c2);
|
||||
; CHECK: Stmt_bb24(4 * c0 + 3, 8 * c1 + 7, c2);
|
||||
; CHECK: }
|
||||
; CHECK: }
|
||||
; CHECK: }
|
||||
;
|
||||
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
|
||||
target triple = "x86_64-unknown-unknown"
|
||||
|
||||
define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1024 x double]* %arg6, [1056 x double]* %arg7) #0 {
|
||||
bb:
|
||||
br label %bb8
|
||||
|
||||
bb8: ; preds = %bb39, %bb
|
||||
%tmp = phi i32 [ 0, %bb ], [ %tmp40, %bb39 ]
|
||||
%tmp9 = icmp slt i32 %tmp, 1056
|
||||
br i1 %tmp9, label %bb10, label %bb41
|
||||
|
||||
bb10: ; preds = %bb8
|
||||
br label %bb11
|
||||
|
||||
bb11: ; preds = %bb37, %bb10
|
||||
%tmp12 = phi i32 [ 0, %bb10 ], [ %tmp38, %bb37 ]
|
||||
%tmp13 = icmp slt i32 %tmp12, 1056
|
||||
br i1 %tmp13, label %bb14, label %bb39
|
||||
|
||||
bb14: ; preds = %bb11
|
||||
%tmp15 = sext i32 %tmp12 to i64
|
||||
%tmp16 = sext i32 %tmp to i64
|
||||
%tmp17 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp16
|
||||
%tmp18 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp17, i64 0, i64 %tmp15
|
||||
%tmp19 = load double, double* %tmp18, align 8
|
||||
%tmp20 = fmul double %tmp19, %arg4
|
||||
store double %tmp20, double* %tmp18, align 8
|
||||
br label %bb21
|
||||
|
||||
bb21: ; preds = %bb24, %bb14
|
||||
%tmp22 = phi i32 [ 0, %bb14 ], [ %tmp36, %bb24 ]
|
||||
%tmp23 = icmp slt i32 %tmp22, 1024
|
||||
br i1 %tmp23, label %bb24, label %bb37
|
||||
|
||||
bb24: ; preds = %bb21
|
||||
%tmp25 = sext i32 %tmp22 to i64
|
||||
%tmp26 = getelementptr inbounds [1024 x double], [1024 x double]* %arg6, i64 %tmp16
|
||||
%tmp27 = getelementptr inbounds [1024 x double], [1024 x double]* %tmp26, i64 0, i64 %tmp25
|
||||
%tmp28 = load double, double* %tmp27, align 8
|
||||
%tmp29 = fmul double %arg3, %tmp28
|
||||
%tmp30 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp25
|
||||
%tmp31 = getelementptr inbounds [1056 x double], [1056 x double]* %tmp30, i64 0, i64 %tmp15
|
||||
%tmp32 = load double, double* %tmp31, align 8
|
||||
%tmp33 = fmul double %tmp29, %tmp32
|
||||
%tmp34 = load double, double* %tmp18, align 8
|
||||
%tmp35 = fadd double %tmp34, %tmp33
|
||||
store double %tmp35, double* %tmp18, align 8
|
||||
%tmp36 = add nsw i32 %tmp22, 1
|
||||
br label %bb21
|
||||
|
||||
bb37: ; preds = %bb21
|
||||
%tmp38 = add nsw i32 %tmp12, 1
|
||||
br label %bb11
|
||||
|
||||
bb39: ; preds = %bb11
|
||||
%tmp40 = add nsw i32 %tmp, 1
|
||||
br label %bb8
|
||||
|
||||
bb41: ; preds = %bb8
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { nounwind uwtable "target-cpu"="x86-64" "target-features"="+aes,+avx,+cmov,+cx16,+fxsr,+mmx,+pclmul,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt" }
|
Loading…
Reference in New Issue