[mlir][linalg] Fix tiling interface implementation offset calculation

The tiling interface implementation was making assumption on the code
generated by makeTiledShape which were wrong. The ExtractSliceOp create
may be combined with other ExtractSliceOp. To solve that we compute
directly the offset using the new utilities.

Differential Revision: https://reviews.llvm.org/D132182
This commit is contained in:
Thomas Raoux 2022-08-18 22:39:29 +00:00
parent 89167e3c5b
commit 06c02d5dbb
4 changed files with 60 additions and 44 deletions

View File

@ -177,18 +177,6 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
return spec;
}
/// Given a `subsetExtractOp`, a `source` and a `dest`, create a new
/// `ParallelInsertSlice` op of `source` into `dest` at the same subset location
/// as `subsetExtractOp`.
static void
createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
tensor::ExtractSliceOp subsetExtractOp,
Value source, Value dest) {
b.create<tensor::ParallelInsertSliceOp>(
loc, source, dest, subsetExtractOp.getMixedOffsets(),
subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides());
}
/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
/// than `iterationSize`.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
@ -333,16 +321,21 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
// Create terminator with parallel subset insert operations.
b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody());
for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
destOperands)) {
createMatchingParallelSubsetInsertOp(
b, loc, cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
std::get<1>(it), std::get<2>(it));
OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
for (auto it :
llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())),
tilingInterfaceOp->getResults(), destOperands)) {
b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
tiledSizes, resultOffsets,
resultSizes)))
return op->emitOpError("output offsets couldn't be calculated");
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody());
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
std::get<2>(it), resultOffsets,
resultSizes, strides);
}
return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
}

View File

@ -161,15 +161,12 @@ struct LinalgOpTilingInterface
}));
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
Value sliceOpResult =
makeTiledShape(b, loc, outOperand->get(), sizes,
linalgOp.getTiedIndexingMap(outOperand), offsets,
/*ubs*/ {}, subShapeSizes, true);
auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
return failure();
resultOffsets = sliceOp.getMixedOffsets();
resultSizes = sliceOp.getMixedSizes();
SliceParameters sliceParams =
computeSliceParameters(b, loc, outOperand->get(), sizes,
linalgOp.getTiedIndexingMap(outOperand), offsets,
/*ubs*/ {}, subShapeSizes, true);
resultOffsets = sliceParams.offsets;
resultSizes = sliceParams.sizes;
return success();
}

View File

@ -59,7 +59,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
// CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
// CHECK: scf.yield %[[RESPARTIAL]]
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][%[[I1]], 0] [2, 16] [1, 1]
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
// CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
// CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
// CHECK-COUNT-2: tensor.extract_slice

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s
// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s
// Offset per thread:
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
@ -22,7 +22,7 @@ module {
// CHECK: %[[RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[tC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
// CHECK: scf.foreach_thread.perform_concurrently {
// CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} :
// CHECK-SAME: tensor<?x?xf32> into tensor<?x?xf32>
// CHECK-NEXT: }
@ -65,11 +65,9 @@ func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: t
// CHECK-NOT: affine.max
// CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
// CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]])
// CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
// CHECK: linalg.matmul
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK-NEXT: tensor.parallel_insert_slice
@ -106,8 +104,6 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
// CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]]
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
// CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
@ -115,8 +111,6 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
// CHECK tensor.extract_slice %[[A]]
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
// CHECK tensor.extract_slice %[[B]]
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
// CHECK tensor.extract_slice %[[C]]
// CHECK: linalg.matmul
// CHECK: scf.foreach_thread.perform_concurrently
@ -156,11 +150,9 @@ func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf
// CHECK-NOT: affine.min
// CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
// CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]])
// CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
// CHECK: linalg.matmul
// CHECK: scf.foreach_thread.perform_concurrently
// CHECK-NEXT: tensor.parallel_insert_slice
@ -177,3 +169,37 @@ transform.with_pdl_patterns {
%1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21]
}
}
// -----
module {
func.func @extract_source(%A: tensor<4xf32>, %B: tensor<16xf32>) -> tensor<4xf32> {
%B1 = tensor.extract_slice %B[10] [4] [1] : tensor<16xf32> to tensor<4xf32>
%result = linalg.generic {indexing_maps = [
affine_map<(d0) -> (d0)>,affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%A : tensor<4xf32>) outs(%B1 : tensor<4xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%2 = arith.addf %arg3, %arg3 : f32
linalg.yield %2 : f32
} -> tensor<4xf32>
return %result : tensor<4xf32>
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0])
}
}
}
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-LABEL: extract_source(
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) -> (tensor<4xf32>) {
// CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
// CHECK: scf.foreach_thread.perform_concurrently {
// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>