forked from OSchip/llvm-project
[mlir][Linalg] Add support for tileFuseAndDistribute on tensors.
This extends TileAndFuse to handle distribution on tensors. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D120441
This commit is contained in:
parent
5e33bd804b
commit
748bf4bb28
|
@ -638,8 +638,20 @@ struct LinalgPaddingOptions {
|
|||
struct LinalgTilingAndFusionOptions {
|
||||
/// Tile sizes used to tile the root operation.
|
||||
SmallVector<int64_t> tileSizes;
|
||||
LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) {
|
||||
tileSizes.assign(ts.begin(), ts.end());
|
||||
return *this;
|
||||
}
|
||||
/// Tile interchange used to permute the tile loops.
|
||||
SmallVector<int64_t> tileInterchange;
|
||||
/// When specified, specifies distribution of generated tile loops to
|
||||
/// processors.
|
||||
Optional<LinalgLoopDistributionOptions> tileDistribution = None;
|
||||
LinalgTilingAndFusionOptions &
|
||||
setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
|
||||
tileDistribution = std::move(distributionOptions);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
struct LinalgTilingOptions {
|
||||
|
|
|
@ -245,73 +245,6 @@ FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
|
|||
OpResult producerOpResult,
|
||||
OpOperand &consumerOpOperand);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fusion on tensor utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A struct to manage the tile loop nest specific information.
|
||||
class TileLoopNest {
|
||||
public:
|
||||
TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
|
||||
|
||||
/// Tile the root operation using the given `tileSizes` and `tileInterchange`.
|
||||
LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange);
|
||||
|
||||
/// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
|
||||
/// the fused producer or fails if fusion is not possible.
|
||||
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
|
||||
|
||||
/// Returns the replacement results for the original untiled root operation.
|
||||
ValueRange getRootOpReplacementResults();
|
||||
|
||||
/// Returns the tiled root operation.
|
||||
LinalgOp getRootOp() { return rootOp; }
|
||||
|
||||
/// Returns the tiled root operation and the fused producers.
|
||||
SmallVector<LinalgOp> getAllTiledAndFusedOps();
|
||||
|
||||
/// Returns the loop ops generated from tiling.
|
||||
ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
|
||||
|
||||
/// Returns true if the tile loop nest has no tile loops.
|
||||
bool isEmpty();
|
||||
|
||||
private:
|
||||
/// Returns true if the tile loop nest invariants are satisfied:
|
||||
/// - The `rootOp` has been tiled at least once.
|
||||
/// - The number of tile loop operations and dimensions match.
|
||||
/// - The innermost tile loop is the parent of `tiledOp`.
|
||||
/// - The tile loops are directly nested.
|
||||
// TODO: relax to support additional control flow, e.g., IfOp.
|
||||
bool isValid();
|
||||
|
||||
/// Searches the block arguments tied to a block argument `bbArg` of the
|
||||
/// innermost tile loop. Returns the block argument from outermost to
|
||||
/// innermost or an empty vector if none are found.
|
||||
SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
|
||||
|
||||
/// Returns the iteration argument of the outermost tile loop mapped to a
|
||||
/// block argument `bbArg` of the innermost tile loop.
|
||||
OpOperand *getTiedIterArg(BlockArgument bbArg);
|
||||
|
||||
/// Returns true if `bbArg` has other used than `sliceOp` and its
|
||||
/// dependencies. Only if there are no other uses, the producer output
|
||||
/// iteration argument may reused to pass the producer result after fusion.
|
||||
bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
|
||||
|
||||
LinalgOp rootOp;
|
||||
SmallVector<scf::ForOp> tileLoopOps;
|
||||
DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
|
||||
};
|
||||
|
||||
/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
|
||||
/// `tileSizes` and `tileInterchange` parameters to control the tiling.
|
||||
FailureOr<TileLoopNest>
|
||||
tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Distribution utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -396,6 +329,77 @@ void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc,
|
|||
Value procId, Value nprocs, Value &lb,
|
||||
Value &ub, Value &step);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fusion on tensor utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// A struct to manage the tile loop nest specific information.
|
||||
class TileLoopNest {
|
||||
public:
|
||||
TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
|
||||
|
||||
/// Tile the root operation using the given `tileSizes` and `tileInterchange`,
|
||||
/// and `tileDistribution`.
|
||||
LogicalResult
|
||||
tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange,
|
||||
Optional<LinalgLoopDistributionOptions> tileDistribution);
|
||||
|
||||
/// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
|
||||
/// the fused producer or fails if fusion is not possible.
|
||||
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
|
||||
|
||||
/// Returns the replacement results for the original untiled root operation.
|
||||
ValueRange getRootOpReplacementResults();
|
||||
|
||||
/// Returns the tiled root operation.
|
||||
LinalgOp getRootOp() { return rootOp; }
|
||||
|
||||
/// Returns the tiled root operation and the fused producers.
|
||||
SmallVector<LinalgOp> getAllTiledAndFusedOps();
|
||||
|
||||
/// Returns the loop ops generated from tiling.
|
||||
ArrayRef<scf::ForOp> getLoopOps() { return tileLoopOps; }
|
||||
|
||||
/// Returns true if the tile loop nest has no tile loops.
|
||||
bool isEmpty();
|
||||
|
||||
private:
|
||||
/// Returns true if the tile loop nest invariants are satisfied:
|
||||
/// - The `rootOp` has been tiled at least once.
|
||||
/// - The number of tile loop operations and dimensions match.
|
||||
/// - The innermost tile loop is the parent of `tiledOp`.
|
||||
/// - The tile loops are directly nested.
|
||||
// TODO: relax to support additional control flow, e.g., IfOp.
|
||||
bool isValid();
|
||||
|
||||
/// Searches the block arguments tied to a block argument `bbArg` of the
|
||||
/// innermost tile loop. Returns the block argument from outermost to
|
||||
/// innermost or an empty vector if none are found.
|
||||
SmallVector<BlockArgument> getTiedBBArgs(BlockArgument bbArg);
|
||||
|
||||
/// Returns the iteration argument of the outermost tile loop mapped to a
|
||||
/// block argument `bbArg` of the innermost tile loop.
|
||||
OpOperand *getTiedIterArg(BlockArgument bbArg);
|
||||
|
||||
/// Returns true if `bbArg` has other used than `sliceOp` and its
|
||||
/// dependencies. Only if there are no other uses, the producer output
|
||||
/// iteration argument may reused to pass the producer result after fusion.
|
||||
bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
|
||||
|
||||
LinalgOp rootOp;
|
||||
SmallVector<scf::ForOp> tileLoopOps;
|
||||
DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
|
||||
};
|
||||
|
||||
/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the
|
||||
/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control
|
||||
/// the tiling.
|
||||
FailureOr<TileLoopNest> tileConsumerAndFuseProducers(
|
||||
OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange,
|
||||
Optional<LinalgLoopDistributionOptions> tileDistribution);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Generic op region utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -269,9 +269,10 @@ bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
|
|||
});
|
||||
}
|
||||
|
||||
LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange) {
|
||||
LogicalResult TileLoopNest::tileRootOp(
|
||||
OpBuilder &b, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange,
|
||||
Optional<LinalgLoopDistributionOptions> tileDistribution) {
|
||||
// Exit if all tile sizes are zero.
|
||||
if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
|
||||
return success();
|
||||
|
@ -283,6 +284,9 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
|
|||
tileInterchange.begin(), tileInterchange.end()))
|
||||
.setTileSizes(tileSizes)
|
||||
.setLoopType(LinalgTilingLoopType::Loops);
|
||||
if (tileDistribution)
|
||||
tilingOptions =
|
||||
tilingOptions.setDistributionOptions(tileDistribution.getValue());
|
||||
|
||||
// TODO: Propagate RewriterBase everywhere.
|
||||
IRRewriter rewriter(b);
|
||||
|
@ -408,10 +412,10 @@ SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() {
|
|||
// Tile and fuse entry-points.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<TileLoopNest>
|
||||
mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange) {
|
||||
FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
|
||||
OpBuilder &b, LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> tileInterchange,
|
||||
Optional<LinalgLoopDistributionOptions> tileDistribution) {
|
||||
assert(tileSizes.size() == tileInterchange.size() &&
|
||||
"expect the number of tile sizes and interchange dims to match");
|
||||
assert(isPermutation(tileInterchange) &&
|
||||
|
@ -446,7 +450,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
|
|||
SmallVector<int64_t> outerTileSizes;
|
||||
outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
|
||||
outerTileSizes.append(tileSizes.size() - split, 0);
|
||||
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
|
||||
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
|
||||
tileDistribution)))
|
||||
return failure();
|
||||
fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
|
||||
|
||||
|
@ -454,7 +459,8 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
|
|||
SmallVector<int64_t> innerTileSizes;
|
||||
innerTileSizes.append(split, 0);
|
||||
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
|
||||
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
|
||||
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
|
||||
tileDistribution)))
|
||||
return failure();
|
||||
fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
|
||||
|
||||
|
|
|
@ -613,8 +613,9 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
|
|||
op, "expect the tile interchange permutes the root loops");
|
||||
|
||||
// Tile `rootOp` and fuse its producers.
|
||||
FailureOr<TileLoopNest> tileLoopNest = tileConsumerAndFuseProducers(
|
||||
rewriter, rootOp, rootTileSizes, rootInterchange);
|
||||
FailureOr<TileLoopNest> tileLoopNest =
|
||||
tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes,
|
||||
rootInterchange, options.tileDistribution);
|
||||
if (failed(tileLoopNest))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "tileConsumerAndFuseProducers failed unexpectedly");
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
|
||||
// CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK: func @fill_matmul_tensors(
|
||||
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
func @fill_matmul_tensors(
|
||||
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32> {
|
||||
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y
|
||||
// CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y
|
||||
// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
|
||||
// CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x
|
||||
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
|
||||
// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
|
||||
// CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
|
||||
// CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
|
||||
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
|
||||
// CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
|
||||
// CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
|
||||
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TC1]]
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[SLICE]])
|
||||
// CHECK: %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
|
||||
// CHECK: %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
|
||||
// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
|
||||
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
|
||||
%3 = linalg.fill(%cst, %2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
%4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"}
|
||||
ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%3: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32>
|
||||
|
||||
// CHECK: return %[[TD0]] : tensor<?x?xf32>
|
||||
return %4 : tensor<?x?xf32>
|
||||
}
|
|
@ -78,6 +78,10 @@ struct TestLinalgTransforms
|
|||
*this, "test-tile-and-distribute-options",
|
||||
llvm::cl::desc("Test tile and distribute options"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testTileFuseAndDistributionOptions{
|
||||
*this, "test-tile-fuse-and-distribute-options",
|
||||
llvm::cl::desc("Test tile, fuse and distribute options"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testVectorTransferForwardingPatterns{
|
||||
*this, "test-vector-transfer-forwarding-patterns",
|
||||
llvm::cl::desc(
|
||||
|
@ -505,6 +509,21 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
|
|||
}
|
||||
}
|
||||
|
||||
static void fillTileFuseAndDistributePatterns(MLIRContext *context,
|
||||
RewritePatternSet &patterns) {
|
||||
LinalgLoopDistributionOptions cyclicNprocsEqNiters;
|
||||
cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic);
|
||||
cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
|
||||
patterns.add<LinalgTileAndFuseTensorOpsPattern>(
|
||||
MatmulOp::getOperationName(), context,
|
||||
LinalgTilingAndFusionOptions()
|
||||
.setTileSizes({8, 8, 4})
|
||||
.setDistributionOptions(cyclicNprocsEqNiters),
|
||||
LinalgTransformationFilter(
|
||||
StringAttr::get(context, "tensors_fuse_distribute1"),
|
||||
StringAttr::get(context, "tensors_after_fuse_distribute1")));
|
||||
}
|
||||
|
||||
static void
|
||||
applyMatmulToVectorPatterns(FuncOp funcOp,
|
||||
bool testMatmulToVectorPatterns1dTiling,
|
||||
|
@ -698,6 +717,12 @@ void TestLinalgTransforms::runOnOperation() {
|
|||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
return;
|
||||
}
|
||||
if (testTileFuseAndDistributionOptions) {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
fillTileFuseAndDistributePatterns(&getContext(), patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
return;
|
||||
}
|
||||
if (testPatterns)
|
||||
return applyPatterns(getOperation());
|
||||
if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
|
||||
|
|
Loading…
Reference in New Issue