forked from OSchip/llvm-project
[mlir][sparse] added linalg.dot to sparse kernel collection
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D121315
This commit is contained in:
parent
fc9e07873f
commit
52fb4f53c2
|
@ -3,6 +3,8 @@
|
|||
// RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
|
||||
// RUN: --sparsification | FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
|
||||
|
||||
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
|
||||
|
||||
// CHECK-LABEL: func @matmul1(
|
||||
|
@ -255,3 +257,64 @@ func @quantized_matmul(%input1: tensor<5x3xi8>,
|
|||
outs(%output : tensor<5x6xi64>) -> tensor<5x6xi64>
|
||||
return %0: tensor<5x6xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_dot(
|
||||
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0:.*]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
|
||||
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1:.*]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
|
||||
// CHECK-DAG: %[[VAL_11:.*]] = memref.alloc() : memref<f32>
|
||||
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2:.*]] : memref<f32>
|
||||
// CHECK-DAG: memref.copy %[[VAL_12]], %[[VAL_11]] : memref<f32> to memref<f32>
|
||||
// CHECK-DAG: %[[VAL_13:.*]] = memref.load %[[VAL_11]][] : memref<f32>
|
||||
// CHECK-DAG: %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK-DAG: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_18:.*]]:3 = scf.while (%[[VAL_19:.*]] = %[[VAL_14]], %[[VAL_20:.*]] = %[[VAL_16]], %[[VAL_21:.*]] = %[[VAL_13]]) : (index, index, f32) -> (index, index, f32) {
|
||||
// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_19]], %[[VAL_15]] : index
|
||||
// CHECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_17]] : index
|
||||
// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_22]], %[[VAL_23]] : i1
|
||||
// CHECK: scf.condition(%[[VAL_24]]) %[[VAL_19]], %[[VAL_20]], %[[VAL_21]] : index, index, f32
|
||||
// CHECK: } do {
|
||||
// CHECK: ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index, %[[VAL_27:.*]]: f32):
|
||||
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_25]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
|
||||
// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
|
||||
// CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
|
||||
// CHECK: %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
|
||||
// CHECK: %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1
|
||||
// CHECK: %[[VAL_35:.*]] = scf.if %[[VAL_34]] -> (f32) {
|
||||
// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref<?xf32>
|
||||
// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xf32>
|
||||
// CHECK: %[[VAL_38:.*]] = arith.mulf %[[VAL_36]], %[[VAL_37]] : f32
|
||||
// CHECK: %[[VAL_39:.*]] = arith.addf %[[VAL_27]], %[[VAL_38]] : f32
|
||||
// CHECK: scf.yield %[[VAL_39]] : f32
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.yield %[[VAL_27]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
|
||||
// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_25]], %[[VAL_4]] : index
|
||||
// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_40]], %[[VAL_41]], %[[VAL_25]] : index
|
||||
// CHECK: %[[VAL_43:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
|
||||
// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_26]], %[[VAL_4]] : index
|
||||
// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_43]], %[[VAL_44]], %[[VAL_26]] : index
|
||||
// CHECK: scf.yield %[[VAL_42]], %[[VAL_45]], %[[VAL_46:.*]] : index, index, f32
|
||||
// CHECK: }
|
||||
// CHECK: memref.store %[[VAL_47:.*]]#2, %[[VAL_11]][] : memref<f32>
|
||||
// CHECK: %[[VAL_48:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<f32>
|
||||
// CHECK: return %[[VAL_48]] : tensor<f32>
|
||||
// CHECK: }
|
||||
func @sparse_dot(%a: tensor<1024xf32, #SparseVector>,
|
||||
%b: tensor<1024xf32, #SparseVector>,
|
||||
%x: tensor<f32>) -> tensor<f32> {
|
||||
%dot = linalg.dot ins(%a, %b: tensor<1024xf32, #SparseVector>,
|
||||
tensor<1024xf32, #SparseVector>)
|
||||
outs(%x: tensor<f32>) -> tensor<f32>
|
||||
return %dot : tensor<f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue