forked from OSchip/llvm-project
[mlir][Linalg] Extend tile+fuse to work on Linalg operation on tensors.
Differential Revision: https://reviews.llvm.org/D93086
This commit is contained in:
parent
3317b38ef8
commit
01defcc8d7
|
@ -247,8 +247,9 @@ private:
|
|||
// Uses std::pair to keep operations and view together and avoid usage errors
|
||||
// related to src/dst and producer/consumer terminology in the context of
|
||||
// dependences.
|
||||
void addDependenceElem(DependenceType dt, OpOperand *indexingOpView,
|
||||
OpOperand *dependentOpView);
|
||||
void addDependenceElem(DependenceType dt,
|
||||
LinalgDependenceGraphElem::OpView indexingOpView,
|
||||
LinalgDependenceGraphElem::OpView dependentOpView);
|
||||
|
||||
/// Implementation detail for findCoveringxxx.
|
||||
SmallVector<Operation *, 8>
|
||||
|
|
|
@ -113,18 +113,21 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
|
|||
}
|
||||
}
|
||||
|
||||
void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
|
||||
OpOperand *indexingOpView,
|
||||
OpOperand *dependentOpView) {
|
||||
void LinalgDependenceGraph::addDependenceElem(
|
||||
DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
|
||||
LinalgDependenceGraphElem::OpView dependentOpView) {
|
||||
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
|
||||
<< indexingOpView->get() << " @"
|
||||
<< indexingOpView->getOperandNumber() << ") -> \n\t\t("
|
||||
<< dependentOpView->get() << " @"
|
||||
<< dependentOpView->getOperandNumber() << ")");
|
||||
dependencesFromGraphs[dt][indexingOpView->getOwner()].push_back(
|
||||
<< LinalgDependenceGraphElem::getValue(indexingOpView)
|
||||
<< " @) -> \n\t\t("
|
||||
<< LinalgDependenceGraphElem::getValue(dependentOpView)
|
||||
<< " @)");
|
||||
dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
|
||||
.push_back(
|
||||
LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
|
||||
dependencesIntoGraphs[dt][dependentOpView->getOwner()].push_back(
|
||||
LinalgDependenceGraphElem{indexingOpView, dependentOpView, dt});
|
||||
dependencesIntoGraphs[dt]
|
||||
[LinalgDependenceGraphElem::getOwner(dependentOpView)]
|
||||
.push_back(LinalgDependenceGraphElem{
|
||||
indexingOpView, dependentOpView, dt});
|
||||
}
|
||||
|
||||
LinalgDependenceGraph::dependence_range
|
||||
|
@ -158,6 +161,18 @@ LinalgDependenceGraph::getDependencesInto(
|
|||
}
|
||||
|
||||
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
|
||||
if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
|
||||
for (OpOperand &dstOpOperand : dst.getInputOpOperands()) {
|
||||
// Check if the operand is defined by the src.
|
||||
auto definingOp = dstOpOperand.get().getDefiningOp<LinalgOp>();
|
||||
if (definingOp && definingOp == src)
|
||||
addDependenceElem(DependenceType::RAW, dstOpOperand.get(),
|
||||
&dstOpOperand);
|
||||
}
|
||||
return;
|
||||
}
|
||||
assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
|
||||
"unhandled dependence tracking for mixed buffer/tensor operations");
|
||||
for (OpOperand *srcOpOperand : src.getOutputBuffersOpOperands()) { // W
|
||||
// RAW graph
|
||||
for (OpOperand *dstOpOperand : dst.getInputBuffersOpOperands()) // R
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -348,13 +349,15 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
|
|||
return true;
|
||||
}
|
||||
|
||||
/// For `consumer` with buffer semantics, find the Linalg operation on buffers
|
||||
/// that is the last writer of `consumerOpOperand`. For now the fusable
|
||||
/// dependence is returned as an instance of the `dependenceGraph`.
|
||||
static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
||||
findFusableProducer(OpOperand &consumerOpOperand,
|
||||
const LinalgDependenceGraph &dependenceGraph) {
|
||||
LinalgOp consumerOp = cast<LinalgOp>(consumerOpOperand.getOwner());
|
||||
// Note that buffer semantics implies that the dependence will only be from
|
||||
// OpOperand -> OpOperand.
|
||||
assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand");
|
||||
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
|
||||
if (!consumerOp)
|
||||
return {};
|
||||
|
||||
// Only consider RAW and WAW atm.
|
||||
for (auto depType : {
|
||||
|
@ -378,18 +381,21 @@ findFusableProducer(OpOperand &consumerOpOperand,
|
|||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "\n"
|
||||
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
|
||||
<< "producer: " << *dependence.getDependentOp() << " view: "
|
||||
<< dependence.getDependentValue() << " output index: "
|
||||
<< (dependence.getDependentOpViewOperandNum().getValue() -
|
||||
producer.getNumInputs())
|
||||
<< "\n");
|
||||
<< "producer: " << *dependence.getDependentOp()
|
||||
<< " view: " << dependence.getDependentValue() << "\n");
|
||||
|
||||
// Simple fusability checks.
|
||||
if (!isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
|
||||
// If the producer and consumer have tensor semantics, the only dependence
|
||||
// between them is through a RAW dependence and they are fusable by
|
||||
// construction. For buffer semantics need additional checks.
|
||||
if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
|
||||
isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
|
||||
producer))
|
||||
continue;
|
||||
|
||||
return dependence;
|
||||
if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
|
||||
assert(dependence.dependenceType ==
|
||||
LinalgDependenceGraph::DependenceType::RAW);
|
||||
return dependence;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
|
@ -439,6 +445,10 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
|
|||
|
||||
/// Walk back use-def chain through scf::For yields.
|
||||
/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
|
||||
|
||||
// TODO(ravishankarm, ntv): This can be moved into the dependence graphs
|
||||
// dependence tracking since the dependence tracking is similar to what is done
|
||||
// w.r.t to buffers.
|
||||
static void getProducerOfTensor(Value tensor, OpResult &opResult) {
|
||||
if (!tensor.getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
@ -722,6 +732,45 @@ collectFusableLoops(ArrayRef<LinalgOp> ops,
|
|||
return fusableLoops;
|
||||
}
|
||||
|
||||
// /// For `consumer` with tensor semantics, find the Linalg operation on
|
||||
// tensors
|
||||
// /// producer the operand at position `consumerIdx`. This is a simple use-def
|
||||
// /// chain using the SSA value, but returned as an element of the
|
||||
// /// `LinalgDependenceGraphElem` to use the same analysis for both tensors and
|
||||
// /// buffers.
|
||||
// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
||||
// findFusableProducerForTensorOp(OpOperand &consumerOpOperand) {
|
||||
// // For now only looking for cases where the operand is produced by another
|
||||
// // Linalg structured operation.
|
||||
// LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
|
||||
// if (!consumer || !consumer.hasTensorSemantics())
|
||||
// return llvm::None;
|
||||
// unsigned consumerIdx = consumerOpOperand.getOperandNumber();
|
||||
// Value value = consumerOpOperand.get();
|
||||
// if (auto linalgOp = value.getDefiningOp<LinalgOp>()) {
|
||||
// return LinalgDependenceGraph::LinalgDependenceGraphElem{
|
||||
// &(linalgOp
|
||||
// .getOutputOpOperands()[value.cast<OpResult>().getResultNumber()]),
|
||||
// &(consumer.getInputOpOperands()[consumerIdx]),
|
||||
// LinalgDependenceGraph::DependenceType::RAW};
|
||||
// }
|
||||
// return llvm::None;
|
||||
// }
|
||||
|
||||
// static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
||||
// findFusableProducer(OpOperand &consumerOpOperand,
|
||||
// const LinalgDependenceGraph &dependenceGraph) {
|
||||
// LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
|
||||
// if (!consumer)
|
||||
// return llvm::None;
|
||||
// if (consumer.hasBufferSemantics())
|
||||
// return findFusableProducerForBufferOp(consumerOpOperand,
|
||||
// dependenceGraph);
|
||||
// if (consumer.hasTensorSemantics())
|
||||
// return findFusableProducerForTensorOp(consumerOpOperand);
|
||||
// return llvm::None;
|
||||
// }
|
||||
|
||||
/// Find all dependences that are fusable.
|
||||
FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
|
||||
ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
|
||||
|
@ -798,7 +847,7 @@ static Optional<TiledLinalgOp> tileRootOperation(
|
|||
/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
|
||||
/// `tiledOp`.
|
||||
static SmallVector<LinalgOp, 1>
|
||||
fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
|
||||
fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp,
|
||||
ArrayRef<LinalgOp> fusionCandidates,
|
||||
const FusableOpDependencesTy &fusableDependences,
|
||||
const std::set<unsigned> &fusedLoops) {
|
||||
|
@ -812,9 +861,33 @@ fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
|
|||
}
|
||||
|
||||
SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
|
||||
DenseMap<Operation *, LinalgOp> origOpToFusedOp;
|
||||
origOpToFusedOp[rootOp.getOperation()] = tiledOp;
|
||||
for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
|
||||
LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
|
||||
LinalgOp origOp = candidate.value();
|
||||
LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges);
|
||||
origOpToFusedOp[origOp.getOperation()] = fusedOp;
|
||||
fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
|
||||
// If the producer consumer operations are linalg operations on tensors, the
|
||||
// dependence is due to value produced (as a return tensor) by the producer
|
||||
// and used in the consumer. The returned value of the fused op needs to be
|
||||
// made the operand of the tiled/fused consumer operation. By construction
|
||||
// the value returned by the producer is the value used by the consumer.
|
||||
for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) {
|
||||
if (origOp.hasTensorSemantics() &&
|
||||
dependence.dependenceType ==
|
||||
LinalgDependenceGraph::DependenceType::RAW) {
|
||||
unsigned resultIndex =
|
||||
dependence.getDependentOpViewResultNum().getValue();
|
||||
LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp());
|
||||
if (!consumer)
|
||||
continue;
|
||||
Value replacementValue = fusedOp.getOperation()->getResult(resultIndex);
|
||||
consumer.getOperation()->setOperand(
|
||||
dependence.getIndexingOpViewOperandNum().getValue(),
|
||||
replacementValue);
|
||||
}
|
||||
}
|
||||
builder.setInsertionPoint(fusedOp);
|
||||
}
|
||||
return fusedOps;
|
||||
|
@ -828,15 +901,17 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
|||
if (ops.size() < 2)
|
||||
return llvm::None;
|
||||
LinalgOp rootOp = ops.back();
|
||||
for (auto op : enumerate(ops)) {
|
||||
// TODO: Nothing in the fusion of sequence of ops is specific to
|
||||
// buffers. This check can be removed after it is tested on tensors.
|
||||
LinalgOp linalgOp = op.value();
|
||||
if (!linalgOp.hasBufferSemantics()) {
|
||||
linalgOp.emitRemark("tile and fuse only tested for buffer operation");
|
||||
if (!llvm::all_of(
|
||||
ops,
|
||||
[](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) &&
|
||||
!llvm::all_of(ops, [](LinalgOp linalgOp) {
|
||||
return linalgOp.hasTensorSemantics();
|
||||
})) {
|
||||
rootOp.emitError(
|
||||
"unable to fuse operations that have tensor semantics with operations "
|
||||
"that have buffer semantics and viceversa.");
|
||||
return llvm::None;
|
||||
}
|
||||
}
|
||||
// TODO: Support interchange with tile + fuse. This might actually help do
|
||||
// better fusion.
|
||||
if (!tilingOptions.interchangeVector.empty()) {
|
||||
|
@ -877,8 +952,9 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
|||
ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
|
||||
|
||||
// Fuse the other operations into the fused inter-tile loops produced above.
|
||||
ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
|
||||
ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(),
|
||||
fusableDependences, ret.fusedLoopDims);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -131,3 +131,115 @@ module {
|
|||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @tensor_op_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
||||
%arg2: tensor<?x?xf32>, %arg3: tensor<?xf32>)
|
||||
-> tensor<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = dim %0, %c1 : tensor<?x?xf32>
|
||||
%3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
|
||||
%4 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%0, %arg3 : tensor<?x?xf32>, tensor<?xf32>)
|
||||
outs(%3 : tensor<?x?xf32>) {
|
||||
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
|
||||
%5 = addf %arg4, %arg5 : f32
|
||||
linalg.yield %5 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %4 : tensor<?x?xf32>
|
||||
}
|
||||
}
|
||||
// CHECK-LABEL: func @tensor_op_fusion
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor
|
||||
// CHECK: %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK-DAG: %[[STARG3:.+]] = subtensor %[[ARG3]]
|
||||
// CHECK-DAG: %[[STARG7:.+]] = subtensor %[[ARG7]]
|
||||
// CHECK-DAG: %[[STARG0:.+]] = subtensor %[[ARG0]]
|
||||
// CHECK-DAG: %[[STARG1:.+]] = subtensor %[[ARG1]]
|
||||
// CHECK-DAG: %[[STARG2:.+]] = subtensor %[[ARG2]]
|
||||
// CHECK: %[[T0:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[STARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[T0:.+]], %[[STARG3]] : tensor<?x?xf32>, tensor<?xf32>)
|
||||
// CHECK-SAME: outs(%[[STARG7]] : tensor<?x?xf32>)
|
||||
// CHECK: %[[RESULT:.+]] = subtensor_insert %[[T1]] into %[[ARG7]]
|
||||
// CHECK: scf.yield %[[RESULT]]
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[R1]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[R0]]
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @tensor_matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
||||
%arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
|
||||
%arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
|
||||
%arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
|
||||
%1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
|
||||
%2 = linalg.matmul ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
|
||||
return %2 : tensor<?x?xf32>
|
||||
}
|
||||
}
|
||||
// CHECK-LABEL: func @tensor_matmul_fusion(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] =
|
||||
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[N3:.+]] = dim %[[ARG8]], %[[C1]]
|
||||
// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N3]]]
|
||||
// CHECK: %[[N2:.+]] = dim %[[ARG3]], %[[C1]]
|
||||
// CHECK: %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[STARG3:.+]] = subtensor %[[ARG3]][0, 0]
|
||||
// CHECK-SAME: [%[[N1]], %[[N2]]]
|
||||
// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N2]]]
|
||||
// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N0]]]
|
||||
// CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][0, 0]
|
||||
// CHECK-SAME: [%[[N0]], %[[N1]]]
|
||||
// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
|
||||
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N1]]]
|
||||
// CHECK: %[[T0:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]]
|
||||
// CHECK-SAME: ) outs(%[[STARG2]] : tensor<?x?xf32>)
|
||||
// CHECK: %[[T1:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[T0]], %[[STARG3]]
|
||||
// CHECK-SAME: ) outs(%[[STARG4]] : tensor<?x?xf32>)
|
||||
// CHECK: %[[T2:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[T1]], %[[ARG5]]
|
||||
// CHECK-SAME: ) outs(%[[STARG6]] : tensor<?x?xf32>)
|
||||
// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]]
|
||||
// CHECK-SAME: into %[[ARG8]][%[[IV0]], %[[C0]]]
|
||||
// CHECK: scf.yield %[[R1]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[R0]]
|
||||
// CHECK: }
|
||||
|
|
|
@ -226,13 +226,22 @@ struct TestLinalgTileAndFuseSequencePass
|
|||
Aliases aliases;
|
||||
LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
|
||||
OpBuilder builder(funcOp.getContext());
|
||||
linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
|
||||
if (llvm::all_of(linalgOps, [](LinalgOp linalgOp) {
|
||||
return linalgOp.hasTensorSemantics();
|
||||
}))
|
||||
loopType = LinalgTilingLoopType::Loops;
|
||||
Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
|
||||
builder, linalgOps, dependenceGraph,
|
||||
LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(
|
||||
LinalgTilingLoopType::ParallelLoops));
|
||||
LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
|
||||
if (!tileAndFuseOps)
|
||||
return signalPassFailure();
|
||||
if (linalgOps.back().hasTensorSemantics()) {
|
||||
linalgOps.back().getOperation()->replaceAllUsesWith(
|
||||
tileAndFuseOps->fusedLoops.front());
|
||||
}
|
||||
for (auto op : linalgOps)
|
||||
if (op.hasBufferSemantics())
|
||||
op.erase();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue