From 01defcc8d74e65f3d304274bc4ede44d838ff22b Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Fri, 22 Jan 2021 11:32:50 -0800 Subject: [PATCH] [mlir][Linalg] Extend tile+fuse to work on Linalg operation on tensors. Differential Revision: https://reviews.llvm.org/D93086 --- .../Linalg/Analysis/DependenceAnalysis.h | 5 +- .../Linalg/Analysis/DependenceAnalysis.cpp | 37 +++-- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 128 ++++++++++++++---- mlir/test/Dialect/Linalg/fusion-sequence.mlir | 114 +++++++++++++++- .../Transforms/TestLinalgFusionTransforms.cpp | 15 +- 5 files changed, 256 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 5ffe4c6c9461..fecaeff1c8df 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -247,8 +247,9 @@ private: // Uses std::pair to keep operations and view together and avoid usage errors // related to src/dst and producer/consumer terminology in the context of // dependences. - void addDependenceElem(DependenceType dt, OpOperand *indexingOpView, - OpOperand *dependentOpView); + void addDependenceElem(DependenceType dt, + LinalgDependenceGraphElem::OpView indexingOpView, + LinalgDependenceGraphElem::OpView dependentOpView); /// Implementation detail for findCoveringxxx. SmallVector diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index f80a00bf64d4..59004867a333 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -113,18 +113,21 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, } } -void LinalgDependenceGraph::addDependenceElem(DependenceType dt, - OpOperand *indexingOpView, - OpOperand *dependentOpView) { +void LinalgDependenceGraph::addDependenceElem( + DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView, + LinalgDependenceGraphElem::OpView dependentOpView) { LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t (" - << indexingOpView->get() << " @" - << indexingOpView->getOperandNumber() << ") -> \n\t\t(" - << dependentOpView->get() << " @" - << dependentOpView->getOperandNumber() << ")"); - dependencesFromGraphs[dt][indexingOpView->getOwner()].push_back( - LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt}); - dependencesIntoGraphs[dt][dependentOpView->getOwner()].push_back( - LinalgDependenceGraphElem{indexingOpView, dependentOpView, dt}); + << LinalgDependenceGraphElem::getValue(indexingOpView) + << " @) -> \n\t\t(" + << LinalgDependenceGraphElem::getValue(dependentOpView) + << " @)"); + dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)] + .push_back( + LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt}); + dependencesIntoGraphs[dt] + [LinalgDependenceGraphElem::getOwner(dependentOpView)] + .push_back(LinalgDependenceGraphElem{ + indexingOpView, dependentOpView, dt}); } LinalgDependenceGraph::dependence_range @@ -158,6 +161,18 @@ LinalgDependenceGraph::getDependencesInto( } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { + if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { + for (OpOperand &dstOpOperand : dst.getInputOpOperands()) { + // Check if the operand is defined by the src. + auto definingOp = dstOpOperand.get().getDefiningOp(); + if (definingOp && definingOp == src) + addDependenceElem(DependenceType::RAW, dstOpOperand.get(), + &dstOpOperand); + } + return; + } + assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && + "unhandled dependence tracking for mixed buffer/tensor operations"); for (OpOperand *srcOpOperand : src.getOutputBuffersOpOperands()) { // W // RAW graph for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 5d37e8f9d782..714bb0f97777 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -348,13 +349,15 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, return true; } +/// For `consumer` with buffer semantics, find the Linalg operation on buffers +/// that is the last writer of `consumerOpOperand`. For now the fusable +/// dependence is returned as an instance of the `dependenceGraph`. static Optional findFusableProducer(OpOperand &consumerOpOperand, const LinalgDependenceGraph &dependenceGraph) { - LinalgOp consumerOp = cast(consumerOpOperand.getOwner()); - // Note that buffer semantics implies that the dependence will only be from - // OpOperand -> OpOperand. - assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand"); + LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); + if (!consumerOp) + return {}; // Only consider RAW and WAW atm. for (auto depType : { @@ -378,18 +381,21 @@ findFusableProducer(OpOperand &consumerOpOperand, LLVM_DEBUG(llvm::dbgs() << "\n" << LinalgDependenceGraph::getDependenceTypeStr(depType) - << "producer: " << *dependence.getDependentOp() << " view: " - << dependence.getDependentValue() << " output index: " - << (dependence.getDependentOpViewOperandNum().getValue() - - producer.getNumInputs()) - << "\n"); + << "producer: " << *dependence.getDependentOp() + << " view: " << dependence.getDependentValue() << "\n"); - // Simple fusability checks. - if (!isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), - producer)) - continue; - - return dependence; + // If the producer and consumer have tensor semantics, the only dependence + // between them is through a RAW dependence and they are fusable by + // construction. For buffer semantics need additional checks. + if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && + isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), + producer)) + return dependence; + if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { + assert(dependence.dependenceType == + LinalgDependenceGraph::DependenceType::RAW); + return dependence; + } } } return {}; @@ -439,6 +445,10 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, /// Walk back use-def chain through scf::For yields. /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp + +// TODO(ravishankarm, ntv): This can be moved into the dependence graphs +// dependence tracking since the dependence tracking is similar to what is done +// w.r.t to buffers. static void getProducerOfTensor(Value tensor, OpResult &opResult) { if (!tensor.getType().isa()) return; @@ -722,6 +732,45 @@ collectFusableLoops(ArrayRef ops, return fusableLoops; } +// /// For `consumer` with tensor semantics, find the Linalg operation on +// tensors +// /// producer the operand at position `consumerIdx`. This is a simple use-def +// /// chain using the SSA value, but returned as an element of the +// /// `LinalgDependenceGraphElem` to use the same analysis for both tensors and +// /// buffers. +// static Optional +// findFusableProducerForTensorOp(OpOperand &consumerOpOperand) { +// // For now only looking for cases where the operand is produced by another +// // Linalg structured operation. +// LinalgOp consumer = cast(consumerOpOperand.getOwner()); +// if (!consumer || !consumer.hasTensorSemantics()) +// return llvm::None; +// unsigned consumerIdx = consumerOpOperand.getOperandNumber(); +// Value value = consumerOpOperand.get(); +// if (auto linalgOp = value.getDefiningOp()) { +// return LinalgDependenceGraph::LinalgDependenceGraphElem{ +// &(linalgOp +// .getOutputOpOperands()[value.cast().getResultNumber()]), +// &(consumer.getInputOpOperands()[consumerIdx]), +// LinalgDependenceGraph::DependenceType::RAW}; +// } +// return llvm::None; +// } + +// static Optional +// findFusableProducer(OpOperand &consumerOpOperand, +// const LinalgDependenceGraph &dependenceGraph) { +// LinalgOp consumer = cast(consumerOpOperand.getOwner()); +// if (!consumer) +// return llvm::None; +// if (consumer.hasBufferSemantics()) +// return findFusableProducerForBufferOp(consumerOpOperand, +// dependenceGraph); +// if (consumer.hasTensorSemantics()) +// return findFusableProducerForTensorOp(consumerOpOperand); +// return llvm::None; +// } + /// Find all dependences that are fusable. FusableOpDependencesTy mlir::linalg::findAllFusableDependences( ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { @@ -798,7 +847,7 @@ static Optional tileRootOperation( /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of /// `tiledOp`. static SmallVector -fuseOperations(OpBuilder &builder, LinalgOp tiledOp, +fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp, ArrayRef fusionCandidates, const FusableOpDependencesTy &fusableDependences, const std::set &fusedLoops) { @@ -812,9 +861,33 @@ fuseOperations(OpBuilder &builder, LinalgOp tiledOp, } SmallVector fusedOps(fusionCandidates.size()); + DenseMap origOpToFusedOp; + origOpToFusedOp[rootOp.getOperation()] = tiledOp; for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { - LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); + LinalgOp origOp = candidate.value(); + LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges); + origOpToFusedOp[origOp.getOperation()] = fusedOp; fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; + // If the producer consumer operations are linalg operations on tensors, the + // dependence is due to value produced (as a return tensor) by the producer + // and used in the consumer. The returned value of the fused op needs to be + // made the operand of the tiled/fused consumer operation. By construction + // the value returned by the producer is the value used by the consumer. + for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { + if (origOp.hasTensorSemantics() && + dependence.dependenceType == + LinalgDependenceGraph::DependenceType::RAW) { + unsigned resultIndex = + dependence.getDependentOpViewResultNum().getValue(); + LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); + if (!consumer) + continue; + Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); + consumer.getOperation()->setOperand( + dependence.getIndexingOpViewOperandNum().getValue(), + replacementValue); + } + } builder.setInsertionPoint(fusedOp); } return fusedOps; @@ -828,14 +901,16 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, if (ops.size() < 2) return llvm::None; LinalgOp rootOp = ops.back(); - for (auto op : enumerate(ops)) { - // TODO: Nothing in the fusion of sequence of ops is specific to - // buffers. This check can be removed after it is tested on tensors. - LinalgOp linalgOp = op.value(); - if (!linalgOp.hasBufferSemantics()) { - linalgOp.emitRemark("tile and fuse only tested for buffer operation"); - return llvm::None; - } + if (!llvm::all_of( + ops, + [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && + !llvm::all_of(ops, [](LinalgOp linalgOp) { + return linalgOp.hasTensorSemantics(); + })) { + rootOp.emitError( + "unable to fuse operations that have tensor semantics with operations " + "that have buffer semantics and viceversa."); + return llvm::None; } // TODO: Support interchange with tile + fuse. This might actually help do // better fusion. @@ -877,8 +952,9 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(), + ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(), fusableDependences, ret.fusedLoopDims); + return ret; } diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir index a02c878ef341..2738eb0f9114 100644 --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -58,7 +58,7 @@ module { module { func @sequence_of_matmul(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, - %arg4: memref) { + %arg4: memref) { %cst = constant 0.000000e+00 : f32 %c0 = constant 0 : index %c1 = constant 1 : index @@ -131,3 +131,115 @@ module { // CHECK: scf.yield // CHECK: } +// ----- + +module { + func @tensor_op_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor, %arg3: tensor) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %1 = dim %0, %c0 : tensor + %2 = dim %0, %c1 : tensor + %3 = linalg.init_tensor [%1, %2] : tensor + %4 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg3 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): + %5 = addf %arg4, %arg5 : f32 + linalg.yield %5 : f32 + } -> tensor + return %4 : tensor + } +} +// CHECK-LABEL: func @tensor_op_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor) { +// CHECK: %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor) { +// CHECK-DAG: %[[STARG3:.+]] = subtensor %[[ARG3]] +// CHECK-DAG: %[[STARG7:.+]] = subtensor %[[ARG7]] +// CHECK-DAG: %[[STARG0:.+]] = subtensor %[[ARG0]] +// CHECK-DAG: %[[STARG1:.+]] = subtensor %[[ARG1]] +// CHECK-DAG: %[[STARG2:.+]] = subtensor %[[ARG2]] +// CHECK: %[[T0:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[STARG2]] : tensor) -> tensor +// CHECK: %[[T1:.+]] = linalg.generic +// CHECK-SAME: ins(%[[T0:.+]], %[[STARG3]] : tensor, tensor) +// CHECK-SAME: outs(%[[STARG7]] : tensor) +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[T1]] into %[[ARG7]] +// CHECK: scf.yield %[[RESULT]] +// CHECK: } +// CHECK: scf.yield %[[R1]] +// CHECK: } +// CHECK: return %[[R0]] + +// ----- + +module { + func @tensor_matmul_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor, %arg3: tensor, + %arg4: tensor, %arg5: tensor, + %arg6: tensor) -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) + outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] + %2 = linalg.matmul ins(%1, %arg5 : tensor, tensor) + outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] + return %2 : tensor + } +} +// CHECK-LABEL: func @tensor_matmul_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor) -> tensor { +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = +// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor) { +// CHECK: %[[N3:.+]] = dim %[[ARG8]], %[[C1]] +// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0] +// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N3]]] +// CHECK: %[[N2:.+]] = dim %[[ARG3]], %[[C1]] +// CHECK: %[[N1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[STARG3:.+]] = subtensor %[[ARG3]][0, 0] +// CHECK-SAME: [%[[N1]], %[[N2]]] +// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0] +// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N2]]] +// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N0]]] +// CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][0, 0] +// CHECK-SAME: [%[[N0]], %[[N1]]] +// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0] +// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N1]]] +// CHECK: %[[T0:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] +// CHECK-SAME: ) outs(%[[STARG2]] : tensor) +// CHECK: %[[T1:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[T0]], %[[STARG3]] +// CHECK-SAME: ) outs(%[[STARG4]] : tensor) +// CHECK: %[[T2:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[T1]], %[[ARG5]] +// CHECK-SAME: ) outs(%[[STARG6]] : tensor) +// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]] +// CHECK-SAME: into %[[ARG8]][%[[IV0]], %[[C0]]] +// CHECK: scf.yield %[[R1]] +// CHECK: } +// CHECK: return %[[R0]] +// CHECK: } diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp index 5d55f0375f37..4ed00e4fbefc 100644 --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -226,14 +226,23 @@ struct TestLinalgTileAndFuseSequencePass Aliases aliases; LinalgDependenceGraph dependenceGraph(aliases, linalgOps); OpBuilder builder(funcOp.getContext()); + linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; + if (llvm::all_of(linalgOps, [](LinalgOp linalgOp) { + return linalgOp.hasTensorSemantics(); + })) + loopType = LinalgTilingLoopType::Loops; Optional tileAndFuseOps = tileAndFuseLinalgOps( builder, linalgOps, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( - LinalgTilingLoopType::ParallelLoops)); + LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); if (!tileAndFuseOps) return signalPassFailure(); + if (linalgOps.back().hasTensorSemantics()) { + linalgOps.back().getOperation()->replaceAllUsesWith( + tileAndFuseOps->fusedLoops.front()); + } for (auto op : linalgOps) - op.erase(); + if (op.hasBufferSemantics()) + op.erase(); } }; } // namespace