[mlir][StandardToSPIRV] Use spv.UMod for index re-calculation

Per Vulkan's SPIR-V environment spec: "While the OpSRem and OpSMod
instructions are supported by the Vulkan environment, they require
non-negative values and thus do not enable additional functionality
beyond what OpUMod provides."

The `getOffsetForBitwidth` function is used for lowering std.load
and std.store, whose indices are of `index` type and cannot be
negative. So we should be okay to use spv.UMod directly here to
be exact. Also made the comment explicit about the assumption.

Differential Revision: https://reviews.llvm.org/D83714
This commit is contained in:
Lei Zhang 2020-07-13 16:20:59 -04:00
parent f3056dcc02
commit 0d03b3901d
2 changed files with 13 additions and 10 deletions

View File

@ -126,9 +126,12 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
}
/// Returns the offset of the value in `targetBits` representation. `srcIdx` is
/// an index into a 1-D array with each element having `sourceBits`. When
/// accessing an element in the array treating as having elements of
/// Returns the offset of the value in `targetBits` representation.
///
/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
/// It's assumed to be non-negative.
///
/// When accessing an element in the array treating as having elements of
/// `targetBits`, multiple values are loaded in the same time. The method
/// returns the offset where the `srcIdx` locates in the value. For example, if
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
@ -144,7 +147,7 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
auto srcBitsValue =
builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
auto m = builder.create<spirv::SModOp>(loc, srcIdx, idx);
auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
}

View File

@ -762,7 +762,7 @@ func @load_i8(%arg0: memref<i8>) {
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
@ -788,7 +788,7 @@ func @load_i16(%arg0: memref<10xi16>, %index : index) {
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
// CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO2]] : i32
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
@ -824,7 +824,7 @@ func @store_i8(%arg0: memref<i8>, %value: i8) {
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
// CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32
// CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
// CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
@ -850,7 +850,7 @@ func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
// CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32
// CHECK: %[[TWO:.+]] = spv.constant 2 : i32
// CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
// CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[MASK1:.+]] = spv.constant 65535 : i32
// CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
@ -907,7 +907,7 @@ func @load_i8(%arg0: memref<i8>) {
// CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]]
// CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
@ -934,7 +934,7 @@ func @store_i8(%arg0: memref<i8>, %value: i8) {
// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
// CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
// CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
// CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32
// CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
// CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32