forked from OSchip/llvm-project
[mlir] Add "mask" operand to vector.transfer_read/write.
Also factors out out-of-bounds mask generation from vector.transfer_read/write into a new MaterializeTransferMask pattern. Differential Revision: https://reviews.llvm.org/D100001
This commit is contained in:
parent
c0ef93bec8
commit
65a3f28939
|
@ -68,7 +68,7 @@ void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
|
|||
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
|
||||
void populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions = false, bool enableIndexOptimizations = true);
|
||||
bool reassociateFPReductions = false);
|
||||
|
||||
/// Create a pass to convert vector operations to the LLVMIR dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
|
||||
|
|
|
@ -88,6 +88,10 @@ void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns);
|
|||
/// `vector.store` and `vector.broadcast`.
|
||||
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// These patterns materialize masks for various vector ops such as transfers.
|
||||
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
||||
bool enableIndexOptimizations);
|
||||
|
||||
/// An attribute that specifies the combining function for `vector.contract`,
|
||||
/// and `vector.reduction`.
|
||||
class CombiningKindAttr
|
||||
|
|
|
@ -1135,10 +1135,12 @@ def Vector_TransferReadOp :
|
|||
Vector_Op<"transfer_read", [
|
||||
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
AttrSizedOperandSegments
|
||||
]>,
|
||||
Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map, AnyType:$padding,
|
||||
Optional<VectorOf<[I1]>>:$mask,
|
||||
OptionalAttr<BoolArrayAttr>:$in_bounds)>,
|
||||
Results<(outs AnyVector:$vector)> {
|
||||
|
||||
|
@ -1167,13 +1169,19 @@ def Vector_TransferReadOp :
|
|||
return type.
|
||||
|
||||
An SSA value `padding` of the same elemental type as the MemRef/Tensor is
|
||||
provided to specify a fallback value in the case of out-of-bounds accesses.
|
||||
provided to specify a fallback value in the case of out-of-bounds accesses
|
||||
and/or masking.
|
||||
|
||||
An optional SSA value `mask` of the same shape as the vector type may be
|
||||
specified to mask out elements. Such elements will be replaces with
|
||||
`padding`. Elements whose corresponding mask element is `0` are masked out.
|
||||
|
||||
An optional boolean array attribute is provided to specify which dimensions
|
||||
of the transfer are guaranteed to be within bounds. The absence of this
|
||||
`in_bounds` attribute signifies that any dimension of the transfer may be
|
||||
out-of-bounds. A `vector.transfer_read` can be lowered to a simple load if
|
||||
all dimensions are specified to be within bounds.
|
||||
all dimensions are specified to be within bounds and no `mask` was
|
||||
specified.
|
||||
|
||||
This operation is called 'read' by opposition to 'load' because the
|
||||
super-vector granularity is generally not representable with a single
|
||||
|
@ -1299,6 +1307,14 @@ def Vector_TransferReadOp :
|
|||
// 'getMinorIdentityMap' (resp. zero).
|
||||
OpBuilder<(ins "VectorType":$vector, "Value":$source,
|
||||
"ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
|
||||
// Builder that does not set mask.
|
||||
OpBuilder<(ins "Type":$vector, "Value":$source,
|
||||
"ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding,
|
||||
"ArrayAttr":$inBounds)>,
|
||||
// Builder that does not set mask.
|
||||
OpBuilder<(ins "Type":$vector, "Value":$source,
|
||||
"ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding,
|
||||
"ArrayAttr":$inBounds)>
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
@ -1308,11 +1324,13 @@ def Vector_TransferWriteOp :
|
|||
Vector_Op<"transfer_write", [
|
||||
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
AttrSizedOperandSegments
|
||||
]>,
|
||||
Arguments<(ins AnyVector:$vector, AnyShaped:$source,
|
||||
Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map,
|
||||
Optional<VectorOf<[I1]>>:$mask,
|
||||
OptionalAttr<BoolArrayAttr>:$in_bounds)>,
|
||||
Results<(outs Optional<AnyRankedTensor>:$result)> {
|
||||
|
||||
|
@ -1341,11 +1359,16 @@ def Vector_TransferWriteOp :
|
|||
|
||||
The size of the slice is specified by the size of the vector.
|
||||
|
||||
An optional SSA value `mask` of the same shape as the vector type may be
|
||||
specified to mask out elements. Elements whose corresponding mask element
|
||||
is `0` are masked out.
|
||||
|
||||
An optional boolean array attribute is provided to specify which dimensions
|
||||
of the transfer are guaranteed to be within bounds. The absence of this
|
||||
`in_bounds` attribute signifies that any dimension of the transfer may be
|
||||
out-of-bounds. A `vector.transfer_write` can be lowered to a simple store
|
||||
if all dimensions are specified to be within bounds.
|
||||
if all dimensions are specified to be within bounds and no `mask` was
|
||||
specified.
|
||||
|
||||
This operation is called 'write' by opposition to 'store' because the
|
||||
super-vector granularity is generally not representable with a single
|
||||
|
@ -1391,6 +1414,8 @@ def Vector_TransferWriteOp :
|
|||
"AffineMap":$permutationMap)>,
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>,
|
||||
OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
|
||||
"AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
|
||||
];
|
||||
|
|
|
@ -104,66 +104,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
|||
return res;
|
||||
}
|
||||
|
||||
static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type targetType, Value value) {
|
||||
if (targetType == value.getType())
|
||||
return value;
|
||||
|
||||
bool targetIsIndex = targetType.isIndex();
|
||||
bool valueIsIndex = value.getType().isIndex();
|
||||
if (targetIsIndex ^ valueIsIndex)
|
||||
return rewriter.create<IndexCastOp>(loc, targetType, value);
|
||||
|
||||
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
|
||||
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
|
||||
assert(targetIntegerType && valueIntegerType &&
|
||||
"unexpected cast between types other than integers and index");
|
||||
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
|
||||
|
||||
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
|
||||
return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
|
||||
return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
|
||||
}
|
||||
|
||||
// Helper that returns a vector comparison that constructs a mask:
|
||||
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
|
||||
//
|
||||
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
|
||||
// much more compact, IR for this operation, but LLVM eventually
|
||||
// generates more elaborate instructions for this intrinsic since it
|
||||
// is very conservative on the boundary conditions.
|
||||
static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, bool enableIndexOptimizations,
|
||||
int64_t dim, Value b, Value *off = nullptr) {
|
||||
auto loc = op->getLoc();
|
||||
// If we can assume all indices fit in 32-bit, we perform the vector
|
||||
// comparison in 32-bit to get a higher degree of SIMD parallelism.
|
||||
// Otherwise we perform the vector comparison using 64-bit indices.
|
||||
Value indices;
|
||||
Type idxType;
|
||||
if (enableIndexOptimizations) {
|
||||
indices = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getI32VectorAttr(
|
||||
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
|
||||
idxType = rewriter.getI32Type();
|
||||
} else {
|
||||
indices = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getI64VectorAttr(
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
|
||||
idxType = rewriter.getI64Type();
|
||||
}
|
||||
// Add in an offset if requested.
|
||||
if (off) {
|
||||
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
|
||||
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
|
||||
indices = rewriter.create<AddIOp>(loc, ov, indices);
|
||||
}
|
||||
// Construct the vector comparison.
|
||||
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
|
||||
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
|
||||
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
|
||||
}
|
||||
|
||||
// Helper that returns data layout alignment of a memref.
|
||||
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
|
||||
MemRefType memrefType, unsigned &align) {
|
||||
|
@ -250,7 +190,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
|||
if (failed(getMemRefAlignment(
|
||||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
|
||||
align);
|
||||
return success();
|
||||
|
@ -266,7 +206,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
|||
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
|
||||
return failure();
|
||||
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
xferOp, adaptor.vector(), dataPtr, mask,
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
|
@ -275,12 +215,12 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
|
|||
|
||||
static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
|
||||
ArrayRef<Value> operands) {
|
||||
return TransferReadOpAdaptor(operands);
|
||||
return TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
}
|
||||
|
||||
static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
|
||||
ArrayRef<Value> operands) {
|
||||
return TransferWriteOpAdaptor(operands);
|
||||
return TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -618,33 +558,6 @@ private:
|
|||
const bool reassociateFPReductions;
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.create_mask (1-D only).
|
||||
class VectorCreateMaskOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
|
||||
bool enableIndexOpt)
|
||||
: ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
int64_t rank = dstType.getRank();
|
||||
if (rank == 1) {
|
||||
rewriter.replaceOp(
|
||||
op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
|
||||
dstType.getDimSize(0), operands[0]));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
class VectorShuffleOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ShuffleOp> {
|
||||
public:
|
||||
|
@ -1177,20 +1090,12 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
|
||||
/// sequence of:
|
||||
/// 1. Get the source/dst address as an LLVM vector pointer.
|
||||
/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
/// 4. Create a mask where offsetVector is compared against memref upper bound.
|
||||
/// 5. Rewrite op as a masked read or write.
|
||||
/// Conversion pattern that converts a 1-D vector transfer read/write op into a
|
||||
/// a masked or unmasked read/write.
|
||||
template <typename ConcreteOp>
|
||||
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
|
||||
public:
|
||||
explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
|
||||
bool enableIndexOpt)
|
||||
: ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
|
||||
|
@ -1212,6 +1117,9 @@ public:
|
|||
auto strides = computeContiguousStrides(memRefType);
|
||||
if (!strides)
|
||||
return failure();
|
||||
// Out-of-bounds dims are handled by MaterializeTransferMask.
|
||||
if (xferOp.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
|
||||
auto toLLVMTy = [&](Type t) {
|
||||
return this->getTypeConverter()->convertType(t);
|
||||
|
@ -1241,40 +1149,24 @@ public:
|
|||
#endif // ifndef NDEBUG
|
||||
}
|
||||
|
||||
// 1. Get the source/dst address as an LLVM vector pointer.
|
||||
// Get the source/dst address as an LLVM vector pointer.
|
||||
VectorType vtp = xferOp.getVectorType();
|
||||
Value dataPtr = this->getStridedElementPtr(
|
||||
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
||||
Value vectorDataPtr =
|
||||
castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
|
||||
|
||||
if (xferOp.isDimInBounds(0))
|
||||
// Rewrite as an unmasked masked read / write.
|
||||
if (!xferOp.mask())
|
||||
return replaceTransferOpWithLoadOrStore(rewriter,
|
||||
*this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr);
|
||||
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// 4. Let dim the memref dimension, compute the vector comparison mask
|
||||
// (in-bounds mask):
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
//
|
||||
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
|
||||
// dimensions here.
|
||||
unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
|
||||
Value mask = buildVectorComparison(
|
||||
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
|
||||
|
||||
// 5. Rewrite as a masked read / write.
|
||||
// Rewrite as a masked read / write.
|
||||
return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
|
||||
xferOp, operands, vectorDataPtr, mask);
|
||||
xferOp, operands, vectorDataPtr,
|
||||
xferOp.mask());
|
||||
}
|
||||
|
||||
private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
|
@ -1484,17 +1376,13 @@ public:
|
|||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
bool reassociateFPReductions, bool enableIndexOptimizations) {
|
||||
bool reassociateFPReductions) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.add<VectorFMAOpNDRewritePattern,
|
||||
VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorExtractStridedSliceOpConversion>(ctx);
|
||||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
|
||||
patterns.add<VectorCreateMaskOpConversion,
|
||||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>>(
|
||||
converter, enableIndexOptimizations);
|
||||
patterns
|
||||
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
|
@ -1508,8 +1396,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||
VectorLoadStoreConversion<vector::MaskedStoreOp,
|
||||
vector::MaskedStoreOpAdaptor>,
|
||||
VectorGatherOpConversion, VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
|
||||
converter);
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
|
||||
VectorTransferConversion<TransferReadOp>,
|
||||
VectorTransferConversion<TransferWriteOp>>(converter);
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
|
|
|
@ -71,9 +71,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
|||
// Convert to the LLVM IR dialect.
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(
|
||||
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns,
|
||||
reassociateFPReductions);
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
|
||||
// Architecture specific augmentations.
|
||||
|
|
|
@ -42,7 +42,7 @@ static LogicalResult replaceTransferOpWithMubuf(
|
|||
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
|
||||
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
|
||||
Value &glc, Value &slc) {
|
||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||
auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
|
||||
dwordConfig, vindex,
|
||||
offsetSizeInBytes, glc, slc);
|
||||
|
@ -62,7 +62,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
typename ConcreteOp::Adaptor adaptor(operands);
|
||||
typename ConcreteOp::Adaptor adaptor(operands, xferOp->getAttrDictionary());
|
||||
|
||||
if (xferOp.getVectorType().getRank() > 1 ||
|
||||
llvm::size(xferOp.indices()) == 0)
|
||||
|
|
|
@ -538,6 +538,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
using namespace mlir::edsc::op;
|
||||
|
||||
TransferReadOp transfer = cast<TransferReadOp>(op);
|
||||
if (transfer.mask())
|
||||
return failure();
|
||||
auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
|
@ -624,6 +626,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
using namespace edsc::op;
|
||||
|
||||
TransferWriteOp transfer = cast<TransferWriteOp>(op);
|
||||
if (transfer.mask())
|
||||
return failure();
|
||||
auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return failure();
|
||||
|
|
|
@ -2295,8 +2295,27 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
|||
build(builder, result, vectorType, source, indices, permMap, inBounds);
|
||||
}
|
||||
|
||||
/// Builder that does not provide a mask.
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
Type vectorType, Value source, ValueRange indices,
|
||||
AffineMap permutationMap, Value padding,
|
||||
ArrayAttr inBounds) {
|
||||
build(builder, result, vectorType, source, indices, permutationMap, padding,
|
||||
/*mask=*/Value(), inBounds);
|
||||
}
|
||||
|
||||
/// Builder that does not provide a mask.
|
||||
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
|
||||
Type vectorType, Value source, ValueRange indices,
|
||||
AffineMapAttr permutationMap, Value padding,
|
||||
ArrayAttr inBounds) {
|
||||
build(builder, result, vectorType, source, indices, permutationMap, padding,
|
||||
/*mask=*/Value(), inBounds);
|
||||
}
|
||||
|
||||
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
||||
SmallVector<StringRef, 2> elidedAttrs;
|
||||
SmallVector<StringRef, 3> elidedAttrs;
|
||||
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
|
||||
if (op.permutation_map().isMinorIdentity())
|
||||
elidedAttrs.push_back(op.getPermutationMapAttrName());
|
||||
bool elideInBounds = true;
|
||||
|
@ -2316,27 +2335,36 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
|||
static void print(OpAsmPrinter &p, TransferReadOp op) {
|
||||
p << op.getOperationName() << " " << op.source() << "[" << op.indices()
|
||||
<< "], " << op.padding();
|
||||
if (op.mask())
|
||||
p << ", " << op.mask();
|
||||
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
|
||||
p << " : " << op.getShapedType() << ", " << op.getVectorType();
|
||||
}
|
||||
|
||||
static ParseResult parseTransferReadOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
llvm::SMLoc typesLoc;
|
||||
OpAsmParser::OperandType sourceInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
||||
OpAsmParser::OperandType paddingInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
OpAsmParser::OperandType maskInfo;
|
||||
// Parsing with support for paddingValue.
|
||||
if (parser.parseOperand(sourceInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseComma() || parser.parseOperand(paddingInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseComma() || parser.parseOperand(paddingInfo))
|
||||
return failure();
|
||||
ParseResult hasMask = parser.parseOptionalComma();
|
||||
if (hasMask.succeeded()) {
|
||||
parser.parseOperand(maskInfo);
|
||||
}
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
||||
return failure();
|
||||
if (types.size() != 2)
|
||||
return parser.emitError(typesLoc, "requires two types");
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
auto indexType = builder.getIndexType();
|
||||
auto shapedType = types[0].dyn_cast<ShapedType>();
|
||||
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
|
||||
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
|
||||
|
@ -2349,12 +2377,21 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
|
|||
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
|
||||
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
|
||||
}
|
||||
return failure(
|
||||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
|
||||
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
|
||||
parser.resolveOperand(paddingInfo, shapedType.getElementType(),
|
||||
result.operands) ||
|
||||
parser.addTypeToList(vectorType, result.types));
|
||||
result.operands))
|
||||
return failure();
|
||||
if (hasMask.succeeded()) {
|
||||
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
|
||||
if (parser.resolveOperand(maskInfo, maskType, result.operands))
|
||||
return failure();
|
||||
}
|
||||
result.addAttribute(
|
||||
TransferReadOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
|
||||
static_cast<int32_t>(hasMask.succeeded())}));
|
||||
return parser.addTypeToList(vectorType, result.types);
|
||||
}
|
||||
|
||||
static LogicalResult verify(TransferReadOp op) {
|
||||
|
@ -2525,7 +2562,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
|||
/*optional*/ ArrayAttr inBounds) {
|
||||
Type resultType = source.getType().dyn_cast<RankedTensorType>();
|
||||
build(builder, result, resultType, vector, source, indices, permutationMap,
|
||||
inBounds);
|
||||
/*mask=*/Value(), inBounds);
|
||||
}
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
|
@ -2534,24 +2571,39 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
|||
/*optional*/ ArrayAttr inBounds) {
|
||||
Type resultType = source.getType().dyn_cast<RankedTensorType>();
|
||||
build(builder, result, resultType, vector, source, indices, permutationMap,
|
||||
inBounds);
|
||||
/*mask=*/Value(), inBounds);
|
||||
}
|
||||
|
||||
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value vector, Value source, ValueRange indices,
|
||||
AffineMap permutationMap, /*optional*/ Value mask,
|
||||
/*optional*/ ArrayAttr inBounds) {
|
||||
Type resultType = source.getType().dyn_cast<RankedTensorType>();
|
||||
build(builder, result, resultType, vector, source, indices, permutationMap,
|
||||
mask, inBounds);
|
||||
}
|
||||
|
||||
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
llvm::SMLoc typesLoc;
|
||||
OpAsmParser::OperandType vectorInfo, sourceInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
OpAsmParser::OperandType maskInfo;
|
||||
if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
|
||||
parser.parseOperand(sourceInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
|
||||
return failure();
|
||||
ParseResult hasMask = parser.parseOptionalComma();
|
||||
if (hasMask.succeeded() && parser.parseOperand(maskInfo))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
|
||||
return failure();
|
||||
if (types.size() != 2)
|
||||
return parser.emitError(typesLoc, "requires two types");
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
auto indexType = builder.getIndexType();
|
||||
VectorType vectorType = types[0].dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return parser.emitError(typesLoc, "requires vector type");
|
||||
|
@ -2564,17 +2616,28 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
|||
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
|
||||
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
|
||||
}
|
||||
return failure(
|
||||
parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
|
||||
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
|
||||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
|
||||
(shapedType.isa<RankedTensorType>() &&
|
||||
parser.addTypeToList(shapedType, result.types)));
|
||||
parser.resolveOperands(indexInfo, indexType, result.operands))
|
||||
return failure();
|
||||
if (hasMask.succeeded()) {
|
||||
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
|
||||
if (parser.resolveOperand(maskInfo, maskType, result.operands))
|
||||
return failure();
|
||||
}
|
||||
result.addAttribute(
|
||||
TransferWriteOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
|
||||
static_cast<int32_t>(hasMask.succeeded())}));
|
||||
return failure(shapedType.isa<RankedTensorType>() &&
|
||||
parser.addTypeToList(shapedType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
||||
p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
|
||||
<< op.indices() << "]";
|
||||
if (op.mask())
|
||||
p << ", " << op.mask();
|
||||
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
|
||||
p << " : " << op.getVectorType() << ", " << op.getShapedType();
|
||||
}
|
||||
|
|
|
@ -596,6 +596,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
|
|||
OpBuilder &builder) {
|
||||
if (!isIdentitySuffix(readOp.permutation_map()))
|
||||
return nullptr;
|
||||
if (readOp.mask())
|
||||
return nullptr;
|
||||
auto sourceVectorType = readOp.getVectorType();
|
||||
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
|
||||
|
||||
|
@ -641,6 +643,8 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
|
|||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||
if (!isIdentitySuffix(writeOp.permutation_map()))
|
||||
return failure();
|
||||
if (writeOp.mask())
|
||||
return failure();
|
||||
VectorType sourceVectorType = writeOp.getVectorType();
|
||||
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
|
||||
TupleType tupleType = generateExtractSlicesOpResultType(
|
||||
|
@ -722,6 +726,9 @@ public:
|
|||
if (ignoreFilter && ignoreFilter(readOp))
|
||||
return failure();
|
||||
|
||||
if (readOp.mask())
|
||||
return failure();
|
||||
|
||||
// TODO: Support splitting TransferReadOp with non-identity permutation
|
||||
// maps. Repurpose code from MaterializeVectors transformation.
|
||||
if (!isIdentitySuffix(readOp.permutation_map()))
|
||||
|
@ -768,6 +775,9 @@ public:
|
|||
if (ignoreFilter && ignoreFilter(writeOp))
|
||||
return failure();
|
||||
|
||||
if (writeOp.mask())
|
||||
return failure();
|
||||
|
||||
// TODO: Support splitting TransferWriteOp with non-identity permutation
|
||||
// maps. Repurpose code from MaterializeVectors transformation.
|
||||
if (!isIdentitySuffix(writeOp.permutation_map()))
|
||||
|
@ -2546,6 +2556,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
|||
"Expected splitFullAndPartialTransferPrecondition to hold");
|
||||
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
|
||||
|
||||
if (xferReadOp.mask())
|
||||
return failure();
|
||||
|
||||
// TODO: add support for write case.
|
||||
if (!xferReadOp)
|
||||
return failure();
|
||||
|
@ -2677,6 +2690,8 @@ struct TransferReadExtractPattern
|
|||
dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
|
||||
if (!extract)
|
||||
return failure();
|
||||
if (read.mask())
|
||||
return failure();
|
||||
edsc::ScopedContext scope(rewriter, read.getLoc());
|
||||
using mlir::edsc::op::operator+;
|
||||
using mlir::edsc::op::operator*;
|
||||
|
@ -2712,6 +2727,8 @@ struct TransferWriteInsertPattern
|
|||
auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
|
||||
if (!insert)
|
||||
return failure();
|
||||
if (write.mask())
|
||||
return failure();
|
||||
edsc::ScopedContext scope(rewriter, write.getLoc());
|
||||
using mlir::edsc::op::operator+;
|
||||
using mlir::edsc::op::operator*;
|
||||
|
@ -2742,6 +2759,7 @@ struct TransferWriteInsertPattern
|
|||
/// - If the memref's element type is a vector type then it coincides with the
|
||||
/// result type.
|
||||
/// - The permutation map doesn't perform permutation (broadcasting is allowed).
|
||||
/// - The op has no mask.
|
||||
struct TransferReadToVectorLoadLowering
|
||||
: public OpRewritePattern<vector::TransferReadOp> {
|
||||
TransferReadToVectorLoadLowering(MLIRContext *context)
|
||||
|
@ -2780,7 +2798,8 @@ struct TransferReadToVectorLoadLowering
|
|||
// MaskedLoadOp.
|
||||
if (read.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
|
||||
if (read.mask())
|
||||
return failure();
|
||||
Operation *loadOp;
|
||||
if (!broadcastedDims.empty() &&
|
||||
unbroadcastedVectorType.getNumElements() == 1) {
|
||||
|
@ -2815,6 +2834,7 @@ struct TransferReadToVectorLoadLowering
|
|||
/// type of the written value.
|
||||
/// - The permutation map is the minor identity map (neither permutation nor
|
||||
/// broadcasting is allowed).
|
||||
/// - The op has no mask.
|
||||
struct TransferWriteToVectorStoreLowering
|
||||
: public OpRewritePattern<vector::TransferWriteOp> {
|
||||
TransferWriteToVectorStoreLowering(MLIRContext *context)
|
||||
|
@ -2840,6 +2860,8 @@ struct TransferWriteToVectorStoreLowering
|
|||
// MaskedStoreOp.
|
||||
if (write.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
if (write.mask())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::StoreOp>(
|
||||
write, write.vector(), write.source(), write.indices());
|
||||
return success();
|
||||
|
@ -2880,6 +2902,8 @@ struct TransferReadPermutationLowering
|
|||
map.getPermutationMap(permutation, op.getContext());
|
||||
if (permutationMap.isIdentity())
|
||||
return failure();
|
||||
if (op.mask())
|
||||
return failure();
|
||||
// Caluclate the map of the new read by applying the inverse permutation.
|
||||
permutationMap = inversePermutation(permutationMap);
|
||||
AffineMap newMap = permutationMap.compose(map);
|
||||
|
@ -2914,6 +2938,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.mask())
|
||||
return failure();
|
||||
AffineMap map = op.permutation_map();
|
||||
unsigned numLeadingBroadcast = 0;
|
||||
for (auto expr : map.getResults()) {
|
||||
|
@ -3062,6 +3088,9 @@ struct CastAwayTransferReadLeadingOneDim
|
|||
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp read,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (read.mask())
|
||||
return failure();
|
||||
|
||||
auto shapedType = read.source().getType().cast<ShapedType>();
|
||||
if (shapedType.getElementType() != read.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
@ -3102,6 +3131,9 @@ struct CastAwayTransferWriteLeadingOneDim
|
|||
|
||||
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (write.mask())
|
||||
return failure();
|
||||
|
||||
auto shapedType = write.source().getType().dyn_cast<ShapedType>();
|
||||
if (shapedType.getElementType() != write.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
@ -3371,6 +3403,151 @@ struct BubbleUpBitCastForStridedSliceInsert
|
|||
}
|
||||
};
|
||||
|
||||
static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
|
||||
Type targetType, Value value) {
|
||||
if (targetType == value.getType())
|
||||
return value;
|
||||
|
||||
bool targetIsIndex = targetType.isIndex();
|
||||
bool valueIsIndex = value.getType().isIndex();
|
||||
if (targetIsIndex ^ valueIsIndex)
|
||||
return rewriter.create<IndexCastOp>(loc, targetType, value);
|
||||
|
||||
auto targetIntegerType = targetType.dyn_cast<IntegerType>();
|
||||
auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
|
||||
assert(targetIntegerType && valueIntegerType &&
|
||||
"unexpected cast between types other than integers and index");
|
||||
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
|
||||
|
||||
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
|
||||
return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
|
||||
return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
|
||||
}
|
||||
|
||||
// Helper that returns a vector comparison that constructs a mask:
|
||||
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
|
||||
//
|
||||
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
|
||||
// much more compact, IR for this operation, but LLVM eventually
|
||||
// generates more elaborate instructions for this intrinsic since it
|
||||
// is very conservative on the boundary conditions.
|
||||
static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
|
||||
bool enableIndexOptimizations, int64_t dim,
|
||||
Value b, Value *off = nullptr) {
|
||||
auto loc = op->getLoc();
|
||||
// If we can assume all indices fit in 32-bit, we perform the vector
|
||||
// comparison in 32-bit to get a higher degree of SIMD parallelism.
|
||||
// Otherwise we perform the vector comparison using 64-bit indices.
|
||||
Value indices;
|
||||
Type idxType;
|
||||
if (enableIndexOptimizations) {
|
||||
indices = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getI32VectorAttr(
|
||||
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
|
||||
idxType = rewriter.getI32Type();
|
||||
} else {
|
||||
indices = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getI64VectorAttr(
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
|
||||
idxType = rewriter.getI64Type();
|
||||
}
|
||||
// Add in an offset if requested.
|
||||
if (off) {
|
||||
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
|
||||
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
|
||||
indices = rewriter.create<AddIOp>(loc, ov, indices);
|
||||
}
|
||||
// Construct the vector comparison.
|
||||
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
|
||||
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
|
||||
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
|
||||
}
|
||||
|
||||
template <typename ConcreteOp>
|
||||
struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
|
||||
public:
|
||||
explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
|
||||
: mlir::OpRewritePattern<ConcreteOp>(context),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ConcreteOp xferOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!xferOp.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
|
||||
if (xferOp.getVectorType().getRank() > 1 ||
|
||||
llvm::size(xferOp.indices()) == 0)
|
||||
return failure();
|
||||
|
||||
Location loc = xferOp->getLoc();
|
||||
VectorType vtp = xferOp.getVectorType();
|
||||
|
||||
// * Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// * Let dim the memref dimension, compute the vector comparison mask
|
||||
// (in-bounds mask):
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
//
|
||||
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
|
||||
// dimensions here.
|
||||
unsigned vecWidth = vtp.getNumElements();
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
|
||||
Value mask = buildVectorComparison(
|
||||
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
|
||||
|
||||
if (xferOp.mask()) {
|
||||
// Intersect the in-bounds with the mask specified as an op parameter.
|
||||
mask = rewriter.create<AndOp>(loc, mask, xferOp.mask());
|
||||
}
|
||||
|
||||
rewriter.updateRootInPlace(xferOp, [&]() {
|
||||
xferOp.maskMutable().assign(mask);
|
||||
xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true}));
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.create_mask (1-D only).
|
||||
class VectorCreateMaskOpConversion
|
||||
: public OpRewritePattern<vector::CreateMaskOp> {
|
||||
public:
|
||||
explicit VectorCreateMaskOpConversion(MLIRContext *context,
|
||||
bool enableIndexOpt)
|
||||
: mlir::OpRewritePattern<vector::CreateMaskOp>(context),
|
||||
enableIndexOptimizations(enableIndexOpt) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getType();
|
||||
int64_t rank = dstType.getRank();
|
||||
if (rank == 1) {
|
||||
rewriter.replaceOp(
|
||||
op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
|
||||
dstType.getDimSize(0), op.getOperand(0)));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
const bool enableIndexOptimizations;
|
||||
};
|
||||
|
||||
void mlir::vector::populateVectorMaskMaterializationPatterns(
|
||||
RewritePatternSet &patterns, bool enableIndexOptimizations) {
|
||||
patterns.add<VectorCreateMaskOpConversion,
|
||||
MaterializeTransferMask<vector::TransferReadOp>,
|
||||
MaterializeTransferMask<vector::TransferWriteOp>>(
|
||||
patterns.getContext(), enableIndexOptimizations);
|
||||
}
|
||||
|
||||
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
||||
// TODO: Add this as DRR pattern.
|
||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||
|
|
|
@ -3,20 +3,19 @@
|
|||
|
||||
// CMP32-LABEL: @genbool_var_1d(
|
||||
// CMP32-SAME: %[[ARG:.*]]: index)
|
||||
// CMP32: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
|
||||
// CMP32: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>
|
||||
// CMP32: %[[T1:.*]] = trunci %[[A]] : i64 to i32
|
||||
// CMP32: %[[T1:.*]] = index_cast %[[ARG]] : index to i32
|
||||
// CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32>
|
||||
// CMP32: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi32>
|
||||
// CMP32: return %[[T3]] : vector<11xi1>
|
||||
|
||||
// CMP64-LABEL: @genbool_var_1d(
|
||||
// CMP64-SAME: %[[ARG:.*]]: index)
|
||||
// CMP64: %[[A:.*]] = llvm.mlir.cast %[[ARG]] : index to i64
|
||||
// CMP64: %[[T0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>
|
||||
// CMP64: %[[T1:.*]] = splat %[[A]] : vector<11xi64>
|
||||
// CMP64: %[[T2:.*]] = cmpi slt, %[[T0]], %[[T1]] : vector<11xi64>
|
||||
// CMP64: return %[[T2]] : vector<11xi1>
|
||||
// CMP64: %[[T1:.*]] = index_cast %[[ARG]] : index to i64
|
||||
// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64>
|
||||
// CMP64: %[[T3:.*]] = cmpi slt, %[[T0]], %[[T2]] : vector<11xi64>
|
||||
// CMP64: return %[[T3]] : vector<11xi1>
|
||||
|
||||
func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
|
||||
%0 = vector.create_mask %arg0 : vector<11xi1>
|
||||
|
|
|
@ -1049,31 +1049,31 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
|||
// CHECK-LABEL: func @transfer_read_1d
|
||||
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
|
||||
// CHECK: %[[c7:.*]] = constant 7.0
|
||||
//
|
||||
// 1. Bitcast to vector form.
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
|
||||
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
|
||||
//
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// CHECK: %[[linearIndex:.*]] = constant dense
|
||||
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
|
||||
// CHECK-SAME: vector<17xi32>
|
||||
//
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// CHECK: %[[otrunc:.*]] = index_cast %[[BASE]] : index to i32
|
||||
// CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32>
|
||||
// CHECK: %[[offsetVec2:.*]] = addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32>
|
||||
//
|
||||
// 4. Let dim the memref dimension, compute the vector comparison mask:
|
||||
// 3. Let dim the memref dimension, compute the vector comparison mask:
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
// CHECK: %[[dtrunc:.*]] = index_cast %[[DIM]] : index to i32
|
||||
// CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
|
||||
// CHECK: %[[mask:.*]] = cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
|
||||
//
|
||||
// 4. Bitcast to vector form.
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
|
||||
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
|
||||
//
|
||||
// 5. Rewrite as a masked read.
|
||||
// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
|
||||
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
|
||||
|
@ -1081,26 +1081,26 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
|
|||
// CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
|
||||
|
||||
//
|
||||
// 1. Bitcast to vector form.
|
||||
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
|
||||
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
|
||||
//
|
||||
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
|
||||
// CHECK: %[[linearIndex_b:.*]] = constant dense
|
||||
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
|
||||
// CHECK-SAME: vector<17xi32>
|
||||
//
|
||||
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
|
||||
// CHECK: splat %{{.*}} : vector<17xi32>
|
||||
// CHECK: addi
|
||||
//
|
||||
// 4. Let dim the memref dimension, compute the vector comparison mask:
|
||||
// 3. Let dim the memref dimension, compute the vector comparison mask:
|
||||
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
|
||||
// CHECK: splat %{{.*}} : vector<17xi32>
|
||||
// CHECK: %[[mask_b:.*]] = cmpi slt, {{.*}} : vector<17xi32>
|
||||
//
|
||||
// 4. Bitcast to vector form.
|
||||
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
|
||||
// CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
|
||||
// CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
|
||||
//
|
||||
// 5. Rewrite as a masked write.
|
||||
// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
|
||||
// CHECK-SAME: {alignment = 4 : i32} :
|
||||
|
@ -1182,6 +1182,21 @@ func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector<17xf
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_1d_mask
|
||||
// CHECK: %[[mask1:.*]] = constant dense<[false, false, true, false, true]>
|
||||
// CHECK: %[[cmpi:.*]] = cmpi slt
|
||||
// CHECK: %[[mask2:.*]] = and %[[cmpi]], %[[mask1]]
|
||||
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
|
||||
// CHECK: return %[[r]]
|
||||
func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
|
||||
%m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
|
||||
%f7 = constant 7.0: f32
|
||||
%f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<5xf32>
|
||||
return %f: vector<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
|
||||
%c0 = constant 0: i32
|
||||
%v = vector.transfer_read %A[%base], %c0 {in_bounds = [true]} :
|
||||
|
|
|
@ -11,6 +11,7 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
|||
%c0 = constant 0 : i32
|
||||
%vf0 = splat %f0 : vector<4x3xf32>
|
||||
%v0 = splat %c0 : vector<4x3xi32>
|
||||
%m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
|
||||
|
||||
//
|
||||
// CHECK: vector.transfer_read
|
||||
|
@ -27,7 +28,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
|||
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
|
||||
%6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
|
||||
|
||||
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<5xf32>
|
||||
%7 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
|
||||
|
||||
// CHECK: vector.transfer_write
|
||||
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
|
||||
|
@ -39,7 +41,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
|||
vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
|
||||
vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
|
||||
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : vector<5xf32>, memref<?x?xf32>
|
||||
vector.transfer_write %7, %arg0[%c3, %c3], %m : vector<5xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,14 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) {
|
|||
return
|
||||
}
|
||||
|
||||
func @transfer_read_mask_1d(%A : memref<?xf32>, %base: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%m = constant dense<[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]> : vector<13xi1>
|
||||
%f = vector.transfer_read %A[%base], %fm42, %m : memref<?xf32>, vector<13xf32>
|
||||
vector.print %f: vector<13xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%f = vector.transfer_read %A[%base], %fm42
|
||||
|
@ -21,6 +29,15 @@ func @transfer_read_inbounds_4(%A : memref<?xf32>, %base: index) {
|
|||
return
|
||||
}
|
||||
|
||||
func @transfer_read_mask_inbounds_4(%A : memref<?xf32>, %base: index) {
|
||||
%fm42 = constant -42.0: f32
|
||||
%m = constant dense<[0, 1, 0, 1]> : vector<4xi1>
|
||||
%f = vector.transfer_read %A[%base], %fm42, %m {in_bounds = [true]}
|
||||
: memref<?xf32>, vector<4xf32>
|
||||
vector.print %f: vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
|
||||
%f0 = constant 0.0 : f32
|
||||
%vf0 = splat %f0 : vector<4xf32>
|
||||
|
@ -47,6 +64,8 @@ func @entry() {
|
|||
// Read shifted by 2 and pad with -42:
|
||||
// ( 2, 3, 4, -42, ..., -42)
|
||||
call @transfer_read_1d(%A, %c2) : (memref<?xf32>, index) -> ()
|
||||
// Read with mask and out-of-bounds access.
|
||||
call @transfer_read_mask_1d(%A, %c2) : (memref<?xf32>, index) -> ()
|
||||
// Write into memory shifted by 3
|
||||
// memory contains [[ 0, 1, 2, 0, 0, xxx garbage xxx ]]
|
||||
call @transfer_write_1d(%A, %c3) : (memref<?xf32>, index) -> ()
|
||||
|
@ -56,9 +75,13 @@ func @entry() {
|
|||
// Read in-bounds 4 @ 1, guaranteed to not overflow.
|
||||
// Exercises proper alignment.
|
||||
call @transfer_read_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
|
||||
// Read in-bounds with mask.
|
||||
call @transfer_read_mask_inbounds_4(%A, %c1) : (memref<?xf32>, index) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
|
||||
// CHECK: ( -42, -42, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
|
||||
// CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
|
||||
// CHECK: ( 1, 2, 0, 0 )
|
||||
// CHECK: ( -42, 2, -42, 0 )
|
||||
|
|
Loading…
Reference in New Issue