[mlir] Add support for fusion into TiledLoopOp.

Differential Revision: https://reviews.llvm.org/D102722
This commit is contained in:
Alexander Belyaev 2021-05-21 18:13:09 +02:00
parent eaaf7a6a09
commit 9ecc8178d7
5 changed files with 330 additions and 25 deletions

View File

@ -584,6 +584,15 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
];
let extraClassDeclaration = [{
/// Number of loops
unsigned getNumLoops() { return step().size(); }
/// Number of input operands
unsigned getNumInputs() { return inputs().size(); }
/// Number of output operands
unsigned getNumOutputs() { return outputs().size(); }
/// Number of operands controlling the loop: lbs, ubs, steps
unsigned getNumControlOperands() { return 3 * getNumLoops(); }
@ -597,7 +606,6 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
return getBody()->getArguments().take_back(outputs().size());
}
void setLowerBounds(ValueRange lowerBounds) {
unsigned numLoops = getNumLoops();
assert(lowerBounds.size() == numLoops &&
@ -622,6 +630,16 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
setOperand(pos, steps[i]);
}
/// Block argument that corresponds to the `input` or `output` operand.
BlockArgument getTiedBlockArgument(OpOperand& operand) {
auto operandIndex = operand.getOperandNumber();
assert(
operandIndex >= getNumControlOperands() &&
operandIndex < getNumOperands() &&
"tied block arg is defined only for `input` and `output` arguments");
return getBody()->getArgument(operandIndex - 2 * getNumLoops());
}
/// Result that corresponds to the `outputs` argument of tensor type.
OpResult getTiedOpResult(OpOperand& opOperand) {
// No result can correspond to a memref argument.
@ -642,7 +660,76 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
return getOperation()->getResult(tensorId);
}
unsigned getNumLoops() { return step().size(); }
/// Append `operand` to the `input` arguments.
OpOperand& appendInputOperand(OpBuilder& builder, Value operand) {
int numLoops = getNumLoops();
int numInputs = getNumInputs();
int numOutputs = getNumOutputs();
getOperation()->insertOperands(getNumControlOperands() + numInputs,
operand);
getBody()->insertArgument(numLoops + numInputs, operand.getType());
getOperation()->setAttr(
TiledLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(
{numLoops, numLoops, numLoops, numInputs + 1, numOutputs}));
return getOperation()->getOpOperand(getNumControlOperands() + numInputs);
}
/// Append `operand` to the `output` arguments.
OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) {
int numLoops = getNumLoops();
int numInputs = getNumInputs();
int numOutputs = getNumOutputs();
getOperation()->insertOperands(
getNumControlOperands() + numInputs + numOutputs, operand);
getBody()->insertArgument(numLoops + numInputs + numOutputs,
operand.getType());
getOperation()->setAttr(
TiledLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(
{numLoops, numLoops, numLoops, numInputs, numOutputs + 1}));
return getOperation()->getOpOperand(getNumControlOperands() + numInputs +
numOutputs);
}
/// Erase `operand` from the `input` or `output` arguments.
void eraseOperand(OpBuilder& builder, OpOperand& operand) {
int numInputs = getNumInputs();
int numLoops = getNumLoops();
int numOutputs = getNumOutputs();
int numControlOperands = getNumControlOperands();
auto operandIndex = operand.getOperandNumber();
assert(operandIndex >= numControlOperands &&
operandIndex < getNumOperands() &&
"Can erase only `input` or `output` operand");
if (operandIndex >= numControlOperands + numInputs)
--numOutputs;
else
--numInputs;
getOperation()->eraseOperand(operandIndex);
getBody()->eraseArgument(operandIndex - 2 * numLoops);
getOperation()->setAttr(
TiledLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(
{numLoops, numLoops, numLoops, numInputs, numOutputs}));
}
OpOperand* findInputOperand(Value value) {
OperandRange::iterator it = llvm::find(inputs(), value);
if (it == inputs().end()) return nullptr;
return it.getBase();
}
OpOperand* findOutputOperand(Value value) {
OperandRange::iterator it = llvm::find(outputs(), value);
if (it == outputs().end()) return nullptr;
return it.getBase();
}
}];
let hasCanonicalizer = 1;

View File

@ -107,6 +107,66 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
// Return tiled operands for the fused producer op. When fusing into
// `linalg.tiled_loop` one has to update `input` and `output` arguments of the
// loop correspondingly.
// Each input tensor of the producer op has to be added to `inputs` of the
// `tiled_loop` if it is not present there already. Each output tensor has to
// be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending
// on whether the correponding result is an input or an output to the loop.
//
// NOTE: This way of updating the arguments of the `tiled_loop` assumes that the
// intermediate result is not used by any other operation but the consumer. A
// more generic way is to append all missing output tensors of the producer to
// the tiled loop outputs and hence modify the number of the results, since we
// would need to add the intermediate results to `linalg.yield`. After that a
// canonicalization pass would move the unused output args of the `tiled_loop`
// to the `input` section.
static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
if (!tiledLoop)
return llvm::to_vector<4>(producer.getShapedOperands());
SmallVector<Value, 4> tiledOperands;
assert(producer.hasTensorSemantics() &&
"only fusion on tensors is currently supported for TiledLinalgOp");
for (auto producerInput : producer.getInputTensors()) {
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
if (addedInput == nullptr)
addedInput = &tiledLoop.appendInputOperand(b, producerInput);
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
tiledOperands.push_back(addedBlockArg);
}
for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
Value producerOutput = en.value();
Value result = producer->getResult(en.index());
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
"The result should be present in `input` or `output` args of "
"`tiled_loop");
bool isInput = resultInputOperand;
int opNumber = isInput ? resultInputOperand->getOperandNumber()
: resultOutputOperand->getOperandNumber();
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
if (addedOutput == nullptr)
addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
: &tiledLoop.appendOutputOperand(b, producerOutput);
OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand);
resultOperandBlockArg.replaceAllUsesWith(addedBlockArg);
tiledLoop.eraseOperand(b, resultOperand);
tiledOperands.push_back(addedBlockArg);
}
return tiledOperands;
}
/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
/// provides the loop range information for the fused loops. The rest are
/// obtained from the producer itself, since they are not tiled + fused.
@ -143,8 +203,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
clonedShapes.reserve(producer.getNumShapedOperands());
// Compute subranges for all tensor input/output operands.
auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
clonedShapes.append(makeTiledShapes(b, loc, producer, tiledOperands, ivs,
clonedShapes.append(makeTiledShapes(b, loc, producer,
getTiledOperands(b, producer), ivs,
tileSizes, sizeBounds));
// Append the other operands.
@ -808,7 +868,7 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
origOpToFusedOp[origOp.getOperation()] = fusedOp;
fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
// Prepare the b for the next insertion point.
// Prepare the builder for the next insertion point.
auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
if (!origOp.hasTensorSemantics())
continue;
@ -844,16 +904,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
// 2. encode destructive updates that may be inplaceable by bufferization.
// To keep the second type of information while letting the unfused op die
// unused, we need to forward the producer output operand.
for (auto &operand :
cast<scf::ForOp>(tiledLinalgOp.loops.front()).getIterOpOperands())
if (auto opResult = operand.get().dyn_cast<OpResult>())
if (opResult.getOwner() == origOp)
operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
for (auto &operand : forOp.getIterOpOperands())
if (auto opResult = operand.get().dyn_cast<OpResult>())
if (opResult.getOwner() == origOp)
operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
}
}
return fusedOps;
}
template <typename LoopType>
static Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
@ -928,11 +988,9 @@ mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
const LinalgTilingOptions &tilingOptions) {
switch (tilingOptions.loopType) {
case LinalgTilingLoopType::Loops:
return tileAndFuseLinalgOpsImpl<scf::ForOp>(b, ops, dependenceGraph,
tilingOptions);
case LinalgTilingLoopType::ParallelLoops:
return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(b, ops, dependenceGraph,
tilingOptions);
case LinalgTilingLoopType::TiledLoops:
return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
default:;
}
return llvm::None;

View File

@ -1,15 +1,16 @@
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
module {
func @matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
%arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
%arg4: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
%1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
return %1 : tensor<?x?xf32>
func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
%AB_init: tensor<?x?xf32>, %C: tensor<?x?xf32>,
%ABC_init: tensor<?x?xf32>) -> tensor<?x?xf32> {
%AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%AB_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
%ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%ABC_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
return %ABC : tensor<?x?xf32>
}
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)>
@ -90,6 +91,64 @@ module {
// CHECK: }
// CHECK: return %[[RESULT]]
// TLOOP-LABEL: func @matmul_fusion(
// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[AB_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[ABC_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// TLOOP: %[[C32:.*]] = constant 32 : index
// TLOOP: %[[C64:.*]] = constant 64 : index
// TLOOP: %[[C16:.*]] = constant 16 : index
// TLOOP: %[[C0:.*]] = constant 0 : index
// TLOOP: %[[C1:.*]] = constant 1 : index
// TLOOP: %[[DIM_A0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
// TLOOP: %[[ABC:.*]] = linalg.tiled_loop (%[[IV0:.*]]) = (%[[C0]])
// TLOOP-SAME: to (%[[DIM_A0]]) step (%[[C32]])
// TLOOP-SAME: ins (%[[C_:.*]] = %[[C]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[A_:.*]] = %[[A]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[AB_INIT_:.*]] = %[[AB_INIT]]: tensor<?x?xf32>)
// TLOOP-SAME: outs (%[[ABC_INIT_:.*]] = %[[ABC_INIT]]: tensor<?x?xf32>) {
// TLOOP: %[[ABC_INIT_SUB:.*]] = subtensor %[[ABC_INIT_]][%[[IV0]], 0]
// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
// TLOOP: %[[AB_INIT_SUB:.*]] = subtensor %[[AB_INIT_]][%[[IV0]], 0]
// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B_]], %[[C1]] : [[TY]]
// TLOOP: %[[DIM_C_1:.*]] = memref.dim %[[C_]], %[[C1]] : [[TY]]
// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) =
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
// TLOOP-SAME: step (%[[C64]], %[[C16]])
// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]],
// TLOOP-SAME: %[[C__:.*]] = %[[C_]]: [[TY]])
// TLOOP-SAME: outs (%[[ABC_INIT_SUB_:.*]] = %[[ABC_INIT_SUB]]: [[TY]])
// TLOOP-SAME: iterators["parallel", "reduction"] {
// TLOOP: %[[AB_SUB_SUB:.*]] = subtensor %[[AB_SUB_]][0, %[[IV2]]]
// TLOOP: %[[C__SUB:.*]] = subtensor %[[C__]][%[[IV2]], %[[IV1]]]
// TLOOP: %[[ABS_INIT_SUB_SUB:.*]] = subtensor %[[ABC_INIT_SUB_]][0, %[[IV1]]]
// TLOOP: %[[ABC_SUB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]])
// TLOOP-SAME: outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
// TLOOP: %[[RES0:.*]] = subtensor_insert %[[ABC_SUB_SUB]]
// TLOOP-SAME: into %[[ABC_INIT_SUB_]][0, %[[IV1]]]
// TLOOP: linalg.yield %[[RES0]] : [[TY]]
// TLOOP: }
// TLOOP: %[[RES1:.*]] = subtensor_insert %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0]
// TLOOP: linalg.yield %[[RES1]] : [[TY]]
// TLOOP: }
// TLOOP: return %[[ABC]] : [[TY]]
// -----
module {
@ -144,6 +203,48 @@ module {
// CHECK: scf.yield %[[YIELD]]
// CHECK: return %[[RESULT]]
// TLOOP-LABEL: func @matmul_plus_matmul
// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
// TLOOP-SAME: %[[AB:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// TLOOP: %[[C32:.*]] = constant 32 : index
// TLOOP: %[[C64:.*]] = constant 64 : index
// TLOOP: %[[C0:.*]] = constant 0 : index
// TLOOP: %[[C1:.*]] = constant 1 : index
// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
// TLOOP: %[[INIT:.*]] = linalg.init_tensor [%[[DIM_A_0]], %[[DIM_B_1]]]
// TLOOP: %[[RESULT:.*]] = linalg.tiled_loop (%[[IV0:.*]], %[[IV1:.*]]) =
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
// TLOOP-SAME: step (%[[C32]], %[[C64]])
// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]],
// TLOOP-SAME: %[[AB_:.*]] = %[[AB]]: [[TY]])
// TLOOP-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: [[TY]]) {
// TLOOP: %[[INIT_SUB:.*]] = subtensor %[[INIT_]][%[[IV0]], %[[IV1]]]
// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[IV1]]]
// TLOOP: %[[AB_SUB_INIT:.*]] = subtensor %[[AB_]][%[[IV0]], %[[IV1]]]
// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[A_SUB]], %[[B_SUB]] : [[TY]], [[TY]])
// TLOOP-SAME: outs(%[[AB_SUB_INIT]] : [[TY]])
// TLOOP: %[[DOUBLE_AB:.*]] = linalg.generic
// TLOOP-SAME: ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]])
// TLOOP: %[[RESULT_SUB:.*]] = subtensor_insert
// TLOOP-SAME: %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]]
// TLOOP: linalg.yield %[[RESULT_SUB]] : [[TY]]
// TLOOP: }
// TLOOP: return %[[RESULT]] : [[TY]]
// -----
module {
@ -174,3 +275,53 @@ module {
// CHECK: scf.yield %[[ST_MM]] : tensor<?x?xf32>
// CHECK: %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}}
// CHECK: scf.yield %[[MM]] : tensor<?x?xf32>
// TLOOP-LABEL: func @matmul_out_fusion(
// TLOOP-SAME: %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// TLOOP-DAG: %[[C0_F32:.*]] = constant 0.0
// TLOOP-DAG: %[[C32:.*]] = constant 32 : index
// TLOOP-DAG: %[[C64:.*]] = constant 64 : index
// TLOOP-DAG: %[[C16:.*]] = constant 16 : index
// TLOOP-DAG: %[[C0:.*]] = constant 0 : index
// TLOOP-DAG: %[[C1:.*]] = constant 1 : index
// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
// TLOOP-SAME: step (%[[C32]], %[[C64]])
// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]])
// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
// TLOOP-SAME: iterators["reduction"] {
// TLOOP: %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]]
// TLOOP: %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0]
// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]]
// TLOOP: }
// TLOOP: %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]]
// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]]
// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]]
// TLOOP: }
// TLOOP: return %[[AB]] : [[TY]]

View File

@ -278,6 +278,13 @@ void registerTestLinalgTensorFusionTransforms() {
"Test Linalg on tensor fusion transformation "
"patterns by applying them greedily.");
}
void registerTestLinalgTiledLoopFusionTransforms() {
PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>>
testTiledLoopFusionTransformsPass(
"test-linalg-tiled-loop-fusion-transform-patterns",
"Test Linalg on tensor fusion transformation "
"patterns by applying them greedily.");
}
void registerTestLinalgGreedyFusion() {
PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
"test-linalg-greedy-fusion",

View File

@ -81,6 +81,7 @@ void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
void registerTestLinalgTensorFusionTransforms();
void registerTestLinalgTiledLoopFusionTransforms();
void registerTestLinalgGreedyFusion();
void registerTestLinalgHoisting();
void registerTestLinalgTileAndFuseSequencePass();
@ -159,6 +160,7 @@ void registerTestPasses() {
test::registerTestPushExpandingReshape();
test::registerTestLinalgFusionTransforms();
test::registerTestLinalgTensorFusionTransforms();
test::registerTestLinalgTiledLoopFusionTransforms();
test::registerTestLinalgGreedyFusion();
test::registerTestLinalgHoisting();
test::registerTestLinalgTileAndFuseSequencePass();