forked from OSchip/llvm-project
Uniformize Vector transforms as patterns on the model of Linalg - NFC
This reorganizes the vector transformations to be more easily testable as patterns and more easily composable into fused passes in the future. PiperOrigin-RevId: 284817474
This commit is contained in:
parent
8ccb350979
commit
ad38e49806
|
@ -19,30 +19,17 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef VECTOR_TRANSFORMS
|
#ifndef VECTOR_TRANSFORM_PATTERNS
|
||||||
#define VECTOR_TRANSFORMS
|
#define VECTOR_TRANSFORM_PATTERNS
|
||||||
|
|
||||||
include "mlir/Dialect/StandardOps/Ops.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Dialect/VectorOps/VectorOps.td"
|
|
||||||
|
|
||||||
class HasShape<list<int> shape> :
|
class HasShape<list<int> shape> :
|
||||||
CPred<"hasShape($0, {" # StrJoinInt<shape>.result # "})">;
|
CPred<"$0->getType().cast<ShapedType>().hasStaticShape({" #
|
||||||
|
StrJoinInt<shape>.result # "})">;
|
||||||
|
|
||||||
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
|
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
|
||||||
"unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " #
|
"unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " #
|
||||||
"{" # StrJoinInt<factors>.result # "})">;
|
"{" # StrJoinInt<factors>.result # "})">;
|
||||||
|
|
||||||
def : Pat<(AddFOp:$op_results $a, $b),
|
#endif // VECTOR_TRANSFORM_PATTERNS
|
||||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
|
||||||
[(Constraint<HasShape<[4, 2]>> $a)]>;
|
|
||||||
|
|
||||||
def : Pat<(AddFOp:$op_results $a, $b),
|
|
||||||
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
|
||||||
[(Constraint<HasShape<[4, 4]>> $a)]>;
|
|
||||||
|
|
||||||
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
|
|
||||||
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
|
|
||||||
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
|
|
||||||
[(Constraint<HasShape<[4, 4]>> $c)]>;
|
|
||||||
|
|
||||||
#endif // VECTOR_TRANSFORMS
|
|
||||||
|
|
|
@ -220,6 +220,9 @@ public:
|
||||||
/// has static shape.
|
/// has static shape.
|
||||||
bool hasStaticShape() const;
|
bool hasStaticShape() const;
|
||||||
|
|
||||||
|
/// If this has a static shape and the shape is equal to `shape` return true.
|
||||||
|
bool hasStaticShape(ArrayRef<int64_t> shape) const;
|
||||||
|
|
||||||
/// If this is a ranked type, return the number of dimensions with dynamic
|
/// If this is a ranked type, return the number of dimensions with dynamic
|
||||||
/// size. Otherwise, abort.
|
/// size. Otherwise, abort.
|
||||||
int64_t getNumDynamicDims() const;
|
int64_t getNumDynamicDims() const;
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
add_llvm_library(MLIRVectorOps
|
add_llvm_library(MLIRVectorOps
|
||||||
DialectRegistration.cpp
|
DialectRegistration.cpp
|
||||||
VectorOps.cpp
|
VectorOps.cpp
|
||||||
VectorToVector.cpp
|
VectorTransforms.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps
|
||||||
|
|
|
@ -113,14 +113,6 @@ static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
|
||||||
return builder.createOperation(res);
|
return builder.createOperation(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function for Tablegen.
|
|
||||||
static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
|
|
||||||
auto t = v->getType().dyn_cast<ShapedType>();
|
|
||||||
if (!t)
|
|
||||||
return false;
|
|
||||||
return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
|
static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
|
||||||
VectorType vt) {
|
VectorType vt) {
|
||||||
auto t = vt.getElementType();
|
auto t = vt.getElementType();
|
||||||
|
@ -454,18 +446,3 @@ Value *mlir::vector::unrollSingleResultOpMatchingType(
|
||||||
return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
|
return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
|
||||||
resultIndex, targetShape, builder);
|
resultIndex, targetShape, builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace vector {
|
|
||||||
namespace {
|
|
||||||
#include "mlir/Dialect/VectorOps/VectorTransformPatterns.h.inc"
|
|
||||||
} // end namespace
|
|
||||||
} // end namespace vector
|
|
||||||
} // end namespace mlir
|
|
||||||
|
|
||||||
void mlir::populateVectorToVectorConversionPatterns(
|
|
||||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
|
||||||
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
|
|
||||||
vector::populateWithGenerated(context, &patterns);
|
|
||||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
|
||||||
}
|
|
|
@ -197,6 +197,10 @@ bool ShapedType::hasStaticShape() const {
|
||||||
return hasRank() && llvm::none_of(getShape(), isDynamic);
|
return hasRank() && llvm::none_of(getShape(), isDynamic);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
|
||||||
|
return hasStaticShape() && getShape() == shape;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// VectorType
|
// VectorType
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -5,3 +5,8 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
|
||||||
set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td)
|
set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td)
|
||||||
mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters)
|
mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen)
|
add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen)
|
||||||
|
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS TestVectorPatterns.td)
|
||||||
|
mlir_tablegen(TestVectorPatterns.h.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
//===- TestVectorTransformPatterns.td - Test patterns ---*- tablegen ----*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This is the pattern definition file for declarative Vector transformations
|
||||||
|
// tests.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TEST_VECTOR_TRANSFORMS_PATTERNS
|
||||||
|
#define TEST_VECTOR_TRANSFORMS_PATTERNS
|
||||||
|
|
||||||
|
include "mlir/Dialect/StandardOps/Ops.td"
|
||||||
|
include "mlir/Dialect/VectorOps/VectorOps.td"
|
||||||
|
include "mlir/Dialect/VectorOps/VectorTransformPatterns.td"
|
||||||
|
|
||||||
|
def : Pat<(AddFOp:$op_results $a, $b),
|
||||||
|
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||||
|
[(Constraint<HasShape<[4, 2]>> $a)]>;
|
||||||
|
|
||||||
|
def : Pat<(AddFOp:$op_results $a, $b),
|
||||||
|
(UnrollVectorOp<[2, 2]> $op_results, $a, $b),
|
||||||
|
[(Constraint<HasShape<[4, 4]>> $a)]>;
|
||||||
|
|
||||||
|
// TODO(andydavis) Add Constraints on lhs/rhs shapes.
|
||||||
|
def : Pat<(Vector_ContractionOp:$op_results $a, $b, $c, $masks, $attr0, $attr1),
|
||||||
|
(UnrollVectorOp<[2, 2, 2]> $op_results, $a, $b, $c),
|
||||||
|
[(Constraint<HasShape<[4, 4]>> $c)]>;
|
||||||
|
|
||||||
|
#endif // TEST_VECTOR_TRANSFORMS_PATTERNS
|
|
@ -10,7 +10,7 @@ add_llvm_library(MLIRTestTransforms
|
||||||
TestOpaqueLoc.cpp
|
TestOpaqueLoc.cpp
|
||||||
TestMemRefStrideCalculation.cpp
|
TestMemRefStrideCalculation.cpp
|
||||||
TestVectorToLoopsConversion.cpp
|
TestVectorToLoopsConversion.cpp
|
||||||
TestVectorToVectorConversion.cpp
|
TestVectorTransforms.cpp
|
||||||
TestVectorizationUtils.cpp
|
TestVectorizationUtils.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
@ -23,6 +23,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/../DeclarativeTransforms)
|
||||||
add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen)
|
add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen)
|
||||||
add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen)
|
add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen)
|
||||||
add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen)
|
add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen)
|
||||||
|
add_dependencies(MLIRTestTransforms MLIRTestVectorTransformPatternsIncGen)
|
||||||
target_link_libraries(MLIRTestTransforms
|
target_link_libraries(MLIRTestTransforms
|
||||||
MLIRAffineOps
|
MLIRAffineOps
|
||||||
MLIRAnalysis
|
MLIRAnalysis
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering
|
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
|
||||||
//-------===//
|
|
||||||
//
|
//
|
||||||
// Copyright 2019 The MLIR Authors.
|
// Copyright 2019 The MLIR Authors.
|
||||||
//
|
//
|
||||||
|
@ -18,25 +17,28 @@
|
||||||
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/Dialect/VectorOps/VectorOps.h"
|
||||||
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
|
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace mlir::vector;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
#include "TestVectorTransformPatterns.h.inc"
|
||||||
|
|
||||||
struct TestVectorToVectorConversion
|
struct TestVectorToVectorConversion
|
||||||
: public FunctionPass<TestVectorToVectorConversion> {
|
: public FunctionPass<TestVectorToVectorConversion> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
populateVectorToVectorConversionPatterns(context, patterns);
|
populateWithGenerated(context, &patterns);
|
||||||
|
populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
||||||
applyPatternsGreedily(getFunction(), patterns);
|
applyPatternsGreedily(getFunction(), patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
static PassRegistration<TestVectorToVectorConversion>
|
static PassRegistration<TestVectorToVectorConversion>
|
Loading…
Reference in New Issue