forked from OSchip/llvm-project
[mlir][Vector] Introduce 'vector.load' and 'vector.store' ops
This patch adds the 'vector.load' and 'vector.store' ops to the Vector dialect [1]. These operations model *contiguous* vector loads and stores from/to memory. Their semantics are similar to the 'affine.vector_load' and 'affine.vector_store' counterparts but without the affine constraints. The most relevant feature is that these new vector operations may perform a vector load/store on memrefs with a non-vector element type, unlike 'std.load' and 'std.store' ops. This opens the representation to model more generic vector load/store scenarios: unaligned vector loads/stores, perform scalar and vector memory access on the same memref, decouple memory allocation constraints from memory accesses, etc [1]. These operations will also facilitate the progressive lowering of both Affine vector loads/stores and Vector transfer reads/writes for those that read/write contiguous slices from/to memory. In particular, this patch adds the 'vector.load' and 'vector.store' ops to the Vector dialect, implements their lowering to the LLVM dialect, and changes the lowering of 'affine.vector_load' and 'affine.vector_store' ops to the new vector ops. The lowering of Vector transfer reads/writes will be implemented in the future, probably as an independent pass. The API of 'vector.maskedload' and 'vector.maskedstore' has also been changed slightly to align it with the transfer read/write ops and the vector new ops. This will improve reusability among all these operations. For example, the lowering of 'vector.load', 'vector.store', 'vector.maskedload' and 'vector.maskedstore' to the LLVM dialect is implemented with a single template conversion pattern. [1] https://llvm.discourse.group/t/memref-type-and-data-layout/ Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D96185
This commit is contained in:
parent
98754e2909
commit
ee66e43a96
|
@ -1320,6 +1320,156 @@ def Vector_TransferWriteOp :
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_LoadOp : Vector_Op<"load"> {
|
||||
let summary = "reads an n-D slice of memory into an n-D vector";
|
||||
let description = [{
|
||||
The 'vector.load' operation reads an n-D slice of memory into an n-D
|
||||
vector. It takes a 'base' memref, an index for each memref dimension and a
|
||||
result vector type as arguments. It returns a value of the result vector
|
||||
type. The 'base' memref and indices determine the start memory address from
|
||||
which to read. Each index provides an offset for each memref dimension
|
||||
based on the element type of the memref. The shape of the result vector
|
||||
type determines the shape of the slice read from the start memory address.
|
||||
The elements along each dimension of the slice are strided by the memref
|
||||
strides. Only memref with default strides are allowed. These constraints
|
||||
guarantee that elements read along the first dimension of the slice are
|
||||
contiguous in memory.
|
||||
|
||||
The memref element type can be a scalar or a vector type. If the memref
|
||||
element type is a scalar, it should match the element type of the result
|
||||
vector. If the memref element type is vector, it should match the result
|
||||
vector type.
|
||||
|
||||
Example 1: 1-D vector load on a scalar memref.
|
||||
```mlir
|
||||
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
|
||||
```
|
||||
|
||||
Example 2: 1-D vector load on a vector memref.
|
||||
```mlir
|
||||
%result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
|
||||
```
|
||||
|
||||
Example 3: 2-D vector load on a scalar memref.
|
||||
```mlir
|
||||
%result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
|
||||
```
|
||||
|
||||
Example 4: 2-D vector load on a vector memref.
|
||||
```mlir
|
||||
%result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
|
||||
```
|
||||
|
||||
Representation-wise, the 'vector.load' operation permits out-of-bounds
|
||||
reads. Support and implementation of out-of-bounds vector loads is
|
||||
target-specific. No assumptions should be made on the value of elements
|
||||
loaded out of bounds. Not all targets may support out-of-bounds vector
|
||||
loads.
|
||||
|
||||
Example 5: Potential out-of-bound vector load.
|
||||
```mlir
|
||||
%result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
|
||||
```
|
||||
|
||||
Example 6: Explicit out-of-bound vector load.
|
||||
```mlir
|
||||
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
|
||||
[MemRead]>:$base,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs AnyVector:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
VectorType getVectorType() {
|
||||
return result().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat =
|
||||
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
|
||||
}
|
||||
|
||||
def Vector_StoreOp : Vector_Op<"store"> {
|
||||
let summary = "writes an n-D vector to an n-D slice of memory";
|
||||
let description = [{
|
||||
The 'vector.store' operation writes an n-D vector to an n-D slice of memory.
|
||||
It takes the vector value to be stored, a 'base' memref and an index for
|
||||
each memref dimension. The 'base' memref and indices determine the start
|
||||
memory address from which to write. Each index provides an offset for each
|
||||
memref dimension based on the element type of the memref. The shape of the
|
||||
vector value to store determines the shape of the slice written from the
|
||||
start memory address. The elements along each dimension of the slice are
|
||||
strided by the memref strides. Only memref with default strides are allowed.
|
||||
These constraints guarantee that elements written along the first dimension
|
||||
of the slice are contiguous in memory.
|
||||
|
||||
The memref element type can be a scalar or a vector type. If the memref
|
||||
element type is a scalar, it should match the element type of the value
|
||||
to store. If the memref element type is vector, it should match the type
|
||||
of the value to store.
|
||||
|
||||
Example 1: 1-D vector store on a scalar memref.
|
||||
```mlir
|
||||
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
|
||||
```
|
||||
|
||||
Example 2: 1-D vector store on a vector memref.
|
||||
```mlir
|
||||
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
|
||||
```
|
||||
|
||||
Example 3: 2-D vector store on a scalar memref.
|
||||
```mlir
|
||||
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
|
||||
```
|
||||
|
||||
Example 4: 2-D vector store on a vector memref.
|
||||
```mlir
|
||||
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
|
||||
```
|
||||
|
||||
Representation-wise, the 'vector.store' operation permits out-of-bounds
|
||||
writes. Support and implementation of out-of-bounds vector stores are
|
||||
target-specific. No assumptions should be made on the memory written out of
|
||||
bounds. Not all targets may support out-of-bounds vector stores.
|
||||
|
||||
Example 5: Potential out-of-bounds vector store.
|
||||
```mlir
|
||||
vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
|
||||
```
|
||||
|
||||
Example 6: Explicit out-of-bounds vector store.
|
||||
```mlir
|
||||
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyVector:$valueToStore,
|
||||
Arg<AnyMemRef, "the reference to store to",
|
||||
[MemWrite]>:$base,
|
||||
Variadic<Index>:$indices);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return base().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
VectorType getVectorType() {
|
||||
return valueToStore().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
|
||||
"`:` type($base) `,` type($valueToStore)";
|
||||
}
|
||||
|
||||
def Vector_MaskedLoadOp :
|
||||
Vector_Op<"maskedload">,
|
||||
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
|
||||
|
@ -1363,7 +1513,7 @@ def Vector_MaskedLoadOp :
|
|||
VectorType getPassThruVectorType() {
|
||||
return pass_thru().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getResultVectorType() {
|
||||
VectorType getVectorType() {
|
||||
return result().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
|
@ -1377,7 +1527,7 @@ def Vector_MaskedStoreOp :
|
|||
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
|
||||
Variadic<Index>:$indices,
|
||||
VectorOfRankAndType<[1], [I1]>:$mask,
|
||||
VectorOfRank<[1]>:$value)> {
|
||||
VectorOfRank<[1]>:$valueToStore)> {
|
||||
|
||||
let summary = "stores elements from a vector into memory as defined by a mask vector";
|
||||
|
||||
|
@ -1411,12 +1561,13 @@ def Vector_MaskedStoreOp :
|
|||
VectorType getMaskVectorType() {
|
||||
return mask().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getValueVectorType() {
|
||||
return value().getType().cast<VectorType>();
|
||||
VectorType getVectorType() {
|
||||
return valueToStore().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
|
||||
"type($base) `,` type($mask) `,` type($value)";
|
||||
let assemblyFormat =
|
||||
"$base `[` $indices `]` `,` $mask `,` $valueToStore "
|
||||
"attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -578,8 +578,9 @@ public:
|
|||
if (!resultOperands)
|
||||
return failure();
|
||||
|
||||
// Build std.load memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
|
||||
// Build vector.load memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, op.getMemRef(),
|
||||
*resultOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -625,8 +626,8 @@ public:
|
|||
return failure();
|
||||
|
||||
// Build std.store valueToStore, memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
|
||||
op.getMemRef(), *maybeExpandedMap);
|
||||
rewriter.replaceOpWithNewOp<mlir::StoreOp>(
|
||||
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -695,8 +696,8 @@ public:
|
|||
};
|
||||
|
||||
/// Apply the affine map from an 'affine.vector_load' operation to its operands,
|
||||
/// and feed the results to a newly created 'vector.transfer_read' operation
|
||||
/// (which replaces the original 'affine.vector_load').
|
||||
/// and feed the results to a newly created 'vector.load' operation (which
|
||||
/// replaces the original 'affine.vector_load').
|
||||
class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
|
||||
public:
|
||||
using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
|
||||
|
@ -710,16 +711,16 @@ public:
|
|||
if (!resultOperands)
|
||||
return failure();
|
||||
|
||||
// Build vector.transfer_read memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<TransferReadOp>(
|
||||
// Build vector.load memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<vector::LoadOp>(
|
||||
op, op.getVectorType(), op.getMemRef(), *resultOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Apply the affine map from an 'affine.vector_store' operation to its
|
||||
/// operands, and feed the results to a newly created 'vector.transfer_write'
|
||||
/// operation (which replaces the original 'affine.vector_store').
|
||||
/// operands, and feed the results to a newly created 'vector.store' operation
|
||||
/// (which replaces the original 'affine.vector_store').
|
||||
class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
|
||||
public:
|
||||
using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
|
||||
|
@ -733,7 +734,7 @@ public:
|
|||
if (!maybeExpandedMap)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<TransferWriteOp>(
|
||||
rewriter.replaceOpWithNewOp<vector::StoreOp>(
|
||||
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -357,64 +357,72 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedload.
|
||||
class VectorMaskedLoadOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
|
||||
/// Overloaded utility that replaces a vector.load, vector.store,
|
||||
/// vector.maskedload and vector.maskedstore with their respective LLVM
|
||||
/// couterparts.
|
||||
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
|
||||
vector::LoadOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
|
||||
}
|
||||
|
||||
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
|
||||
vector::MaskedLoadOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
|
||||
}
|
||||
|
||||
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
|
||||
vector::StoreOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
|
||||
ptr, align);
|
||||
}
|
||||
|
||||
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
|
||||
vector::MaskedStoreOpAdaptor adaptor,
|
||||
VectorType vectorTy, Value ptr, unsigned align,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
|
||||
}
|
||||
|
||||
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
|
||||
/// vector.maskedstore.
|
||||
template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
|
||||
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
|
||||
matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = load->getLoc();
|
||||
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
|
||||
MemRefType memRefType = load.getMemRefType();
|
||||
// Only 1-D vectors can be lowered to LLVM.
|
||||
VectorType vectorTy = loadOrStoreOp.getVectorType();
|
||||
if (vectorTy.getRank() > 1)
|
||||
return failure();
|
||||
|
||||
auto loc = loadOrStoreOp->getLoc();
|
||||
auto adaptor = LoadOrStoreOpAdaptor(operands);
|
||||
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(load.getResultVectorType());
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
|
||||
.template cast<VectorType>();
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
|
||||
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.maskedstore.
|
||||
class VectorMaskedStoreOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = store->getLoc();
|
||||
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
|
||||
MemRefType memRefType = store.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
unsigned align;
|
||||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
|
||||
return failure();
|
||||
|
||||
// Resolve address.
|
||||
auto vtype = typeConverter->convertType(store.getValueVectorType());
|
||||
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
|
||||
adaptor.indices(), rewriter);
|
||||
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
|
||||
store, adaptor.value(), ptr, adaptor.mask(),
|
||||
rewriter.getI32IntegerAttr(align));
|
||||
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1511,8 +1519,14 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||
VectorInsertOpConversion,
|
||||
VectorPrintOpConversion,
|
||||
VectorTypeCastOpConversion,
|
||||
VectorMaskedLoadOpConversion,
|
||||
VectorMaskedStoreOpConversion,
|
||||
VectorLoadStoreConversion<vector::LoadOp,
|
||||
vector::LoadOpAdaptor>,
|
||||
VectorLoadStoreConversion<vector::MaskedLoadOp,
|
||||
vector::MaskedLoadOpAdaptor>,
|
||||
VectorLoadStoreConversion<vector::StoreOp,
|
||||
vector::StoreOpAdaptor>,
|
||||
VectorLoadStoreConversion<vector::MaskedStoreOp,
|
||||
vector::MaskedStoreOpAdaptor>,
|
||||
VectorGatherOpConversion,
|
||||
VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion,
|
||||
|
|
|
@ -2373,6 +2373,67 @@ void TransferWriteOp::getEffects(
|
|||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
|
||||
MemRefType memRefTy) {
|
||||
auto affineMaps = memRefTy.getAffineMaps();
|
||||
if (!affineMaps.empty())
|
||||
return op->emitOpError("base memref should have a default identity layout");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(vector::LoadOp op) {
|
||||
VectorType resVecTy = op.getVectorType();
|
||||
MemRefType memRefTy = op.getMemRefType();
|
||||
|
||||
if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
|
||||
return failure();
|
||||
|
||||
// Checks for vector memrefs.
|
||||
Type memElemTy = memRefTy.getElementType();
|
||||
if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
|
||||
if (memVecTy != resVecTy)
|
||||
return op.emitOpError("base memref and result vector types should match");
|
||||
memElemTy = memVecTy.getElementType();
|
||||
}
|
||||
|
||||
if (resVecTy.getElementType() != memElemTy)
|
||||
return op.emitOpError("base and result element types should match");
|
||||
if (llvm::size(op.indices()) != memRefTy.getRank())
|
||||
return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(vector::StoreOp op) {
|
||||
VectorType valueVecTy = op.getVectorType();
|
||||
MemRefType memRefTy = op.getMemRefType();
|
||||
|
||||
if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
|
||||
return failure();
|
||||
|
||||
// Checks for vector memrefs.
|
||||
Type memElemTy = memRefTy.getElementType();
|
||||
if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
|
||||
if (memVecTy != valueVecTy)
|
||||
return op.emitOpError(
|
||||
"base memref and valueToStore vector types should match");
|
||||
memElemTy = memVecTy.getElementType();
|
||||
}
|
||||
|
||||
if (valueVecTy.getElementType() != memElemTy)
|
||||
return op.emitOpError("base and valueToStore element type should match");
|
||||
if (llvm::size(op.indices()) != memRefTy.getRank())
|
||||
return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaskedLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2380,7 +2441,7 @@ void TransferWriteOp::getEffects(
|
|||
static LogicalResult verify(MaskedLoadOp op) {
|
||||
VectorType maskVType = op.getMaskVectorType();
|
||||
VectorType passVType = op.getPassThruVectorType();
|
||||
VectorType resVType = op.getResultVectorType();
|
||||
VectorType resVType = op.getVectorType();
|
||||
MemRefType memType = op.getMemRefType();
|
||||
|
||||
if (resVType.getElementType() != memType.getElementType())
|
||||
|
@ -2427,15 +2488,15 @@ void MaskedLoadOp::getCanonicalizationPatterns(
|
|||
|
||||
static LogicalResult verify(MaskedStoreOp op) {
|
||||
VectorType maskVType = op.getMaskVectorType();
|
||||
VectorType valueVType = op.getValueVectorType();
|
||||
VectorType valueVType = op.getVectorType();
|
||||
MemRefType memType = op.getMemRefType();
|
||||
|
||||
if (valueVType.getElementType() != memType.getElementType())
|
||||
return op.emitOpError("base and value element type should match");
|
||||
return op.emitOpError("base and valueToStore element type should match");
|
||||
if (llvm::size(op.indices()) != memType.getRank())
|
||||
return op.emitOpError("requires ") << memType.getRank() << " indices";
|
||||
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
|
||||
return op.emitOpError("expected value dim to match mask dim");
|
||||
return op.emitOpError("expected valueToStore dim to match mask dim");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -2448,7 +2509,7 @@ public:
|
|||
switch (get1DMaskFormat(store.mask())) {
|
||||
case MaskFormat::AllTrue:
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
store, store.value(), store.base(), store.indices(), false);
|
||||
store, store.valueToStore(), store.base(), store.indices(), false);
|
||||
return success();
|
||||
case MaskFormat::AllFalse:
|
||||
rewriter.eraseOp(store);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s
|
||||
|
||||
|
||||
// CHECK-LABEL: func @affine_vector_load
|
||||
func @affine_vector_load(%arg0 : index) {
|
||||
%0 = alloc() : memref<100xf32>
|
||||
|
@ -10,8 +11,7 @@ func @affine_vector_load(%arg0 : index) {
|
|||
// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
|
||||
// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index
|
||||
// CHECK-NEXT: %[[pad:.*]] = constant 0.0
|
||||
// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
|
||||
// CHECK-NEXT: vector.load %[[buf]][%[[b]]] : memref<100xf32>, vector<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -31,44 +31,7 @@ func @affine_vector_store(%arg0 : index) {
|
|||
// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index
|
||||
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
|
||||
// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index
|
||||
// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @affine_vector_load
|
||||
func @affine_vector_load(%arg0 : index) {
|
||||
%0 = alloc() : memref<100xf32>
|
||||
affine.for %i0 = 0 to 16 {
|
||||
%1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
|
||||
}
|
||||
// CHECK: %[[buf:.*]] = alloc
|
||||
// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
|
||||
// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index
|
||||
// CHECK-NEXT: %[[pad:.*]] = constant 0.0
|
||||
// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @affine_vector_store
|
||||
func @affine_vector_store(%arg0 : index) {
|
||||
%0 = alloc() : memref<100xf32>
|
||||
%1 = constant dense<11.0> : vector<4xf32>
|
||||
affine.for %i0 = 0 to 16 {
|
||||
affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
|
||||
}
|
||||
// CHECK: %[[buf:.*]] = alloc
|
||||
// CHECK: %[[val:.*]] = constant dense
|
||||
// CHECK: %[[c_1:.*]] = constant -1 : index
|
||||
// CHECK-NEXT: %[[a:.*]] = muli %arg0, %[[c_1]] : index
|
||||
// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index
|
||||
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
|
||||
// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index
|
||||
// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
|
||||
// CHECK-NEXT: vector.store %[[val]], %[[buf]][%[[c]]] : memref<100xf32>, vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -83,8 +46,7 @@ func @vector_load_2d() {
|
|||
// CHECK: %[[buf:.*]] = alloc
|
||||
// CHECK: scf.for %[[i0:.*]] =
|
||||
// CHECK: scf.for %[[i1:.*]] =
|
||||
// CHECK-NEXT: %[[pad:.*]] = constant 0.0
|
||||
// CHECK-NEXT: vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] : memref<100x100xf32>, vector<2x8xf32>
|
||||
// CHECK-NEXT: vector.load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -103,9 +65,8 @@ func @vector_store_2d() {
|
|||
// CHECK: %[[val:.*]] = constant dense
|
||||
// CHECK: scf.for %[[i0:.*]] =
|
||||
// CHECK: scf.for %[[i1:.*]] =
|
||||
// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[i0]], %[[i1]]] : vector<2x8xf32>, memref<100x100xf32>
|
||||
// CHECK-NEXT: vector.store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
|
||||
func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
|
||||
%0 = vector.broadcast %arg0 : f32 to vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
|
@ -1242,6 +1243,33 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
|
||||
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_load_op
|
||||
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
|
||||
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
|
||||
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
|
||||
// CHECK: llvm.load %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<8xf32>>
|
||||
|
||||
func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
|
||||
%val = constant dense<11.0> : vector<4xf32>
|
||||
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_store_op
|
||||
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
|
||||
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
|
||||
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
|
||||
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
|
||||
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
|
||||
// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
|
||||
|
||||
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
|
||||
%c0 = constant 0: index
|
||||
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
|
|
|
@ -1198,6 +1198,38 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
|
|||
|
||||
// -----
|
||||
|
||||
func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
|
||||
%i : index, %j : index, %value : vector<8xf32>) {
|
||||
// expected-error@+1 {{'vector.store' op base memref should have a default identity layout}}
|
||||
vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
|
||||
vector<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @vector_memref_mismatch(%memref : memref<200x100xvector<4xf32>>, %i : index,
|
||||
%j : index, %value : vector<8xf32>) {
|
||||
// expected-error@+1 {{'vector.store' op base memref and valueToStore vector types should match}}
|
||||
vector.store %value, %memref[%i, %j] : memref<200x100xvector<4xf32>>, vector<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @store_base_type_mismatch(%base : memref<?xf64>, %value : vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.store' op base and valueToStore element type should match}}
|
||||
vector.store %value, %base[%c0] : memref<?xf64>, vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16xf32>) {
|
||||
// expected-error@+1 {{'vector.store' op requires 1 indices}}
|
||||
vector.store %value, %base[] : memref<?xf32>, vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
|
||||
|
@ -1231,7 +1263,7 @@ func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pa
|
|||
|
||||
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
|
||||
// expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}}
|
||||
vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
|
||||
}
|
||||
|
||||
|
@ -1239,7 +1271,7 @@ func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
|
|||
|
||||
func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
|
||||
// expected-error@+1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
|
||||
vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -450,6 +450,56 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
|
|||
return %0 : vector<16xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
|
||||
func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
|
||||
%i : index, %j : index) {
|
||||
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
|
||||
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
|
||||
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
|
||||
vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_load_and_store_1d_vector_memref
|
||||
func @vector_load_and_store_1d_vector_memref(%memref : memref<200x100xvector<8xf32>>,
|
||||
%i : index, %j : index) {
|
||||
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
|
||||
%0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
|
||||
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
|
||||
vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_load_and_store_out_of_bounds
|
||||
func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
|
||||
%0 = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
|
||||
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
|
||||
vector.store %0, %memref[%c0] : memref<7xf32>, vector<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_load_and_store_2d_scalar_memref
|
||||
func @vector_load_and_store_2d_scalar_memref(%memref : memref<200x100xf32>,
|
||||
%i : index, %j : index) {
|
||||
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
|
||||
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
|
||||
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
|
||||
vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vector_load_and_store_2d_vector_memref
|
||||
func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvector<4x8xf32>>,
|
||||
%i : index, %j : index) {
|
||||
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
|
||||
%0 = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
|
||||
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
|
||||
vector.store %0, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @masked_load_and_store
|
||||
func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue