[mlir][Linalg] Relax vectorization condition to allow transposed output.

Reviewed By: ThomasRaoux, dcaballe

Differential Revision: https://reviews.llvm.org/D126454
This commit is contained in:
Hanhan Wang 2022-05-26 19:20:36 -07:00
parent 52992f136b
commit 5aefdafccf
2 changed files with 21 additions and 1 deletions

View File

@ -441,7 +441,7 @@ static bool isElementwise(Operation *op) {
return false;
// TODO: relax the restrictions on indexing map.
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
return false;
}
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));

View File

@ -121,6 +121,26 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
// CHECK: func @generic_interchanged_transpose
func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tensor<128x12x32xf32> {
// CHECK: %[[IN:.+]] = vector.transfer_read
// CHECK: vector.transfer_write %[[IN]], {{.+}} permutation_map = #[[MAP]]
%0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32>
%1 = linalg.generic {indexing_maps = [#map0, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<12x128x32xf32>)
outs(%0 : tensor<128x12x32xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<128x12x32xf32>
return %1 : tensor<128x12x32xf32>
}
// -----
#matmul_trait = {
args_in = 2,
args_out = 1,