From 6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 8 Apr 2020 14:53:37 -0400 Subject: [PATCH] [mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors This revision builds a simple "fused pass" consisting of 2 levels of tiling, memory promotion and vectorization using linalg transformations written as composable pattern rewrites. --- .../test/Dialect/Linalg/matmul-to-vector.mlir | 16 ++++++ .../lib/DeclarativeTransforms/CMakeLists.txt | 4 ++ .../TestLinalgMatmulToVectorPatterns.td | 43 ++++++++++++++++ mlir/test/lib/Transforms/CMakeLists.txt | 2 + .../Transforms/TestLinalgMatmulToVector.cpp | 51 +++++++++++++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 6 files changed, 118 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/matmul-to-vector.mlir create mode 100644 mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td create mode 100644 mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp diff --git a/mlir/test/Dialect/Linalg/matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir new file mode 100644 index 000000000000..351b2041d8c0 --- /dev/null +++ b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -linalg-matmul-to-vector | FileCheck %s + +func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { + linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} : + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]> + return +} + +// CHECK-LABEL:func @matmul_perm +// CHECK: vector.contract +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt index 9672edb4c493..f06854289abb 100644 --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -5,3 +5,7 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen) set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) + +set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td) +mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td new file mode 100644 index 000000000000..7fa4a3db6128 --- /dev/null +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td @@ -0,0 +1,43 @@ +//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the pattern definition file for declarative Linalg transformations +// tests. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS +#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS + +include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" +include "mlir/Dialect/Vector/VectorTransformPatterns.td" + +//===----------------------------------------------------------------------===// +// Linalg tiling and permutation patterns. +//===----------------------------------------------------------------------===// +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSubviewsLinalgOp), + [(Constraint>), + (Constraint>)]>; + +//===----------------------------------------------------------------------===// +// Linalg to vector contraction patterns. +//===----------------------------------------------------------------------===// +def : Pattern<(MatmulOp:$op $_, $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp]>>)]>; + +#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 904a47221ac1..23107f223b9c 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp TestInlining.cpp + TestLinalgMatmulToVector.cpp TestLinalgTransforms.cpp TestLiveness.cpp TestLoopMapping.cpp @@ -24,6 +25,7 @@ add_llvm_library(MLIRTestTransforms DEPENDS MLIRStandardOpsIncGen + MLIRTestLinalgMatmulToVectorPatternsIncGen MLIRTestLinalgTransformPatternsIncGen MLIRTestVectorTransformPatternsIncGen ) diff --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp new file mode 100644 index 000000000000..6f49fabc192a --- /dev/null +++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp @@ -0,0 +1,51 @@ +//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::vector; + +namespace { +#include "TestLinalgMatmulToVectorPatterns.h.inc" + +struct DeclarativeTransforms + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + auto *context = &getContext(); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + AffineMinOp::getCanonicalizationPatterns(patterns, context); + AffineMaxOp::getCanonicalizationPatterns(patterns, context); + AllocOp::getCanonicalizationPatterns(patterns, context); + SubViewOp::getCanonicalizationPatterns(patterns, context); + ViewOp::getCanonicalizationPatterns(patterns, context); + populateWithGenerated(context, &patterns); + applyPatternsGreedily(getFunction(), patterns); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestLinalgMatmulToVectorPass() { + PassRegistration pass( + "linalg-matmul-to-vector", + "Test declarative transform patterns for matmul 3-D tiling + promotion" + " + vectorization"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index e8b2f3dc49f5..50a929616f27 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -39,6 +39,7 @@ void registerSimpleParametricTilingPass(); void registerSymbolTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAllReduceLoweringPass(); +void registerTestLinalgMatmulToVectorPass(); void registerTestLoopPermutationPass(); void registerTestCallGraphPass(); void registerTestConstantFold(); @@ -101,6 +102,7 @@ void registerTestPasses() { registerSymbolTestPasses(); registerTestAffineDataCopyPass(); registerTestAllReduceLoweringPass(); + registerTestLinalgMatmulToVectorPass(); registerTestLoopPermutationPass(); registerTestCallGraphPass(); registerTestConstantFold();