[mlir][Vector] NFC - Add option to hook vector.transpose lowering to strategies.

This revision also moves some code around to improve overall structure.

Differential Revision: https://reviews.llvm.org/D112437
This commit is contained in:
Nicolas Vasilache 2021-10-25 11:22:22 +00:00
parent 3b1165ba3d
commit d054b80bd3
12 changed files with 484 additions and 443 deletions

View File

@ -15,7 +15,7 @@
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Bufferize.h"
@ -846,6 +846,9 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
};
//===----------------------------------------------------------------------===//
// Transformation and lowering options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
/// Options to control the application of enabling transformations.
/// Hoisting transformations are always deemed beneficial and must be disabled
/// explicitly.
@ -887,16 +890,10 @@ struct LinalgVectorLoweringOptions {
transferLowering = val;
return *this;
}
/// Trigger full / partial vector.transfer splits.
bool transferPartialRewrite = false;
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
transferPartialRewrite = val;
return *this;
}
/// Enable lowering of vector.contract.
bool contractionLowering = false;
LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
contractionLowering = val;
/// Enable lowering of vector.transpose.
bool transposeLowering = false;
LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
transposeLowering = val;
return *this;
}
/// Enable lowering of vector.multi_reduce.
@ -905,19 +902,24 @@ struct LinalgVectorLoweringOptions {
multiReductionLowering = val;
return *this;
}
/// Enable lowering of vector.contract.
bool contractionLowering = false;
LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
contractionLowering = val;
return *this;
}
/// Trigger full / partial vector.transfer splits.
bool transferPartialRewrite = false;
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
transferPartialRewrite = val;
return *this;
}
/// Enable lowering of vector.transfer to scf.
bool transferToSCFConversion = false;
LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
transferToSCFConversion = val;
return *this;
}
/// Configure late vector transformations.
vector::VectorTransformsOptions vectorTransformOptions;
LinalgVectorLoweringOptions &
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
vectorTransformOptions = options;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
VectorTransferToSCFOptions vectorTransferToSCFOptions;
@ -926,8 +928,18 @@ struct LinalgVectorLoweringOptions {
vectorTransferToSCFOptions = options;
return *this;
}
/// Configure late vector transformations.
vector::VectorTransformsOptions vectorTransformOptions;
LinalgVectorLoweringOptions &
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
vectorTransformOptions = options;
return *this;
}
};
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
/// Trait to check if T provides a `getOperationName` method.
template <typename T, typename... Args>
using has_get_operation_name = decltype(T::getOperationName());

View File

@ -40,76 +40,6 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail
/// Enum to control the lowering of `vector.contract` operations.
enum class VectorContractLowering {
/// Progressively lower to finer grained `vector.contract` and dot-products.
Dot = 0,
/// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
Matmul = 1,
/// Lower to `vector.outerproduct`.
OuterProduct = 2,
};
/// Enum to control the lowering of `vector.multi_reduction` operations.
enum class VectorMultiReductionLowering {
/// Lower multi_reduction into outer-reduction and inner-parallel ops.
InnerParallel = 0,
/// Lower multi_reduction into outer-parallel and inner-reduction ops.
InnerReduction = 1,
};
/// Enum to control the lowering of `vector.transpose` operations.
enum class VectorTransposeLowering {
/// Lower transpose into element-wise extract and inserts.
EltWise = 0,
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
/// intrinsics.
Flat = 1,
};
/// Enum to control the splitting of `vector.transfer` operations into
/// in-bounds and out-of-bounds variants.
enum class VectorTransferSplit {
/// Do not split vector transfer operations.
None = 0,
/// Split using in-bounds + out-of-bounds vector.transfer operations.
VectorTransfer = 1,
/// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
/// operations.
LinalgCopy = 2,
/// Do not split vector transfer operation but instead mark it as "in-bounds".
ForceInBounds = 3
};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
/// Option to control the lowering of vector.multi_reduction.
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions &
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
vectorMultiReductionLowering = opt;
return *this;
}
/// Option to control the lowering of vector.transpose.
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransposeLowering(VectorTransposeLowering opt) {
vectorTransposeLowering = opt;
return *this;
}
/// Option to control the splitting of vector transfers.
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
vectorTransferSplit = opt;
return *this;
}
};
/// Return whether `srcType` can be broadcast to `dstVectorType` under the
/// semantics of the `vector.broadcast` op.
enum class BroadcastableToResult {
@ -161,33 +91,6 @@ void populateVectorTransferPermutationMapLoweringPatterns(
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
/// that all reduction dimensions are either innermost or outermost, by adding
/// the proper vector.transpose operations.
/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
/// form, with an **outermost** reduction dimension, unroll the outer dimension
/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
/// tree-reduction (in the future).
/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
/// with an **innermost** reduction dimension, unroll the outer dimension to
/// obtain a sequence of extract + vector.reduction + insert. This can further
/// lower to horizontal reduction ops.
/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
/// reduction (and are thus missing either a parallel or a reduction), we lift
/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
/// the other patterns can kick in, thus fully exiting out of the
/// vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns,
VectorMultiReductionLowering options =
vector::VectorMultiReductionLowering::InnerParallel);
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
/// chain.
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
@ -212,12 +115,6 @@ public:
/// vectors to low-D vector ops.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
/// Collects patterns to progressively lower vector contraction ops on high-D
/// into low-D reduction and product ops.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions());
/// Collects patterns to progressively lower vector mask ops into elementary
/// selection and insertion ops.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
@ -227,15 +124,6 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
/// ops.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions());
/// Collect patterns to convert reduction op to vector.contract and fold
/// transpose/broadcast ops into the contract.
void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

View File

@ -9,11 +9,173 @@
#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
class RewritePatternSet;
namespace vector {
//===----------------------------------------------------------------------===//
// Vector transformation options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
/// Enum to control the lowering of `vector.transpose` operations.
enum class VectorTransposeLowering {
/// Lower transpose into element-wise extract and inserts.
EltWise = 0,
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
/// intrinsics.
Flat = 1,
};
/// Enum to control the lowering of `vector.multi_reduction` operations.
enum class VectorMultiReductionLowering {
/// Lower multi_reduction into outer-reduction and inner-parallel ops.
InnerParallel = 0,
/// Lower multi_reduction into outer-parallel and inner-reduction ops.
InnerReduction = 1,
};
/// Enum to control the lowering of `vector.contract` operations.
enum class VectorContractLowering {
/// Progressively lower to finer grained `vector.contract` and dot-products.
Dot = 0,
/// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
Matmul = 1,
/// Lower to `vector.outerproduct`.
OuterProduct = 2,
};
/// Enum to control the splitting of `vector.transfer` operations into
/// in-bounds and out-of-bounds variants.
enum class VectorTransferSplit {
/// Do not split vector transfer operations.
None = 0,
/// Split using in-bounds + out-of-bounds vector.transfer operations.
VectorTransfer = 1,
/// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
/// operations.
LinalgCopy = 2,
/// Do not split vector transfer operation but instead mark it as "in-bounds".
ForceInBounds = 3
};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
/// Option to control the lowering of vector.multi_reduction.
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions &
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
vectorMultiReductionLowering = opt;
return *this;
}
/// Option to control the lowering of vector.transpose.
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransposeLowering(VectorTransposeLowering opt) {
vectorTransposeLowering = opt;
return *this;
}
/// Option to control the splitting of vector transfers.
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
vectorTransferSplit = opt;
return *this;
}
};
/// Options that control the vector unrolling.
struct UnrollVectorOptions {
using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
/// Callback function that indicates whether vector unrolling should be
/// attempted on the operation.
FilterConstraintFnType filterConstraint = nullptr;
UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
filterConstraint = constraint;
return *this;
}
using NativeShapeFnType =
std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
/// Function that returns the shape of the vector to unroll to for a given
/// operation. The unrolling is aborted if the function returns `llvm::None`.
NativeShapeFnType nativeShape = nullptr;
UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
nativeShape = fn;
return *this;
}
/// Set the native shape to use for unrolling.
UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
return tsShape;
};
return *this;
}
};
//===----------------------------------------------------------------------===//
// Vector transformation exposed as populate functions over rewrite patterns.
//===----------------------------------------------------------------------===//
/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions());
/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
/// that all reduction dimensions are either innermost or outermost, by adding
/// the proper vector.transpose operations.
/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
/// form, with an **outermost** reduction dimension, unroll the outer dimension
/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
/// tree-reduction (in the future).
/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
/// with an **innermost** reduction dimension, unroll the outer dimension to
/// obtain a sequence of extract + vector.reduction + insert. This can further
/// lower to horizontal reduction ops.
/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
/// reduction (and are thus missing either a parallel or a reduction), we lift
/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
/// the other patterns can kick in, thus fully exiting out of the
/// vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns,
VectorMultiReductionLowering options =
VectorMultiReductionLowering::InnerParallel);
/// Collects patterns to progressively lower vector contraction ops on high-D
/// into low-D reduction and product ops.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions());
/// Collect patterns to convert reduction op to vector.contract and fold
/// transpose/broadcast ops into the contract.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to reduce the rank of the operands of vector
/// transfer ops to operate on the largest contigious vector.
/// These patterns are useful when lowering to dialects with 1d vector type
/// such as llvm and it will result fewer memory reads.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns);
/// Populate `patterns` with the following patterns.
///
/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
@ -52,6 +214,235 @@ namespace vector {
void populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns);
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
/// `options` structure controls which operations are unrolled and the target
/// shape.
/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
/// assumed the unrolling factors divide the vector sizes.
/// 2. ExtractStridedSlice are created to break-up the vector operands.
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
/// result.
/// 4. InsertStridedSlice are inserted to re-assemble the slices into the
/// original vectore shape.
///
/// Example:
///
/// opA(operand0, operand1) // numUnrolledInstances = 3
///
/// operand0 operand1
/// | |
/// fork fork
/// <----------gather all fork ops --------->
/// /|\ /|\
/// f00 f01 f02 f10 f11 f12
/// <---------- clone op 3 times --------->
/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
/// \ | /
/// <-------------------- join ------------------------->
///
/// Other local patterns then kick in iteratively (including DCE) and compose
/// to combine the ExtractStridedSlice/InsertStridedSlice.
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options);
//===----------------------------------------------------------------------===//
// Finer-grained patterns exposed for more control over individual lowerings.
//===----------------------------------------------------------------------===//
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
struct VectorTransferFullPartialRewriter : public RewritePattern {
using FilterConstraintType =
std::function<LogicalResult(VectorTransferOpInterface op)>;
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
filter(filter) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
VectorTransformsOptions options;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to:
/// ```
/// %flattened_a = vector.shape_cast %a
/// %flattened_b = vector.shape_cast %b
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
/// %d = vector.shape_cast %%flattened_d
/// %e = add %c, %d
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to a reduction_size-unrolled sequence:
/// ```
/// %at = vector.transpose %a, [1, 0]
/// %bRow0 = vector.extract %b[0]
/// %atRow0 = vector.extract %at[0]
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
/// ...
/// %bRowK = vector.extract %b[K]
/// %atRowK = vector.extract %at[K]
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to an output-size-unrolled sequence:
/// ```
/// %out = arith.constant ... : vector<MxNxelt_type>
/// %bt = vector.transpose %b, [1, 0]
/// %aRow0 = vector.extract %a[0]
/// %btRow0 = vector.extract %bt[0]
/// %c00 = vector.reduce %atRow0, %bRow0
/// %out00 = vector.insert %c00, %out[0, 0]
/// ...
/// %aRowLast = vector.extract %at[M-1]
/// %btRowLast = vector.extract %b[N-1]
/// %cLastLast = vector.reduce %atRowLast, %bRowLast
/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to Dot and
/// the vector.contract op is a row-major matmul or matvec.
class ContractionOpToDotLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToDotLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of ContractionOp.
///
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
/// %a = vector.contract with one less free/batch dimension
/// %b = vector.contract with one less free/batch dimension
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a dot-product.
///
/// This only kicks in when either VectorTransformsOptions is set
/// to Dot or when other contraction patterns fail.
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context,
FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const;
// Lower one reduction dimension.
Value lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const;
};
} // namespace vector
} // namespace mlir

View File

@ -9,10 +9,8 @@
#ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
class MLIRContext;
@ -26,77 +24,9 @@ class IfOp;
namespace vector {
/// Options that control the vector unrolling.
struct UnrollVectorOptions {
using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
/// Callback function that indicates whether vector unrolling should be
/// attempted on the operation.
FilterConstraintFnType filterConstraint = nullptr;
UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
filterConstraint = constraint;
return *this;
}
using NativeShapeFnType =
std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
/// Function that returns the shape of the vector to unroll to for a given
/// operation. The unrolling is aborted if the function returns `llvm::None`.
NativeShapeFnType nativeShape = nullptr;
UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
nativeShape = fn;
return *this;
}
/// Set the native shape to use for unrolling.
UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
return tsShape;
};
return *this;
}
};
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
/// `options` structure controls which operations are unrolled and the target
/// shape.
/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
/// assumed the unrolling factors divide the vector sizes.
/// 2. ExtractStridedSlice are created to break-up the vector operands.
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
/// result.
/// 4. InsertStridedSlice are inserted to re-assemble the slices into the
/// original vectore shape.
///
/// Example:
///
/// opA(operand0, operand1) // numUnrolledInstances = 3
///
/// operand0 operand1
/// | |
/// fork fork
/// <----------gather all fork ops --------->
/// /|\ /|\
/// f00 f01 f02 f10 f11 f12
/// <---------- clone op 3 times --------->
/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
/// \ | /
/// <-------------------- join ------------------------->
///
/// Other local patterns then kick in iteratively (including DCE) and compose
/// to combine the ExtractStridedSlice/InsertStridedSlice.
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
const UnrollVectorOptions &options);
/// Collect a set of patterns to reduce the rank of the operands of vector
/// transfer ops to operate on the largest contigious vector.
/// These patterns are useful when lowering to dialects with 1d vector type
/// such as llvm and it will result fewer memory reads.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
// Standalone transformations and helpers.
//===----------------------------------------------------------------------===//
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
@ -130,37 +60,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
LogicalResult splitFullAndPartialTransfer(
OpBuilder &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options = VectorTransformsOptions(),
scf::IfOp *ifOp = nullptr);
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
struct VectorTransferFullPartialRewriter : public RewritePattern {
using FilterConstraintType =
std::function<LogicalResult(VectorTransferOpInterface op)>;
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
filter(filter) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
VectorTransformsOptions options;
FilterConstraintType filter;
};
struct DistributeOps {
ExtractMapOp extract;
InsertMapOp insert;
@ -188,180 +92,6 @@ distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
void transferOpflowOpt(FuncOp func);
} // namespace vector
//===----------------------------------------------------------------------===//
// Finer-grained patterns exposed for more control over individual lowerings.
//===----------------------------------------------------------------------===//
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to:
/// ```
/// %flattened_a = vector.shape_cast %a
/// %flattened_b = vector.shape_cast %b
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
/// %d = vector.shape_cast %%flattened_d
/// %e = add %c, %d
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to a reduction_size-unrolled sequence:
/// ```
/// %at = vector.transpose %a, [1, 0]
/// %bRow0 = vector.extract %b[0]
/// %atRow0 = vector.extract %at[0]
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
/// ...
/// %bRowK = vector.extract %b[K]
/// %atRowK = vector.extract %at[K]
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to an output-size-unrolled sequence:
/// ```
/// %out = arith.constant ... : vector<MxNxelt_type>
/// %bt = vector.transpose %b, [1, 0]
/// %aRow0 = vector.extract %a[0]
/// %btRow0 = vector.extract %bt[0]
/// %c00 = vector.reduce %atRow0, %bRow0
/// %out00 = vector.insert %c00, %out[0, 0]
/// ...
/// %aRowLast = vector.extract %at[M-1]
/// %btRowLast = vector.extract %b[N-1]
/// %cLastLast = vector.reduce %atRowLast, %bRowLast
/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to Dot and
/// the vector.contract op is a row-major matmul or matvec.
class ContractionOpToDotLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToDotLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of ContractionOp.
///
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
/// %a = vector.contract with one less free/batch dimension
/// %b = vector.contract with one less free/batch dimension
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a dot-product.
///
/// This only kicks in when either VectorTransformsOptions is set
/// to Dot or when other contraction patterns fail.
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context,
FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const;
// Lower one reduction dimension.
Value lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const;
};
} // namespace mlir
#endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_

View File

@ -14,8 +14,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"

View File

@ -21,7 +21,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@ -22,7 +22,6 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@ -32,6 +31,7 @@
#include "mlir/Transforms/Utils.h"
using namespace mlir;
using namespace mlir::vector;
using namespace linalg;
namespace {
@ -191,7 +191,7 @@ struct LinalgStrategyVectorizePass
}
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorizationPatterns);
vector::populateVetorReductionToContractPatterns(vectorizationPatterns);
vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
funcOp.getContext(), /*benefit=*/2);
@ -268,9 +268,14 @@ struct LinalgStrategyLowerVectorsPass
vector::populateVectorTransferLoweringPatterns(patterns,
options.maxTransferRank);
}
if (options.transferPartialRewrite) {
patterns.add<vector::VectorTransferFullPartialRewriter>(
context, options.vectorTransformOptions);
if (options.transposeLowering) {
vector::populateVectorTransposeLoweringPatterns(
patterns, options.vectorTransformOptions);
}
if (options.multiReductionLowering) {
vector::populateVectorMultiReductionLoweringPatterns(
patterns,
options.vectorTransformOptions.vectorMultiReductionLowering);
}
if (options.contractionLowering) {
patterns.add<ContractionOpToOuterProductOpLowering,
@ -278,15 +283,15 @@ struct LinalgStrategyLowerVectorsPass
options.vectorTransformOptions, context);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
if (options.multiReductionLowering) {
vector::populateVectorMultiReductionLoweringPatterns(
patterns,
options.vectorTransformOptions.vectorMultiReductionLowering);
if (options.transferPartialRewrite) {
patterns.add<vector::VectorTransferFullPartialRewriter>(
context, options.vectorTransformOptions);
}
if (options.transferToSCFConversion) {
populateVectorToSCFConversionPatterns(patterns,
options.vectorTransferToSCFOptions);
}
vector::populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

View File

@ -10,14 +10,9 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"

View File

@ -21,21 +21,10 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/DenseSet.h"
@ -48,6 +37,7 @@
#define DEBUG_TYPE "vector-to-vector"
using namespace mlir;
using namespace mlir::vector;
// Helper to find an index in an affine map.
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
@ -1978,9 +1968,41 @@ static Value createInBoundsCond(OpBuilder &b,
});
return inBoundsCond;
}
LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
VectorTransferOpInterface xferOp) {
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accomodate for the fact that the original vector.transfer indexing may be
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// // fastpath, direct cast
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slowpath, not in-bounds vector.transfer or linalg.copy.
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
static LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
// TODO: expand support to these 2 cases.
if (!xferOp.permutation_map().isMinorIdentity())
return failure();
@ -3863,7 +3885,7 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
patterns.add<TransposeOpLowering>(options, patterns.getContext());
}
void mlir::vector::populateVetorReductionToContractPatterns(
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractTranspose>(patterns.getContext());

View File

@ -98,9 +98,8 @@ void TestConvVectorization::runOnOperation() {
VectorTransposeLowering::EltWise};
RewritePatternSet vectorTransferPatterns(context);
// Pattern is not applied because rank-reducing vector transfer is not yet
// supported as can be seen in splitFullAndPartialTransferPrecondition,
// VectorTransforms.cpp
// Pattern is not applied: rank-reducing vector transfer is not yet supported
// (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp).
vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
context, vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));

View File

@ -536,7 +536,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
RewritePatternSet canonicalizationPatterns(funcOp.getContext());
vector::populateVectorTransferPermutationMapLoweringPatterns(
canonicalizationPatterns);
vector::populateVetorReductionToContractPatterns(canonicalizationPatterns);
vector::populateVectorReductionToContractPatterns(canonicalizationPatterns);
stage1Patterns.push_back(std::move(canonicalizationPatterns));
}
SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;

View File

@ -14,13 +14,13 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
namespace {
struct TestVectorToVectorConversion
@ -511,7 +511,7 @@ struct TestVectorReduceToContractPatternsPatterns
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVetorReductionToContractPatterns(patterns);
populateVectorReductionToContractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};