forked from OSchip/llvm-project
Uniformize Vector transforms as patterns on the model of Linalg - NFC
This reorganizes the vector transformations to be more easily testable as patterns and more easily composable into fused passes in the future. PiperOrigin-RevId: 284817474
This commit is contained in:
parent
8ccb350979
commit
ad38e49806
|
@ -19,30 +19,17 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef VECTOR_TRANSFORMS
|
||||
#define VECTOR_TRANSFORMS
|
||||
#ifndef VECTOR_TRANSFORM_PATTERNS
|
||||
#define VECTOR_TRANSFORM_PATTERNS
|
||||
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/VectorOps/VectorOps.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
class HasShape<list<int> shape> :
|
||||
CPred<"hasShape($0, {" # StrJoinInt<shape>.result # "})">;
|
||||
CPred<"$0->getType().cast<ShapedType>().hasStaticShape({" #
|
||||
StrJoinInt<shape>.result # "})">;
|
||||
|
||||
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
|
||||
"unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " #
|
||||
"{" # StrJoinInt<factors>.result # "})">;
|
||||
|
||||
def : Pat<(AddFOp:$op_results $a, $b),
|
||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||
[(Constraint<HasShape<[4, 2]>> $a)]>;
|
||||
|
||||
def : Pat<(AddFOp:$op_results $a, $b),
|
||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||
[(Constraint<HasShape<[4, 4]>> $a)]>;
|
||||
|
||||
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
|
||||
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
|
||||
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
|
||||
[(Constraint<HasShape<[4, 4]>> $c)]>;
|
||||
|
||||
#endif // VECTOR_TRANSFORMS
|
||||
#endif // VECTOR_TRANSFORM_PATTERNS
|
||||
|
|
|
@ -220,6 +220,9 @@ public:
|
|||
/// has static shape.
|
||||
bool hasStaticShape() const;
|
||||
|
||||
/// If this has a static shape and the shape is equal to `shape` return true.
|
||||
bool hasStaticShape(ArrayRef<int64_t> shape) const;
|
||||
|
||||
/// If this is a ranked type, return the number of dimensions with dynamic
|
||||
/// size. Otherwise, abort.
|
||||
int64_t getNumDynamicDims() const;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
add_llvm_library(MLIRVectorOps
|
||||
DialectRegistration.cpp
|
||||
VectorOps.cpp
|
||||
VectorToVector.cpp
|
||||
VectorTransforms.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps
|
||||
|
|
|
@ -113,14 +113,6 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
|
|||
return builder.createOperation(res);
|
||||
}
|
||||
|
||||
// Helper function for Tablegen.
|
||||
static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
|
||||
auto t = v->getType().dyn_cast<ShapedType>();
|
||||
if (!t)
|
||||
return false;
|
||||
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
|
||||
}
|
||||
|
||||
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
|
||||
VectorType vt) {
|
||||
auto t = vt.getElementType();
|
||||
|
@ -454,18 +446,3 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
|
|||
return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
|
||||
resultIndex, targetShape, builder);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace vector {
|
||||
namespace {
|
||||
#include "mlir/Dialect/VectorOps/VectorTransformPatterns.h.inc"
|
||||
} // end namespace
|
||||
} // end namespace vector
|
||||
} // end namespace mlir
|
||||
|
||||
void mlir::populateVectorToVectorConversionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
|
||||
vector::populateWithGenerated(context, &patterns);
|
||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
||||
}
|
|
@ -197,6 +197,10 @@ bool ShapedType::hasStaticShape() const {
|
|||
return hasRank() && llvm::none_of(getShape(), isDynamic);
|
||||
}
|
||||
|
||||
bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
|
||||
return hasStaticShape() && getShape() == shape;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -5,3 +5,8 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
|
|||
set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td)
|
||||
mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen)
|
||||
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestVectorPatterns.td)
|
||||
mlir_tablegen(TestVectorPatterns.h.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
//===- TestVectorTransformPatterns.td - Test patterns ---*- tablegen ----*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This is the pattern definition file for declarative Vector transformations
|
||||
// tests.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TEST_VECTOR_TRANSFORMS_PATTERNS
|
||||
#define TEST_VECTOR_TRANSFORMS_PATTERNS
|
||||
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/VectorOps/VectorOps.td"
|
||||
include "mlir/Dialect/VectorOps/VectorTransformPatterns.td"
|
||||
|
||||
def : Pat<(AddFOp:$op_results $a, $b),
|
||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||
[(Constraint<HasShape<[4, 2]>> $a)]>;
|
||||
|
||||
def : Pat<(AddFOp:$op_results $a, $b),
|
||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||
[(Constraint<HasShape<[4, 4]>> $a)]>;
|
||||
|
||||
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
|
||||
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
|
||||
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
|
||||
[(Constraint<HasShape<[4, 4]>> $c)]>;
|
||||
|
||||
#endif // TEST_VECTOR_TRANSFORMS_PATTERNS
|
|
@ -10,7 +10,7 @@ add_llvm_library(MLIRTestTransforms
|
|||
TestOpaqueLoc.cpp
|
||||
TestMemRefStrideCalculation.cpp
|
||||
TestVectorToLoopsConversion.cpp
|
||||
TestVectorToVectorConversion.cpp
|
||||
TestVectorTransforms.cpp
|
||||
TestVectorizationUtils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
@ -23,6 +23,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/../DeclarativeTransforms)
|
|||
add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen)
|
||||
add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen)
|
||||
add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen)
|
||||
add_dependencies(MLIRTestTransforms MLIRTestVectorTransformPatternsIncGen)
|
||||
target_link_libraries(MLIRTestTransforms
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering
|
||||
//-------===//
|
||||
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -18,25 +17,28 @@
|
|||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
#include "TestVectorTransformPatterns.h.inc"
|
||||
|
||||
struct TestVectorToVectorConversion
|
||||
: public FunctionPass<TestVectorToVectorConversion> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
populateVectorToVectorConversionPatterns(context, patterns);
|
||||
populateWithGenerated(context, &patterns);
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
static PassRegistration<TestVectorToVectorConversion>
|
Loading…
Reference in New Issue