forked from OSchip/llvm-project
Add op stats pass to mlir-opt.
op-stats pass currently returns the number of occurrences of different operations in a Module. Useful for verifying transformation properties (e.g., 3 ops of specific dialect, 0 of another), but probably not useful outside of that so keeping it local to mlir-opt. This does not consider op attributes when counting. PiperOrigin-RevId: 222259727
This commit is contained in:
parent
d63ab4b47a
commit
d0590caa90
|
@ -1,4 +1,4 @@
|
||||||
//===- CFGFunctionViewGraph.h - View/write graphviz graphs ------*- C++ -*-===//
|
//===- CFGFunctionViewGraph.cpp - View/write graphviz graphs --------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The MLIR Authors.
|
// Copyright 2019 The MLIR Authors.
|
||||||
//
|
//
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
// RUN: mlir-opt -print-op-stats %s -o=/dev/null 2>&1 | FileCheck %s
|
||||||
|
|
||||||
|
cfgfunc @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
|
||||||
|
%0 = addf %arg0, %arg1 : tensor<4xf32>
|
||||||
|
%1 = addf %arg0, %arg1 : tensor<4xf32>
|
||||||
|
%2 = addf %arg0, %arg1 : tensor<4xf32>
|
||||||
|
%3 = addf %arg0, %arg1 : tensor<4xf32>
|
||||||
|
%4 = addf %arg0, %arg1 : tensor<4xf32>
|
||||||
|
%5 = addf %arg0, %arg1 : tensor<4xf32>
|
||||||
|
%10 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%11 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%12 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%13 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%14 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%15 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%16 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%17 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%18 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%19 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%20 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%21 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%22 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%23 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%24 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%25 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%26 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
%30 = "long_op_name"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
|
||||||
|
return %1 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: Operations encountered
|
||||||
|
// CHECK: 'addf' , 6
|
||||||
|
// CHECK: 'long_op_name' , 1
|
||||||
|
// CHECK: 'return' , 1
|
||||||
|
// CHECK: 'xla.add' , 17
|
|
@ -0,0 +1,125 @@
|
||||||
|
//===- OpStats.cpp - Prints stats of operations in module -----------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#include "mlir/IR/CFGFunction.h"
|
||||||
|
#include "mlir/IR/MLFunction.h"
|
||||||
|
#include "mlir/IR/OperationSupport.h"
|
||||||
|
#include "mlir/IR/Statements.h"
|
||||||
|
#include "mlir/IR/StmtVisitor.h"
|
||||||
|
#include "mlir/Pass.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
|
||||||
|
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
|
||||||
|
: FunctionPass(&PrintOpStatsPass::passID), os(os) {}
|
||||||
|
|
||||||
|
// Prints the resultant operation stats post iterating over the module.
|
||||||
|
PassResult runOnModule(Module *m) override;
|
||||||
|
|
||||||
|
// Process CFG function considering the instructions in basic blocks.
|
||||||
|
PassResult runOnCFGFunction(CFGFunction *function) override;
|
||||||
|
|
||||||
|
// Process ML functions and operation statments in ML functions.
|
||||||
|
PassResult runOnMLFunction(MLFunction *function) override;
|
||||||
|
void visitOperationStmt(OperationStmt *stmt);
|
||||||
|
|
||||||
|
// Print summary of op stats.
|
||||||
|
void printSummary();
|
||||||
|
|
||||||
|
static char passID;
|
||||||
|
|
||||||
|
private:
|
||||||
|
llvm::StringMap<int64_t> opCount;
|
||||||
|
|
||||||
|
llvm::raw_ostream &os;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
char PrintOpStatsPass::passID = 0;
|
||||||
|
|
||||||
|
PassResult PrintOpStatsPass::runOnModule(Module *m) {
|
||||||
|
auto result = FunctionPass::runOnModule(m);
|
||||||
|
if (!result)
|
||||||
|
printSummary();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
PassResult PrintOpStatsPass::runOnCFGFunction(CFGFunction *function) {
|
||||||
|
for (const auto &bb : *function)
|
||||||
|
for (const auto &inst : bb)
|
||||||
|
++opCount[inst.getName().getStringRef()];
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrintOpStatsPass::visitOperationStmt(OperationStmt *stmt) {
|
||||||
|
++opCount[stmt->getName().getStringRef()];
|
||||||
|
}
|
||||||
|
|
||||||
|
PassResult PrintOpStatsPass::runOnMLFunction(MLFunction *function) {
|
||||||
|
walk(function);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrintOpStatsPass::printSummary() {
|
||||||
|
os << "Operations encountered:\n";
|
||||||
|
os << "-----------------------\n";
|
||||||
|
std::vector<StringRef> sorted(opCount.keys().begin(), opCount.keys().end());
|
||||||
|
llvm::sort(sorted);
|
||||||
|
|
||||||
|
// Returns the lenght of the dialect prefix of an op.
|
||||||
|
auto dialectLen = [](StringRef opName) -> size_t {
|
||||||
|
auto dialectEnd = opName.find_last_of('.');
|
||||||
|
if (dialectEnd == StringRef::npos)
|
||||||
|
return 0;
|
||||||
|
// Count the periond too.
|
||||||
|
return dialectEnd + 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Left-align the names (aligning on the dialect) and right-align count below.
|
||||||
|
// The alignment is for readability and does not affect CSV/FileCheck parsing.
|
||||||
|
size_t maxLenName = 0;
|
||||||
|
size_t maxLenNamePrefixLen = 0;
|
||||||
|
size_t maxLenDialect = 0;
|
||||||
|
int maxLenCount = 0;
|
||||||
|
for (const auto &key : sorted) {
|
||||||
|
size_t len = key.size();
|
||||||
|
size_t prefix = dialectLen(key);
|
||||||
|
if (len > maxLenName) {
|
||||||
|
maxLenName = len;
|
||||||
|
maxLenNamePrefixLen = prefix;
|
||||||
|
}
|
||||||
|
maxLenDialect = max(maxLenDialect, prefix);
|
||||||
|
// This takes advantage of the fact that opCount[key] > 0.
|
||||||
|
maxLenCount = max(maxLenCount, (int)log10(opCount[key]) + 1);
|
||||||
|
}
|
||||||
|
// Adjust the max name length to account for the dialect.
|
||||||
|
maxLenName += (maxLenDialect - maxLenNamePrefixLen);
|
||||||
|
|
||||||
|
for (const auto &key : sorted) {
|
||||||
|
size_t prefix = maxLenDialect - dialectLen(key);
|
||||||
|
os.indent(2 + prefix) << '\'' << key << '\'';
|
||||||
|
os.indent(maxLenName - key.size() - prefix) << " ,";
|
||||||
|
os.indent(maxLenCount - (int)log10(opCount[key])) << opCount[key] << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<PrintOpStatsPass>
|
||||||
|
pass("print-op-stats", "Print statistics of operations");
|
Loading…
Reference in New Issue