[mlir][Linalg] Add support for tileFuseAndDistribute on tensors.

This extends TileAndFuse to handle distribution on tensors.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D120441
This commit is contained in:
Hanhan Wang 2022-02-25 10:52:08 -08:00
parent 5e33bd804b
commit 748bf4bb28
6 changed files with 179 additions and 78 deletions

View File

@ -638,8 +638,20 @@ struct LinalgPaddingOptions {
struct LinalgTilingAndFusionOptions {
/// Tile sizes used to tile the root operation.
SmallVector<int64_t> tileSizes;
LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) {
tileSizes.assign(ts.begin(), ts.end());
return *this;
}
/// Tile interchange used to permute the tile loops.
SmallVector<int64_t> tileInterchange;
/// When specified, specifies distribution of generated tile loops to
/// processors.
Optional<LinalgLoopDistributionOptions> tileDistribution = None;
LinalgTilingAndFusionOptions &
setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
tileDistribution = std::move(distributionOptions);
return *this;
}
};
struct LinalgTilingOptions {

View File

@ -245,73 +245,6 @@ FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
OpResult producerOpResult,
OpOperand &consumerOpOperand);
//===----------------------------------------------------------------------===//
// Fusion on tensor utilities
//===----------------------------------------------------------------------===//
/// A struct to manage the tile loop nest specific information.
class TileLoopNest {
public:
TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
/// Tile the root operation using the given `tileSizes` and `tileInterchange`.
LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange);
/// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
/// the fused producer or fails if fusion is not possible.
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
/// Returns the replacement results for the original untiled root operation.
ValueRange getRootOpReplacementResults();
/// Returns the tiled root operation.
LinalgOp getRootOp() { return rootOp; }
/// Returns the tiled root operation and the fused producers.
SmallVector<LinalgOp> getAllTiledAndFusedOps();
/// Returns the loop ops generated from tiling.
ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
/// Returns true if the tile loop nest has no tile loops.
bool isEmpty();
private:
/// Returns true if the tile loop nest invariants are satisfied:
/// - The `rootOp` has been tiled at least once.
/// - The number of tile loop operations and dimensions match.
/// - The innermost tile loop is the parent of `tiledOp`.
/// - The tile loops are directly nested.
// TODO: relax to support additional control flow, e.g., IfOp.
bool isValid();
/// Searches the block arguments tied to a block argument `bbArg` of the
/// innermost tile loop. Returns the block argument from outermost to
/// innermost or an empty vector if none are found.
SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
/// Returns the iteration argument of the outermost tile loop mapped to a
/// block argument `bbArg` of the innermost tile loop.
OpOperand *getTiedIterArg(BlockArgument bbArg);
/// Returns true if `bbArg` has other used than `sliceOp` and its
/// dependencies. Only if there are no other uses, the producer output
/// iteration argument may reused to pass the producer result after fusion.
bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
LinalgOp rootOp;
SmallVector<scf::ForOp> tileLoopOps;
DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
};
/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
/// `tileSizes` and `tileInterchange` parameters to control the tiling.
FailureOr<TileLoopNest>
tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange);
//===----------------------------------------------------------------------===//
// Distribution utilities
//===----------------------------------------------------------------------===//
@ -396,6 +329,77 @@ void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc,
Value procId, Value nprocs, Value &lb,
Value &ub, Value &step);
//===----------------------------------------------------------------------===//
// Fusion on tensor utilities
//===----------------------------------------------------------------------===//
/// A struct to manage the tile loop nest specific information.
class TileLoopNest {
public:
TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
/// Tile the root operation using the given `tileSizes` and `tileInterchange`,
/// and `tileDistribution`.
LogicalResult
tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
Optional<LinalgLoopDistributionOptions> tileDistribution);
/// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
/// the fused producer or fails if fusion is not possible.
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
/// Returns the replacement results for the original untiled root operation.
ValueRange getRootOpReplacementResults();
/// Returns the tiled root operation.
LinalgOp getRootOp() { return rootOp; }
/// Returns the tiled root operation and the fused producers.
SmallVector<LinalgOp> getAllTiledAndFusedOps();
/// Returns the loop ops generated from tiling.
ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
/// Returns true if the tile loop nest has no tile loops.
bool isEmpty();
private:
/// Returns true if the tile loop nest invariants are satisfied:
/// - The `rootOp` has been tiled at least once.
/// - The number of tile loop operations and dimensions match.
/// - The innermost tile loop is the parent of `tiledOp`.
/// - The tile loops are directly nested.
// TODO: relax to support additional control flow, e.g., IfOp.
bool isValid();
/// Searches the block arguments tied to a block argument `bbArg` of the
/// innermost tile loop. Returns the block argument from outermost to
/// innermost or an empty vector if none are found.
SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
/// Returns the iteration argument of the outermost tile loop mapped to a
/// block argument `bbArg` of the innermost tile loop.
OpOperand *getTiedIterArg(BlockArgument bbArg);
/// Returns true if `bbArg` has other used than `sliceOp` and its
/// dependencies. Only if there are no other uses, the producer output
/// iteration argument may reused to pass the producer result after fusion.
bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
LinalgOp rootOp;
SmallVector<scf::ForOp> tileLoopOps;
DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
};
/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control
/// the tiling.
FailureOr<TileLoopNest> tileConsumerAndFuseProducers(
OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
Optional<LinalgLoopDistributionOptions> tileDistribution);
//===----------------------------------------------------------------------===//
// Generic op region utilities
//===----------------------------------------------------------------------===//

View File

@ -269,9 +269,10 @@ bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
});
}
LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange) {
LogicalResult TileLoopNest::tileRootOp(
OpBuilder &b, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
Optional<LinalgLoopDistributionOptions> tileDistribution) {
// Exit if all tile sizes are zero.
if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
return success();
@ -283,6 +284,9 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
tileInterchange.begin(), tileInterchange.end()))
.setTileSizes(tileSizes)
.setLoopType(LinalgTilingLoopType::Loops);
if (tileDistribution)
tilingOptions =
tilingOptions.setDistributionOptions(tileDistribution.getValue());
// TODO: Propagate RewriterBase everywhere.
IRRewriter rewriter(b);
@ -408,10 +412,10 @@ SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//
FailureOr<TileLoopNest>
mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange) {
FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
Optional<LinalgLoopDistributionOptions> tileDistribution) {
assert(tileSizes.size() == tileInterchange.size() &&
"expect the number of tile sizes and interchange dims to match");
assert(isPermutation(tileInterchange) &&
@ -446,7 +450,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
SmallVector<int64_t> outerTileSizes;
outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
outerTileSizes.append(tileSizes.size() - split, 0);
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
tileDistribution)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
@ -454,7 +459,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
SmallVector<int64_t> innerTileSizes;
innerTileSizes.append(split, 0);
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
tileDistribution)))
return failure();
fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());

View File

@ -613,8 +613,9 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
op, "expect the tile interchange permutes the root loops");
// Tile `rootOp` and fuse its producers.
FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
rewriter, rootOp, rootTileSizes, rootInterchange);
FailureOr<TileLoopNest> tileLoopNest =
tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
rootInterchange, options.tileDistribution);
if (failed(tileLoopNest))
return rewriter.notifyMatchFailure(
op, "tileConsumerAndFuseProducers failed unexpectedly");

View File

@ -0,0 +1,53 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s
// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: func @fill_matmul_tensors(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
func @fill_matmul_tensors(
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
-> tensor<?x?xf32> {
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
// CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
// CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
// CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
// CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor<?x?xf32>) {
// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
// CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
// CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TC1]]
// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[SLICE]])
// CHECK: %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor<?x?xf32>) {
// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
// CHECK: %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
%3 = linalg.fill(%cst, %2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"}
ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%3: tensor<?x?xf32>)
-> tensor<?x?xf32>
// CHECK: return %[[TD0]] : tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}

View File

@ -78,6 +78,10 @@ struct TestLinalgTransforms
*this, "test-tile-and-distribute-options",
llvm::cl::desc("Test tile and distribute options"),
llvm::cl::init(false)};
Option<bool> testTileFuseAndDistributionOptions{
*this, "test-tile-fuse-and-distribute-options",
llvm::cl::desc("Test tile, fuse and distribute options"),
llvm::cl::init(false)};
Option<bool> testVectorTransferForwardingPatterns{
*this, "test-vector-transfer-forwarding-patterns",
llvm::cl::desc(
@ -505,6 +509,21 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
}
}
static void fillTileFuseAndDistributePatterns(MLIRContext *context,
RewritePatternSet &patterns) {
LinalgLoopDistributionOptions cyclicNprocsEqNiters;
cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.add<LinalgTileAndFuseTensorOpsPattern>(
MatmulOp::getOperationName(), context,
LinalgTilingAndFusionOptions()
.setTileSizes({8, 8, 4})
.setDistributionOptions(cyclicNprocsEqNiters),
LinalgTransformationFilter(
StringAttr::get(context, "tensors_fuse_distribute1"),
StringAttr::get(context, "tensors_after_fuse_distribute1")));
}
static void
applyMatmulToVectorPatterns(FuncOp funcOp,
bool testMatmulToVectorPatterns1dTiling,
@ -698,6 +717,12 @@ void TestLinalgTransforms::runOnOperation() {
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
if (testTileFuseAndDistributionOptions) {
RewritePatternSet patterns(&getContext());
fillTileFuseAndDistributePatterns(&getContext(), patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
if (testPatterns)
return applyPatterns(getOperation());
if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)