[mlir][amx] add a full tile matrix mult example to integration tests

Rationale:
Demonstrates the maximum tile size allowed for the f32 <= bf16 x bf16 op

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D118277
This commit is contained in:
Aart Bik 2022-01-26 11:58:17 -08:00
parent 48a38954c9
commit a5257ae277
1 changed files with 141 additions and 0 deletions

View File

@ -0,0 +1,141 @@
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
// Note: To run this test, your CPU must support AMX.
// Multiply full size tiles into zero destination.
func @kernel(%arg0: memref<16x32xbf16>,
%arg1: memref<16x32xbf16>,
%arg2: memref<16x16xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%3 = amx.tile_zero : vector<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
return
}
func @entry() -> i32 {
%fu = arith.constant -1.0: f32
%c0 = arith.constant 0: index
%c1 = arith.constant 1: index
%c16 = arith.constant 16: index
%c32 = arith.constant 32: index
// Setup simple test data. Note that bf16 does not seem to work well with the
// tensor type yet, which is why we use vectors and transfer the data into memref.
// TODO: use tensors and bufferization.to_memref instead
%0 = arith.constant dense<[
[ 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
]> : vector<16x32xbf16>
%1 = arith.constant dense<[
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
]> : vector<16x32xbf16>
// Set up memory.
%a = memref.alloc() : memref<16x32xbf16>
%b = memref.alloc() : memref<16x32xbf16>
%c = memref.alloc() : memref<16x16xf32>
vector.transfer_write %0, %a[%c0, %c0] : vector<16x32xbf16>, memref<16x32xbf16>
vector.transfer_write %1, %b[%c0, %c0] : vector<16x32xbf16>, memref<16x32xbf16>
// Call kernel.
call @kernel(%a, %b, %c) : (memref<16x32xbf16>, memref<16x32xbf16>, memref<16x16xf32>) -> ()
//
// Print and verify the 16x16 result.
//
// CHECK: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
// CHECK-NEXT: ( 32.2031, 34.5, 34.6094, 33.7969, 33.0284, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
// CHECK-NEXT: ( 32.2969, 34.5938, 34.7031, 33.8906, 33.1619, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
// CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
// CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
// CHECK-NEXT: ( 32.6016, 34.8984, 35.0078, 34.3739, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
// CHECK-NEXT: ( 32.7031, 35, 35.1094, 34.577, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
// CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
// CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
// CHECK-NEXT: ( 32.7031, 35, 35.4609, 34.2969, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
// CHECK-NEXT: ( 32.6016, 34.8984, 35.3697, 34.1953, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
// CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
// CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
// CHECK-NEXT: ( 32.2969, 34.8025, 34.7031, 33.8906, 33.1016, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
// CHECK-NEXT: ( 32.2031, 34.6619, 34.6094, 33.7969, 33.0078, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
// CHECK-NEXT: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
//
scf.for %i = %c0 to %c16 step %c1 {
%v = vector.transfer_read %c[%i, %c0], %fu: memref<16x16xf32>, vector<16xf32>
vector.print %v : vector<16xf32>
}
// Release resources.
memref.dealloc %a : memref<16x32xbf16>
memref.dealloc %b : memref<16x32xbf16>
memref.dealloc %c : memref<16x16xf32>
%i0 = arith.constant 0 : i32
return %i0 : i32
}