forked from OSchip/llvm-project
[mlir] Introduce an intrinsic for llvm.matrix.multiply
This revision adds the first intrinsic for llvm.matrix.multiply. This uses the more general `LLVM_OneResultOp` for now since the goal is to use the specific Matrix builders that @fhahn has created recently. When piped through: ``` opt -O3 -enable-matrix | llc -O3 -march=x86-64 -mcpu=skylake-avx512 ``` this has been verified to generate ymm instructions. Additional function attribute support will be needed to generate proper zmm instructions but at least things run end to end. Benchmarking will be provided separately with the experimental metaprogramming [ModelBuilder](https://github.com/google/iree/tree/master/experimental/ModelBuilder) tool when ready.
This commit is contained in:
parent
d0e8abc438
commit
cac1ed1f4b
|
@ -787,6 +787,27 @@ def LLVM_experimental_vector_reduce_xor : LLVM_VectorReduction<"xor">;
|
|||
def LLVM_experimental_vector_reduce_v2_fadd : LLVM_VectorReductionV2<"fadd">;
|
||||
def LLVM_experimental_vector_reduce_v2_fmul : LLVM_VectorReductionV2<"fmul">;
|
||||
|
||||
//
|
||||
// LLVM Matrix operations.
|
||||
//
|
||||
|
||||
/// As specified in the LLVM MatrixBuilder:
|
||||
/// Create a llvm.matrix.multiply call, multiplying matrices LHS and RHS.
|
||||
def LLVM_MatrixMultiplyOp
|
||||
: LLVM_OneResultOp<"intr.matrix.multiply">,
|
||||
Arguments<(
|
||||
ins LLVM_Type:$lhs, LLVM_Type:$rhs,
|
||||
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> {
|
||||
string llvmBuilder = [{
|
||||
llvm::MatrixBuilder<decltype(builder)> mb(builder);
|
||||
$res = mb.CreateMatrixMultiply(
|
||||
$lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(),
|
||||
$rhs_rows.getZExtValue());
|
||||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict "
|
||||
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic operations.
|
||||
//
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/MatrixBuilder.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -130,6 +130,19 @@ llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %a
|
|||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @matrix_intrinsics
|
||||
// 4x16 16x3
|
||||
llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>">)
|
||||
// 4x3
|
||||
-> !llvm<"<12 x float>">
|
||||
{
|
||||
// CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3)
|
||||
%C = llvm.intr.matrix.multiply %A, %B
|
||||
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} :
|
||||
(!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>">
|
||||
llvm.return %C: !llvm<"<12 x float>">
|
||||
}
|
||||
|
||||
// Check that intrinsics are declared with appropriate types.
|
||||
// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
|
||||
// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
|
||||
|
@ -153,3 +166,4 @@ llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %a
|
|||
// CHECK-DAG: declare float @llvm.cos.f32(float)
|
||||
// CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
|
||||
// CHECK-DAG: declare float @llvm.copysign.f32(float, float)
|
||||
// CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg)
|
||||
|
|
Loading…
Reference in New Issue