[mlir][Vector] Introduce 'vector.load' and 'vector.store' ops

This patch adds the 'vector.load' and 'vector.store' ops to the Vector
dialect [1]. These operations model *contiguous* vector loads and stores
from/to memory. Their semantics are similar to the 'affine.vector_load' and
'affine.vector_store' counterparts but without the affine constraints. The
most relevant feature is that these new vector operations may perform a vector
load/store on memrefs with a non-vector element type, unlike 'std.load' and
'std.store' ops. This opens the representation to model more generic vector
load/store scenarios: unaligned vector loads/stores, perform scalar and vector
memory access on the same memref, decouple memory allocation constraints from
memory accesses, etc [1]. These operations will also facilitate the progressive
lowering of both Affine vector loads/stores and Vector transfer reads/writes
for those that read/write contiguous slices from/to memory.

In particular, this patch adds the 'vector.load' and 'vector.store' ops to the
Vector dialect, implements their lowering to the LLVM dialect, and changes the
lowering of 'affine.vector_load' and 'affine.vector_store' ops to the new vector
ops. The lowering of Vector transfer reads/writes will be implemented in the
future, probably as an independent pass. The API of 'vector.maskedload' and
'vector.maskedstore' has also been changed slightly to align it with the
transfer read/write ops and the vector new ops. This will improve reusability
among all these operations. For example, the lowering of 'vector.load',
'vector.store', 'vector.maskedload' and 'vector.maskedstore' to the LLVM dialect
is implemented with a single template conversion pattern.

[1] https://llvm.discourse.group/t/memref-type-and-data-layout/

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96185
This commit is contained in:
Diego Caballero 2021-02-12 19:41:46 +02:00
parent 98754e2909
commit ee66e43a96
8 changed files with 414 additions and 116 deletions

View File

@ -1320,6 +1320,156 @@ def Vector_TransferWriteOp :
let hasFolder = 1;
}
def Vector_LoadOp : Vector_Op<"load"> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
The 'vector.load' operation reads an n-D slice of memory into an n-D
vector. It takes a 'base' memref, an index for each memref dimension and a
result vector type as arguments. It returns a value of the result vector
type. The 'base' memref and indices determine the start memory address from
which to read. Each index provides an offset for each memref dimension
based on the element type of the memref. The shape of the result vector
type determines the shape of the slice read from the start memory address.
The elements along each dimension of the slice are strided by the memref
strides. Only memref with default strides are allowed. These constraints
guarantee that elements read along the first dimension of the slice are
contiguous in memory.
The memref element type can be a scalar or a vector type. If the memref
element type is a scalar, it should match the element type of the result
vector. If the memref element type is vector, it should match the result
vector type.
Example 1: 1-D vector load on a scalar memref.
```mlir
%result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
```
Example 2: 1-D vector load on a vector memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
```
Example 3: 2-D vector load on a scalar memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
```
Example 4: 2-D vector load on a vector memref.
```mlir
%result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
```
Representation-wise, the 'vector.load' operation permits out-of-bounds
reads. Support and implementation of out-of-bounds vector loads is
target-specific. No assumptions should be made on the value of elements
loaded out of bounds. Not all targets may support out-of-bounds vector
loads.
Example 5: Potential out-of-bound vector load.
```mlir
%result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
```
Example 6: Explicit out-of-bound vector load.
```mlir
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
```
}];
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs AnyVector:$result);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
VectorType getVectorType() {
return result().getType().cast<VectorType>();
}
}];
let assemblyFormat =
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
}
def Vector_StoreOp : Vector_Op<"store"> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
The 'vector.store' operation writes an n-D vector to an n-D slice of memory.
It takes the vector value to be stored, a 'base' memref and an index for
each memref dimension. The 'base' memref and indices determine the start
memory address from which to write. Each index provides an offset for each
memref dimension based on the element type of the memref. The shape of the
vector value to store determines the shape of the slice written from the
start memory address. The elements along each dimension of the slice are
strided by the memref strides. Only memref with default strides are allowed.
These constraints guarantee that elements written along the first dimension
of the slice are contiguous in memory.
The memref element type can be a scalar or a vector type. If the memref
element type is a scalar, it should match the element type of the value
to store. If the memref element type is vector, it should match the type
of the value to store.
Example 1: 1-D vector store on a scalar memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
```
Example 2: 1-D vector store on a vector memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
```
Example 3: 2-D vector store on a scalar memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
```
Example 4: 2-D vector store on a vector memref.
```mlir
vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
```
Representation-wise, the 'vector.store' operation permits out-of-bounds
writes. Support and implementation of out-of-bounds vector stores are
target-specific. No assumptions should be made on the memory written out of
bounds. Not all targets may support out-of-bounds vector stores.
Example 5: Potential out-of-bounds vector store.
```mlir
vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
```
Example 6: Explicit out-of-bounds vector store.
```mlir
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
```
}];
let arguments = (ins AnyVector:$valueToStore,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
VectorType getVectorType() {
return valueToStore().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
"`:` type($base) `,` type($valueToStore)";
}
def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
@ -1363,7 +1513,7 @@ def Vector_MaskedLoadOp :
VectorType getPassThruVectorType() {
return pass_thru().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
VectorType getVectorType() {
return result().getType().cast<VectorType>();
}
}];
@ -1377,7 +1527,7 @@ def Vector_MaskedStoreOp :
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
VectorOfRank<[1]>:$valueToStore)> {
let summary = "stores elements from a vector into memory as defined by a mask vector";
@ -1411,12 +1561,13 @@ def Vector_MaskedStoreOp :
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
VectorType getValueVectorType() {
return value().getType().cast<VectorType>();
VectorType getVectorType() {
return valueToStore().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
"type($base) `,` type($mask) `,` type($value)";
let assemblyFormat =
"$base `[` $indices `]` `,` $mask `,` $valueToStore "
"attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
}

View File

@ -578,8 +578,9 @@ public:
if (!resultOperands)
return failure();
// Build std.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
// Build vector.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, op.getMemRef(),
*resultOperands);
return success();
}
};
@ -625,8 +626,8 @@ public:
return failure();
// Build std.store valueToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
op.getMemRef(), *maybeExpandedMap);
rewriter.replaceOpWithNewOp<mlir::StoreOp>(
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
return success();
}
};
@ -695,8 +696,8 @@ public:
};
/// Apply the affine map from an 'affine.vector_load' operation to its operands,
/// and feed the results to a newly created 'vector.transfer_read' operation
/// (which replaces the original 'affine.vector_load').
/// and feed the results to a newly created 'vector.load' operation (which
/// replaces the original 'affine.vector_load').
class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
public:
using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
@ -710,16 +711,16 @@ public:
if (!resultOperands)
return failure();
// Build vector.transfer_read memref[expandedMap.results].
rewriter.replaceOpWithNewOp<TransferReadOp>(
// Build vector.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getVectorType(), op.getMemRef(), *resultOperands);
return success();
}
};
/// Apply the affine map from an 'affine.vector_store' operation to its
/// operands, and feed the results to a newly created 'vector.transfer_write'
/// operation (which replaces the original 'affine.vector_store').
/// operands, and feed the results to a newly created 'vector.store' operation
/// (which replaces the original 'affine.vector_store').
class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
public:
using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
@ -733,7 +734,7 @@ public:
if (!maybeExpandedMap)
return failure();
rewriter.replaceOpWithNewOp<TransferWriteOp>(
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
return success();
}

View File

@ -357,64 +357,72 @@ public:
}
};
/// Conversion pattern for a vector.maskedload.
class VectorMaskedLoadOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
/// Overloaded utility that replaces a vector.load, vector.store,
/// vector.maskedload and vector.maskedstore with their respective LLVM
/// couterparts.
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
vector::LoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
}
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
vector::MaskedLoadOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
}
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
vector::StoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
ptr, align);
}
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
vector::MaskedStoreOpAdaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
}
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
/// vector.maskedstore.
template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
MemRefType memRefType = load.getMemRefType();
// Only 1-D vectors can be lowered to LLVM.
VectorType vectorTy = loadOrStoreOp.getVectorType();
if (vectorTy.getRank() > 1)
return failure();
auto loc = loadOrStoreOp->getLoc();
auto adaptor = LoadOrStoreOpAdaptor(operands);
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
return failure();
// Resolve address.
auto vtype = typeConverter->convertType(load.getResultVectorType());
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
.template cast<VectorType>();
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
rewriter.getI32IntegerAttr(align));
return success();
}
};
/// Conversion pattern for a vector.maskedstore.
class VectorMaskedStoreOpConversion
: public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
public:
using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
MemRefType memRefType = store.getMemRefType();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
auto vtype = typeConverter->convertType(store.getValueVectorType());
Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
store, adaptor.value(), ptr, adaptor.mask(),
rewriter.getI32IntegerAttr(align));
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
return success();
}
};
@ -1511,8 +1519,14 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion,
VectorPrintOpConversion,
VectorTypeCastOpConversion,
VectorMaskedLoadOpConversion,
VectorMaskedStoreOpConversion,
VectorLoadStoreConversion<vector::LoadOp,
vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,
VectorLoadStoreConversion<vector::StoreOp,
vector::StoreOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion,
VectorScatterOpConversion,
VectorExpandLoadOpConversion,

View File

@ -2373,6 +2373,67 @@ void TransferWriteOp::getEffects(
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
MemRefType memRefTy) {
auto affineMaps = memRefTy.getAffineMaps();
if (!affineMaps.empty())
return op->emitOpError("base memref should have a default identity layout");
return success();
}
static LogicalResult verify(vector::LoadOp op) {
VectorType resVecTy = op.getVectorType();
MemRefType memRefTy = op.getMemRefType();
if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
return failure();
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
if (memVecTy != resVecTy)
return op.emitOpError("base memref and result vector types should match");
memElemTy = memVecTy.getElementType();
}
if (resVecTy.getElementType() != memElemTy)
return op.emitOpError("base and result element types should match");
if (llvm::size(op.indices()) != memRefTy.getRank())
return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
return success();
}
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(vector::StoreOp op) {
VectorType valueVecTy = op.getVectorType();
MemRefType memRefTy = op.getMemRefType();
if (failed(verifyLoadStoreMemRefLayout(op, memRefTy)))
return failure();
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
if (memVecTy != valueVecTy)
return op.emitOpError(
"base memref and valueToStore vector types should match");
memElemTy = memVecTy.getElementType();
}
if (valueVecTy.getElementType() != memElemTy)
return op.emitOpError("base and valueToStore element type should match");
if (llvm::size(op.indices()) != memRefTy.getRank())
return op.emitOpError("requires ") << memRefTy.getRank() << " indices";
return success();
}
//===----------------------------------------------------------------------===//
// MaskedLoadOp
//===----------------------------------------------------------------------===//
@ -2380,7 +2441,7 @@ void TransferWriteOp::getEffects(
static LogicalResult verify(MaskedLoadOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType passVType = op.getPassThruVectorType();
VectorType resVType = op.getResultVectorType();
VectorType resVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (resVType.getElementType() != memType.getElementType())
@ -2427,15 +2488,15 @@ void MaskedLoadOp::getCanonicalizationPatterns(
static LogicalResult verify(MaskedStoreOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType valueVType = op.getValueVectorType();
VectorType valueVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (valueVType.getElementType() != memType.getElementType())
return op.emitOpError("base and value element type should match");
return op.emitOpError("base and valueToStore element type should match");
if (llvm::size(op.indices()) != memType.getRank())
return op.emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected value dim to match mask dim");
return op.emitOpError("expected valueToStore dim to match mask dim");
return success();
}
@ -2448,7 +2509,7 @@ public:
switch (get1DMaskFormat(store.mask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
store, store.value(), store.base(), store.indices(), false);
store, store.valueToStore(), store.base(), store.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(store);

View File

@ -1,5 +1,6 @@
// RUN: mlir-opt -lower-affine --split-input-file %s | FileCheck %s
// CHECK-LABEL: func @affine_vector_load
func @affine_vector_load(%arg0 : index) {
%0 = alloc() : memref<100xf32>
@ -10,8 +11,7 @@ func @affine_vector_load(%arg0 : index) {
// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index
// CHECK-NEXT: %[[pad:.*]] = constant 0.0
// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
// CHECK-NEXT: vector.load %[[buf]][%[[b]]] : memref<100xf32>, vector<8xf32>
return
}
@ -31,44 +31,7 @@ func @affine_vector_store(%arg0 : index) {
// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index
// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
return
}
// -----
// CHECK-LABEL: func @affine_vector_load
func @affine_vector_load(%arg0 : index) {
%0 = alloc() : memref<100xf32>
affine.for %i0 = 0 to 16 {
%1 = affine.vector_load %0[%i0 + symbol(%arg0) + 7] : memref<100xf32>, vector<8xf32>
}
// CHECK: %[[buf:.*]] = alloc
// CHECK: %[[a:.*]] = addi %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
// CHECK-NEXT: %[[b:.*]] = addi %[[a]], %[[c7]] : index
// CHECK-NEXT: %[[pad:.*]] = constant 0.0
// CHECK-NEXT: vector.transfer_read %[[buf]][%[[b]]], %[[pad]] : memref<100xf32>, vector<8xf32>
return
}
// -----
// CHECK-LABEL: func @affine_vector_store
func @affine_vector_store(%arg0 : index) {
%0 = alloc() : memref<100xf32>
%1 = constant dense<11.0> : vector<4xf32>
affine.for %i0 = 0 to 16 {
affine.vector_store %1, %0[%i0 - symbol(%arg0) + 7] : memref<100xf32>, vector<4xf32>
}
// CHECK: %[[buf:.*]] = alloc
// CHECK: %[[val:.*]] = constant dense
// CHECK: %[[c_1:.*]] = constant -1 : index
// CHECK-NEXT: %[[a:.*]] = muli %arg0, %[[c_1]] : index
// CHECK-NEXT: %[[b:.*]] = addi %{{.*}}, %[[a]] : index
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
// CHECK-NEXT: %[[c:.*]] = addi %[[b]], %[[c7]] : index
// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[c]]] : vector<4xf32>, memref<100xf32>
// CHECK-NEXT: vector.store %[[val]], %[[buf]][%[[c]]] : memref<100xf32>, vector<4xf32>
return
}
@ -83,8 +46,7 @@ func @vector_load_2d() {
// CHECK: %[[buf:.*]] = alloc
// CHECK: scf.for %[[i0:.*]] =
// CHECK: scf.for %[[i1:.*]] =
// CHECK-NEXT: %[[pad:.*]] = constant 0.0
// CHECK-NEXT: vector.transfer_read %[[buf]][%[[i0]], %[[i1]]], %[[pad]] : memref<100x100xf32>, vector<2x8xf32>
// CHECK-NEXT: vector.load %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
}
}
return
@ -103,9 +65,8 @@ func @vector_store_2d() {
// CHECK: %[[val:.*]] = constant dense
// CHECK: scf.for %[[i0:.*]] =
// CHECK: scf.for %[[i1:.*]] =
// CHECK-NEXT: vector.transfer_write %[[val]], %[[buf]][%[[i0]], %[[i1]]] : vector<2x8xf32>, memref<100x100xf32>
// CHECK-NEXT: vector.store %[[val]], %[[buf]][%[[i0]], %[[i1]]] : memref<100x100xf32>, vector<2x8xf32>
}
}
return
}

View File

@ -23,6 +23,7 @@ func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
// -----
func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2xf32>
return %0 : vector<2xf32>
@ -1242,6 +1243,33 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
// -----
func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: func @vector_load_op
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
// CHECK: llvm.load %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<8xf32>>
func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
%val = constant dense<11.0> : vector<4xf32>
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
return
}
// CHECK-LABEL: func @vector_store_op
// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64
// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

View File

@ -1198,6 +1198,38 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
// -----
func @store_unsupported_layout(%memref : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
%i : index, %j : index, %value : vector<8xf32>) {
// expected-error@+1 {{'vector.store' op base memref should have a default identity layout}}
vector.store %value, %memref[%i, %j] : memref<200x100xf32, affine_map<(d0, d1) -> (d1, d0)>>,
vector<8xf32>
}
// -----
func @vector_memref_mismatch(%memref : memref<200x100xvector<4xf32>>, %i : index,
%j : index, %value : vector<8xf32>) {
// expected-error@+1 {{'vector.store' op base memref and valueToStore vector types should match}}
vector.store %value, %memref[%i, %j] : memref<200x100xvector<4xf32>>, vector<8xf32>
}
// -----
func @store_base_type_mismatch(%base : memref<?xf64>, %value : vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.store' op base and valueToStore element type should match}}
vector.store %value, %base[%c0] : memref<?xf64>, vector<16xf32>
}
// -----
func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16xf32>) {
// expected-error@+1 {{'vector.store' op requires 1 indices}}
vector.store %value, %base[] : memref<?xf32>, vector<16xf32>
}
// -----
func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
@ -1231,7 +1263,7 @@ func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pa
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
// expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}}
vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
}
@ -1239,7 +1271,7 @@ func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
// expected-error@+1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
}

View File

@ -450,6 +450,56 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
return %0 : vector<16xi32>
}
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<8xf32>
vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
return
}
// CHECK-LABEL: @vector_load_and_store_1d_vector_memref
func @vector_load_and_store_1d_vector_memref(%memref : memref<200x100xvector<8xf32>>,
%i : index, %j : index) {
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
%0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32>
vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
return
}
// CHECK-LABEL: @vector_load_and_store_out_of_bounds
func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) {
%c0 = constant 0 : index
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
%0 = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<7xf32>, vector<8xf32>
vector.store %0, %memref[%c0] : memref<7xf32>, vector<8xf32>
return
}
// CHECK-LABEL: @vector_load_and_store_2d_scalar_memref
func @vector_load_and_store_2d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<4x8xf32>
vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
return
}
// CHECK-LABEL: @vector_load_and_store_2d_vector_memref
func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvector<4x8xf32>>,
%i : index, %j : index) {
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
%0 = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
vector.store %0, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
return
}
// CHECK-LABEL: @masked_load_and_store
func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
%c0 = constant 0 : index