forked from OSchip/llvm-project
[mlir][Analysis] Allow Slice Analysis to work with linalg::LinalgOp
Differential Revision: https://reviews.llvm.org/D87307
This commit is contained in:
parent
bc0a35f3b7
commit
0a391c6079
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -84,7 +85,8 @@ static void getBackwardSliceImpl(Operation *op,
|
|||
if (!op)
|
||||
return;
|
||||
|
||||
assert((op->getNumRegions() == 0 || isa<AffineForOp, scf::ForOp>(op)) &&
|
||||
assert((op->getNumRegions() == 0 ||
|
||||
isa<AffineForOp, scf::ForOp, linalg::LinalgOp>(op)) &&
|
||||
"unexpected generic op with regions");
|
||||
|
||||
// Evaluate whether we should keep this def.
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s
|
||||
|
||||
func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%a = alloc(%arg0, %arg2) : memref<?x?xf32>
|
||||
%b = alloc(%arg2, %arg1) : memref<?x?xf32>
|
||||
%c = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
%d = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
linalg.matmul %a, %b, %c : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
|
||||
linalg.matmul %a, %b, %d : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
|
||||
dealloc %c : memref<?x?xf32>
|
||||
dealloc %b : memref<?x?xf32>
|
||||
dealloc %a : memref<?x?xf32>
|
||||
dealloc %d : memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
|
||||
// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
|
||||
// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
|
||||
// CHECK: return
|
||||
|
||||
// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
|
||||
// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
|
||||
// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
|
||||
// CHECK: return
|
|
@ -6,6 +6,7 @@ add_mlir_library(MLIRTestIR
|
|||
TestPrintDefUse.cpp
|
||||
TestPrintNesting.cpp
|
||||
TestSideEffects.cpp
|
||||
TestSlicing.cpp
|
||||
TestSymbolUses.cpp
|
||||
TestTypes.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
//===- TestSlicing.cpp - Testing slice functionality ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a simple testing pass for slicing.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Create a function with the same signature as the parent function of `op`
|
||||
/// with name being the function name and a `suffix`.
|
||||
static LogicalResult createBackwardSliceFunction(Operation *op,
|
||||
StringRef suffix) {
|
||||
FuncOp parentFuncOp = op->getParentOfType<FuncOp>();
|
||||
OpBuilder builder(parentFuncOp);
|
||||
Location loc = op->getLoc();
|
||||
std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
|
||||
FuncOp clonedFuncOp =
|
||||
builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType());
|
||||
BlockAndValueMapping mapper;
|
||||
builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
|
||||
for (auto arg : enumerate(parentFuncOp.getArguments()))
|
||||
mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
|
||||
llvm::SetVector<Operation *> slice;
|
||||
getBackwardSlice(op, &slice);
|
||||
for (Operation *slicedOp : slice)
|
||||
builder.clone(*slicedOp, mapper);
|
||||
builder.create<ReturnOp>(loc);
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pass to test slice generated from slice analysis.
|
||||
struct SliceAnalysisTestPass
|
||||
: public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
SliceAnalysisTestPass() = default;
|
||||
SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void SliceAnalysisTestPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
auto funcOps = module.getOps<FuncOp>();
|
||||
unsigned opNum = 0;
|
||||
for (auto funcOp : funcOps) {
|
||||
// TODO: For now this is just looking for Linalg ops. It can be generalized
|
||||
// to look for other ops using flags.
|
||||
funcOp.walk([&](Operation *op) {
|
||||
if (!isa<linalg::LinalgOp>(op))
|
||||
return WalkResult::advance();
|
||||
std::string append =
|
||||
std::string("__backward_slice__") + std::to_string(opNum);
|
||||
createBackwardSliceFunction(op, append);
|
||||
opNum++;
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerSliceAnalysisTestPass() {
|
||||
PassRegistration<SliceAnalysisTestPass> pass(
|
||||
"slice-analysis-test", "Test Slice analysis functionality.");
|
||||
}
|
||||
} // namespace mlir
|
|
@ -38,6 +38,7 @@ void registerPatternsTestPass();
|
|||
void registerPrintOpAvailabilityPass();
|
||||
void registerSideEffectTestPasses();
|
||||
void registerSimpleParametricTilingPass();
|
||||
void registerSliceAnalysisTestPass();
|
||||
void registerSymbolTestPasses();
|
||||
void registerTestAffineDataCopyPass();
|
||||
void registerTestAffineLoopUnswitchingPass();
|
||||
|
@ -88,6 +89,7 @@ void registerTestPasses() {
|
|||
registerPrintOpAvailabilityPass();
|
||||
registerSideEffectTestPasses();
|
||||
registerSimpleParametricTilingPass();
|
||||
registerSliceAnalysisTestPass();
|
||||
registerSymbolTestPasses();
|
||||
registerTestAffineDataCopyPass();
|
||||
registerTestAllReduceLoweringPass();
|
||||
|
|
Loading…
Reference in New Issue