[mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.

Differential Revision: https://reviews.llvm.org/D112301
This commit is contained in:
Nicolas Vasilache 2021-10-22 09:39:07 +00:00
parent cac8808f15
commit eda2ebd780
9 changed files with 321 additions and 252 deletions

View File

@ -0,0 +1,58 @@
//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
namespace mlir {
class RewritePatternSet;
namespace vector {
/// Populate `patterns` with the following patterns.
///
/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
/// =======================================================
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have different ranks.
///
/// When ranks are different, InsertStridedSlice needs to extract a properly
/// ranked vector from the destination vector into which to insert. This pattern
/// only takes care of this extraction part and forwards the rest to
/// [VectorInsertStridedSliceOpSameRankRewritePattern].
///
/// For a k-D source and n-D destination vector (k < n), we emit:
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
/// insert the k-D source.
/// 2. k-D -> (n-1)-D InsertStridedSlice op
/// 3. InsertOp that is the reverse of 1.
///
/// [VectorInsertStridedSliceOpSameRankRewritePattern]
/// ==================================================
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have the same rank. For each outermost index in the slice:
/// begin end stride
/// [offset : offset+size*stride : stride]
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
/// 3. the destination subvector is inserted back in the proper place
/// 3. InsertOp that is the reverse of 1.
///
/// [VectorExtractStridedSliceOpRewritePattern]
/// ===========================================
/// Progressive lowering of ExtractStridedSliceOp to either:
/// 1. single offset extract as a direct vector::ShuffleOp.
/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
/// InsertOp/InsertElementOp for the n-D case.
void populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns);
} // namespace vector
} // namespace mlir
#endif // DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_

View File

@ -24,13 +24,6 @@ namespace scf {
class IfOp; class IfOp;
} // namespace scf } // namespace scf
/// Collect a set of patterns to convert from the Vector dialect to itself.
/// Should be merged with populateVectorToSCFLoweringPattern.
void populateVectorToVectorConversionPatterns(
MLIRContext *context, RewritePatternSet &patterns,
ArrayRef<int64_t> coarseVectorShape = {},
ArrayRef<int64_t> fineVectorShape = {});
namespace vector { namespace vector {
/// Options that control the vector unrolling. /// Options that control the vector unrolling.

View File

@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_ #ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_
#define MLIR_DIALECT_VECTOR_VECTORUTILS_H_ #define MLIR_DIALECT_VECTOR_VECTORUTILS_H_
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
@ -184,6 +185,11 @@ bool checkSameValueRAW(vector::TransferWriteOp defWrite,
bool checkSameValueWAW(vector::TransferWriteOp write, bool checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite); vector::TransferWriteOp priorWrite);
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,
unsigned dropBack = 0);
namespace matcher { namespace matcher {
/// Matches vector.transfer_read, vector.transfer_write and ops that return a /// Matches vector.transfer_read, vector.transfer_write and ops that return a

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/MathExtras.h" #include "mlir/Support/MathExtras.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h"
@ -52,17 +53,6 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos)); rewriter.getI64ArrayAttr(pos));
} }
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = into.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
// Helper that picks the proper sequence for extracting. // Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter, static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc, LLVMTypeConverter &typeConverter, Location loc,
@ -79,32 +69,6 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos)); rewriter.getI64ArrayAttr(pos));
} }
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = vector.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
// TODO: Better support for attribute subtype forwarding + slicing.
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,
unsigned dropBack = 0) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
SmallVector<int64_t, 4> res;
res.reserve(arrayAttr.size() - dropFront - dropBack);
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
it != eit; ++it)
res.push_back((*it).getValue().getSExtValue());
return res;
}
// Helper that returns data layout alignment of a memref. // Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) { MemRefType memrefType, unsigned &align) {
@ -813,132 +777,6 @@ public:
} }
}; };
// When ranks are different, InsertStridedSlice needs to extract a properly
// ranked vector from the destination vector into which to insert. This pattern
// only takes care of this part and forwards the rest of the conversion to
// another pattern that converts InsertStridedSlice for operands of the same
// rank.
//
// RewritePattern for InsertStridedSliceOp where source and destination vectors
// have different ranks. In this case:
// 1. the proper subvector is extracted from the destination vector
// 2. a new InsertStridedSlice op is created to insert the source in the
// destination subvector
// 3. the destination subvector is inserted back in the proper place
// 4. the op is replaced by the result of step 3.
// The new InsertStridedSlice from step 2. will be picked up by a
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
class VectorInsertStridedSliceOpDifferentRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
// on it.
Value extracted =
rewriter.create<ExtractOp>(loc, op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
// A different pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
loc, op.source(), extracted,
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
getI64SubArray(op.strides(), /*dropFront=*/0));
rewriter.replaceOpWithNewOp<InsertOp>(
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
return success();
}
};
// RewritePattern for InsertStridedSliceOp where source and destination vectors
// have the same rank. In this case, we reduce
// 1. the proper subvector is extracted from the destination vector
// 2. a new InsertStridedSlice op is created to insert the source in the
// destination subvector
// 3. the destination subvector is inserted back in the proper place
// 4. the op is replaced by the result of step 3.
// The new InsertStridedSlice from step 2. will be picked up by a
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
class VectorInsertStridedSliceOpSameRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
// bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return failure();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff != 0)
return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.source());
return success();
}
int64_t offset =
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
op.strides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
Value res = op.dest();
// For each slice of the source vector along the most major dimension.
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
if (extractedSource.getType().isa<VectorType>()) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
extractedSource = rewriter.create<InsertStridedSliceOp>(
loc, extractedSource, extractedDest,
getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
}
rewriter.replaceOp(op, res);
return success();
}
};
/// Returns the strides if the memory underlying `memRefType` has a contiguous /// Returns the strides if the memory underlying `memRefType` has a contiguous
/// static layout. /// static layout.
static llvm::Optional<SmallVector<int64_t, 4>> static llvm::Optional<SmallVector<int64_t, 4>>
@ -1189,67 +1027,6 @@ private:
} }
}; };
/// Progressive lowering of ExtractStridedSliceOp to either:
/// 1. express single offset extract as a direct shuffle.
/// 2. extract + lower rank strided_slice + insert for the n-D case.
class VectorExtractStridedSliceOpConversion
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
// is bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
op.strides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
// Single offset can be more efficiently shuffled.
if (op.offsets().getValue().size() == 1) {
SmallVector<int64_t, 4> offsets;
offsets.reserve(size);
for (int64_t off = offset, e = offset + size * stride; off < e;
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
op.vector(),
rewriter.getI64ArrayAttr(offsets));
return success();
}
// Extract/insert on a lower ranked extract strided slice op.
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = extractOne(rewriter, loc, op.vector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.sizes(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
return success();
}
};
} // namespace } // namespace
/// Populate the given list with patterns that convert from Vector to LLVM. /// Populate the given list with patterns that convert from Vector to LLVM.
@ -1257,10 +1034,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns, LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions) { bool reassociateFPReductions) {
MLIRContext *ctx = converter.getDialect()->getContext(); MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorFMAOpNDRewritePattern, patterns.add<VectorFMAOpNDRewritePattern>(ctx);
VectorInsertStridedSliceOpDifferentRankRewritePattern, populateVectorInsertExtractStridedSliceTransforms(patterns);
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpConversion>(ctx);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion, .add<VectorBitCastOpConversion, VectorShuffleOpConversion,

View File

@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRVector add_mlir_dialect_library(MLIRVector
VectorOps.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorMultiDimReductionTransforms.cpp VectorMultiDimReductionTransforms.cpp
VectorOps.cpp
VectorTransferOpTransforms.cpp VectorTransferOpTransforms.cpp
VectorTransforms.cpp VectorTransforms.cpp
VectorUtils.cpp VectorUtils.cpp

View File

@ -0,0 +1,236 @@
//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
using namespace mlir;
using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = into.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = vector.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have different ranks.
///
/// When ranks are different, InsertStridedSlice needs to extract a properly
/// ranked vector from the destination vector into which to insert. This pattern
/// only takes care of this extraction part and forwards the rest to
/// [VectorInsertStridedSliceOpSameRankRewritePattern].
///
/// For a k-D source and n-D destination vector (k < n), we emit:
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
/// insert the k-D source.
/// 2. k-D -> (n-1)-D InsertStridedSlice op
/// 3. InsertOp that is the reverse of 1.
class VectorInsertStridedSliceOpDifferentRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
// on it.
Value extracted =
rewriter.create<ExtractOp>(loc, op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
// A different pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
loc, op.source(), extracted,
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
getI64SubArray(op.strides(), /*dropFront=*/0));
rewriter.replaceOpWithNewOp<InsertOp>(
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
return success();
}
};
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have the same rank. For each outermost index in the slice:
/// begin end stride
/// [offset : offset+size*stride : stride]
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
/// 3. the destination subvector is inserted back in the proper place
/// 3. InsertOp that is the reverse of 1.
class VectorInsertStridedSliceOpSameRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
// bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return failure();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff != 0)
return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.source());
return success();
}
int64_t offset =
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
op.strides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
Value res = op.dest();
// For each slice of the source vector along the most major dimension.
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
if (extractedSource.getType().isa<VectorType>()) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
extractedSource = rewriter.create<InsertStridedSliceOp>(
loc, extractedSource, extractedDest,
getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
}
rewriter.replaceOp(op, res);
return success();
}
};
/// Progressive lowering of ExtractStridedSliceOp to either:
/// 1. single offset extract as a direct vector::ShuffleOp.
/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
/// InsertOp/InsertElementOp for the n-D case.
class VectorExtractStridedSliceOpRewritePattern
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
// is bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
op.strides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
// Single offset can be more efficiently shuffled.
if (op.offsets().getValue().size() == 1) {
SmallVector<int64_t, 4> offsets;
offsets.reserve(size);
for (int64_t off = offset, e = offset + size * stride; off < e;
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
op.vector(),
rewriter.getI64ArrayAttr(offsets));
return success();
}
// Extract/insert on a lower ranked extract strided slice op.
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = extractOne(rewriter, loc, op.vector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.sizes(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
return success();
}
};
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns) {
patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpRewritePattern>(
patterns.getContext());
}

View File

@ -2204,20 +2204,6 @@ public:
} }
}; };
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,
unsigned dropBack = 0) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
SmallVector<int64_t, 4> res;
res.reserve(arrayAttr.size() - dropFront - dropBack);
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
it != eit; ++it)
res.push_back((*it).getValue().getSExtValue());
return res;
}
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
// BroadcastOp(ExtractStrideSliceOp). // BroadcastOp(ExtractStrideSliceOp).
class StridedSliceBroadcast final class StridedSliceBroadcast final

View File

@ -1034,10 +1034,11 @@ public:
}; };
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
/// vectors progressively on the way from targeting llvm.matrix intrinsics. /// vectors progressively.
/// This iterates over the most major dimension of the 2-D vector and performs /// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into: /// rewrites into:
/// vector.strided_slice from 1-D + vector.insert into 2-D /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
class ShapeCastOp2DUpCastRewritePattern class ShapeCastOp2DUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> { : public OpRewritePattern<vector::ShapeCastOp> {
public: public:

View File

@ -362,3 +362,16 @@ bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
priorWrite.getVectorType() == write.getVectorType() && priorWrite.getVectorType() == write.getVectorType() &&
priorWrite.permutation_map() == write.permutation_map(); priorWrite.permutation_map() == write.permutation_map();
} }
SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
SmallVector<int64_t, 4> res;
res.reserve(arrayAttr.size() - dropFront - dropBack);
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
it != eit; ++it)
res.push_back((*it).getValue().getSExtValue());
return res;
}