Rewrite OpStats to use llvm formatting utilities.

Example Output:

Operations encountered:
-----------------------
      addf                  , 11
      constant              , 4
      return                , 19
      some_op               , 1
   tf.AvgPool               , 3
   tf.DepthwiseConv2dNative , 3
   tf.FusedBatchNorm        , 2
  tfl.add                   , 7
  tfl.average_pool_2d       , 1
  tfl.leaky_relu            , 1

PiperOrigin-RevId: 229937190
This commit is contained in:
River Riddle 2019-01-18 09:00:34 -08:00 committed by jpienaar
parent c4237ae990
commit a1c0da42ec
2 changed files with 29 additions and 33 deletions

View File

@ -21,6 +21,7 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@ -67,41 +68,36 @@ void PrintOpStatsPass::printSummary() {
std::vector<StringRef> sorted(opCount.keys().begin(), opCount.keys().end());
llvm::sort(sorted);
// Returns the lenght of the dialect prefix of an op.
auto dialectLen = [](StringRef opName) -> size_t {
auto dialectEnd = opName.find_last_of('.');
if (dialectEnd == StringRef::npos)
return 0;
// Count the period too.
return dialectEnd + 1;
// Split an operation name from its dialect prefix.
auto splitOperationName = [](StringRef opName) {
auto splitName = opName.split('.');
return splitName.second.empty() ? std::make_pair("", splitName.first)
: splitName;
};
// Left-align the names (aligning on the dialect) and right-align count below.
// The alignment is for readability and does not affect CSV/FileCheck parsing.
size_t maxLenName = 0;
size_t maxLenNamePrefixLen = 0;
size_t maxLenDialect = 0;
int maxLenCount = 0;
// Compute the largest dialect and operation name.
StringRef dialectName, opName;
size_t maxLenOpName = 0, maxLenDialect = 0;
for (const auto &key : sorted) {
size_t len = key.size();
size_t prefix = dialectLen(key);
if (len > maxLenName) {
maxLenName = len;
maxLenNamePrefixLen = prefix;
}
maxLenDialect = max(maxLenDialect, prefix);
// This takes advantage of the fact that opCount[key] > 0.
maxLenCount = max(maxLenCount, (int)log10(opCount[key]) + 1);
std::tie(dialectName, opName) = splitOperationName(key);
maxLenDialect = std::max(maxLenDialect, dialectName.size());
maxLenOpName = std::max(maxLenOpName, opName.size());
}
// Adjust the max name length to account for the dialect.
maxLenName += (maxLenDialect - maxLenNamePrefixLen);
for (const auto &key : sorted) {
size_t prefix = maxLenDialect - dialectLen(key);
os.indent(2 + prefix) << '\'' << key << '\'';
// Add one to compensate for the period of the dialect.
os.indent(maxLenName + 1 - key.size() - prefix) << " ,";
os.indent(maxLenCount - (int)log10(opCount[key])) << opCount[key] << "\n";
std::tie(dialectName, opName) = splitOperationName(key);
// Left-align the names (aligning on the dialect) and right-align the count
// below. The alignment is for readability and does not affect CSV/FileCheck
// parsing.
if (dialectName.empty())
os.indent(maxLenDialect + 3);
else
os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.';
// Left justify the operation name.
os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key]
<< '\n';
}
}

View File

@ -30,7 +30,7 @@ func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
}
// CHECK-LABEL: Operations encountered
// CHECK: 'addf' , 6
// CHECK: 'long_op_name' , 1
// CHECK: 'return' , 1
// CHECK: 'xla.add' , 17
// CHECK: addf , 6
// CHECK: long_op_name , 1
// CHECK: return , 1
// CHECK: xla.add , 17