forked from OSchip/llvm-project
[mlir] Add support for fusion into TiledLoopOp.
Differential Revision: https://reviews.llvm.org/D102722
This commit is contained in:
parent
eaaf7a6a09
commit
9ecc8178d7
|
@ -584,6 +584,15 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
|||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Number of loops
|
||||
unsigned getNumLoops() { return step().size(); }
|
||||
|
||||
/// Number of input operands
|
||||
unsigned getNumInputs() { return inputs().size(); }
|
||||
|
||||
/// Number of output operands
|
||||
unsigned getNumOutputs() { return outputs().size(); }
|
||||
|
||||
/// Number of operands controlling the loop: lbs, ubs, steps
|
||||
unsigned getNumControlOperands() { return 3 * getNumLoops(); }
|
||||
|
||||
|
@ -597,7 +606,6 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
|||
return getBody()->getArguments().take_back(outputs().size());
|
||||
}
|
||||
|
||||
|
||||
void setLowerBounds(ValueRange lowerBounds) {
|
||||
unsigned numLoops = getNumLoops();
|
||||
assert(lowerBounds.size() == numLoops &&
|
||||
|
@ -622,6 +630,16 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
|||
setOperand(pos, steps[i]);
|
||||
}
|
||||
|
||||
/// Block argument that corresponds to the `input` or `output` operand.
|
||||
BlockArgument getTiedBlockArgument(OpOperand& operand) {
|
||||
auto operandIndex = operand.getOperandNumber();
|
||||
assert(
|
||||
operandIndex >= getNumControlOperands() &&
|
||||
operandIndex < getNumOperands() &&
|
||||
"tied block arg is defined only for `input` and `output` arguments");
|
||||
return getBody()->getArgument(operandIndex - 2 * getNumLoops());
|
||||
}
|
||||
|
||||
/// Result that corresponds to the `outputs` argument of tensor type.
|
||||
OpResult getTiedOpResult(OpOperand& opOperand) {
|
||||
// No result can correspond to a memref argument.
|
||||
|
@ -642,7 +660,76 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
|||
return getOperation()->getResult(tensorId);
|
||||
}
|
||||
|
||||
unsigned getNumLoops() { return step().size(); }
|
||||
/// Append `operand` to the `input` arguments.
|
||||
OpOperand& appendInputOperand(OpBuilder& builder, Value operand) {
|
||||
int numLoops = getNumLoops();
|
||||
int numInputs = getNumInputs();
|
||||
int numOutputs = getNumOutputs();
|
||||
|
||||
getOperation()->insertOperands(getNumControlOperands() + numInputs,
|
||||
operand);
|
||||
getBody()->insertArgument(numLoops + numInputs, operand.getType());
|
||||
getOperation()->setAttr(
|
||||
TiledLoopOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr(
|
||||
{numLoops, numLoops, numLoops, numInputs + 1, numOutputs}));
|
||||
return getOperation()->getOpOperand(getNumControlOperands() + numInputs);
|
||||
}
|
||||
|
||||
/// Append `operand` to the `output` arguments.
|
||||
OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) {
|
||||
int numLoops = getNumLoops();
|
||||
int numInputs = getNumInputs();
|
||||
int numOutputs = getNumOutputs();
|
||||
|
||||
getOperation()->insertOperands(
|
||||
getNumControlOperands() + numInputs + numOutputs, operand);
|
||||
getBody()->insertArgument(numLoops + numInputs + numOutputs,
|
||||
operand.getType());
|
||||
getOperation()->setAttr(
|
||||
TiledLoopOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr(
|
||||
{numLoops, numLoops, numLoops, numInputs, numOutputs + 1}));
|
||||
return getOperation()->getOpOperand(getNumControlOperands() + numInputs +
|
||||
numOutputs);
|
||||
}
|
||||
|
||||
/// Erase `operand` from the `input` or `output` arguments.
|
||||
void eraseOperand(OpBuilder& builder, OpOperand& operand) {
|
||||
int numInputs = getNumInputs();
|
||||
int numLoops = getNumLoops();
|
||||
int numOutputs = getNumOutputs();
|
||||
int numControlOperands = getNumControlOperands();
|
||||
|
||||
auto operandIndex = operand.getOperandNumber();
|
||||
assert(operandIndex >= numControlOperands &&
|
||||
operandIndex < getNumOperands() &&
|
||||
"Can erase only `input` or `output` operand");
|
||||
|
||||
if (operandIndex >= numControlOperands + numInputs)
|
||||
--numOutputs;
|
||||
else
|
||||
--numInputs;
|
||||
|
||||
getOperation()->eraseOperand(operandIndex);
|
||||
getBody()->eraseArgument(operandIndex - 2 * numLoops);
|
||||
getOperation()->setAttr(
|
||||
TiledLoopOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr(
|
||||
{numLoops, numLoops, numLoops, numInputs, numOutputs}));
|
||||
}
|
||||
|
||||
OpOperand* findInputOperand(Value value) {
|
||||
OperandRange::iterator it = llvm::find(inputs(), value);
|
||||
if (it == inputs().end()) return nullptr;
|
||||
return it.getBase();
|
||||
}
|
||||
|
||||
OpOperand* findOutputOperand(Value value) {
|
||||
OperandRange::iterator it = llvm::find(outputs(), value);
|
||||
if (it == outputs().end()) return nullptr;
|
||||
return it.getBase();
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
|
|
@ -107,6 +107,66 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
|||
llvm_unreachable("Expect to be able to extract a shape defining loop range");
|
||||
}
|
||||
|
||||
// Return tiled operands for the fused producer op. When fusing into
|
||||
// `linalg.tiled_loop` one has to update `input` and `output` arguments of the
|
||||
// loop correspondingly.
|
||||
// Each input tensor of the producer op has to be added to `inputs` of the
|
||||
// `tiled_loop` if it is not present there already. Each output tensor has to
|
||||
// be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending
|
||||
// on whether the correponding result is an input or an output to the loop.
|
||||
//
|
||||
// NOTE: This way of updating the arguments of the `tiled_loop` assumes that the
|
||||
// intermediate result is not used by any other operation but the consumer. A
|
||||
// more generic way is to append all missing output tensors of the producer to
|
||||
// the tiled loop outputs and hence modify the number of the results, since we
|
||||
// would need to add the intermediate results to `linalg.yield`. After that a
|
||||
// canonicalization pass would move the unused output args of the `tiled_loop`
|
||||
// to the `input` section.
|
||||
static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
|
||||
auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
|
||||
if (!tiledLoop)
|
||||
return llvm::to_vector<4>(producer.getShapedOperands());
|
||||
|
||||
SmallVector<Value, 4> tiledOperands;
|
||||
assert(producer.hasTensorSemantics() &&
|
||||
"only fusion on tensors is currently supported for TiledLinalgOp");
|
||||
|
||||
for (auto producerInput : producer.getInputTensors()) {
|
||||
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
|
||||
if (addedInput == nullptr)
|
||||
addedInput = &tiledLoop.appendInputOperand(b, producerInput);
|
||||
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
|
||||
tiledOperands.push_back(addedBlockArg);
|
||||
}
|
||||
for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
|
||||
Value producerOutput = en.value();
|
||||
|
||||
Value result = producer->getResult(en.index());
|
||||
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
|
||||
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
|
||||
assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
|
||||
"The result should be present in `input` or `output` args of "
|
||||
"`tiled_loop");
|
||||
|
||||
bool isInput = resultInputOperand;
|
||||
int opNumber = isInput ? resultInputOperand->getOperandNumber()
|
||||
: resultOutputOperand->getOperandNumber();
|
||||
|
||||
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
|
||||
if (addedOutput == nullptr)
|
||||
addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
|
||||
: &tiledLoop.appendOutputOperand(b, producerOutput);
|
||||
|
||||
OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
|
||||
auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
|
||||
auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand);
|
||||
resultOperandBlockArg.replaceAllUsesWith(addedBlockArg);
|
||||
tiledLoop.eraseOperand(b, resultOperand);
|
||||
tiledOperands.push_back(addedBlockArg);
|
||||
}
|
||||
return tiledOperands;
|
||||
}
|
||||
|
||||
/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
|
||||
/// provides the loop range information for the fused loops. The rest are
|
||||
/// obtained from the producer itself, since they are not tiled + fused.
|
||||
|
@ -143,8 +203,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
|||
clonedShapes.reserve(producer.getNumShapedOperands());
|
||||
|
||||
// Compute subranges for all tensor input/output operands.
|
||||
auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
|
||||
clonedShapes.append(makeTiledShapes(b, loc, producer, tiledOperands, ivs,
|
||||
clonedShapes.append(makeTiledShapes(b, loc, producer,
|
||||
getTiledOperands(b, producer), ivs,
|
||||
tileSizes, sizeBounds));
|
||||
|
||||
// Append the other operands.
|
||||
|
@ -808,7 +868,7 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
|
|||
origOpToFusedOp[origOp.getOperation()] = fusedOp;
|
||||
fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
|
||||
|
||||
// Prepare the b for the next insertion point.
|
||||
// Prepare the builder for the next insertion point.
|
||||
auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); });
|
||||
if (!origOp.hasTensorSemantics())
|
||||
continue;
|
||||
|
@ -844,16 +904,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
|
|||
// 2. encode destructive updates that may be inplaceable by bufferization.
|
||||
// To keep the second type of information while letting the unfused op die
|
||||
// unused, we need to forward the producer output operand.
|
||||
for (auto &operand :
|
||||
cast<scf::ForOp>(tiledLinalgOp.loops.front()).getIterOpOperands())
|
||||
if (auto opResult = operand.get().dyn_cast<OpResult>())
|
||||
if (opResult.getOwner() == origOp)
|
||||
operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
|
||||
for (auto &operand : forOp.getIterOpOperands())
|
||||
if (auto opResult = operand.get().dyn_cast<OpResult>())
|
||||
if (opResult.getOwner() == origOp)
|
||||
operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
|
||||
}
|
||||
}
|
||||
return fusedOps;
|
||||
}
|
||||
|
||||
template <typename LoopType>
|
||||
static Optional<TiledAndFusedLinalgOps>
|
||||
tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops,
|
||||
const LinalgDependenceGraph &dependenceGraph,
|
||||
|
@ -928,11 +988,9 @@ mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops,
|
|||
const LinalgTilingOptions &tilingOptions) {
|
||||
switch (tilingOptions.loopType) {
|
||||
case LinalgTilingLoopType::Loops:
|
||||
return tileAndFuseLinalgOpsImpl<scf::ForOp>(b, ops, dependenceGraph,
|
||||
tilingOptions);
|
||||
case LinalgTilingLoopType::ParallelLoops:
|
||||
return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(b, ops, dependenceGraph,
|
||||
tilingOptions);
|
||||
case LinalgTilingLoopType::TiledLoops:
|
||||
return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions);
|
||||
default:;
|
||||
}
|
||||
return llvm::None;
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
|
||||
|
||||
module {
|
||||
func @matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
||||
%arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
|
||||
%arg4: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
|
||||
%1 = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
|
||||
ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
|
||||
return %1 : tensor<?x?xf32>
|
||||
func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
|
||||
%AB_init: tensor<?x?xf32>, %C: tensor<?x?xf32>,
|
||||
%ABC_init: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%AB_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN1> <N1xN2>
|
||||
%ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
|
||||
ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%ABC_init : tensor<?x?xf32>) -> tensor<?x?xf32> // <MxN2> <N2xN3>
|
||||
return %ABC : tensor<?x?xf32>
|
||||
}
|
||||
}
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (32, d0 - d1)>
|
||||
|
@ -90,6 +91,64 @@ module {
|
|||
// CHECK: }
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// TLOOP-LABEL: func @matmul_fusion(
|
||||
// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[AB_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[ABC_INIT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
|
||||
// TLOOP: %[[C32:.*]] = constant 32 : index
|
||||
// TLOOP: %[[C64:.*]] = constant 64 : index
|
||||
// TLOOP: %[[C16:.*]] = constant 16 : index
|
||||
// TLOOP: %[[C0:.*]] = constant 0 : index
|
||||
// TLOOP: %[[C1:.*]] = constant 1 : index
|
||||
|
||||
// TLOOP: %[[DIM_A0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
|
||||
|
||||
// TLOOP: %[[ABC:.*]] = linalg.tiled_loop (%[[IV0:.*]]) = (%[[C0]])
|
||||
// TLOOP-SAME: to (%[[DIM_A0]]) step (%[[C32]])
|
||||
// TLOOP-SAME: ins (%[[C_:.*]] = %[[C]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[A_:.*]] = %[[A]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[AB_INIT_:.*]] = %[[AB_INIT]]: tensor<?x?xf32>)
|
||||
// TLOOP-SAME: outs (%[[ABC_INIT_:.*]] = %[[ABC_INIT]]: tensor<?x?xf32>) {
|
||||
|
||||
// TLOOP: %[[ABC_INIT_SUB:.*]] = subtensor %[[ABC_INIT_]][%[[IV0]], 0]
|
||||
// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
|
||||
// TLOOP: %[[AB_INIT_SUB:.*]] = subtensor %[[AB_INIT_]][%[[IV0]], 0]
|
||||
|
||||
// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
|
||||
// TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
|
||||
|
||||
// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B_]], %[[C1]] : [[TY]]
|
||||
// TLOOP: %[[DIM_C_1:.*]] = memref.dim %[[C_]], %[[C1]] : [[TY]]
|
||||
|
||||
// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) =
|
||||
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
|
||||
// TLOOP-SAME: step (%[[C64]], %[[C16]])
|
||||
// TLOOP-SAME: ins (%[[AB_SUB_:.*]] = %[[AB_SUB]]: [[TY]],
|
||||
// TLOOP-SAME: %[[C__:.*]] = %[[C_]]: [[TY]])
|
||||
// TLOOP-SAME: outs (%[[ABC_INIT_SUB_:.*]] = %[[ABC_INIT_SUB]]: [[TY]])
|
||||
// TLOOP-SAME: iterators["parallel", "reduction"] {
|
||||
|
||||
// TLOOP: %[[AB_SUB_SUB:.*]] = subtensor %[[AB_SUB_]][0, %[[IV2]]]
|
||||
// TLOOP: %[[C__SUB:.*]] = subtensor %[[C__]][%[[IV2]], %[[IV1]]]
|
||||
// TLOOP: %[[ABS_INIT_SUB_SUB:.*]] = subtensor %[[ABC_INIT_SUB_]][0, %[[IV1]]]
|
||||
|
||||
// TLOOP: %[[ABC_SUB_SUB:.*]] = linalg.matmul
|
||||
// TLOOP-SAME: ins(%[[AB_SUB_SUB]], %[[C__SUB]] : [[TY]], [[TY]])
|
||||
// TLOOP-SAME: outs(%[[ABS_INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
|
||||
|
||||
// TLOOP: %[[RES0:.*]] = subtensor_insert %[[ABC_SUB_SUB]]
|
||||
// TLOOP-SAME: into %[[ABC_INIT_SUB_]][0, %[[IV1]]]
|
||||
// TLOOP: linalg.yield %[[RES0]] : [[TY]]
|
||||
// TLOOP: }
|
||||
// TLOOP: %[[RES1:.*]] = subtensor_insert %[[ABC_SUB_]] into %[[ABC_INIT_]][%[[IV0]], 0]
|
||||
// TLOOP: linalg.yield %[[RES1]] : [[TY]]
|
||||
// TLOOP: }
|
||||
// TLOOP: return %[[ABC]] : [[TY]]
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
|
@ -144,6 +203,48 @@ module {
|
|||
// CHECK: scf.yield %[[YIELD]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// TLOOP-LABEL: func @matmul_plus_matmul
|
||||
// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>,
|
||||
// TLOOP-SAME: %[[AB:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
|
||||
// TLOOP: %[[C32:.*]] = constant 32 : index
|
||||
// TLOOP: %[[C64:.*]] = constant 64 : index
|
||||
// TLOOP: %[[C0:.*]] = constant 0 : index
|
||||
// TLOOP: %[[C1:.*]] = constant 1 : index
|
||||
|
||||
// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
|
||||
// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
|
||||
|
||||
// TLOOP: %[[INIT:.*]] = linalg.init_tensor [%[[DIM_A_0]], %[[DIM_B_1]]]
|
||||
|
||||
// TLOOP: %[[RESULT:.*]] = linalg.tiled_loop (%[[IV0:.*]], %[[IV1:.*]]) =
|
||||
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
|
||||
// TLOOP-SAME: step (%[[C32]], %[[C64]])
|
||||
// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
|
||||
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]],
|
||||
// TLOOP-SAME: %[[AB_:.*]] = %[[AB]]: [[TY]])
|
||||
// TLOOP-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: [[TY]]) {
|
||||
|
||||
// TLOOP: %[[INIT_SUB:.*]] = subtensor %[[INIT_]][%[[IV0]], %[[IV1]]]
|
||||
// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[IV0]], 0]
|
||||
// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[IV1]]]
|
||||
// TLOOP: %[[AB_SUB_INIT:.*]] = subtensor %[[AB_]][%[[IV0]], %[[IV1]]]
|
||||
|
||||
// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
|
||||
// TLOOP-SAME: ins(%[[A_SUB]], %[[B_SUB]] : [[TY]], [[TY]])
|
||||
// TLOOP-SAME: outs(%[[AB_SUB_INIT]] : [[TY]])
|
||||
|
||||
// TLOOP: %[[DOUBLE_AB:.*]] = linalg.generic
|
||||
// TLOOP-SAME: ins(%[[AB_SUB]] : [[TY]]) outs(%[[INIT_SUB]] : [[TY]])
|
||||
|
||||
// TLOOP: %[[RESULT_SUB:.*]] = subtensor_insert
|
||||
// TLOOP-SAME: %[[DOUBLE_AB:.*]] into %[[INIT_]][%[[IV0]], %[[IV1]]]
|
||||
|
||||
// TLOOP: linalg.yield %[[RESULT_SUB]] : [[TY]]
|
||||
// TLOOP: }
|
||||
// TLOOP: return %[[RESULT]] : [[TY]]
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
|
@ -174,3 +275,53 @@ module {
|
|||
// CHECK: scf.yield %[[ST_MM]] : tensor<?x?xf32>
|
||||
// CHECK: %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}}
|
||||
// CHECK: scf.yield %[[MM]] : tensor<?x?xf32>
|
||||
|
||||
|
||||
// TLOOP-LABEL: func @matmul_out_fusion(
|
||||
// TLOOP-SAME: %[[OUT:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// TLOOP-SAME: %[[A:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// TLOOP-SAME: %[[B:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
|
||||
// TLOOP-DAG: %[[C0_F32:.*]] = constant 0.0
|
||||
// TLOOP-DAG: %[[C32:.*]] = constant 32 : index
|
||||
// TLOOP-DAG: %[[C64:.*]] = constant 64 : index
|
||||
// TLOOP-DAG: %[[C16:.*]] = constant 16 : index
|
||||
// TLOOP-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// TLOOP-DAG: %[[C1:.*]] = constant 1 : index
|
||||
|
||||
// TLOOP: %[[DIM_A_0:.*]] = memref.dim %[[A]], %[[C0]] : [[TY:.*]]
|
||||
// TLOOP: %[[DIM_B_1:.*]] = memref.dim %[[B]], %[[C1]] : [[TY]]
|
||||
|
||||
// TLOOP: %[[AB:.*]] = linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) =
|
||||
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_A_0]], %[[DIM_B_1]])
|
||||
// TLOOP-SAME: step (%[[C32]], %[[C64]])
|
||||
// TLOOP-SAME: ins (%[[A_:.*]] = %[[A]]: [[TY]],
|
||||
// TLOOP-SAME: %[[B_:.*]] = %[[B]]: [[TY]])
|
||||
// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
|
||||
|
||||
// TLOOP: %[[DIM_A__1:.*]] = memref.dim %[[A_]], %[[C1]] : [[TY]]
|
||||
// TLOOP: %[[A_SUB:.*]] = subtensor %[[A_]][%[[I]], 0]
|
||||
// TLOOP: %[[B_SUB:.*]] = subtensor %[[B_]][0, %[[J]]]
|
||||
// TLOOP: %[[OUT_SUB:.*]] = subtensor %[[OUT_]][%[[I]], %[[J]]]
|
||||
// TLOOP: %[[INIT_SUB:.*]] = linalg.fill(%[[OUT_SUB]], %[[C0_F32]])
|
||||
|
||||
// TLOOP: %[[AB_SUB:.*]] = linalg.tiled_loop (%[[K:.*]]) = (%[[C0]])
|
||||
// TLOOP-SAME: to (%[[DIM_A__1]]) step (%[[C16]])
|
||||
// TLOOP-SAME: ins (%[[A_SUB_:.*]] = %[[A_SUB]]: [[TY]],
|
||||
// TLOOP-SAME: %[[B_SUB_:.*]] = %[[B_SUB]]: [[TY]])
|
||||
// TLOOP-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: [[TY]])
|
||||
// TLOOP-SAME: iterators["reduction"] {
|
||||
|
||||
// TLOOP: %[[A_SUB_SUB:.*]] = subtensor %[[A_SUB_]][0, %[[K]]]
|
||||
// TLOOP: %[[B_SUB_SUB:.*]] = subtensor %[[B_SUB_]][%[[K]], 0]
|
||||
|
||||
// TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul
|
||||
// TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
|
||||
// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
|
||||
// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]]
|
||||
// TLOOP: }
|
||||
// TLOOP: %[[SUB_RESULT:.*]] = subtensor_insert %[[AB_SUB]]
|
||||
// TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]]
|
||||
// TLOOP: linalg.yield %[[SUB_RESULT]] : [[TY]]
|
||||
// TLOOP: }
|
||||
// TLOOP: return %[[AB]] : [[TY]]
|
||||
|
|
|
@ -278,6 +278,13 @@ void registerTestLinalgTensorFusionTransforms() {
|
|||
"Test Linalg on tensor fusion transformation "
|
||||
"patterns by applying them greedily.");
|
||||
}
|
||||
void registerTestLinalgTiledLoopFusionTransforms() {
|
||||
PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>>
|
||||
testTiledLoopFusionTransformsPass(
|
||||
"test-linalg-tiled-loop-fusion-transform-patterns",
|
||||
"Test Linalg on tensor fusion transformation "
|
||||
"patterns by applying them greedily.");
|
||||
}
|
||||
void registerTestLinalgGreedyFusion() {
|
||||
PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
|
||||
"test-linalg-greedy-fusion",
|
||||
|
|
|
@ -81,6 +81,7 @@ void registerTestLinalgElementwiseFusion();
|
|||
void registerTestPushExpandingReshape();
|
||||
void registerTestLinalgFusionTransforms();
|
||||
void registerTestLinalgTensorFusionTransforms();
|
||||
void registerTestLinalgTiledLoopFusionTransforms();
|
||||
void registerTestLinalgGreedyFusion();
|
||||
void registerTestLinalgHoisting();
|
||||
void registerTestLinalgTileAndFuseSequencePass();
|
||||
|
@ -159,6 +160,7 @@ void registerTestPasses() {
|
|||
test::registerTestPushExpandingReshape();
|
||||
test::registerTestLinalgFusionTransforms();
|
||||
test::registerTestLinalgTensorFusionTransforms();
|
||||
test::registerTestLinalgTiledLoopFusionTransforms();
|
||||
test::registerTestLinalgGreedyFusion();
|
||||
test::registerTestLinalgHoisting();
|
||||
test::registerTestLinalgTileAndFuseSequencePass();
|
||||
|
|
Loading…
Reference in New Issue