forked from OSchip/llvm-project
[mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors
This revision builds a simple "fused pass" consisting of 2 levels of tiling, memory promotion and vectorization using linalg transformations written as composable pattern rewrites.
This commit is contained in:
parent
c6e917d2d3
commit
6fb6a4d7f9
|
@ -0,0 +1,16 @@
|
|||
// RUN: mlir-opt %s -linalg-matmul-to-vector | FileCheck %s
|
||||
|
||||
func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
|
||||
linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} :
|
||||
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
||||
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL:func @matmul_perm
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
|
|
@ -5,3 +5,7 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
|
|||
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
|
||||
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
|
||||
mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the pattern definition file for declarative Linalg transformations
|
||||
// tests.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
|
||||
#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
|
||||
|
||||
include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
|
||||
include "mlir/Dialect/Vector/VectorTransformPatterns.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg tiling and permutation patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>),
|
||||
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>),
|
||||
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
|
||||
def : Pat<(MatmulOp:$op $_, $_, $_),
|
||||
(PromoteSubviewsLinalgOp),
|
||||
[(Constraint<HasOperandsOfType<"SubViewOp">>),
|
||||
(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector contraction patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pattern<(MatmulOp:$op $_, $_, $_),
|
||||
[(VectorizeLinalgOp)],
|
||||
[(Constraint<And<[
|
||||
HasLinalgTransformMarker<"L1__with_perm__">,
|
||||
PreconditionVectorizeLinalgOp]>>)]>;
|
||||
|
||||
#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
|
|
@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms
|
|||
TestGpuMemoryPromotion.cpp
|
||||
TestGpuParallelLoopMapping.cpp
|
||||
TestInlining.cpp
|
||||
TestLinalgMatmulToVector.cpp
|
||||
TestLinalgTransforms.cpp
|
||||
TestLiveness.cpp
|
||||
TestLoopMapping.cpp
|
||||
|
@ -24,6 +25,7 @@ add_llvm_library(MLIRTestTransforms
|
|||
|
||||
DEPENDS
|
||||
MLIRStandardOpsIncGen
|
||||
MLIRTestLinalgMatmulToVectorPatternsIncGen
|
||||
MLIRTestLinalgTransformPatternsIncGen
|
||||
MLIRTestVectorTransformPatternsIncGen
|
||||
)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
#include "TestLinalgMatmulToVectorPatterns.h.inc"
|
||||
|
||||
struct DeclarativeTransforms
|
||||
: public PassWrapper<DeclarativeTransforms, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineMinOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineMaxOp::getCanonicalizationPatterns(patterns, context);
|
||||
AllocOp::getCanonicalizationPatterns(patterns, context);
|
||||
SubViewOp::getCanonicalizationPatterns(patterns, context);
|
||||
ViewOp::getCanonicalizationPatterns(patterns, context);
|
||||
populateWithGenerated(context, &patterns);
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
void registerTestLinalgMatmulToVectorPass() {
|
||||
PassRegistration<DeclarativeTransforms> pass(
|
||||
"linalg-matmul-to-vector",
|
||||
"Test declarative transform patterns for matmul 3-D tiling + promotion"
|
||||
" + vectorization");
|
||||
}
|
||||
} // namespace mlir
|
|
@ -39,6 +39,7 @@ void registerSimpleParametricTilingPass();
|
|||
void registerSymbolTestPasses();
|
||||
void registerTestAffineDataCopyPass();
|
||||
void registerTestAllReduceLoweringPass();
|
||||
void registerTestLinalgMatmulToVectorPass();
|
||||
void registerTestLoopPermutationPass();
|
||||
void registerTestCallGraphPass();
|
||||
void registerTestConstantFold();
|
||||
|
@ -101,6 +102,7 @@ void registerTestPasses() {
|
|||
registerSymbolTestPasses();
|
||||
registerTestAffineDataCopyPass();
|
||||
registerTestAllReduceLoweringPass();
|
||||
registerTestLinalgMatmulToVectorPass();
|
||||
registerTestLoopPermutationPass();
|
||||
registerTestCallGraphPass();
|
||||
registerTestConstantFold();
|
||||
|
|
Loading…
Reference in New Issue