[mlir] Introduce an intrinsic for llvm.matrix.multiply

This revision adds the first intrinsic for llvm.matrix.multiply.
This uses the more general `LLVM_OneResultOp` for now since the goal is
to use the
specific Matrix builders that @fhahn has created recently.

When piped through:
```
opt -O3 -enable-matrix | llc -O3 -march=x86-64 -mcpu=skylake-avx512
```
this has been verified to generate ymm instructions.

Additional function attribute support will be needed to generate proper
zmm instructions but at least things run end to end.

Benchmarking will be provided separately with the experimental
metaprogramming
[ModelBuilder](https://github.com/google/iree/tree/master/experimental/ModelBuilder)
tool when ready.
This commit is contained in:
Nicolas Vasilache 2020-03-05 17:27:52 -05:00
parent d0e8abc438
commit cac1ed1f4b
3 changed files with 36 additions and 0 deletions

View File

@ -787,6 +787,27 @@ def LLVM_experimental_vector_reduce_xor : LLVM_VectorReduction<"xor">;
def LLVM_experimental_vector_reduce_v2_fadd : LLVM_VectorReductionV2<"fadd">;
def LLVM_experimental_vector_reduce_v2_fmul : LLVM_VectorReductionV2<"fmul">;
//
// LLVM Matrix operations.
//
/// As specified in the LLVM MatrixBuilder:
/// Create a llvm.matrix.multiply call, multiplying matrices LHS and RHS.
def LLVM_MatrixMultiplyOp
: LLVM_OneResultOp<"intr.matrix.multiply">,
Arguments<(
ins LLVM_Type:$lhs, LLVM_Type:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> {
string llvmBuilder = [{
llvm::MatrixBuilder<decltype(builder)> mb(builder);
$res = mb.CreateMatrixMultiply(
$lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(),
$rhs_rows.getZExtValue());
}];
let assemblyFormat = "$lhs `,` $rhs attr-dict "
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
}
//
// Atomic operations.
//

View File

@ -23,6 +23,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/Value.h"
namespace mlir {

View File

@ -130,6 +130,19 @@ llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %a
llvm.return
}
// CHECK-LABEL: @matrix_intrinsics
// 4x16 16x3
llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>">)
// 4x3
-> !llvm<"<12 x float>">
{
// CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3)
%C = llvm.intr.matrix.multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} :
(!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>">
llvm.return %C: !llvm<"<12 x float>">
}
// Check that intrinsics are declared with appropriate types.
// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@ -153,3 +166,4 @@ llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %a
// CHECK-DAG: declare float @llvm.cos.f32(float)
// CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
// CHECK-DAG: declare float @llvm.copysign.f32(float, float)
// CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg)