forked from OSchip/llvm-project
[mlir][Vector] NFC - Add option to hook vector.transpose lowering to strategies.
This revision also moves some code around to improve overall structure. Differential Revision: https://reviews.llvm.org/D112437
This commit is contained in:
parent
3b1165ba3d
commit
d054b80bd3
|
@ -15,7 +15,7 @@
|
|||
#include "mlir/Dialect/SCF/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
|
@ -846,6 +846,9 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
|
|||
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformation and lowering options exposed as auxiliary structs.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Options to control the application of enabling transformations.
|
||||
/// Hoisting transformations are always deemed beneficial and must be disabled
|
||||
/// explicitly.
|
||||
|
@ -887,16 +890,10 @@ struct LinalgVectorLoweringOptions {
|
|||
transferLowering = val;
|
||||
return *this;
|
||||
}
|
||||
/// Trigger full / partial vector.transfer splits.
|
||||
bool transferPartialRewrite = false;
|
||||
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
|
||||
transferPartialRewrite = val;
|
||||
return *this;
|
||||
}
|
||||
/// Enable lowering of vector.contract.
|
||||
bool contractionLowering = false;
|
||||
LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
|
||||
contractionLowering = val;
|
||||
/// Enable lowering of vector.transpose.
|
||||
bool transposeLowering = false;
|
||||
LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
|
||||
transposeLowering = val;
|
||||
return *this;
|
||||
}
|
||||
/// Enable lowering of vector.multi_reduce.
|
||||
|
@ -905,19 +902,24 @@ struct LinalgVectorLoweringOptions {
|
|||
multiReductionLowering = val;
|
||||
return *this;
|
||||
}
|
||||
/// Enable lowering of vector.contract.
|
||||
bool contractionLowering = false;
|
||||
LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
|
||||
contractionLowering = val;
|
||||
return *this;
|
||||
}
|
||||
/// Trigger full / partial vector.transfer splits.
|
||||
bool transferPartialRewrite = false;
|
||||
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
|
||||
transferPartialRewrite = val;
|
||||
return *this;
|
||||
}
|
||||
/// Enable lowering of vector.transfer to scf.
|
||||
bool transferToSCFConversion = false;
|
||||
LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
|
||||
transferToSCFConversion = val;
|
||||
return *this;
|
||||
}
|
||||
/// Configure late vector transformations.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
LinalgVectorLoweringOptions &
|
||||
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
|
||||
vectorTransformOptions = options;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector.transfer to scf
|
||||
/// conversion.
|
||||
VectorTransferToSCFOptions vectorTransferToSCFOptions;
|
||||
|
@ -926,8 +928,18 @@ struct LinalgVectorLoweringOptions {
|
|||
vectorTransferToSCFOptions = options;
|
||||
return *this;
|
||||
}
|
||||
/// Configure late vector transformations.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
LinalgVectorLoweringOptions &
|
||||
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
|
||||
vectorTransformOptions = options;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as rewrite patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Trait to check if T provides a `getOperationName` method.
|
||||
template <typename T, typename... Args>
|
||||
using has_get_operation_name = decltype(T::getOperationName());
|
||||
|
|
|
@ -40,76 +40,6 @@ namespace detail {
|
|||
struct BitmaskEnumStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// Enum to control the lowering of `vector.contract` operations.
|
||||
enum class VectorContractLowering {
|
||||
/// Progressively lower to finer grained `vector.contract` and dot-products.
|
||||
Dot = 0,
|
||||
/// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
|
||||
Matmul = 1,
|
||||
/// Lower to `vector.outerproduct`.
|
||||
OuterProduct = 2,
|
||||
};
|
||||
/// Enum to control the lowering of `vector.multi_reduction` operations.
|
||||
enum class VectorMultiReductionLowering {
|
||||
/// Lower multi_reduction into outer-reduction and inner-parallel ops.
|
||||
InnerParallel = 0,
|
||||
/// Lower multi_reduction into outer-parallel and inner-reduction ops.
|
||||
InnerReduction = 1,
|
||||
};
|
||||
/// Enum to control the lowering of `vector.transpose` operations.
|
||||
enum class VectorTransposeLowering {
|
||||
/// Lower transpose into element-wise extract and inserts.
|
||||
EltWise = 0,
|
||||
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
|
||||
/// intrinsics.
|
||||
Flat = 1,
|
||||
};
|
||||
/// Enum to control the splitting of `vector.transfer` operations into
|
||||
/// in-bounds and out-of-bounds variants.
|
||||
enum class VectorTransferSplit {
|
||||
/// Do not split vector transfer operations.
|
||||
None = 0,
|
||||
/// Split using in-bounds + out-of-bounds vector.transfer operations.
|
||||
VectorTransfer = 1,
|
||||
/// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
|
||||
/// operations.
|
||||
LinalgCopy = 2,
|
||||
/// Do not split vector transfer operation but instead mark it as "in-bounds".
|
||||
ForceInBounds = 3
|
||||
};
|
||||
/// Structure to control the behavior of vector transform patterns.
|
||||
struct VectorTransformsOptions {
|
||||
/// Option to control the lowering of vector.contract.
|
||||
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransformsOptions(VectorContractLowering opt) {
|
||||
vectorContractLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.multi_reduction.
|
||||
VectorMultiReductionLowering vectorMultiReductionLowering =
|
||||
VectorMultiReductionLowering::InnerParallel;
|
||||
VectorTransformsOptions &
|
||||
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
|
||||
vectorMultiReductionLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.transpose.
|
||||
VectorTransposeLowering vectorTransposeLowering =
|
||||
VectorTransposeLowering::EltWise;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransposeLowering(VectorTransposeLowering opt) {
|
||||
vectorTransposeLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the splitting of vector transfers.
|
||||
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
|
||||
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
|
||||
vectorTransferSplit = opt;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// Return whether `srcType` can be broadcast to `dstVectorType` under the
|
||||
/// semantics of the `vector.broadcast` op.
|
||||
enum class BroadcastableToResult {
|
||||
|
@ -161,33 +91,6 @@ void populateVectorTransferPermutationMapLoweringPatterns(
|
|||
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
||||
bool enableIndexOptimizations);
|
||||
|
||||
/// Collect a set of patterns to convert vector.multi_reduction op into
|
||||
/// a sequence of vector.reduction ops. The patterns comprise:
|
||||
/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
|
||||
/// that all reduction dimensions are either innermost or outermost, by adding
|
||||
/// the proper vector.transpose operations.
|
||||
/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
|
||||
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
|
||||
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
|
||||
/// back.
|
||||
/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
|
||||
/// form, with an **outermost** reduction dimension, unroll the outer dimension
|
||||
/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
|
||||
/// tree-reduction (in the future).
|
||||
/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
|
||||
/// with an **innermost** reduction dimension, unroll the outer dimension to
|
||||
/// obtain a sequence of extract + vector.reduction + insert. This can further
|
||||
/// lower to horizontal reduction ops.
|
||||
/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
|
||||
/// reduction (and are thus missing either a parallel or a reduction), we lift
|
||||
/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
|
||||
/// the other patterns can kick in, thus fully exiting out of the
|
||||
/// vector.multi_reduction abstraction.
|
||||
void populateVectorMultiReductionLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorMultiReductionLowering options =
|
||||
vector::VectorMultiReductionLowering::InnerParallel);
|
||||
|
||||
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
|
||||
/// chain.
|
||||
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
|
||||
|
@ -212,12 +115,6 @@ public:
|
|||
/// vectors to low-D vector ops.
|
||||
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Collects patterns to progressively lower vector contraction ops on high-D
|
||||
/// into low-D reduction and product ops.
|
||||
void populateVectorContractLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options = VectorTransformsOptions());
|
||||
|
||||
/// Collects patterns to progressively lower vector mask ops into elementary
|
||||
/// selection and insertion ops.
|
||||
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
|
||||
|
@ -227,15 +124,6 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
|
|||
/// ops.
|
||||
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Insert TransposeLowering patterns into extraction/insertion.
|
||||
void populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options = VectorTransformsOptions());
|
||||
|
||||
/// Collect patterns to convert reduction op to vector.contract and fold
|
||||
/// transpose/broadcast ops into the contract.
|
||||
void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Returns the integer type required for subscripts in the vector dialect.
|
||||
IntegerType getVectorSubscriptType(Builder &builder);
|
||||
|
||||
|
|
|
@ -9,11 +9,173 @@
|
|||
#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
||||
#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
||||
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
class RewritePatternSet;
|
||||
|
||||
namespace vector {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Vector transformation options exposed as auxiliary structs.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Enum to control the lowering of `vector.transpose` operations.
|
||||
enum class VectorTransposeLowering {
|
||||
/// Lower transpose into element-wise extract and inserts.
|
||||
EltWise = 0,
|
||||
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
|
||||
/// intrinsics.
|
||||
Flat = 1,
|
||||
};
|
||||
/// Enum to control the lowering of `vector.multi_reduction` operations.
|
||||
enum class VectorMultiReductionLowering {
|
||||
/// Lower multi_reduction into outer-reduction and inner-parallel ops.
|
||||
InnerParallel = 0,
|
||||
/// Lower multi_reduction into outer-parallel and inner-reduction ops.
|
||||
InnerReduction = 1,
|
||||
};
|
||||
/// Enum to control the lowering of `vector.contract` operations.
|
||||
enum class VectorContractLowering {
|
||||
/// Progressively lower to finer grained `vector.contract` and dot-products.
|
||||
Dot = 0,
|
||||
/// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
|
||||
Matmul = 1,
|
||||
/// Lower to `vector.outerproduct`.
|
||||
OuterProduct = 2,
|
||||
};
|
||||
/// Enum to control the splitting of `vector.transfer` operations into
|
||||
/// in-bounds and out-of-bounds variants.
|
||||
enum class VectorTransferSplit {
|
||||
/// Do not split vector transfer operations.
|
||||
None = 0,
|
||||
/// Split using in-bounds + out-of-bounds vector.transfer operations.
|
||||
VectorTransfer = 1,
|
||||
/// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
|
||||
/// operations.
|
||||
LinalgCopy = 2,
|
||||
/// Do not split vector transfer operation but instead mark it as "in-bounds".
|
||||
ForceInBounds = 3
|
||||
};
|
||||
/// Structure to control the behavior of vector transform patterns.
|
||||
struct VectorTransformsOptions {
|
||||
/// Option to control the lowering of vector.contract.
|
||||
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransformsOptions(VectorContractLowering opt) {
|
||||
vectorContractLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.multi_reduction.
|
||||
VectorMultiReductionLowering vectorMultiReductionLowering =
|
||||
VectorMultiReductionLowering::InnerParallel;
|
||||
VectorTransformsOptions &
|
||||
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
|
||||
vectorMultiReductionLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the lowering of vector.transpose.
|
||||
VectorTransposeLowering vectorTransposeLowering =
|
||||
VectorTransposeLowering::EltWise;
|
||||
VectorTransformsOptions &
|
||||
setVectorTransposeLowering(VectorTransposeLowering opt) {
|
||||
vectorTransposeLowering = opt;
|
||||
return *this;
|
||||
}
|
||||
/// Option to control the splitting of vector transfers.
|
||||
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
|
||||
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
|
||||
vectorTransferSplit = opt;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// Options that control the vector unrolling.
|
||||
struct UnrollVectorOptions {
|
||||
using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
|
||||
/// Callback function that indicates whether vector unrolling should be
|
||||
/// attempted on the operation.
|
||||
FilterConstraintFnType filterConstraint = nullptr;
|
||||
UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
|
||||
filterConstraint = constraint;
|
||||
return *this;
|
||||
}
|
||||
|
||||
using NativeShapeFnType =
|
||||
std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
|
||||
/// Function that returns the shape of the vector to unroll to for a given
|
||||
/// operation. The unrolling is aborted if the function returns `llvm::None`.
|
||||
NativeShapeFnType nativeShape = nullptr;
|
||||
UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
|
||||
nativeShape = fn;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Set the native shape to use for unrolling.
|
||||
UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
|
||||
nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
|
||||
return tsShape;
|
||||
};
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Vector transformation exposed as populate functions over rewrite patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Insert TransposeLowering patterns into extraction/insertion.
|
||||
void populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options = VectorTransformsOptions());
|
||||
|
||||
/// Collect a set of patterns to convert vector.multi_reduction op into
|
||||
/// a sequence of vector.reduction ops. The patterns comprise:
|
||||
/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
|
||||
/// that all reduction dimensions are either innermost or outermost, by adding
|
||||
/// the proper vector.transpose operations.
|
||||
/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
|
||||
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
|
||||
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
|
||||
/// back.
|
||||
/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
|
||||
/// form, with an **outermost** reduction dimension, unroll the outer dimension
|
||||
/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
|
||||
/// tree-reduction (in the future).
|
||||
/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
|
||||
/// with an **innermost** reduction dimension, unroll the outer dimension to
|
||||
/// obtain a sequence of extract + vector.reduction + insert. This can further
|
||||
/// lower to horizontal reduction ops.
|
||||
/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
|
||||
/// reduction (and are thus missing either a parallel or a reduction), we lift
|
||||
/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
|
||||
/// the other patterns can kick in, thus fully exiting out of the
|
||||
/// vector.multi_reduction abstraction.
|
||||
void populateVectorMultiReductionLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorMultiReductionLowering options =
|
||||
VectorMultiReductionLowering::InnerParallel);
|
||||
|
||||
/// Collects patterns to progressively lower vector contraction ops on high-D
|
||||
/// into low-D reduction and product ops.
|
||||
void populateVectorContractLoweringPatterns(
|
||||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options = VectorTransformsOptions());
|
||||
|
||||
/// Collect patterns to convert reduction op to vector.contract and fold
|
||||
/// transpose/broadcast ops into the contract.
|
||||
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to reduce the rank of the operands of vector
|
||||
/// transfer ops to operate on the largest contigious vector.
|
||||
/// These patterns are useful when lowering to dialects with 1d vector type
|
||||
/// such as llvm and it will result fewer memory reads.
|
||||
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Populate `patterns` with the following patterns.
|
||||
///
|
||||
/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
|
||||
|
@ -52,6 +214,235 @@ namespace vector {
|
|||
void populateVectorInsertExtractStridedSliceTransforms(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
|
||||
/// `options` structure controls which operations are unrolled and the target
|
||||
/// shape.
|
||||
/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||
/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||
/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
|
||||
/// assumed the unrolling factors divide the vector sizes.
|
||||
/// 2. ExtractStridedSlice are created to break-up the vector operands.
|
||||
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
|
||||
/// result.
|
||||
/// 4. InsertStridedSlice are inserted to re-assemble the slices into the
|
||||
/// original vectore shape.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// opA(operand0, operand1) // numUnrolledInstances = 3
|
||||
///
|
||||
/// operand0 operand1
|
||||
/// | |
|
||||
/// fork fork
|
||||
/// <----------gather all fork ops --------->
|
||||
/// /|\ /|\
|
||||
/// f00 f01 f02 f10 f11 f12
|
||||
/// <---------- clone op 3 times --------->
|
||||
/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
|
||||
/// \ | /
|
||||
/// <-------------------- join ------------------------->
|
||||
///
|
||||
/// Other local patterns then kick in iteratively (including DCE) and compose
|
||||
/// to combine the ExtractStridedSlice/InsertStridedSlice.
|
||||
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
|
||||
const UnrollVectorOptions &options);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Finer-grained patterns exposed for more control over individual lowerings.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
|
||||
/// may take an extra filter to perform selection at a finer granularity.
|
||||
struct VectorTransferFullPartialRewriter : public RewritePattern {
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(VectorTransferOpInterface op)>;
|
||||
|
||||
explicit VectorTransferFullPartialRewriter(
|
||||
MLIRContext *context,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
FilterConstraintType filter =
|
||||
[](VectorTransferOpInterface op) { return success(); },
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
|
||||
filter(filter) {}
|
||||
|
||||
/// Performs the rewrite.
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
VectorTransformsOptions options;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to:
|
||||
/// ```
|
||||
/// %flattened_a = vector.shape_cast %a
|
||||
/// %flattened_b = vector.shape_cast %b
|
||||
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
|
||||
/// %d = vector.shape_cast %%flattened_d
|
||||
/// %e = add %c, %d
|
||||
/// ```
|
||||
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
|
||||
//
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
class ContractionOpToMatmulOpLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToMatmulOpLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to a reduction_size-unrolled sequence:
|
||||
/// ```
|
||||
/// %at = vector.transpose %a, [1, 0]
|
||||
/// %bRow0 = vector.extract %b[0]
|
||||
/// %atRow0 = vector.extract %at[0]
|
||||
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
|
||||
/// ...
|
||||
/// %bRowK = vector.extract %b[K]
|
||||
/// %atRowK = vector.extract %at[K]
|
||||
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
class ContractionOpToOuterProductOpLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToOuterProductOpLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to an output-size-unrolled sequence:
|
||||
/// ```
|
||||
/// %out = arith.constant ... : vector<MxNxelt_type>
|
||||
/// %bt = vector.transpose %b, [1, 0]
|
||||
/// %aRow0 = vector.extract %a[0]
|
||||
/// %btRow0 = vector.extract %bt[0]
|
||||
/// %c00 = vector.reduce %atRow0, %bRow0
|
||||
/// %out00 = vector.insert %c00, %out[0, 0]
|
||||
/// ...
|
||||
/// %aRowLast = vector.extract %at[M-1]
|
||||
/// %btRowLast = vector.extract %b[N-1]
|
||||
/// %cLastLast = vector.reduce %atRowLast, %bRowLast
|
||||
/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to Dot and
|
||||
/// the vector.contract op is a row-major matmul or matvec.
|
||||
class ContractionOpToDotLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToDotLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of ContractionOp.
|
||||
///
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
/// is replaced by:
|
||||
/// %a = vector.contract with one less free/batch dimension
|
||||
/// %b = vector.contract with one less free/batch dimension
|
||||
/// ..
|
||||
/// %x = combine %a %b ..
|
||||
/// until a pure contraction is reached (no free/batch dimensions),
|
||||
/// which is replaced by a dot-product.
|
||||
///
|
||||
/// This only kicks in when either VectorTransformsOptions is set
|
||||
/// to Dot or when other contraction patterns fail.
|
||||
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context,
|
||||
FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
// Lower one parallel dimension.
|
||||
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex, PatternRewriter &rewriter) const;
|
||||
// Lower one reduction dimension.
|
||||
Value lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -9,10 +9,8 @@
|
|||
#ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
|
||||
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
|
||||
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
|
@ -26,77 +24,9 @@ class IfOp;
|
|||
|
||||
namespace vector {
|
||||
|
||||
/// Options that control the vector unrolling.
|
||||
struct UnrollVectorOptions {
|
||||
using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
|
||||
/// Callback function that indicates whether vector unrolling should be
|
||||
/// attempted on the operation.
|
||||
FilterConstraintFnType filterConstraint = nullptr;
|
||||
UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
|
||||
filterConstraint = constraint;
|
||||
return *this;
|
||||
}
|
||||
|
||||
using NativeShapeFnType =
|
||||
std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
|
||||
/// Function that returns the shape of the vector to unroll to for a given
|
||||
/// operation. The unrolling is aborted if the function returns `llvm::None`.
|
||||
NativeShapeFnType nativeShape = nullptr;
|
||||
UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
|
||||
nativeShape = fn;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Set the native shape to use for unrolling.
|
||||
UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
|
||||
nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
|
||||
return tsShape;
|
||||
};
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
|
||||
/// `options` structure controls which operations are unrolled and the target
|
||||
/// shape.
|
||||
/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||
/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||
/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
|
||||
/// assumed the unrolling factors divide the vector sizes.
|
||||
/// 2. ExtractStridedSlice are created to break-up the vector operands.
|
||||
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
|
||||
/// result.
|
||||
/// 4. InsertStridedSlice are inserted to re-assemble the slices into the
|
||||
/// original vectore shape.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// opA(operand0, operand1) // numUnrolledInstances = 3
|
||||
///
|
||||
/// operand0 operand1
|
||||
/// | |
|
||||
/// fork fork
|
||||
/// <----------gather all fork ops --------->
|
||||
/// /|\ /|\
|
||||
/// f00 f01 f02 f10 f11 f12
|
||||
/// <---------- clone op 3 times --------->
|
||||
/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
|
||||
/// \ | /
|
||||
/// <-------------------- join ------------------------->
|
||||
///
|
||||
/// Other local patterns then kick in iteratively (including DCE) and compose
|
||||
/// to combine the ExtractStridedSlice/InsertStridedSlice.
|
||||
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
|
||||
const UnrollVectorOptions &options);
|
||||
|
||||
/// Collect a set of patterns to reduce the rank of the operands of vector
|
||||
/// transfer ops to operate on the largest contigious vector.
|
||||
/// These patterns are useful when lowering to dialects with 1d vector type
|
||||
/// such as llvm and it will result fewer memory reads.
|
||||
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Standalone transformations and helpers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
|
||||
/// masking) fastpath and a slowpath.
|
||||
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
|
||||
|
@ -130,37 +60,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
|
|||
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
|
||||
/// must be equal. This will be relaxed in the future but requires
|
||||
/// rank-reducing subviews.
|
||||
LogicalResult
|
||||
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
|
||||
LogicalResult splitFullAndPartialTransfer(
|
||||
OpBuilder &b, VectorTransferOpInterface xferOp,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
scf::IfOp *ifOp = nullptr);
|
||||
|
||||
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
|
||||
/// may take an extra filter to perform selection at a finer granularity.
|
||||
struct VectorTransferFullPartialRewriter : public RewritePattern {
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(VectorTransferOpInterface op)>;
|
||||
|
||||
explicit VectorTransferFullPartialRewriter(
|
||||
MLIRContext *context,
|
||||
VectorTransformsOptions options = VectorTransformsOptions(),
|
||||
FilterConstraintType filter =
|
||||
[](VectorTransferOpInterface op) { return success(); },
|
||||
PatternBenefit benefit = 1)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
|
||||
filter(filter) {}
|
||||
|
||||
/// Performs the rewrite.
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
VectorTransformsOptions options;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
struct DistributeOps {
|
||||
ExtractMapOp extract;
|
||||
InsertMapOp insert;
|
||||
|
@ -188,180 +92,6 @@ distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
|
|||
void transferOpflowOpt(FuncOp func);
|
||||
|
||||
} // namespace vector
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Finer-grained patterns exposed for more control over individual lowerings.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to:
|
||||
/// ```
|
||||
/// %flattened_a = vector.shape_cast %a
|
||||
/// %flattened_b = vector.shape_cast %b
|
||||
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
|
||||
/// %d = vector.shape_cast %%flattened_d
|
||||
/// %e = add %c, %d
|
||||
/// ```
|
||||
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
|
||||
//
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
class ContractionOpToMatmulOpLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToMatmulOpLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to a reduction_size-unrolled sequence:
|
||||
/// ```
|
||||
/// %at = vector.transpose %a, [1, 0]
|
||||
/// %bRow0 = vector.extract %b[0]
|
||||
/// %atRow0 = vector.extract %at[0]
|
||||
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
|
||||
/// ...
|
||||
/// %bRowK = vector.extract %b[K]
|
||||
/// %atRowK = vector.extract %at[K]
|
||||
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
|
||||
/// the vector.contract op is a row-major matrix multiply.
|
||||
class ContractionOpToOuterProductOpLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToOuterProductOpLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to an output-size-unrolled sequence:
|
||||
/// ```
|
||||
/// %out = arith.constant ... : vector<MxNxelt_type>
|
||||
/// %bt = vector.transpose %b, [1, 0]
|
||||
/// %aRow0 = vector.extract %a[0]
|
||||
/// %btRow0 = vector.extract %bt[0]
|
||||
/// %c00 = vector.reduce %atRow0, %bRow0
|
||||
/// %out00 = vector.insert %c00, %out[0, 0]
|
||||
/// ...
|
||||
/// %aRowLast = vector.extract %at[M-1]
|
||||
/// %btRowLast = vector.extract %b[N-1]
|
||||
/// %cLastLast = vector.reduce %atRowLast, %bRowLast
|
||||
/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
|
||||
/// ```
|
||||
///
|
||||
/// This only kicks in when VectorTransformsOptions is set to Dot and
|
||||
/// the vector.contract op is a row-major matmul or matvec.
|
||||
class ContractionOpToDotLowering
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpToDotLowering(
|
||||
vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Progressive lowering of ContractionOp.
|
||||
///
|
||||
/// One:
|
||||
/// %x = vector.contract with at least one free/batch dimension
|
||||
/// is replaced by:
|
||||
/// %a = vector.contract with one less free/batch dimension
|
||||
/// %b = vector.contract with one less free/batch dimension
|
||||
/// ..
|
||||
/// %x = combine %a %b ..
|
||||
/// until a pure contraction is reached (no free/batch dimensions),
|
||||
/// which is replaced by a dot-product.
|
||||
///
|
||||
/// This only kicks in when either VectorTransformsOptions is set
|
||||
/// to Dot or when other contraction patterns fail.
|
||||
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
using FilterConstraintType =
|
||||
std::function<LogicalResult(vector::ContractionOp op)>;
|
||||
|
||||
static LogicalResult defaultFilter(vector::ContractionOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
|
||||
MLIRContext *context,
|
||||
FilterConstraintType constraint = defaultFilter)
|
||||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
// Lower one parallel dimension.
|
||||
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex, PatternRewriter &rewriter) const;
|
||||
// Lower one reduction dimension.
|
||||
Value lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_
|
||||
|
|
|
@ -14,8 +14,7 @@
|
|||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
|
@ -32,6 +31,7 @@
|
|||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
using namespace linalg;
|
||||
|
||||
namespace {
|
||||
|
@ -191,7 +191,7 @@ struct LinalgStrategyVectorizePass
|
|||
}
|
||||
vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
vectorizationPatterns);
|
||||
vector::populateVetorReductionToContractPatterns(vectorizationPatterns);
|
||||
vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
|
||||
vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
|
||||
linalg::LinalgCopyVTWForwardingPattern>(
|
||||
funcOp.getContext(), /*benefit=*/2);
|
||||
|
@ -268,9 +268,14 @@ struct LinalgStrategyLowerVectorsPass
|
|||
vector::populateVectorTransferLoweringPatterns(patterns,
|
||||
options.maxTransferRank);
|
||||
}
|
||||
if (options.transferPartialRewrite) {
|
||||
patterns.add<vector::VectorTransferFullPartialRewriter>(
|
||||
context, options.vectorTransformOptions);
|
||||
if (options.transposeLowering) {
|
||||
vector::populateVectorTransposeLoweringPatterns(
|
||||
patterns, options.vectorTransformOptions);
|
||||
}
|
||||
if (options.multiReductionLowering) {
|
||||
vector::populateVectorMultiReductionLoweringPatterns(
|
||||
patterns,
|
||||
options.vectorTransformOptions.vectorMultiReductionLowering);
|
||||
}
|
||||
if (options.contractionLowering) {
|
||||
patterns.add<ContractionOpToOuterProductOpLowering,
|
||||
|
@ -278,15 +283,15 @@ struct LinalgStrategyLowerVectorsPass
|
|||
options.vectorTransformOptions, context);
|
||||
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
|
||||
}
|
||||
if (options.multiReductionLowering) {
|
||||
vector::populateVectorMultiReductionLoweringPatterns(
|
||||
patterns,
|
||||
options.vectorTransformOptions.vectorMultiReductionLowering);
|
||||
if (options.transferPartialRewrite) {
|
||||
patterns.add<vector::VectorTransferFullPartialRewriter>(
|
||||
context, options.vectorTransformOptions);
|
||||
}
|
||||
if (options.transferToSCFConversion) {
|
||||
populateVectorToSCFConversionPatterns(patterns,
|
||||
options.vectorTransferToSCFOptions);
|
||||
}
|
||||
vector::populateVectorShapeCastLoweringPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
|
|
|
@ -10,14 +10,9 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
|
|
|
@ -21,21 +21,10 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
@ -48,6 +37,7 @@
|
|||
#define DEBUG_TYPE "vector-to-vector"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
// Helper to find an index in an affine map.
|
||||
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
||||
|
@ -1978,9 +1968,41 @@ static Value createInBoundsCond(OpBuilder &b,
|
|||
});
|
||||
return inBoundsCond;
|
||||
}
|
||||
|
||||
LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
|
||||
VectorTransferOpInterface xferOp) {
|
||||
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
|
||||
/// masking) fastpath and a slowpath.
|
||||
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
|
||||
/// newly created conditional upon function return.
|
||||
/// To accomodate for the fact that the original vector.transfer indexing may be
|
||||
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
|
||||
/// scf.if op returns a view and values of type index.
|
||||
/// At this time, only vector.transfer_read case is implemented.
|
||||
///
|
||||
/// Example (a 2-D vector.transfer_read):
|
||||
/// ```
|
||||
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
|
||||
/// ```
|
||||
/// is transformed into:
|
||||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// // fastpath, direct cast
|
||||
/// memref.cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// // slowpath, not in-bounds vector.transfer or linalg.copy.
|
||||
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4 : compatibleMemRefType, index, index
|
||||
// }
|
||||
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
|
||||
/// ```
|
||||
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
|
||||
///
|
||||
/// Preconditions:
|
||||
/// 1. `xferOp.permutation_map()` must be a minor identity map
|
||||
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
|
||||
/// must be equal. This will be relaxed in the future but requires
|
||||
/// rank-reducing subviews.
|
||||
static LogicalResult
|
||||
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
|
||||
// TODO: expand support to these 2 cases.
|
||||
if (!xferOp.permutation_map().isMinorIdentity())
|
||||
return failure();
|
||||
|
@ -3863,7 +3885,7 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
|
|||
patterns.add<TransposeOpLowering>(options, patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVetorReductionToContractPatterns(
|
||||
void mlir::vector::populateVectorReductionToContractPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<MultiReduceToContract, CombineContractBroadcast,
|
||||
CombineContractTranspose>(patterns.getContext());
|
||||
|
|
|
@ -98,9 +98,8 @@ void TestConvVectorization::runOnOperation() {
|
|||
VectorTransposeLowering::EltWise};
|
||||
|
||||
RewritePatternSet vectorTransferPatterns(context);
|
||||
// Pattern is not applied because rank-reducing vector transfer is not yet
|
||||
// supported as can be seen in splitFullAndPartialTransferPrecondition,
|
||||
// VectorTransforms.cpp
|
||||
// Pattern is not applied: rank-reducing vector transfer is not yet supported
|
||||
// (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp).
|
||||
vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
|
||||
context, vectorTransformOptions);
|
||||
(void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
|
||||
|
|
|
@ -536,7 +536,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
|
|||
RewritePatternSet canonicalizationPatterns(funcOp.getContext());
|
||||
vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
canonicalizationPatterns);
|
||||
vector::populateVetorReductionToContractPatterns(canonicalizationPatterns);
|
||||
vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
|
||||
stage1Patterns.push_back(std::move(canonicalizationPatterns));
|
||||
}
|
||||
SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestVectorToVectorConversion
|
||||
|
@ -511,7 +511,7 @@ struct TestVectorReduceToContractPatternsPatterns
|
|||
}
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVetorReductionToContractPatterns(patterns);
|
||||
populateVectorReductionToContractPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue