diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index cffc20caa3c2..6e7f644d3aeb 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1324,15 +1324,14 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( VectorType lhsType = op.getLhsType(); Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); - // Set up the parallel/reduction structure in right form. - AffineExpr m, n, k; - bindDims(rewriter.getContext(), m, n, k); - // // Two outer parallel, one inner reduction (matmat flavor). // UnrolledOuterProductEmitter e(rewriter, op); if (e.iters({Par(), Par(), Red()})) { + // Set up the parallel/reduction structure in right form. + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (e.layout({{m, k}, {k, n}, {m, n}})) return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); @@ -1367,17 +1366,42 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( // One outer parallel, one inner reduction (matvec flavor) // if (e.iters({Par(), Red()})) { + AffineExpr m, k; + bindDims(rewriter.getContext(), m, k); + // Case mat-vec: transpose. - if (e.layout({{m, n}, {n}, {m}})) + if (e.layout({{m, k}, {k}, {m}})) return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); // Case mat-trans-vec: ready to go. - if (e.layout({{n, m}, {n}, {m}})) + if (e.layout({{k, m}, {k}, {m}})) return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); // Case vec-mat: swap and transpose. - if (e.layout({{n}, {m, n}, {m}})) + if (e.layout({{k}, {m, k}, {m}})) return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); // Case vec-mat-trans: swap and ready to go. - if (e.layout({{n}, {n, m}, {m}})) + if (e.layout({{k}, {k, m}, {m}})) + return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + return failure(); + } + + // + // One outer reduction, one inner parallel (tmatvec flavor) + // + if (e.iters({Red(), Par()})) { + AffineExpr k, m; + bindDims(rewriter.getContext(), k, m); + + // Case mat-vec: transpose. + if (e.layout({{m, k}, {k}, {m}})) + return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); + // Case mat-trans-vec: ready to go. + if (e.layout({{k, m}, {k}, {m}})) + return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + // Case vec-mat: swap and transpose. + if (e.layout({{k}, {m, k}, {m}})) + return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); + // Case vec-mat-trans: swap and ready to go. + if (e.layout({{k}, {k, m}, {m}})) return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); return failure(); } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir index 6d84f64d37d8..9c163453ee8b 100644 --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -45,6 +45,16 @@ iterator_types = ["parallel", "reduction"] } +#redpar_vecmattrans_accesses = [ + affine_map<(i, j) -> (i)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)> +] +#redpar_vecmattrans_trait = { + indexing_maps = #redpar_vecmattrans_accesses, + iterator_types = ["reduction", "parallel"] +} + // CHECK-LABEL: func @matvec2x2 // CHECK-SAME: %[[A:.*0]]: memref> // CHECK-SAME: %[[B:.*1]]: memref> @@ -172,3 +182,28 @@ func @vecmattrans2x2(%arg0: memref>, %arg1: memref memref.store %0, %arg2[] : memref> return } + +// CHECK-LABEL: func @redpar_vecmattrans2x2 +// CHECK-SAME: %[[A:.*0]]: memref> +// CHECK-SAME: %[[B:.*1]]: memref> +// CHECK-SAME: %[[C:.*2]]: memref> +// CHECK: %[[T0:.*]] = memref.load %[[A]][] : memref> +// CHECK: %[[T1:.*]] = memref.load %[[B]][] : memref> +// CHECK: %[[T2:.*]] = memref.load %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind} : vector<2xf32>, f32 +// CHECK: memref.store %[[T8]], %[[C]][] : memref> +// CHECK: return +func @redpar_vecmattrans2x2(%arg0: memref>, %arg1: memref>, + %arg2: memref>) { + %A = memref.load %arg0[] : memref> + %x = memref.load %arg1[] : memref> + %b = memref.load %arg2[] : memref> + %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> + memref.store %0, %arg2[] : memref> + return +}