forked from OSchip/llvm-project
[mlir][Linalg] Add a hoistViewAllocOps helper function
This revision adds a helper function to hoist alloc/dealloc pairs and alloca op out of immediately enclosing scf::ForOp if both conditions are true: 1. all operands are defined outside the loop. 2. all uses are ViewLikeOp or DeallocOp. This is now considered Linalg-specific and will be generalized on a per-need basis. Differential Revision: https://reviews.llvm.org/D81152
This commit is contained in:
parent
a95c08db12
commit
3463d9835b
|
@ -0,0 +1,27 @@
|
|||
//===- Hoisting.h - Linalg hoisting transformations -------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_
|
||||
#define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_
|
||||
|
||||
namespace mlir {
|
||||
class FuncOp;
|
||||
|
||||
namespace linalg {
|
||||
|
||||
/// Hoist alloc/dealloc pairs and alloca op out of immediately enclosing
|
||||
/// scf::ForOp if both conditions are true:
|
||||
/// 1. all operands are defined outside the loop.
|
||||
/// 2. all uses are ViewLikeOp or DeallocOp.
|
||||
// TODO: generalize on a per-need basis.
|
||||
void hoistViewAllocOps(FuncOp func);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_
|
|
@ -1,6 +1,7 @@
|
|||
add_mlir_dialect_library(MLIRLinalgTransforms
|
||||
DropUnitDims.cpp
|
||||
Fusion.cpp
|
||||
Hoisting.cpp
|
||||
Interchange.cpp
|
||||
Loops.cpp
|
||||
Promotion.cpp
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements functions concerned with hoisting invariant operations
|
||||
// in the context of Linalg transformations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "linalg-hoisting"
|
||||
|
||||
#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using llvm::dbgs;
|
||||
|
||||
void mlir::linalg::hoistViewAllocOps(FuncOp func) {
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
func.walk([&changed](Operation *op) {
|
||||
if (!isa<AllocOp>(op) && !isa<AllocaOp>(op) && !isa<DeallocOp>(op))
|
||||
return;
|
||||
|
||||
LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n");
|
||||
auto loop = dyn_cast<scf::ForOp>(op->getParentOp());
|
||||
LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n");
|
||||
|
||||
// Only hoist out of immediately enclosing scf::ForOp.
|
||||
if (!loop)
|
||||
return;
|
||||
|
||||
// If any operand is defined inside the loop don't hoist.
|
||||
if (llvm::any_of(op->getOperands(), [&](Value v) {
|
||||
return !loop.isDefinedOutsideOfLoop(v);
|
||||
}))
|
||||
return;
|
||||
|
||||
LLVM_DEBUG(DBGS() << "All operands defined outside \n");
|
||||
|
||||
// If alloc has other uses than ViewLikeOp and DeallocOp don't hoist.
|
||||
Value v;
|
||||
if (op->getNumResults() > 0) {
|
||||
assert(op->getNumResults() == 1 && "Unexpected multi-result alloc");
|
||||
v = op->getResult(0);
|
||||
}
|
||||
if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) {
|
||||
return isa<ViewLikeOpInterface>(operand.getOwner()) ||
|
||||
isa<DeallocOp>(operand.getOwner());
|
||||
})) {
|
||||
LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n");
|
||||
return;
|
||||
}
|
||||
|
||||
// Move AllocOp before the loop.
|
||||
if (isa<AllocOp>(op) || isa<AllocaOp>(op))
|
||||
loop.moveOutOfLoop({op});
|
||||
else // Move DeallocOp outside of the loop.
|
||||
op->moveAfter(loop);
|
||||
changed = true;
|
||||
});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @hoist(
|
||||
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
|
||||
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
|
||||
func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
|
||||
// CHECK-DAG: alloca(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK-DAG: %[[A0:.*]] = alloc(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||
// CHECK: alloca(%[[I]]) : memref<?xi8>
|
||||
// CHECK: %[[A1:.*]] = alloc(%[[I]]) : memref<?xi8>
|
||||
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||
// CHECK-DAG: alloca(%[[J]]) : memref<?xi8>
|
||||
// CHECK-DAG: %[[A2:.*]] = alloc(%[[J]]) : memref<?xi8>
|
||||
// CHECK: scf.for %[[K:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||
scf.for %i = %lb to %ub step %step {
|
||||
scf.for %j = %lb to %ub step %step {
|
||||
scf.for %k = %lb to %ub step %step {
|
||||
// Hoist allocs / deallocs outermost, keep view/subview below k.
|
||||
%sa0 = alloca(%val) : memref<? x i8>
|
||||
%a0 = alloc(%val) : memref<? x i8>
|
||||
// CHECK: std.view %[[A0]][%[[LB]]][] : memref<?xi8> to memref<16xf32>
|
||||
// CHECK: subview %[[A0]][0] [42] [1] : memref<?xi8> to memref<42xi8>
|
||||
%v0 = view %a0[%lb][] : memref<? x i8> to memref<16 x f32>
|
||||
%sv0 = subview %a0[0][42][1] : memref<? x i8> to memref<42 x i8>
|
||||
dealloc %a0 : memref<? x i8>
|
||||
|
||||
// Hoist below i.
|
||||
%sa1 = alloca(%i) : memref<? x i8>
|
||||
%a1 = alloc(%i) : memref<? x i8>
|
||||
dealloc %a1 : memref<? x i8>
|
||||
|
||||
// Hoist below j.
|
||||
%sa2 = alloca(%j) : memref<? x i8>
|
||||
%a2 = alloc(%j) : memref<? x i8>
|
||||
dealloc %a2 : memref<? x i8>
|
||||
|
||||
// Don't hoist since k innermost.
|
||||
// CHECK: alloca(%[[K]]) : memref<?xi8>
|
||||
// CHECK: %[[A3:.*]] = alloc(%[[K]]) : memref<?xi8>
|
||||
// CHECK: dealloc %[[A3]] : memref<?xi8>
|
||||
%sa3 = alloca(%k) : memref<? x i8>
|
||||
%a3 = alloc(%k) : memref<? x i8>
|
||||
dealloc %a3 : memref<? x i8>
|
||||
|
||||
// No hoisting due to control flow.
|
||||
// CHECK: scf.if %[[CMP]] {
|
||||
// CHECK: alloca(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK: %[[A4:.*]] = alloc(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK: dealloc %[[A4]] : memref<?xi8>
|
||||
scf.if %cmp {
|
||||
%sa4 = alloca(%val) : memref<? x i8>
|
||||
%a4 = alloc(%val) : memref<? x i8>
|
||||
dealloc %a4 : memref<? x i8>
|
||||
}
|
||||
|
||||
// No hoisting due to load/store.
|
||||
// CHECK: %[[SA5:.*]] = alloca(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK: %[[A5:.*]] = alloc(%[[VAL]]) : memref<?xi8>
|
||||
// CHECK: load %[[A5]][%[[LB]]] : memref<?xi8>
|
||||
// CHECK: store %{{.*}}, %[[SA5]][%[[LB]]] : memref<?xi8>
|
||||
// CHECK: dealloc %[[A5]] : memref<?xi8>
|
||||
%sa5 = alloca(%val) : memref<? x i8>
|
||||
%a5 = alloc(%val) : memref<? x i8>
|
||||
%v5 = load %a5[%lb] : memref<? x i8>
|
||||
store %v5, %sa5[%lb] : memref<? x i8>
|
||||
dealloc %a5 : memref<? x i8>
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
// CHECK: }
|
||||
// CHECK: dealloc %[[A2]] : memref<?xi8>
|
||||
// CHECK: }
|
||||
// CHECK: dealloc %[[A1]] : memref<?xi8>
|
||||
// CHECK: }
|
||||
// CHECK: dealloc %[[A0]] : memref<?xi8>
|
||||
return
|
||||
}
|
|
@ -11,6 +11,7 @@ add_mlir_library(MLIRTestTransforms
|
|||
TestGpuMemoryPromotion.cpp
|
||||
TestGpuParallelLoopMapping.cpp
|
||||
TestInlining.cpp
|
||||
TestLinalgHoisting.cpp
|
||||
TestLinalgTransforms.cpp
|
||||
TestLiveness.cpp
|
||||
TestLoopMapping.cpp
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
//===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements logic for testing Linalg hoisting functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace {
|
||||
struct TestLinalgHoisting
|
||||
: public PassWrapper<TestLinalgHoisting, FunctionPass> {
|
||||
TestLinalgHoisting() = default;
|
||||
TestLinalgHoisting(const TestLinalgHoisting &pass) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
Option<bool> testHoistViewAllocs{
|
||||
*this, "test-hoist-view-allocs",
|
||||
llvm::cl::desc("Test hoisting alloc used by view"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void TestLinalgHoisting::runOnFunction() {
|
||||
if (testHoistViewAllocs) {
|
||||
hoistViewAllocOps(getFunction());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerTestLinalgHoisting() {
|
||||
PassRegistration<TestLinalgHoisting> testTestLinalgHoistingPass(
|
||||
"test-linalg-hoisting", "Test Linalg hoisting functions.");
|
||||
}
|
||||
} // namespace mlir
|
|
@ -50,6 +50,7 @@ void registerTestConvertGPUKernelToHsacoPass();
|
|||
void registerTestDominancePass();
|
||||
void registerTestFunc();
|
||||
void registerTestGpuMemoryPromotionPass();
|
||||
void registerTestLinalgHoisting();
|
||||
void registerTestLinalgTransforms();
|
||||
void registerTestLivenessPass();
|
||||
void registerTestLoopFusion();
|
||||
|
@ -121,6 +122,7 @@ void registerTestPasses() {
|
|||
registerTestDominancePass();
|
||||
registerTestFunc();
|
||||
registerTestGpuMemoryPromotionPass();
|
||||
registerTestLinalgHoisting();
|
||||
registerTestLinalgTransforms();
|
||||
registerTestLivenessPass();
|
||||
registerTestLoopFusion();
|
||||
|
|
Loading…
Reference in New Issue