forked from OSchip/llvm-project
[mlir][linalg] Replace monomorphic contration ops with polymorphic variants.
* Moves `batch_matmul`, `matmul`, `matvec`, `vectmat`, `dot` to the new mechanism. * This is not just an NFC change, in addition to using a new code generation mechanism, it also activates symbolic casting, allowing mixed precision operands and results. * These definitions were generated from DSL by the tool: https://github.com/stellaraccident/mlir-linalgpy/blob/main/mlir_linalg/oplib/core.py (will be upstreamed in a subsequent set of changes). Reviewed By: nicolasvasilache, ThomasRaoux Differential Revision: https://reviews.llvm.org/D97719
This commit is contained in:
parent
d36a15de1f
commit
6d2fd3d9cd
|
@ -1,12 +1,12 @@
|
|||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: polymorphic_matmul
|
||||
cpp_op_name: PolymorphicMatmulOp
|
||||
name: matmul
|
||||
cpp_op_name: MatmulOp
|
||||
doc: |-
|
||||
Type polymorphic matrix multiplication.
|
||||
Performs a matrix multiplacation of two 2D inputs.
|
||||
|
||||
This op is presently here to test a new path for generation and will replace
|
||||
the existing 'matmul' op when ready. Do not use.
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
|
@ -60,4 +60,249 @@ structured_op: !LinalgStructuredOpConfig
|
|||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: batch_matmul
|
||||
cpp_op_name: BatchMatmulOp
|
||||
doc: |-
|
||||
Performs a batched matrix multiplacation of two 3D inputs.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
name: B
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
name: C
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
|
||||
element_type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
|
||||
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
|
||||
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: matvec
|
||||
cpp_op_name: MatvecOp
|
||||
doc: |-
|
||||
Performs a matrix-vector multiplication.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s0, s1)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
name: y
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s1)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
name: x
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1] -> (s0)>
|
||||
element_type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d1)>
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: x
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: x
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: y
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: vecmat
|
||||
cpp_op_name: VecmatOp
|
||||
doc: |-
|
||||
Performs a vector-matrix multiplacation.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
name: y
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s1)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0, s1] -> (s1, s0)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
name: x
|
||||
usage: output
|
||||
shape: affine_map<()[s0, s1] -> (s0)>
|
||||
element_type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d1)>
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
|
||||
- affine_map<(d0, d1)[s0, s1] -> (d0)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: x
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: x
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: y
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: dot
|
||||
cpp_op_name: DotOp
|
||||
doc: |-
|
||||
Performs a dot product of two vectors to a scalar result.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !<LinalgTensorDef>
|
||||
name: A
|
||||
usage: input
|
||||
shape: affine_map<()[s0] -> (s0)>
|
||||
element_type_var: T1
|
||||
- !<LinalgTensorDef>
|
||||
name: B
|
||||
usage: input
|
||||
shape: affine_map<()[s0] -> (s0)>
|
||||
element_type_var: T2
|
||||
- !<LinalgTensorDef>
|
||||
name: C
|
||||
usage: output
|
||||
shape: affine_map<()[s0] -> ()>
|
||||
element_type_var: U
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0)[s0] -> (d0)>
|
||||
- affine_map<(d0)[s0] -> (d0)>
|
||||
- affine_map<(d0)[s0] -> ()>
|
||||
iterator_types:
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: C
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: C
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: A
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
|
||||
|
|
|
@ -1,9 +1,3 @@
|
|||
ods_def<MatmulOp>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
|
||||
C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
|
||||
}
|
||||
|
||||
ods_def<MatmulColumnMajorOp>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
|
||||
|
@ -30,12 +24,6 @@ def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
|
|||
C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
|
||||
}
|
||||
|
||||
ods_def<MatvecOp>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
|
||||
x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
|
||||
}
|
||||
|
||||
ods_def<MatvecI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
|
||||
|
@ -54,12 +42,6 @@ def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
|
|||
x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
|
||||
}
|
||||
|
||||
ods_def<VecmatOp>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
|
||||
x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
|
||||
}
|
||||
|
||||
ods_def<VecmatI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
|
||||
|
@ -78,12 +60,6 @@ def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
|
|||
x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
|
||||
}
|
||||
|
||||
ods_def<DotOp>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
|
||||
C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
|
||||
}
|
||||
|
||||
ods_def<DotI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
|
||||
|
@ -102,12 +78,6 @@ def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
|
|||
C() = std_addi<m>(C(), std_muli(A(m), B(m)));
|
||||
}
|
||||
|
||||
ods_def<BatchMatmulOp>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
|
||||
C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
|
||||
}
|
||||
|
||||
ods_def<BatchMatmulI8I8I32Op>
|
||||
implements_interface<LinalgContractionOpInterface> :
|
||||
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
|
||||
|
||||
func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
|
||||
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
return %0: tensor<16x32xf32>
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>,
|
|||
// -----
|
||||
|
||||
func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
|
||||
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
|
||||
return %0: tensor<16x32xi32>
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
|
|||
// -----
|
||||
// Verifies floating point to integer cast.
|
||||
func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
|
||||
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
|
||||
return %0: tensor<16x32xi16>
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x3
|
|||
// -----
|
||||
// Verifies sign extension cast.
|
||||
func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
|
||||
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
|
||||
return %0: tensor<16x32xi32>
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
|
|||
// -----
|
||||
// Verifies that different argument types is legal.
|
||||
func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
|
||||
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
|
||||
return %0: tensor<16x32xi32>
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32x
|
|||
// -----
|
||||
// Somewhat non-sensical but checks integer truncation cast.
|
||||
func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
|
||||
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
|
||||
return %0: tensor<16x32xi16>
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x3
|
|||
// -----
|
||||
// Verifies integer to floating point cast.
|
||||
func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
|
||||
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
return %0: tensor<16x32xf32>
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
|
|||
// -----
|
||||
// Verifies floating point extension cast.
|
||||
func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
|
||||
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
return %0: tensor<16x32xf32>
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x3
|
|||
// -----
|
||||
// Verifies floating point truncation.
|
||||
func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
||||
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
|
||||
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
return %0: tensor<16x32xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue