[mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors

This revision builds a simple "fused pass" consisting of 2 levels of tiling, memory promotion and vectorization using linalg transformations written as composable pattern rewrites.
This commit is contained in:
Nicolas Vasilache 2020-04-08 14:53:37 -04:00
parent c6e917d2d3
commit 6fb6a4d7f9
6 changed files with 118 additions and 0 deletions

View File

@ -0,0 +1,16 @@
// RUN: mlir-opt %s -linalg-matmul-to-vector | FileCheck %s
func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} :
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
return
}
// CHECK-LABEL:func @matmul_perm
// CHECK: vector.contract
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>

View File

@ -5,3 +5,7 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)

View File

@ -0,0 +1,43 @@
//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Linalg transformations
// tests.
//
//===----------------------------------------------------------------------===//
#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
include "mlir/Dialect/Vector/VectorTransformPatterns.td"
//===----------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>),
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>),
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(PromoteSubviewsLinalgOp),
[(Constraint<HasOperandsOfType<"SubViewOp">>),
(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(MatmulOp:$op $_, $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
HasLinalgTransformMarker<"L1__with_perm__">,
PreconditionVectorizeLinalgOp]>>)]>;
#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS

View File

@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms
TestGpuMemoryPromotion.cpp
TestGpuParallelLoopMapping.cpp
TestInlining.cpp
TestLinalgMatmulToVector.cpp
TestLinalgTransforms.cpp
TestLiveness.cpp
TestLoopMapping.cpp
@ -24,6 +25,7 @@ add_llvm_library(MLIRTestTransforms
DEPENDS
MLIRStandardOpsIncGen
MLIRTestLinalgMatmulToVectorPatternsIncGen
MLIRTestLinalgTransformPatternsIncGen
MLIRTestVectorTransformPatternsIncGen
)

View File

@ -0,0 +1,51 @@
//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::vector;
namespace {
#include "TestLinalgMatmulToVectorPatterns.h.inc"
struct DeclarativeTransforms
: public PassWrapper<DeclarativeTransforms, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *context = &getContext();
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
AffineMinOp::getCanonicalizationPatterns(patterns, context);
AffineMaxOp::getCanonicalizationPatterns(patterns, context);
AllocOp::getCanonicalizationPatterns(patterns, context);
SubViewOp::getCanonicalizationPatterns(patterns, context);
ViewOp::getCanonicalizationPatterns(patterns, context);
populateWithGenerated(context, &patterns);
applyPatternsGreedily(getFunction(), patterns);
}
};
} // end anonymous namespace
namespace mlir {
void registerTestLinalgMatmulToVectorPass() {
PassRegistration<DeclarativeTransforms> pass(
"linalg-matmul-to-vector",
"Test declarative transform patterns for matmul 3-D tiling + promotion"
" + vectorization");
}
} // namespace mlir

View File

@ -39,6 +39,7 @@ void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAllReduceLoweringPass();
void registerTestLinalgMatmulToVectorPass();
void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
void registerTestConstantFold();
@ -101,6 +102,7 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();
registerTestLinalgMatmulToVectorPass();
registerTestLoopPermutationPass();
registerTestCallGraphPass();
registerTestConstantFold();