From ad38e4980642a2d9b0add2923454212eac3cd94f Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 10 Dec 2019 11:54:00 -0800 Subject: [PATCH] Uniformize Vector transforms as patterns on the model of Linalg - NFC This reorganizes the vector transformations to be more easily testable as patterns and more easily composable into fused passes in the future. PiperOrigin-RevId: 284817474 --- .../VectorOps/VectorTransformPatterns.td | 25 +++-------- mlir/include/mlir/IR/StandardTypes.h | 3 ++ mlir/lib/Dialect/VectorOps/CMakeLists.txt | 2 +- ...ectorToVector.cpp => VectorTransforms.cpp} | 23 ---------- mlir/lib/IR/StandardTypes.cpp | 4 ++ .../VectorOps/vector-transforms.mlir} | 0 .../lib/DeclarativeTransforms/CMakeLists.txt | 5 +++ .../TestVectorTransformPatterns.td | 43 +++++++++++++++++++ mlir/test/lib/Transforms/CMakeLists.txt | 3 +- ...onversion.cpp => TestVectorTransforms.cpp} | 12 +++--- 10 files changed, 71 insertions(+), 49 deletions(-) rename mlir/lib/Dialect/VectorOps/{VectorToVector.cpp => VectorTransforms.cpp} (96%) rename mlir/test/{Conversion/VectorConversions/vector-to-vector.mlir => Dialect/VectorOps/vector-transforms.mlir} (100%) create mode 100644 mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td rename mlir/test/lib/Transforms/{TestVectorToVectorConversion.cpp => TestVectorTransforms.cpp} (82%) diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td index e71679620d6d..86ff9b505d54 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td @@ -19,30 +19,17 @@ // //===----------------------------------------------------------------------===// -#ifndef VECTOR_TRANSFORMS -#define VECTOR_TRANSFORMS +#ifndef VECTOR_TRANSFORM_PATTERNS +#define VECTOR_TRANSFORM_PATTERNS -include "mlir/Dialect/StandardOps/Ops.td" -include "mlir/Dialect/VectorOps/VectorOps.td" +include "mlir/IR/OpBase.td" class HasShape shape> : - CPred<"hasShape($0, {" # StrJoinInt.result # "})">; + CPred<"$0->getType().cast().hasStaticShape({" # + StrJoinInt.result # "})">; class UnrollVectorOp factors> : NativeCodeCall< "unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " # "{" # StrJoinInt.result # "})">; -def : Pat<(AddFOp:$op_results $a, $b), - (UnrollVectorOp<[2, 2]> $op_results, $a, $b), - [(Constraint> $a)]>; - -def : Pat<(AddFOp:$op_results $a, $b), - (UnrollVectorOp<[2, 2]> $op_results, $a, $b), - [(Constraint> $a)]>; - -// TODO(andydavis) Add Constraints on lhs/rhs shapes. -def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1), - (UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c), - [(Constraint> $c)]>; - -#endif // VECTOR_TRANSFORMS +#endif // VECTOR_TRANSFORM_PATTERNS diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 23a1ff2177ea..5634f86254fb 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -220,6 +220,9 @@ public: /// has static shape. bool hasStaticShape() const; + /// If this has a static shape and the shape is equal to `shape` return true. + bool hasStaticShape(ArrayRef shape) const; + /// If this is a ranked type, return the number of dimensions with dynamic /// size. Otherwise, abort. int64_t getNumDynamicDims() const; diff --git a/mlir/lib/Dialect/VectorOps/CMakeLists.txt b/mlir/lib/Dialect/VectorOps/CMakeLists.txt index 754e62de14ec..08d58404b718 100644 --- a/mlir/lib/Dialect/VectorOps/CMakeLists.txt +++ b/mlir/lib/Dialect/VectorOps/CMakeLists.txt @@ -1,7 +1,7 @@ add_llvm_library(MLIRVectorOps DialectRegistration.cpp VectorOps.cpp - VectorToVector.cpp + VectorTransforms.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps diff --git a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp similarity index 96% rename from mlir/lib/Dialect/VectorOps/VectorToVector.cpp rename to mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 1acac63602c8..6b13bcf75ca1 100644 --- a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -113,14 +113,6 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, return builder.createOperation(res); } -// Helper function for Tablegen. -static bool hasShape(Value *v, ArrayRef shape) { - auto t = v->getType().dyn_cast(); - if (!t) - return false; - return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin()); -} - static Value *makeSplatZero(Location loc, PatternRewriter &rewriter, VectorType vt) { auto t = vt.getElementType(); @@ -454,18 +446,3 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( return unrollSingleResultStructuredOp(op, iterationBounds, vectors, resultIndex, targetShape, builder); } - -namespace mlir { -namespace vector { -namespace { -#include "mlir/Dialect/VectorOps/VectorTransformPatterns.h.inc" -} // end namespace -} // end namespace vector -} // end namespace mlir - -void mlir::populateVectorToVectorConversionPatterns( - MLIRContext *context, OwningRewritePatternList &patterns, - ArrayRef coarseVectorShape, ArrayRef fineVectorShape) { - vector::populateWithGenerated(context, &patterns); - vector::populateVectorToVectorCanonicalizationPatterns(patterns, context); -} diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 7054f6d5ca85..8a47c5b0b41c 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -197,6 +197,10 @@ bool ShapedType::hasStaticShape() const { return hasRank() && llvm::none_of(getShape(), isDynamic); } +bool ShapedType::hasStaticShape(ArrayRef shape) const { + return hasStaticShape() && getShape() == shape; +} + //===----------------------------------------------------------------------===// // VectorType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorConversions/vector-to-vector.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir similarity index 100% rename from mlir/test/Conversion/VectorConversions/vector-to-vector.mlir rename to mlir/test/Dialect/VectorOps/vector-transforms.mlir diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt index 1ee62d82129f..7cddcd65d02b 100644 --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -5,3 +5,8 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen) set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td) mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TestVectorPatterns.td) +mlir_tablegen(TestVectorPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td new file mode 100644 index 000000000000..228a8a018d6e --- /dev/null +++ b/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td @@ -0,0 +1,43 @@ +//===- TestVectorTransformPatterns.td - Test patterns ---*- tablegen ----*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the pattern definition file for declarative Vector transformations +// tests. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_VECTOR_TRANSFORMS_PATTERNS +#define TEST_VECTOR_TRANSFORMS_PATTERNS + +include "mlir/Dialect/StandardOps/Ops.td" +include "mlir/Dialect/VectorOps/VectorOps.td" +include "mlir/Dialect/VectorOps/VectorTransformPatterns.td" + +def : Pat<(AddFOp:$op_results $a, $b), + (UnrollVectorOp<[2, 2]> $op_results, $a, $b), + [(Constraint> $a)]>; + +def : Pat<(AddFOp:$op_results $a, $b), + (UnrollVectorOp<[2, 2]> $op_results, $a, $b), + [(Constraint> $a)]>; + +// TODO(andydavis) Add Constraints on lhs/rhs shapes. +def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1), + (UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c), + [(Constraint> $c)]>; + +#endif // TEST_VECTOR_TRANSFORMS_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 8a7933451b88..11d27483dc6c 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -10,7 +10,7 @@ add_llvm_library(MLIRTestTransforms TestOpaqueLoc.cpp TestMemRefStrideCalculation.cpp TestVectorToLoopsConversion.cpp - TestVectorToVectorConversion.cpp + TestVectorTransforms.cpp TestVectorizationUtils.cpp ADDITIONAL_HEADER_DIRS @@ -23,6 +23,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/../DeclarativeTransforms) add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen) add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen) add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen) +add_dependencies(MLIRTestTransforms MLIRTestVectorTransformPatternsIncGen) target_link_libraries(MLIRTestTransforms MLIRAffineOps MLIRAnalysis diff --git a/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp similarity index 82% rename from mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp rename to mlir/test/lib/Transforms/TestVectorTransforms.cpp index 9f9b8a554fe2..909fe2afba68 100644 --- a/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -1,5 +1,4 @@ -//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering -//-------===// +//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// // // Copyright 2019 The MLIR Authors. // @@ -18,25 +17,28 @@ #include +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/Dialect/VectorOps/VectorTransforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Passes.h" using namespace mlir; +using namespace mlir::vector; namespace { +#include "TestVectorTransformPatterns.h.inc" struct TestVectorToVectorConversion : public FunctionPass { void runOnFunction() override { OwningRewritePatternList patterns; auto *context = &getContext(); - populateVectorToVectorConversionPatterns(context, patterns); + populateWithGenerated(context, &patterns); + populateVectorToVectorCanonicalizationPatterns(patterns, context); applyPatternsGreedily(getFunction(), patterns); } }; - } // end anonymous namespace static PassRegistration