forked from OSchip/llvm-project
Rewrite OpStats to use llvm formatting utilities.
Example Output: Operations encountered: ----------------------- addf , 11 constant , 4 return , 19 some_op , 1 tf.AvgPool , 3 tf.DepthwiseConv2dNative , 3 tf.FusedBatchNorm , 2 tfl.add , 7 tfl.average_pool_2d , 1 tfl.leaky_relu , 1 PiperOrigin-RevId: 229937190
This commit is contained in:
parent
c4237ae990
commit
a1c0da42ec
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -67,41 +68,36 @@ void PrintOpStatsPass::printSummary() {
|
|||
std::vector<StringRef> sorted(opCount.keys().begin(), opCount.keys().end());
|
||||
llvm::sort(sorted);
|
||||
|
||||
// Returns the lenght of the dialect prefix of an op.
|
||||
auto dialectLen = [](StringRef opName) -> size_t {
|
||||
auto dialectEnd = opName.find_last_of('.');
|
||||
if (dialectEnd == StringRef::npos)
|
||||
return 0;
|
||||
// Count the period too.
|
||||
return dialectEnd + 1;
|
||||
// Split an operation name from its dialect prefix.
|
||||
auto splitOperationName = [](StringRef opName) {
|
||||
auto splitName = opName.split('.');
|
||||
return splitName.second.empty() ? std::make_pair("", splitName.first)
|
||||
: splitName;
|
||||
};
|
||||
|
||||
// Left-align the names (aligning on the dialect) and right-align count below.
|
||||
// The alignment is for readability and does not affect CSV/FileCheck parsing.
|
||||
size_t maxLenName = 0;
|
||||
size_t maxLenNamePrefixLen = 0;
|
||||
size_t maxLenDialect = 0;
|
||||
int maxLenCount = 0;
|
||||
// Compute the largest dialect and operation name.
|
||||
StringRef dialectName, opName;
|
||||
size_t maxLenOpName = 0, maxLenDialect = 0;
|
||||
for (const auto &key : sorted) {
|
||||
size_t len = key.size();
|
||||
size_t prefix = dialectLen(key);
|
||||
if (len > maxLenName) {
|
||||
maxLenName = len;
|
||||
maxLenNamePrefixLen = prefix;
|
||||
}
|
||||
maxLenDialect = max(maxLenDialect, prefix);
|
||||
// This takes advantage of the fact that opCount[key] > 0.
|
||||
maxLenCount = max(maxLenCount, (int)log10(opCount[key]) + 1);
|
||||
std::tie(dialectName, opName) = splitOperationName(key);
|
||||
maxLenDialect = std::max(maxLenDialect, dialectName.size());
|
||||
maxLenOpName = std::max(maxLenOpName, opName.size());
|
||||
}
|
||||
// Adjust the max name length to account for the dialect.
|
||||
maxLenName += (maxLenDialect - maxLenNamePrefixLen);
|
||||
|
||||
for (const auto &key : sorted) {
|
||||
size_t prefix = maxLenDialect - dialectLen(key);
|
||||
os.indent(2 + prefix) << '\'' << key << '\'';
|
||||
// Add one to compensate for the period of the dialect.
|
||||
os.indent(maxLenName + 1 - key.size() - prefix) << " ,";
|
||||
os.indent(maxLenCount - (int)log10(opCount[key])) << opCount[key] << "\n";
|
||||
std::tie(dialectName, opName) = splitOperationName(key);
|
||||
|
||||
// Left-align the names (aligning on the dialect) and right-align the count
|
||||
// below. The alignment is for readability and does not affect CSV/FileCheck
|
||||
// parsing.
|
||||
if (dialectName.empty())
|
||||
os.indent(maxLenDialect + 3);
|
||||
else
|
||||
os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.';
|
||||
|
||||
// Left justify the operation name.
|
||||
os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key]
|
||||
<< '\n';
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: Operations encountered
|
||||
// CHECK: 'addf' , 6
|
||||
// CHECK: 'long_op_name' , 1
|
||||
// CHECK: 'return' , 1
|
||||
// CHECK: 'xla.add' , 17
|
||||
// CHECK: addf , 6
|
||||
// CHECK: long_op_name , 1
|
||||
// CHECK: return , 1
|
||||
// CHECK: xla.add , 17
|
||||
|
|
Loading…
Reference in New Issue