[mlir][Analysis] Allow Slice Analysis to work with linalg::LinalgOp

Differential Revision: https://reviews.llvm.org/D87307
This commit is contained in:
MaheshRavishankar 2020-09-10 16:47:29 -07:00
parent bc0a35f3b7
commit 0a391c6079
5 changed files with 120 additions and 1 deletions

View File

@ -12,6 +12,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
@ -84,7 +85,8 @@ static void getBackwardSliceImpl(Operation *op,
if (!op)
return;
assert((op->getNumRegions() == 0 || isa<AffineForOp, scf::ForOp>(op)) &&
assert((op->getNumRegions() == 0 ||
isa<AffineForOp, scf::ForOp, linalg::LinalgOp>(op)) &&
"unexpected generic op with regions");
// Evaluate whether we should keep this def.

33
mlir/test/IR/slice.mlir Normal file
View File

@ -0,0 +1,33 @@
// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s
func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) {
%a = alloc(%arg0, %arg2) : memref<?x?xf32>
%b = alloc(%arg2, %arg1) : memref<?x?xf32>
%c = alloc(%arg0, %arg1) : memref<?x?xf32>
%d = alloc(%arg0, %arg1) : memref<?x?xf32>
linalg.matmul %a, %b, %c : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
linalg.matmul %a, %b, %d : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
dealloc %c : memref<?x?xf32>
dealloc %b : memref<?x?xf32>
dealloc %a : memref<?x?xf32>
dealloc %d : memref<?x?xf32>
return
}
// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
// CHECK: return
// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref<?x?xf32>
// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref<?x?xf32>
// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref<?x?xf32>
// CHECK: return

View File

@ -6,6 +6,7 @@ add_mlir_library(MLIRTestIR
TestPrintDefUse.cpp
TestPrintNesting.cpp
TestSideEffects.cpp
TestSlicing.cpp
TestSymbolUses.cpp
TestTypes.cpp

View File

@ -0,0 +1,81 @@
//===- TestSlicing.cpp - Testing slice functionality ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a simple testing pass for slicing.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
/// Create a function with the same signature as the parent function of `op`
/// with name being the function name and a `suffix`.
static LogicalResult createBackwardSliceFunction(Operation *op,
StringRef suffix) {
FuncOp parentFuncOp = op->getParentOfType<FuncOp>();
OpBuilder builder(parentFuncOp);
Location loc = op->getLoc();
std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
FuncOp clonedFuncOp =
builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType());
BlockAndValueMapping mapper;
builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
for (auto arg : enumerate(parentFuncOp.getArguments()))
mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice);
for (Operation *slicedOp : slice)
builder.clone(*slicedOp, mapper);
builder.create<ReturnOp>(loc);
return success();
}
namespace {
/// Pass to test slice generated from slice analysis.
struct SliceAnalysisTestPass
: public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
SliceAnalysisTestPass() = default;
SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
};
} // namespace
void SliceAnalysisTestPass::runOnOperation() {
ModuleOp module = getOperation();
auto funcOps = module.getOps<FuncOp>();
unsigned opNum = 0;
for (auto funcOp : funcOps) {
// TODO: For now this is just looking for Linalg ops. It can be generalized
// to look for other ops using flags.
funcOp.walk([&](Operation *op) {
if (!isa<linalg::LinalgOp>(op))
return WalkResult::advance();
std::string append =
std::string("__backward_slice__") + std::to_string(opNum);
createBackwardSliceFunction(op, append);
opNum++;
return WalkResult::advance();
});
}
}
namespace mlir {
void registerSliceAnalysisTestPass() {
PassRegistration<SliceAnalysisTestPass> pass(
"slice-analysis-test", "Test Slice analysis functionality.");
}
} // namespace mlir

View File

@ -38,6 +38,7 @@ void registerPatternsTestPass();
void registerPrintOpAvailabilityPass();
void registerSideEffectTestPasses();
void registerSimpleParametricTilingPass();
void registerSliceAnalysisTestPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAffineLoopUnswitchingPass();
@ -88,6 +89,7 @@ void registerTestPasses() {
registerPrintOpAvailabilityPass();
registerSideEffectTestPasses();
registerSimpleParametricTilingPass();
registerSliceAnalysisTestPass();
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();