[mlir][linalg] Replace monomorphic contration ops with polymorphic variants.

* Moves `batch_matmul`, `matmul`, `matvec`, `vectmat`, `dot` to the new mechanism.
* This is not just an NFC change, in addition to using a new code generation mechanism, it also activates symbolic casting, allowing mixed precision operands and results.
* These definitions were generated from DSL by the tool: https://github.com/stellaraccident/mlir-linalgpy/blob/main/mlir_linalg/oplib/core.py (will be upstreamed in a subsequent set of changes).

Reviewed By: nicolasvasilache, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D97719
This commit is contained in:
Stella Laurenzo 2021-03-01 21:19:39 -08:00
parent d36a15de1f
commit 6d2fd3d9cd
3 changed files with 259 additions and 44 deletions

View File

@ -1,12 +1,12 @@
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: polymorphic_matmul
cpp_op_name: PolymorphicMatmulOp
name: matmul
cpp_op_name: MatmulOp
doc: |-
Type polymorphic matrix multiplication.
Performs a matrix multiplacation of two 2D inputs.
This op is presently here to test a new path for generation and will replace
the existing 'matmul' op when ready. Do not use.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
@ -60,4 +60,249 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
cpp_op_name: BatchMatmulOp
doc: |-
Performs a batched matrix multiplacation of two 3D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !<LinalgTensorDef>
name: A
usage: input
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
element_type_var: T1
- !<LinalgTensorDef>
name: B
usage: input
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
element_type_var: T2
- !<LinalgTensorDef>
name: C
usage: output
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
element_type_var: U
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
iterator_types:
- parallel
- parallel
- parallel
- reduction
assignments:
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matvec
cpp_op_name: MatvecOp
doc: |-
Performs a matrix-vector multiplication.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !<LinalgTensorDef>
name: A
usage: input
shape: affine_map<()[s0, s1] -> (s0, s1)>
element_type_var: T1
- !<LinalgTensorDef>
name: y
usage: input
shape: affine_map<()[s0, s1] -> (s1)>
element_type_var: T2
- !<LinalgTensorDef>
name: x
usage: output
shape: affine_map<()[s0, s1] -> (s0)>
element_type_var: U
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
- affine_map<(d0, d1)[s0, s1] -> (d1)>
- affine_map<(d0, d1)[s0, s1] -> (d0)>
iterator_types:
- parallel
- reduction
assignments:
- !ScalarAssign
arg: x
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: y
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: vecmat
cpp_op_name: VecmatOp
doc: |-
Performs a vector-matrix multiplacation.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !<LinalgTensorDef>
name: y
usage: input
shape: affine_map<()[s0, s1] -> (s1)>
element_type_var: T1
- !<LinalgTensorDef>
name: A
usage: input
shape: affine_map<()[s0, s1] -> (s1, s0)>
element_type_var: T2
- !<LinalgTensorDef>
name: x
usage: output
shape: affine_map<()[s0, s1] -> (s0)>
element_type_var: U
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1] -> (d1)>
- affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
- affine_map<(d0, d1)[s0, s1] -> (d0)>
iterator_types:
- parallel
- reduction
assignments:
- !ScalarAssign
arg: x
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: x
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: y
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: dot
cpp_op_name: DotOp
doc: |-
Performs a dot product of two vectors to a scalar result.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !<LinalgTensorDef>
name: A
usage: input
shape: affine_map<()[s0] -> (s0)>
element_type_var: T1
- !<LinalgTensorDef>
name: B
usage: input
shape: affine_map<()[s0] -> (s0)>
element_type_var: T2
- !<LinalgTensorDef>
name: C
usage: output
shape: affine_map<()[s0] -> ()>
element_type_var: U
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0)[s0] -> (d0)>
- affine_map<(d0)[s0] -> (d0)>
- affine_map<(d0)[s0] -> ()>
iterator_types:
- reduction
assignments:
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: B

View File

@ -1,9 +1,3 @@
ods_def<MatmulOp>
implements_interface<LinalgContractionOpInterface> :
def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
}
ods_def<MatmulColumnMajorOp>
implements_interface<LinalgContractionOpInterface> :
def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
@ -30,12 +24,6 @@ def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
}
ods_def<MatvecOp>
implements_interface<LinalgContractionOpInterface> :
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
}
ods_def<MatvecI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
@ -54,12 +42,6 @@ def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
}
ods_def<VecmatOp>
implements_interface<LinalgContractionOpInterface> :
def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
}
ods_def<VecmatI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
@ -78,12 +60,6 @@ def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
}
ods_def<DotOp>
implements_interface<LinalgContractionOpInterface> :
def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
}
ods_def<DotI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
@ -102,12 +78,6 @@ def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
C() = std_addi<m>(C(), std_muli(A(m), B(m)));
}
ods_def<BatchMatmulOp>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
}
ods_def<BatchMatmulI8I8I32Op>
implements_interface<LinalgContractionOpInterface> :
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
@ -16,7 +16,7 @@ func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>,
// -----
func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
@ -31,7 +31,7 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
// -----
// Verifies floating point to integer cast.
func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
return %0: tensor<16x32xi16>
}
@ -48,7 +48,7 @@ func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x3
// -----
// Verifies sign extension cast.
func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
@ -65,7 +65,7 @@ func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
// -----
// Verifies that different argument types is legal.
func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
return %0: tensor<16x32xi32>
}
@ -82,7 +82,7 @@ func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32x
// -----
// Somewhat non-sensical but checks integer truncation cast.
func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
return %0: tensor<16x32xi16>
}
@ -99,7 +99,7 @@ func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x3
// -----
// Verifies integer to floating point cast.
func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
@ -116,7 +116,7 @@ func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi
// -----
// Verifies floating point extension cast.
func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}
@ -133,7 +133,7 @@ func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x3
// -----
// Verifies floating point truncation.
func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
%0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
}