forked from OSchip/llvm-project
[MLIR][Linalg] introduce batch-reduce GEMM
The batch-reduce GEMM kernel essentially multiplies a sequence of input tensor blocks (which form a batch) and the partial multiplication results are reduced into a single output tensor block. See: https://ieeexplore.ieee.org/document/9139809 for more details. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D134163
This commit is contained in:
parent
21a9abc1ce
commit
3718082e2b
|
@ -653,6 +653,76 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_arg: BZp
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: batch_reduce_matmul
|
||||
cpp_class_name: BatchReduceMatmulOp
|
||||
doc: |-
|
||||
Performs a batch-reduce matrix multiplication of two 3D inputs.
|
||||
The partial multiplication results are reduced into a 2D output.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output."
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: A
|
||||
kind: input_tensor
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: B
|
||||
kind: input_tensor
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: C
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
|
||||
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
|
||||
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
|
||||
iterator_types:
|
||||
- reduction
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: matvec
|
||||
cpp_class_name: MatvecOp
|
||||
|
|
|
@ -150,6 +150,20 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K),
|
|||
TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(
|
||||
U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp))
|
||||
|
||||
@linalg_structured_op
|
||||
def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K),
|
||||
B=TensorDef(T2, Batch, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
"""Performs a batch-reduce matrix multiplication of two 3D inputs.
|
||||
The partial multiplication results are reduced into a 2D output.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.b, D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed(
|
||||
U, B[D.b, D.k, D.n]))
|
||||
|
||||
@linalg_structured_op
|
||||
def matvec(A=TensorDef(T1, S.M, S.N),
|
||||
|
|
|
@ -248,3 +248,27 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
|
|||
// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
|
||||
// CHECK: linalg.yield %[[ADD]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
|
||||
linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
|
||||
outs(%out: memref<8x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
|
||||
|
||||
// CHECK: @batch_reduce_gemm
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xf32>, memref<7x9x8xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32>
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
|
||||
// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32
|
||||
// CHECK: linalg.yield %[[ADD]] : f32
|
||||
|
|
|
@ -794,3 +794,23 @@ func.func @conv_interface_wrong_num_operands(
|
|||
}) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = array<i32: 2, 1>, strides = dense<1> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @batch_reduce_matmul(%arg0: tensor<8x128x256xf32>, %arg1: tensor<8x256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> {
|
||||
// CHECK: %{{.+}} = linalg.batch_reduce_matmul
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<8x128x256xf32>, tensor<8x256x512xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<128x512xf32>) -> tensor<128x512xf32>
|
||||
%0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<8x128x256xf32>, tensor<8x256x512xf32>) outs(%arg2: tensor<128x512xf32>) -> tensor<128x512xf32>
|
||||
return %0: tensor<128x512xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
// CHECK: linalg.batch_reduce_matmul
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
|
||||
linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?xf32>)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue