forked from OSchip/llvm-project
[mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.
Differential Revision: https://reviews.llvm.org/D112301
This commit is contained in:
parent
cac8808f15
commit
eda2ebd780
|
@ -0,0 +1,58 @@
|
||||||
|
//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
||||||
|
#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class RewritePatternSet;
|
||||||
|
|
||||||
|
namespace vector {
|
||||||
|
|
||||||
|
/// Populate `patterns` with the following patterns.
|
||||||
|
///
|
||||||
|
/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
|
||||||
|
/// =======================================================
|
||||||
|
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||||
|
/// have different ranks.
|
||||||
|
///
|
||||||
|
/// When ranks are different, InsertStridedSlice needs to extract a properly
|
||||||
|
/// ranked vector from the destination vector into which to insert. This pattern
|
||||||
|
/// only takes care of this extraction part and forwards the rest to
|
||||||
|
/// [VectorInsertStridedSliceOpSameRankRewritePattern].
|
||||||
|
///
|
||||||
|
/// For a k-D source and n-D destination vector (k < n), we emit:
|
||||||
|
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
|
||||||
|
/// insert the k-D source.
|
||||||
|
/// 2. k-D -> (n-1)-D InsertStridedSlice op
|
||||||
|
/// 3. InsertOp that is the reverse of 1.
|
||||||
|
///
|
||||||
|
/// [VectorInsertStridedSliceOpSameRankRewritePattern]
|
||||||
|
/// ==================================================
|
||||||
|
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||||
|
/// have the same rank. For each outermost index in the slice:
|
||||||
|
/// begin end stride
|
||||||
|
/// [offset : offset+size*stride : stride]
|
||||||
|
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
|
||||||
|
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
|
||||||
|
/// 3. the destination subvector is inserted back in the proper place
|
||||||
|
/// 3. InsertOp that is the reverse of 1.
|
||||||
|
///
|
||||||
|
/// [VectorExtractStridedSliceOpRewritePattern]
|
||||||
|
/// ===========================================
|
||||||
|
/// Progressive lowering of ExtractStridedSliceOp to either:
|
||||||
|
/// 1. single offset extract as a direct vector::ShuffleOp.
|
||||||
|
/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
|
||||||
|
/// InsertOp/InsertElementOp for the n-D case.
|
||||||
|
void populateVectorInsertExtractStridedSliceTransforms(
|
||||||
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
} // namespace vector
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
|
|
@ -24,13 +24,6 @@ namespace scf {
|
||||||
class IfOp;
|
class IfOp;
|
||||||
} // namespace scf
|
} // namespace scf
|
||||||
|
|
||||||
/// Collect a set of patterns to convert from the Vector dialect to itself.
|
|
||||||
/// Should be merged with populateVectorToSCFLoweringPattern.
|
|
||||||
void populateVectorToVectorConversionPatterns(
|
|
||||||
MLIRContext *context, RewritePatternSet &patterns,
|
|
||||||
ArrayRef<int64_t> coarseVectorShape = {},
|
|
||||||
ArrayRef<int64_t> fineVectorShape = {});
|
|
||||||
|
|
||||||
namespace vector {
|
namespace vector {
|
||||||
|
|
||||||
/// Options that control the vector unrolling.
|
/// Options that control the vector unrolling.
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_
|
#ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_
|
||||||
#define MLIR_DIALECT_VECTOR_VECTORUTILS_H_
|
#define MLIR_DIALECT_VECTOR_VECTORUTILS_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
@ -184,6 +185,11 @@ bool checkSameValueRAW(vector::TransferWriteOp defWrite,
|
||||||
bool checkSameValueWAW(vector::TransferWriteOp write,
|
bool checkSameValueWAW(vector::TransferWriteOp write,
|
||||||
vector::TransferWriteOp priorWrite);
|
vector::TransferWriteOp priorWrite);
|
||||||
|
|
||||||
|
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||||
|
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||||
|
unsigned dropFront = 0,
|
||||||
|
unsigned dropBack = 0);
|
||||||
|
|
||||||
namespace matcher {
|
namespace matcher {
|
||||||
|
|
||||||
/// Matches vector.transfer_read, vector.transfer_write and ops that return a
|
/// Matches vector.transfer_read, vector.transfer_write and ops that return a
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
|
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/Support/MathExtras.h"
|
#include "mlir/Support/MathExtras.h"
|
||||||
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
|
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
|
||||||
|
@ -52,17 +53,6 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
|
||||||
rewriter.getI64ArrayAttr(pos));
|
rewriter.getI64ArrayAttr(pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper that picks the proper sequence for inserting.
|
|
||||||
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
|
|
||||||
Value into, int64_t offset) {
|
|
||||||
auto vectorType = into.getType().cast<VectorType>();
|
|
||||||
if (vectorType.getRank() > 1)
|
|
||||||
return rewriter.create<InsertOp>(loc, from, into, offset);
|
|
||||||
return rewriter.create<vector::InsertElementOp>(
|
|
||||||
loc, vectorType, from, into,
|
|
||||||
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper that picks the proper sequence for extracting.
|
// Helper that picks the proper sequence for extracting.
|
||||||
static Value extractOne(ConversionPatternRewriter &rewriter,
|
static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||||
LLVMTypeConverter &typeConverter, Location loc,
|
LLVMTypeConverter &typeConverter, Location loc,
|
||||||
|
@ -79,32 +69,6 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||||
rewriter.getI64ArrayAttr(pos));
|
rewriter.getI64ArrayAttr(pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper that picks the proper sequence for extracting.
|
|
||||||
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
|
|
||||||
int64_t offset) {
|
|
||||||
auto vectorType = vector.getType().cast<VectorType>();
|
|
||||||
if (vectorType.getRank() > 1)
|
|
||||||
return rewriter.create<ExtractOp>(loc, vector, offset);
|
|
||||||
return rewriter.create<vector::ExtractElementOp>(
|
|
||||||
loc, vectorType.getElementType(), vector,
|
|
||||||
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
|
||||||
// TODO: Better support for attribute subtype forwarding + slicing.
|
|
||||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
|
||||||
unsigned dropFront = 0,
|
|
||||||
unsigned dropBack = 0) {
|
|
||||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
|
||||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
|
||||||
SmallVector<int64_t, 4> res;
|
|
||||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
|
||||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
|
||||||
it != eit; ++it)
|
|
||||||
res.push_back((*it).getValue().getSExtValue());
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper that returns data layout alignment of a memref.
|
// Helper that returns data layout alignment of a memref.
|
||||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||||
MemRefType memrefType, unsigned &align) {
|
MemRefType memrefType, unsigned &align) {
|
||||||
|
@ -813,132 +777,6 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// When ranks are different, InsertStridedSlice needs to extract a properly
|
|
||||||
// ranked vector from the destination vector into which to insert. This pattern
|
|
||||||
// only takes care of this part and forwards the rest of the conversion to
|
|
||||||
// another pattern that converts InsertStridedSlice for operands of the same
|
|
||||||
// rank.
|
|
||||||
//
|
|
||||||
// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
|
||||||
// have different ranks. In this case:
|
|
||||||
// 1. the proper subvector is extracted from the destination vector
|
|
||||||
// 2. a new InsertStridedSlice op is created to insert the source in the
|
|
||||||
// destination subvector
|
|
||||||
// 3. the destination subvector is inserted back in the proper place
|
|
||||||
// 4. the op is replaced by the result of step 3.
|
|
||||||
// The new InsertStridedSlice from step 2. will be picked up by a
|
|
||||||
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
|
|
||||||
class VectorInsertStridedSliceOpDifferentRankRewritePattern
|
|
||||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto srcType = op.getSourceVectorType();
|
|
||||||
auto dstType = op.getDestVectorType();
|
|
||||||
|
|
||||||
if (op.offsets().getValue().empty())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
|
||||||
assert(rankDiff >= 0);
|
|
||||||
if (rankDiff == 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
int64_t rankRest = dstType.getRank() - rankDiff;
|
|
||||||
// Extract / insert the subvector of matching rank and InsertStridedSlice
|
|
||||||
// on it.
|
|
||||||
Value extracted =
|
|
||||||
rewriter.create<ExtractOp>(loc, op.dest(),
|
|
||||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
|
||||||
/*dropBack=*/rankRest));
|
|
||||||
// A different pattern will kick in for InsertStridedSlice with matching
|
|
||||||
// ranks.
|
|
||||||
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
|
|
||||||
loc, op.source(), extracted,
|
|
||||||
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
|
|
||||||
getI64SubArray(op.strides(), /*dropFront=*/0));
|
|
||||||
rewriter.replaceOpWithNewOp<InsertOp>(
|
|
||||||
op, stridedSliceInnerOp.getResult(), op.dest(),
|
|
||||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
|
||||||
/*dropBack=*/rankRest));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
|
||||||
// have the same rank. In this case, we reduce
|
|
||||||
// 1. the proper subvector is extracted from the destination vector
|
|
||||||
// 2. a new InsertStridedSlice op is created to insert the source in the
|
|
||||||
// destination subvector
|
|
||||||
// 3. the destination subvector is inserted back in the proper place
|
|
||||||
// 4. the op is replaced by the result of step 3.
|
|
||||||
// The new InsertStridedSlice from step 2. will be picked up by a
|
|
||||||
// `VectorInsertStridedSliceOpSameRankRewritePattern`.
|
|
||||||
class VectorInsertStridedSliceOpSameRankRewritePattern
|
|
||||||
: public OpRewritePattern<InsertStridedSliceOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
void initialize() {
|
|
||||||
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
|
|
||||||
// bounded as the rank is strictly decreasing.
|
|
||||||
setHasBoundedRewriteRecursion();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto srcType = op.getSourceVectorType();
|
|
||||||
auto dstType = op.getDestVectorType();
|
|
||||||
|
|
||||||
if (op.offsets().getValue().empty())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
|
||||||
assert(rankDiff >= 0);
|
|
||||||
if (rankDiff != 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (srcType == dstType) {
|
|
||||||
rewriter.replaceOp(op, op.source());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t offset =
|
|
||||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
|
||||||
int64_t size = srcType.getShape().front();
|
|
||||||
int64_t stride =
|
|
||||||
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
Value res = op.dest();
|
|
||||||
// For each slice of the source vector along the most major dimension.
|
|
||||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
|
||||||
off += stride, ++idx) {
|
|
||||||
// 1. extract the proper subvector (or element) from source
|
|
||||||
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
|
|
||||||
if (extractedSource.getType().isa<VectorType>()) {
|
|
||||||
// 2. If we have a vector, extract the proper subvector from destination
|
|
||||||
// Otherwise we are at the element level and no need to recurse.
|
|
||||||
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
|
|
||||||
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
|
|
||||||
// smaller rank.
|
|
||||||
extractedSource = rewriter.create<InsertStridedSliceOp>(
|
|
||||||
loc, extractedSource, extractedDest,
|
|
||||||
getI64SubArray(op.offsets(), /* dropFront=*/1),
|
|
||||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
|
||||||
}
|
|
||||||
// 4. Insert the extractedSource into the res vector.
|
|
||||||
res = insertOne(rewriter, loc, extractedSource, res, off);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, res);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Returns the strides if the memory underlying `memRefType` has a contiguous
|
/// Returns the strides if the memory underlying `memRefType` has a contiguous
|
||||||
/// static layout.
|
/// static layout.
|
||||||
static llvm::Optional<SmallVector<int64_t, 4>>
|
static llvm::Optional<SmallVector<int64_t, 4>>
|
||||||
|
@ -1189,67 +1027,6 @@ private:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Progressive lowering of ExtractStridedSliceOp to either:
|
|
||||||
/// 1. express single offset extract as a direct shuffle.
|
|
||||||
/// 2. extract + lower rank strided_slice + insert for the n-D case.
|
|
||||||
class VectorExtractStridedSliceOpConversion
|
|
||||||
: public OpRewritePattern<ExtractStridedSliceOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
void initialize() {
|
|
||||||
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
|
|
||||||
// is bounded as the rank is strictly decreasing.
|
|
||||||
setHasBoundedRewriteRecursion();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto dstType = op.getType();
|
|
||||||
|
|
||||||
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
|
|
||||||
|
|
||||||
int64_t offset =
|
|
||||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
|
||||||
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
|
|
||||||
int64_t stride =
|
|
||||||
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
auto elemType = dstType.getElementType();
|
|
||||||
assert(elemType.isSignlessIntOrIndexOrFloat());
|
|
||||||
|
|
||||||
// Single offset can be more efficiently shuffled.
|
|
||||||
if (op.offsets().getValue().size() == 1) {
|
|
||||||
SmallVector<int64_t, 4> offsets;
|
|
||||||
offsets.reserve(size);
|
|
||||||
for (int64_t off = offset, e = offset + size * stride; off < e;
|
|
||||||
off += stride)
|
|
||||||
offsets.push_back(off);
|
|
||||||
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
|
|
||||||
op.vector(),
|
|
||||||
rewriter.getI64ArrayAttr(offsets));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract/insert on a lower ranked extract strided slice op.
|
|
||||||
Value zero = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc, elemType, rewriter.getZeroAttr(elemType));
|
|
||||||
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
|
|
||||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
|
||||||
off += stride, ++idx) {
|
|
||||||
Value one = extractOne(rewriter, loc, op.vector(), off);
|
|
||||||
Value extracted = rewriter.create<ExtractStridedSliceOp>(
|
|
||||||
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
|
|
||||||
getI64SubArray(op.sizes(), /* dropFront=*/1),
|
|
||||||
getI64SubArray(op.strides(), /* dropFront=*/1));
|
|
||||||
res = insertOne(rewriter, loc, extracted, res, idx);
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(op, res);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||||
|
@ -1257,10 +1034,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||||
bool reassociateFPReductions) {
|
bool reassociateFPReductions) {
|
||||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||||
patterns.add<VectorFMAOpNDRewritePattern,
|
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
|
||||||
VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
populateVectorInsertExtractStridedSliceTransforms(patterns);
|
||||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
|
||||||
VectorExtractStridedSliceOpConversion>(ctx);
|
|
||||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||||
patterns
|
patterns
|
||||||
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
add_mlir_dialect_library(MLIRVector
|
add_mlir_dialect_library(MLIRVector
|
||||||
VectorOps.cpp
|
VectorInsertExtractStridedSliceRewritePatterns.cpp
|
||||||
VectorMultiDimReductionTransforms.cpp
|
VectorMultiDimReductionTransforms.cpp
|
||||||
|
VectorOps.cpp
|
||||||
VectorTransferOpTransforms.cpp
|
VectorTransferOpTransforms.cpp
|
||||||
VectorTransforms.cpp
|
VectorTransforms.cpp
|
||||||
VectorUtils.cpp
|
VectorUtils.cpp
|
||||||
|
|
|
@ -0,0 +1,236 @@
|
||||||
|
//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
|
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
|
||||||
|
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::vector;
|
||||||
|
|
||||||
|
// Helper that picks the proper sequence for inserting.
|
||||||
|
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
|
||||||
|
Value into, int64_t offset) {
|
||||||
|
auto vectorType = into.getType().cast<VectorType>();
|
||||||
|
if (vectorType.getRank() > 1)
|
||||||
|
return rewriter.create<InsertOp>(loc, from, into, offset);
|
||||||
|
return rewriter.create<vector::InsertElementOp>(
|
||||||
|
loc, vectorType, from, into,
|
||||||
|
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper that picks the proper sequence for extracting.
|
||||||
|
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
|
||||||
|
int64_t offset) {
|
||||||
|
auto vectorType = vector.getType().cast<VectorType>();
|
||||||
|
if (vectorType.getRank() > 1)
|
||||||
|
return rewriter.create<ExtractOp>(loc, vector, offset);
|
||||||
|
return rewriter.create<vector::ExtractElementOp>(
|
||||||
|
loc, vectorType.getElementType(), vector,
|
||||||
|
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||||
|
/// have different ranks.
|
||||||
|
///
|
||||||
|
/// When ranks are different, InsertStridedSlice needs to extract a properly
|
||||||
|
/// ranked vector from the destination vector into which to insert. This pattern
|
||||||
|
/// only takes care of this extraction part and forwards the rest to
|
||||||
|
/// [VectorInsertStridedSliceOpSameRankRewritePattern].
|
||||||
|
///
|
||||||
|
/// For a k-D source and n-D destination vector (k < n), we emit:
|
||||||
|
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
|
||||||
|
/// insert the k-D source.
|
||||||
|
/// 2. k-D -> (n-1)-D InsertStridedSlice op
|
||||||
|
/// 3. InsertOp that is the reverse of 1.
|
||||||
|
class VectorInsertStridedSliceOpDifferentRankRewritePattern
|
||||||
|
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto srcType = op.getSourceVectorType();
|
||||||
|
auto dstType = op.getDestVectorType();
|
||||||
|
|
||||||
|
if (op.offsets().getValue().empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||||
|
assert(rankDiff >= 0);
|
||||||
|
if (rankDiff == 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t rankRest = dstType.getRank() - rankDiff;
|
||||||
|
// Extract / insert the subvector of matching rank and InsertStridedSlice
|
||||||
|
// on it.
|
||||||
|
Value extracted =
|
||||||
|
rewriter.create<ExtractOp>(loc, op.dest(),
|
||||||
|
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||||
|
/*dropBack=*/rankRest));
|
||||||
|
|
||||||
|
// A different pattern will kick in for InsertStridedSlice with matching
|
||||||
|
// ranks.
|
||||||
|
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
|
||||||
|
loc, op.source(), extracted,
|
||||||
|
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
|
||||||
|
getI64SubArray(op.strides(), /*dropFront=*/0));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<InsertOp>(
|
||||||
|
op, stridedSliceInnerOp.getResult(), op.dest(),
|
||||||
|
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||||
|
/*dropBack=*/rankRest));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
|
||||||
|
/// have the same rank. For each outermost index in the slice:
|
||||||
|
/// begin end stride
|
||||||
|
/// [offset : offset+size*stride : stride]
|
||||||
|
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
|
||||||
|
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
|
||||||
|
/// 3. the destination subvector is inserted back in the proper place
|
||||||
|
/// 3. InsertOp that is the reverse of 1.
|
||||||
|
class VectorInsertStridedSliceOpSameRankRewritePattern
|
||||||
|
: public OpRewritePattern<InsertStridedSliceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
void initialize() {
|
||||||
|
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
|
||||||
|
// bounded as the rank is strictly decreasing.
|
||||||
|
setHasBoundedRewriteRecursion();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto srcType = op.getSourceVectorType();
|
||||||
|
auto dstType = op.getDestVectorType();
|
||||||
|
|
||||||
|
if (op.offsets().getValue().empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||||
|
assert(rankDiff >= 0);
|
||||||
|
if (rankDiff != 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (srcType == dstType) {
|
||||||
|
rewriter.replaceOp(op, op.source());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t offset =
|
||||||
|
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
||||||
|
int64_t size = srcType.getShape().front();
|
||||||
|
int64_t stride =
|
||||||
|
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value res = op.dest();
|
||||||
|
// For each slice of the source vector along the most major dimension.
|
||||||
|
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||||
|
off += stride, ++idx) {
|
||||||
|
// 1. extract the proper subvector (or element) from source
|
||||||
|
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
|
||||||
|
if (extractedSource.getType().isa<VectorType>()) {
|
||||||
|
// 2. If we have a vector, extract the proper subvector from destination
|
||||||
|
// Otherwise we are at the element level and no need to recurse.
|
||||||
|
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
|
||||||
|
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
|
||||||
|
// smaller rank.
|
||||||
|
extractedSource = rewriter.create<InsertStridedSliceOp>(
|
||||||
|
loc, extractedSource, extractedDest,
|
||||||
|
getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||||
|
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||||
|
}
|
||||||
|
// 4. Insert the extractedSource into the res vector.
|
||||||
|
res = insertOne(rewriter, loc, extractedSource, res, off);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Progressive lowering of ExtractStridedSliceOp to either:
|
||||||
|
/// 1. single offset extract as a direct vector::ShuffleOp.
|
||||||
|
/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
|
||||||
|
/// InsertOp/InsertElementOp for the n-D case.
|
||||||
|
class VectorExtractStridedSliceOpRewritePattern
|
||||||
|
: public OpRewritePattern<ExtractStridedSliceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
void initialize() {
|
||||||
|
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
|
||||||
|
// is bounded as the rank is strictly decreasing.
|
||||||
|
setHasBoundedRewriteRecursion();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto dstType = op.getType();
|
||||||
|
|
||||||
|
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
|
||||||
|
|
||||||
|
int64_t offset =
|
||||||
|
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
|
||||||
|
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
|
||||||
|
int64_t stride =
|
||||||
|
op.strides().getValue().front().cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto elemType = dstType.getElementType();
|
||||||
|
assert(elemType.isSignlessIntOrIndexOrFloat());
|
||||||
|
|
||||||
|
// Single offset can be more efficiently shuffled.
|
||||||
|
if (op.offsets().getValue().size() == 1) {
|
||||||
|
SmallVector<int64_t, 4> offsets;
|
||||||
|
offsets.reserve(size);
|
||||||
|
for (int64_t off = offset, e = offset + size * stride; off < e;
|
||||||
|
off += stride)
|
||||||
|
offsets.push_back(off);
|
||||||
|
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
|
||||||
|
op.vector(),
|
||||||
|
rewriter.getI64ArrayAttr(offsets));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract/insert on a lower ranked extract strided slice op.
|
||||||
|
Value zero = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, elemType, rewriter.getZeroAttr(elemType));
|
||||||
|
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
|
||||||
|
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||||
|
off += stride, ++idx) {
|
||||||
|
Value one = extractOne(rewriter, loc, op.vector(), off);
|
||||||
|
Value extracted = rewriter.create<ExtractStridedSliceOp>(
|
||||||
|
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
|
||||||
|
getI64SubArray(op.sizes(), /* dropFront=*/1),
|
||||||
|
getI64SubArray(op.strides(), /* dropFront=*/1));
|
||||||
|
res = insertOne(rewriter, loc, extracted, res, idx);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||||
|
void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
|
||||||
|
RewritePatternSet &patterns) {
|
||||||
|
patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||||
|
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||||
|
VectorExtractStridedSliceOpRewritePattern>(
|
||||||
|
patterns.getContext());
|
||||||
|
}
|
|
@ -2204,20 +2204,6 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
|
||||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
|
||||||
unsigned dropFront = 0,
|
|
||||||
unsigned dropBack = 0) {
|
|
||||||
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
|
||||||
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
|
||||||
SmallVector<int64_t, 4> res;
|
|
||||||
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
|
||||||
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
|
||||||
it != eit; ++it)
|
|
||||||
res.push_back((*it).getValue().getSExtValue());
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
|
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
|
||||||
// BroadcastOp(ExtractStrideSliceOp).
|
// BroadcastOp(ExtractStrideSliceOp).
|
||||||
class StridedSliceBroadcast final
|
class StridedSliceBroadcast final
|
||||||
|
|
|
@ -1034,10 +1034,11 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
|
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
|
||||||
/// vectors progressively on the way from targeting llvm.matrix intrinsics.
|
/// vectors progressively.
|
||||||
/// This iterates over the most major dimension of the 2-D vector and performs
|
/// This iterates over the most major dimension of the 2-D vector and performs
|
||||||
/// rewrites into:
|
/// rewrites into:
|
||||||
/// vector.strided_slice from 1-D + vector.insert into 2-D
|
/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
|
||||||
|
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
|
||||||
class ShapeCastOp2DUpCastRewritePattern
|
class ShapeCastOp2DUpCastRewritePattern
|
||||||
: public OpRewritePattern<vector::ShapeCastOp> {
|
: public OpRewritePattern<vector::ShapeCastOp> {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -362,3 +362,16 @@ bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
|
||||||
priorWrite.getVectorType() == write.getVectorType() &&
|
priorWrite.getVectorType() == write.getVectorType() &&
|
||||||
priorWrite.permutation_map() == write.permutation_map();
|
priorWrite.permutation_map() == write.permutation_map();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
|
||||||
|
unsigned dropFront,
|
||||||
|
unsigned dropBack) {
|
||||||
|
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
|
||||||
|
auto range = arrayAttr.getAsRange<IntegerAttr>();
|
||||||
|
SmallVector<int64_t, 4> res;
|
||||||
|
res.reserve(arrayAttr.size() - dropFront - dropBack);
|
||||||
|
for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
|
||||||
|
it != eit; ++it)
|
||||||
|
res.push_back((*it).getValue().getSExtValue());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue