[mlir][AVX512] Add while loop-based sparse vector-vector dot product variants.

Differential Revision: https://reviews.llvm.org/D98480
This commit is contained in:
Matthias Springer 2021-03-15 16:52:40 +09:00
parent 0ddd537605
commit 581672be04
1 changed files with 192 additions and 1 deletions

View File

@ -9,10 +9,17 @@
// Each sparse vector is represented by an index memref (A or C) and by a data
// memref (B or D), containing M or N elements.
// There are two implementations:
// There are four different implementations:
// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops.
// * `memref_dot_optimized`: An optimized O(N*M) version of the previous
// implementation, where the second for loop skips over some elements.
// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a
// single while loop, coiterating over both vectors.
// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that
// consists of a single while loop and has no branches within the loop.
// Output of llvm-mca:
// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841
#contraction_accesses = [
affine_map<(i) -> (i)>,
@ -224,6 +231,166 @@ func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
return %r0 : f64
// Vector dot product with a while loop. Implemented as follows:
// r = 0.0, a = 0, b = 0
// while (a < M && b < N) {
// segA = A[a:a+8], segB = B[b:b+8]
// if (segB[7] < segA[0]) b += 8
// elif (segA[7] < segB[0]) a += 8
// else {
// r += vector_dot(...)
// if (segA[7] < segB[7]) a += 8
// elif (segB[7] < segA[7]) b += 8
// else a += 8, b += 8
// }
// }
func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
%m_C : memref<?xi64>, %m_D : memref<?xf64>,
%M : index, %N : index)
-> f64 {
// Helper constants for loops.
%c0 = constant 0 : index
%i0 = constant 0 : i32
%i7 = constant 7 : i32
%c8 = constant 8 : index
%data_zero = constant 0.0 : f64
%index_padding = constant 9223372036854775807 : i64
%r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
: (f64, index, index) -> (f64, index, index) {
%cond_i = cmpi "slt", %a1, %M : index
%cond_j = cmpi "slt", %b1, %N : index
%cond = and %cond_i, %cond_j : i1
scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
} do {
^bb0(%r1 : f64, %a1 : index, %b1 : index):
// v_A, v_B, seg*_* could be part of the loop state to avoid a few
// redundant reads.
%v_A = vector.transfer_read %m_A[%a1], %index_padding
: memref<?xi64>, vector<8xi64>
%v_C = vector.transfer_read %m_C[%b1], %index_padding
: memref<?xi64>, vector<8xi64>
%segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64>
%segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
%segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64>
%segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
%seg1_done = cmpi "slt", %segB_max, %segA_min : i64
%r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) {
%b3 = addi %b1, %c8 : index
scf.yield %r1, %a1, %b3 : f64, index, index
} else {
%seg0_done = cmpi "slt", %segA_max, %segB_min : i64
%r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) {
%a5 = addi %a1, %c8 : index
scf.yield %r1, %a5, %b1 : f64, index, index
} else {
%v_B = vector.transfer_read %m_B[%a1], %data_zero
: memref<?xf64>, vector<8xf64>
%v_D = vector.transfer_read %m_D[%b1], %data_zero
: memref<?xf64>, vector<8xf64>
%subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
: (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
-> f64
%r6 = addf %r1, %subresult : f64
%incr_a = cmpi "slt", %segA_max, %segB_max : i64
%a6, %b6 = scf.if %incr_a -> (index, index) {
%a7 = addi %a1, %c8 : index
scf.yield %a7, %b1 : index, index
} else {
%incr_b = cmpi "slt", %segB_max, %segA_max : i64
%a8, %b8 = scf.if %incr_b -> (index, index) {
%b9 = addi %b1, %c8 : index
scf.yield %a1, %b9 : index, index
} else {
%a10 = addi %a1, %c8 : index
%b10 = addi %b1, %c8 : index
scf.yield %a10, %b10 : index, index
scf.yield %a8, %b8 : index, index
scf.yield %r6, %a6, %b6 : f64, index, index
scf.yield %r4, %a4, %b4 : f64, index, index
scf.yield %r2, %a2, %b2 : f64, index, index
return %r0 : f64
// Vector dot product with a while loop that has no branches (apart from the
// while loop itself). Implemented as follows:
// r = 0.0, a = 0, b = 0
// while (a < M && b < N) {
// segA = A[a:a+8], segB = B[b:b+8]
// r += vector_dot(...)
// a += (segA[7] <= segB[7]) * 8
// b += (segB[7] <= segA[7]) * 8
// }
func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64>,
%m_C : memref<?xi64>, %m_D : memref<?xf64>,
%M : index, %N : index)
-> f64 {
// Helper constants for loops.
%c0 = constant 0 : index
%i7 = constant 7 : i32
%c8 = constant 8 : index
%data_zero = constant 0.0 : f64
%index_padding = constant 9223372036854775807 : i64
%r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0)
: (f64, index, index) -> (f64, index, index) {
%cond_i = cmpi "slt", %a1, %M : index
%cond_j = cmpi "slt", %b1, %N : index
%cond = and %cond_i, %cond_j : i1
scf.condition(%cond) %r1, %a1, %b1 : f64, index, index
} do {
^bb0(%r1 : f64, %a1 : index, %b1 : index):
// v_A, v_B, seg*_* could be part of the loop state to avoid a few
// redundant reads.
%v_A = vector.transfer_read %m_A[%a1], %index_padding
: memref<?xi64>, vector<8xi64>
%v_B = vector.transfer_read %m_B[%a1], %data_zero
: memref<?xf64>, vector<8xf64>
%v_C = vector.transfer_read %m_C[%b1], %index_padding
: memref<?xi64>, vector<8xi64>
%v_D = vector.transfer_read %m_D[%b1], %data_zero
: memref<?xf64>, vector<8xf64>
%subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D)
: (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>)
-> f64
%r2 = addf %r1, %subresult : f64
%segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64>
%segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64>
%cond_a = cmpi "sle", %segA_max, %segB_max : i64
%cond_a_i64 = zexti %cond_a : i1 to i64
%cond_a_idx = index_cast %cond_a_i64 : i64 to index
%incr_a = muli %cond_a_idx, %c8 : index
%a2 = addi %a1, %incr_a : index
%cond_b = cmpi "sle", %segB_max, %segA_max : i64
%cond_b_i64 = zexti %cond_b : i1 to i64
%cond_b_idx = index_cast %cond_b_i64 : i64 to index
%incr_b = muli %cond_b_idx, %c8 : index
%b2 = addi %b1, %incr_b : index
scf.yield %r2, %a2, %b2 : f64, index, index
return %r0 : f64
func @entry() -> i32 {
// Initialize large buffers that can be used for multiple test cases of
// different sizes.
@ -256,6 +423,18 @@ func @entry() -> i32 {
vector.print %r1 : f64
// CHECK: 86
%r2 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
index, index) -> f64
vector.print %r2 : f64
// CHECK: 86
%r6 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1)
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
index, index) -> f64
vector.print %r6 : f64
// CHECK: 86
// --- Test case 2 ---.
// M and N must be a multiple of 8 if smaller than 128.
// (Because padding kicks in only for out-of-bounds accesses.)
@ -275,6 +454,18 @@ func @entry() -> i32 {
vector.print %r4 : f64
// CHECK: 111
%r5 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
index, index) -> f64
vector.print %r5 : f64
// CHECK: 111
%r7 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2)
: (memref<?xi64>, memref<?xf64>, memref<?xi64>, memref<?xf64>,
index, index) -> f64
vector.print %r7 : f64
// CHECK: 111
// Release all resources.
dealloc %b_A : memref<128xi64>
dealloc %b_B : memref<128xf64>