forked from OSchip/llvm-project
[mlir][AVX512] Add while loop-based sparse vector-vector dot product variants.
Differential Revision: https://reviews.llvm.org/D98480
This commit is contained in:
parent
0ddd537605
commit
581672be04
|
@ -9,10 +9,17 @@
|
|||
// Each sparse vector is represented by an index memref (A or C) and by a data
|
||||
// memref (B or D), containing M or N elements.
|
||||
//
|
||||
// There are two implementations:
|
||||
// There are four different implementations:
|
||||
// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops.
|
||||
// * `memref_dot_optimized`: An optimized O(N*M) version of the previous
|
||||
// implementation, where the second for loop skips over some elements.
|
||||
// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a
|
||||
// single while loop, coiterating over both vectors.
|
||||
// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that
|
||||
// consists of a single while loop and has no branches within the loop.
|
||||
//
|
||||
// Output of llvm-mca:
|
||||
// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841
|
||||
|
||||
#contraction_accesses = [
|
||||
affine_map<(i) -> (i)>,
|
||||
|
@ -224,6 +231,166 @@ func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
|
|||
return %r0 : f64
|
||||
}
|
||||
|
||||
// Vector dot product with a while loop. Implemented as follows:
|
||||
//
|
||||
// r = 0.0, a = 0, b = 0
|
||||
// while (a < M && b < N) {
|
||||
// segA = A[a:a+8], segB = B[b:b+8]
|
||||
// if (segB[7] < segA[0]) b += 8
|
||||
// elif (segA[7] < segB[0]) a += 8
|
||||
// else {
|
||||
// r += vector_dot(...)
|
||||
// if (segA[7] < segB[7]) a += 8
|
||||
// elif (segB[7] < segA[7]) b += 8
|
||||
// else a += 8, b += 8
|
||||
// }
|
||||
// }
|
||||
func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
|
||||
%m_C : memref<?xi64>, %m_D : memref<?xf64>,
|
||||
%M : index, %N : index)
|
||||
-> f64 {
|
||||
// Helper constants for loops.
|
||||
%c0 = constant 0 : index
|
||||
%i0 = constant 0 : i32
|
||||
%i7 = constant 7 : i32
|
||||
%c8 = constant 8 : index
|
||||
|
||||
%data_zero = constant 0.0 : f64
|
||||
%index_padding = constant 9223372036854775807 : i64
|
||||
|
||||
%r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
|
||||
: (f64, index, index) -> (f64, index, index) {
|
||||
%cond_i = cmpi "slt", %a1, %M : index
|
||||
%cond_j = cmpi "slt", %b1, %N : index
|
||||
%cond = and %cond_i, %cond_j : i1
|
||||
scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
|
||||
} do {
|
||||
^bb0(%r1 : f64, %a1 : index, %b1 : index):
|
||||
// v_A, v_B, seg*_* could be part of the loop state to avoid a few
|
||||
// redundant reads.
|
||||
%v_A = vector.transfer_read %m_A[%a1], %index_padding
|
||||
: memref<?xi64>, vector<8xi64>
|
||||
%v_C = vector.transfer_read %m_C[%b1], %index_padding
|
||||
: memref<?xi64>, vector<8xi64>
|
||||
|
||||
%segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
|
||||
%segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
|
||||
%segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64>
|
||||
%segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
|
||||
|
||||
%seg1_done = cmpi "slt", %segB_max, %segA_min : i64
|
||||
%r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
|
||||
%b3 = addi %b1, %c8 : index
|
||||
scf.yield %r1, %a1, %b3 : f64, index, index
|
||||
} else {
|
||||
%seg0_done = cmpi "slt", %segA_max, %segB_min : i64
|
||||
%r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) {
|
||||
%a5 = addi %a1, %c8 : index
|
||||
scf.yield %r1, %a5, %b1 : f64, index, index
|
||||
} else {
|
||||
%v_B = vector.transfer_read %m_B[%a1], %data_zero
|
||||
: memref<?xf64>, vector<8xf64>
|
||||
%v_D = vector.transfer_read %m_D[%b1], %data_zero
|
||||
: memref<?xf64>, vector<8xf64>
|
||||
|
||||
%subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
|
||||
: (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
|
||||
-> f64
|
||||
%r6 = addf %r1, %subresult : f64
|
||||
|
||||
%incr_a = cmpi "slt", %segA_max, %segB_max : i64
|
||||
%a6, %b6 = scf.if %incr_a -> (index, index) {
|
||||
%a7 = addi %a1, %c8 : index
|
||||
scf.yield %a7, %b1 : index, index
|
||||
} else {
|
||||
%incr_b = cmpi "slt", %segB_max, %segA_max : i64
|
||||
%a8, %b8 = scf.if %incr_b -> (index, index) {
|
||||
%b9 = addi %b1, %c8 : index
|
||||
scf.yield %a1, %b9 : index, index
|
||||
} else {
|
||||
%a10 = addi %a1, %c8 : index
|
||||
%b10 = addi %b1, %c8 : index
|
||||
scf.yield %a10, %b10 : index, index
|
||||
}
|
||||
scf.yield %a8, %b8 : index, index
|
||||
}
|
||||
scf.yield %r6, %a6, %b6 : f64, index, index
|
||||
}
|
||||
scf.yield %r4, %a4, %b4 : f64, index, index
|
||||
}
|
||||
scf.yield %r2, %a2, %b2 : f64, index, index
|
||||
}
|
||||
|
||||
return %r0 : f64
|
||||
}
|
||||
|
||||
// Vector dot product with a while loop that has no branches (apart from the
|
||||
// while loop itself). Implemented as follows:
|
||||
//
|
||||
// r = 0.0, a = 0, b = 0
|
||||
// while (a < M && b < N) {
|
||||
// segA = A[a:a+8], segB = B[b:b+8]
|
||||
// r += vector_dot(...)
|
||||
// a += (segA[7] <= segB[7]) * 8
|
||||
// b += (segB[7] <= segA[7]) * 8
|
||||
// }
|
||||
func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
|
||||
%m_C : memref<?xi64>, %m_D : memref<?xf64>,
|
||||
%M : index, %N : index)
|
||||
-> f64 {
|
||||
// Helper constants for loops.
|
||||
%c0 = constant 0 : index
|
||||
%i7 = constant 7 : i32
|
||||
%c8 = constant 8 : index
|
||||
|
||||
%data_zero = constant 0.0 : f64
|
||||
%index_padding = constant 9223372036854775807 : i64
|
||||
|
||||
%r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
|
||||
: (f64, index, index) -> (f64, index, index) {
|
||||
%cond_i = cmpi "slt", %a1, %M : index
|
||||
%cond_j = cmpi "slt", %b1, %N : index
|
||||
%cond = and %cond_i, %cond_j : i1
|
||||
scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
|
||||
} do {
|
||||
^bb0(%r1 : f64, %a1 : index, %b1 : index):
|
||||
// v_A, v_B, seg*_* could be part of the loop state to avoid a few
|
||||
// redundant reads.
|
||||
%v_A = vector.transfer_read %m_A[%a1], %index_padding
|
||||
: memref<?xi64>, vector<8xi64>
|
||||
%v_B = vector.transfer_read %m_B[%a1], %data_zero
|
||||
: memref<?xf64>, vector<8xf64>
|
||||
%v_C = vector.transfer_read %m_C[%b1], %index_padding
|
||||
: memref<?xi64>, vector<8xi64>
|
||||
%v_D = vector.transfer_read %m_D[%b1], %data_zero
|
||||
: memref<?xf64>, vector<8xf64>
|
||||
|
||||
%subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
|
||||
: (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
|
||||
-> f64
|
||||
%r2 = addf %r1, %subresult : f64
|
||||
|
||||
%segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
|
||||
%segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
|
||||
|
||||
%cond_a = cmpi "sle", %segA_max, %segB_max : i64
|
||||
%cond_a_i64 = zexti %cond_a : i1 to i64
|
||||
%cond_a_idx = index_cast %cond_a_i64 : i64 to index
|
||||
%incr_a = muli %cond_a_idx, %c8 : index
|
||||
%a2 = addi %a1, %incr_a : index
|
||||
|
||||
%cond_b = cmpi "sle", %segB_max, %segA_max : i64
|
||||
%cond_b_i64 = zexti %cond_b : i1 to i64
|
||||
%cond_b_idx = index_cast %cond_b_i64 : i64 to index
|
||||
%incr_b = muli %cond_b_idx, %c8 : index
|
||||
%b2 = addi %b1, %incr_b : index
|
||||
|
||||
scf.yield %r2, %a2, %b2 : f64, index, index
|
||||
}
|
||||
|
||||
return %r0 : f64
|
||||
}
|
||||
|
||||
func @entry() -> i32 {
|
||||
// Initialize large buffers that can be used for multiple test cases of
|
||||
// different sizes.
|
||||
|
@ -256,6 +423,18 @@ func @entry() -> i32 {
|
|||
vector.print %r1 : f64
|
||||
// CHECK: 86
|
||||
|
||||
%r2 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
|
||||
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
|
||||
index, index) -> f64
|
||||
vector.print %r2 : f64
|
||||
// CHECK: 86
|
||||
|
||||
%r6 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
|
||||
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
|
||||
index, index) -> f64
|
||||
vector.print %r6 : f64
|
||||
// CHECK: 86
|
||||
|
||||
// --- Test case 2 ---.
|
||||
// M and N must be a multiple of 8 if smaller than 128.
|
||||
// (Because padding kicks in only for out-of-bounds accesses.)
|
||||
|
@ -275,6 +454,18 @@ func @entry() -> i32 {
|
|||
vector.print %r4 : f64
|
||||
// CHECK: 111
|
||||
|
||||
%r5 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
|
||||
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
|
||||
index, index) -> f64
|
||||
vector.print %r5 : f64
|
||||
// CHECK: 111
|
||||
|
||||
%r7 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
|
||||
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
|
||||
index, index) -> f64
|
||||
vector.print %r7 : f64
|
||||
// CHECK: 111
|
||||
|
||||
// Release all resources.
|
||||
dealloc %b_A : memref<128xi64>
|
||||
dealloc %b_B : memref<128xf64>
|
||||
|
|
Loading…
Reference in New Issue