Uniformize Vector transforms as patterns on the model of Linalg - NFC

This reorganizes the vector transformations to be more easily testable as patterns and more easily composable into fused passes in the future.

PiperOrigin-RevId: 284817474
This commit is contained in:
Nicolas Vasilache 2019-12-10 11:54:00 -08:00 committed by A. Unique TensorFlower
parent 8ccb350979
commit ad38e49806
10 changed files with 71 additions and 49 deletions

View File

@ -19,30 +19,17 @@
//
//===----------------------------------------------------------------------===//
#ifndef VECTOR_TRANSFORMS
#define VECTOR_TRANSFORMS
#ifndef VECTOR_TRANSFORM_PATTERNS
#define VECTOR_TRANSFORM_PATTERNS
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/VectorOps/VectorOps.td"
include "mlir/IR/OpBase.td"
class HasShape<list<int> shape> :
CPred<"hasShape($0, {" # StrJoinInt<shape>.result # "})">;
CPred<"$0->getType().cast<ShapedType>().hasStaticShape({" #
StrJoinInt<shape>.result # "})">;
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
"unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " #
"{" # StrJoinInt<factors>.result # "})">;
def : Pat<(AddFOp:$op_results $a, $b),
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
[(Constraint<HasShape<[4, 2]>> $a)]>;
def : Pat<(AddFOp:$op_results $a, $b),
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
[(Constraint<HasShape<[4, 4]>> $a)]>;
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
[(Constraint<HasShape<[4, 4]>> $c)]>;
#endif // VECTOR_TRANSFORMS
#endif // VECTOR_TRANSFORM_PATTERNS

View File

@ -220,6 +220,9 @@ public:
/// has static shape.
bool hasStaticShape() const;
/// If this has a static shape and the shape is equal to `shape` return true.
bool hasStaticShape(ArrayRef<int64_t> shape) const;
/// If this is a ranked type, return the number of dimensions with dynamic
/// size. Otherwise, abort.
int64_t getNumDynamicDims() const;

View File

@ -1,7 +1,7 @@
add_llvm_library(MLIRVectorOps
DialectRegistration.cpp
VectorOps.cpp
VectorToVector.cpp
VectorTransforms.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps

View File

@ -113,14 +113,6 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
return builder.createOperation(res);
}
// Helper function for Tablegen.
static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
auto t = v->getType().dyn_cast<ShapedType>();
if (!t)
return false;
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
}
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
VectorType vt) {
auto t = vt.getElementType();
@ -454,18 +446,3 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
resultIndex, targetShape, builder);
}
namespace mlir {
namespace vector {
namespace {
#include "mlir/Dialect/VectorOps/VectorTransformPatterns.h.inc"
} // end namespace
} // end namespace vector
} // end namespace mlir
void mlir::populateVectorToVectorConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
vector::populateWithGenerated(context, &patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
}

View File

@ -197,6 +197,10 @@ bool ShapedType::hasStaticShape() const {
return hasRank() && llvm::none_of(getShape(), isDynamic);
}
bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
return hasStaticShape() && getShape() == shape;
}
//===----------------------------------------------------------------------===//
// VectorType
//===----------------------------------------------------------------------===//

View File

@ -5,3 +5,8 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td)
mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen)
set(LLVM_TARGET_DEFINITIONS TestVectorPatterns.td)
mlir_tablegen(TestVectorPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)

View File

@ -0,0 +1,43 @@
//===- TestVectorTransformPatterns.td - Test patterns ---*- tablegen ----*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the pattern definition file for declarative Vector transformations
// tests.
//
//===----------------------------------------------------------------------===//
#ifndef TEST_VECTOR_TRANSFORMS_PATTERNS
#define TEST_VECTOR_TRANSFORMS_PATTERNS
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/VectorOps/VectorOps.td"
include "mlir/Dialect/VectorOps/VectorTransformPatterns.td"
def : Pat<(AddFOp:$op_results $a, $b),
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
[(Constraint<HasShape<[4, 2]>> $a)]>;
def : Pat<(AddFOp:$op_results $a, $b),
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
[(Constraint<HasShape<[4, 4]>> $a)]>;
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
[(Constraint<HasShape<[4, 4]>> $c)]>;
#endif // TEST_VECTOR_TRANSFORMS_PATTERNS

View File

@ -10,7 +10,7 @@ add_llvm_library(MLIRTestTransforms
TestOpaqueLoc.cpp
TestMemRefStrideCalculation.cpp
TestVectorToLoopsConversion.cpp
TestVectorToVectorConversion.cpp
TestVectorTransforms.cpp
TestVectorizationUtils.cpp
ADDITIONAL_HEADER_DIRS
@ -23,6 +23,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/../DeclarativeTransforms)
add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen)
add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen)
add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen)
add_dependencies(MLIRTestTransforms MLIRTestVectorTransformPatternsIncGen)
target_link_libraries(MLIRTestTransforms
MLIRAffineOps
MLIRAnalysis

View File

@ -1,5 +1,4 @@
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering
//-------===//
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
//
// Copyright 2019 The MLIR Authors.
//
@ -18,25 +17,28 @@
#include <type_traits>
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::vector;
namespace {
#include "TestVectorTransformPatterns.h.inc"
struct TestVectorToVectorConversion
: public FunctionPass<TestVectorToVectorConversion> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *context = &getContext();
populateVectorToVectorConversionPatterns(context, patterns);
populateWithGenerated(context, &patterns);
populateVectorToVectorCanonicalizationPatterns(patterns, context);
applyPatternsGreedily(getFunction(), patterns);
}
};
} // end anonymous namespace
static PassRegistration<TestVectorToVectorConversion>