forked from OSchip/llvm-project
[mlir][Linalg] Relax vectorization condition to allow transposed output.
Reviewed By: ThomasRaoux, dcaballe Differential Revision: https://reviews.llvm.org/D126454
This commit is contained in:
parent
52992f136b
commit
5aefdafccf
|
@ -441,7 +441,7 @@ static bool isElementwise(Operation *op) {
|
|||
return false;
|
||||
// TODO: relax the restrictions on indexing map.
|
||||
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
|
||||
if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
|
||||
if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
|
||||
return false;
|
||||
}
|
||||
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
|
||||
|
|
|
@ -121,6 +121,26 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
// CHECK: func @generic_interchanged_transpose
|
||||
func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tensor<128x12x32xf32> {
|
||||
// CHECK: %[[IN:.+]] = vector.transfer_read
|
||||
// CHECK: vector.transfer_write %[[IN]], {{.+}} permutation_map = #[[MAP]]
|
||||
%0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32>
|
||||
%1 = linalg.generic {indexing_maps = [#map0, #map1],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<12x128x32xf32>)
|
||||
outs(%0 : tensor<128x12x32xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
linalg.yield %arg1 : f32
|
||||
} -> tensor<128x12x32xf32>
|
||||
return %1 : tensor<128x12x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
|
|
Loading…
Reference in New Issue