forked from OSchip/llvm-project
[mlir][spirv] Allow bitwidth emulation on runtime arrays
Runtime arrays are converted from memrefs with unknown dimensions. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D100335
This commit is contained in:
parent
a3fabc79ae
commit
2eb98d89ac
|
@ -994,13 +994,16 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
|
||||||
bool isBool = srcBits == 1;
|
bool isBool = srcBits == 1;
|
||||||
if (isBool)
|
if (isBool)
|
||||||
srcBits = typeConverter.getOptions().boolNumBits;
|
srcBits = typeConverter.getOptions().boolNumBits;
|
||||||
auto dstType = typeConverter.convertType(memrefType)
|
Type pointeeType = typeConverter.convertType(memrefType)
|
||||||
.cast<spirv::PointerType>()
|
.cast<spirv::PointerType>()
|
||||||
.getPointeeType()
|
.getPointeeType();
|
||||||
.cast<spirv::StructType>()
|
Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
|
||||||
.getElementType(0)
|
Type dstType;
|
||||||
.cast<spirv::ArrayType>()
|
if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
|
||||||
.getElementType();
|
dstType = arrayType.getElementType();
|
||||||
|
else
|
||||||
|
dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
|
||||||
|
|
||||||
int dstBits = dstType.getIntOrFloatBitWidth();
|
int dstBits = dstType.getIntOrFloatBitWidth();
|
||||||
assert(dstBits % srcBits == 0);
|
assert(dstBits % srcBits == 0);
|
||||||
|
|
||||||
|
@ -1136,13 +1139,16 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
||||||
bool isBool = srcBits == 1;
|
bool isBool = srcBits == 1;
|
||||||
if (isBool)
|
if (isBool)
|
||||||
srcBits = typeConverter.getOptions().boolNumBits;
|
srcBits = typeConverter.getOptions().boolNumBits;
|
||||||
auto dstType = typeConverter.convertType(memrefType)
|
Type pointeeType = typeConverter.convertType(memrefType)
|
||||||
.cast<spirv::PointerType>()
|
.cast<spirv::PointerType>()
|
||||||
.getPointeeType()
|
.getPointeeType();
|
||||||
.cast<spirv::StructType>()
|
Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
|
||||||
.getElementType(0)
|
Type dstType;
|
||||||
.cast<spirv::ArrayType>()
|
if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
|
||||||
.getElementType();
|
dstType = arrayType.getElementType();
|
||||||
|
else
|
||||||
|
dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
|
||||||
|
|
||||||
int dstBits = dstType.getIntOrFloatBitWidth();
|
int dstBits = dstType.getIntOrFloatBitWidth();
|
||||||
assert(dstBits % srcBits == 0);
|
assert(dstBits % srcBits == 0);
|
||||||
|
|
||||||
|
|
|
@ -905,6 +905,19 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @load_store_unknown_dim
|
||||||
|
// CHECK-SAME: %[[SRC:[a-z0-9]+]]: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>,
|
||||||
|
// CHECK-SAME: %[[DST:[a-z0-9]+]]: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>)
|
||||||
|
func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?xi32>) {
|
||||||
|
// CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]]
|
||||||
|
// CHECK: spv.Load "StorageBuffer" %[[AC0]]
|
||||||
|
%0 = memref.load %source[%i] : memref<?xi32>
|
||||||
|
// CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]]
|
||||||
|
// CHECK: spv.Store "StorageBuffer" %[[AC1]]
|
||||||
|
memref.store %0, %dest[%i]: memref<?xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
} // end module
|
} // end module
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue