[MLIR] Add option to print users of an operation as comment in the printer

This allows printing the users of an operation as proposed in the git issue #53286.
To be able to refer to operations with no result, these operations are assigned an
ID in SSANameState.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D124048
This commit is contained in:
cpillmayer 2022-04-22 18:52:54 +00:00 committed by Mehdi Amini
parent 907ed12d95
commit 3e8560f890
5 changed files with 244 additions and 2 deletions

View File

@ -752,6 +752,9 @@ public:
/// the full module.
OpPrintingFlags &useLocalScope();
/// Print users of values as comments.
OpPrintingFlags &printValueUsers();
/// Return if the given ElementsAttr should be elided.
bool shouldElideElementsAttr(ElementsAttr attr) const;
@ -773,6 +776,9 @@ public:
/// Return if the printer should use local scope when dumping the IR.
bool shouldUseLocalScope() const;
/// Return if the printer should print users of values.
bool shouldPrintValueUsers() const;
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
@ -790,6 +796,9 @@ private:
/// Print operations with numberings local to the current operation.
bool printLocalScope : 1;
/// Print users of values.
bool printValueUsersFlag : 1;
};
//===----------------------------------------------------------------------===//

View File

@ -152,6 +152,11 @@ struct AsmPrinterOptions {
"mlir-print-local-scope", llvm::cl::init(false),
llvm::cl::desc("Print with local scope and inline information (eliding "
"aliases for attributes, types, and locations")};
llvm::cl::opt<bool> printValueUsers{
"mlir-print-value-users", llvm::cl::init(false),
llvm::cl::desc(
"Print users of operation results and block arguments as a comment")};
};
} // namespace
@ -168,7 +173,7 @@ void mlir::registerAsmPrinterCLOptions() {
OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), assumeVerifiedFlag(false),
printLocalScope(false) {
printLocalScope(false), printValueUsersFlag(false) {
// Initialize based upon command line options, if they are available.
if (!clOptions.isConstructed())
return;
@ -179,6 +184,7 @@ OpPrintingFlags::OpPrintingFlags()
printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
printLocalScope = clOptions->printLocalScopeOpt;
printValueUsersFlag = clOptions->printValueUsers;
}
/// Enable the elision of large elements attributes, by printing a '...'
@ -219,6 +225,12 @@ OpPrintingFlags &OpPrintingFlags::useLocalScope() {
return *this;
}
/// Print users of values as comments.
OpPrintingFlags &OpPrintingFlags::printValueUsers() {
printValueUsersFlag = true;
return *this;
}
/// Return if the given ElementsAttr should be elided.
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
return elementsAttrElementLimit.hasValue() &&
@ -254,6 +266,11 @@ bool OpPrintingFlags::shouldAssumeVerified() const {
/// Return if the printer should use local scope when dumping the IR.
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
/// Return if the printer should print users of values.
bool OpPrintingFlags::shouldPrintValueUsers() const {
return printValueUsersFlag;
}
/// Returns true if an ElementsAttr with the given number of elements should be
/// printed with hex.
static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
@ -831,6 +848,9 @@ public:
/// of this value.
void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
/// Print the operation identifier.
void printOperationID(Operation *op, raw_ostream &stream) const;
/// Return the result indices for each of the result groups registered by this
/// operation, or empty if none exist.
ArrayRef<int> getOpResultGroups(Operation *op);
@ -868,6 +888,10 @@ private:
DenseMap<Value, unsigned> valueIDs;
DenseMap<Value, StringRef> valueNames;
/// When printing users of values, an operation without a result might
/// be the user. This map holds ids for such operations.
DenseMap<Operation *, unsigned> operationIDs;
/// This is a map of operations that contain multiple named result groups,
/// i.e. there may be multiple names for the results of the operation. The
/// value of this map are the result numbers that start a result group.
@ -990,6 +1014,15 @@ void SSANameState::printValueID(Value value, bool printResultNo,
stream << '#' << resultNo;
}
void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
auto it = operationIDs.find(op);
if (it == operationIDs.end()) {
stream << "<<UNKOWN OPERATION>>";
} else {
stream << '%' << it->second;
}
}
ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
auto it = opResultGroups.find(op);
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
@ -1121,8 +1154,14 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
unsigned numResults = op.getNumResults();
if (numResults == 0)
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
if (printerFlags.shouldPrintValueUsers()) {
if (operationIDs.try_emplace(&op, nextValueID).second)
++nextValueID;
}
return;
}
Value resultBegin = op.getResult(0);
// If the first result wasn't numbered, give it a default number.
@ -2481,6 +2520,10 @@ public:
void printValueID(Value value, bool printResultNo = true,
raw_ostream *streamOverride = nullptr) const;
/// Print the ID of the given operation.
void printOperationID(Operation *op,
raw_ostream *streamOverride = nullptr) const;
//===--------------------------------------------------------------------===//
// OpAsmPrinter methods
//===--------------------------------------------------------------------===//
@ -2549,6 +2592,19 @@ public:
void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
ValueRange symOperands) override;
/// Print users of this operation or id of this operation if it has no result.
void printUsersComment(Operation *op);
/// Print users of this block arg.
void printUsersComment(BlockArgument arg);
/// Print the users of a value.
void printValueUsers(Value value);
/// Print either the ids of the result values or the id of the operation if
/// the operation has no results.
void printUserIDs(Operation *user, bool prefixComma = false);
private:
// Contains the stack of default dialects to use when printing regions.
// A new dialect is pushed to the stack before parsing regions nested under an
@ -2602,6 +2658,8 @@ void OperationPrinter::print(Operation *op) {
os.indent(currentIndent);
printOperation(op);
printTrailingLocation(op->getLoc());
if (printerFlags.shouldPrintValueUsers())
printUsersComment(op);
}
void OperationPrinter::printOperation(Operation *op) {
@ -2657,6 +2715,80 @@ void OperationPrinter::printOperation(Operation *op) {
printGenericOp(op, /*printOpName=*/true);
}
void OperationPrinter::printUsersComment(Operation *op) {
unsigned numResults = op->getNumResults();
if (!numResults && op->getNumOperands()) {
os << " // id: ";
printOperationID(op);
} else if (numResults && op->use_empty()) {
os << " // unused";
} else if (numResults && !op->use_empty()) {
// Print "user" if the operation has one result used to compute one other
// result, or is used in one operation with no result.
unsigned usedInNResults = 0;
unsigned usedInNOperations = 0;
SmallPtrSet<Operation *, 1> userSet;
for (Operation *user : op->getUsers()) {
if (userSet.insert(user).second) {
++usedInNOperations;
usedInNResults += user->getNumResults();
}
}
// We already know that users is not empty.
bool exactlyOneUniqueUse =
usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
bool shouldPrintBrackets = numResults > 1;
auto printOpResult = [&](OpResult opResult) {
if (shouldPrintBrackets)
os << "(";
printValueUsers(opResult);
if (shouldPrintBrackets)
os << ")";
};
interleaveComma(op->getResults(), printOpResult);
}
}
void OperationPrinter::printUsersComment(BlockArgument arg) {
os << "// ";
printValueID(arg);
if (arg.use_empty()) {
os << " is unused";
} else {
os << " is used by ";
printValueUsers(arg);
}
os << newLine;
}
void OperationPrinter::printValueUsers(Value value) {
if (value.use_empty())
os << "unused";
// One value might be used as the operand of an operation more than once.
// Only print the operations results once in that case.
SmallPtrSet<Operation *, 1> userSet;
for (auto &indexedUser : enumerate(value.getUsers())) {
if (userSet.insert(indexedUser.value()).second)
printUserIDs(indexedUser.value(), indexedUser.index());
}
}
void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
if (prefixComma)
os << ", ";
if (!user->getNumResults()) {
printOperationID(user);
} else {
interleaveComma(user->getResults(),
[this](Value result) { printValueID(result); });
}
}
void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
if (printOpName) {
os << '"';
@ -2745,6 +2877,14 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
}
currentIndent += indentWidth;
if (printerFlags.shouldPrintValueUsers()) {
for (BlockArgument arg : block->getArguments()) {
os.indent(currentIndent);
printUsersComment(arg);
}
}
bool hasTerminator =
!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
auto range = llvm::make_range(
@ -2764,6 +2904,12 @@ void OperationPrinter::printValueID(Value value, bool printResultNo,
streamOverride ? *streamOverride : os);
}
void OperationPrinter::printOperationID(Operation *op,
raw_ostream *streamOverride) const {
state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
: os);
}
void OperationPrinter::printSuccessor(Block *successor) {
printBlockName(successor);
}

View File

@ -0,0 +1,65 @@
// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-value-users -split-input-file %s | FileCheck %s
module {
// CHECK: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
func @foo(%arg0: i32, %arg1: i32, %arg3: i32) -> i32 {
// CHECK-NEXT: // %[[ARG0]] is used by %[[ARG0U1:.+]], %[[ARG0U2:.+]], %[[ARG0U3:.+]]
// CHECK-NEXT: // %[[ARG1]] is used by %[[ARG1U1:.+]], %[[ARG1U2:.+]]
// CHECK-NEXT: // %[[ARG2]] is unused
// CHECK-NEXT: test.noop
// CHECK-NOT: // unused
"test.noop"() : () -> ()
// When no result is produced, an id should be printed.
// CHECK-NEXT: // id: %[[ARG0U3]]
"test.no_result"(%arg0) {} : (i32) -> ()
// Check for unused result.
// CHECK-NEXT: %[[ARG0U2]] =
// CHECK-SAME: // unused
%1 = "test.unused_result"(%arg0, %arg1) {} : (i32, i32) -> i32
// Check that both users are printed.
// CHECK-NEXT: %[[ARG0U1]] =
// CHECK-SAME: // users: %[[A:.+]]#0, %[[A]]#1
%2 = "test.one_result"(%arg0, %arg1) {} : (i32, i32) -> i32
// For multiple results, users should be grouped per result.
// CHECK-NEXT: %[[A]]:2 =
// CHECK-SAME: // users: (%[[B:.+]], %[[C:.+]]), (%[[B]], %[[D:.+]])
%3:2 = "test.many_results"(%2) {} : (i32) -> (i32, i32)
// Two results are produced, but there is only one user.
// CHECK-NEXT: // users:
%7:2 = "test.many_results"() : () -> (i32, i32)
// CHECK-NEXT: %[[C]] =
// Result is used twice in next operation but it produces only one result.
// CHECK-SAME: // user:
%4 = "test.foo"(%3#0) {} : (i32) -> i32
// CHECK-NEXT: %[[D]] =
%5 = "test.foo"(%3#1, %4, %4) {} : (i32, i32, i32) -> i32
// CHECK-NEXT: %[[B]] =
// Result is not used in any other result but in two operations.
// CHECK-SAME: // users:
%6 = "test.foo"(%3#0, %3#1) {} : (i32, i32) -> i32
"test.no_result"(%6) {} : (i32) -> ()
"test.no_result"(%7#0) : (i32) -> ()
return %6: i32
}
}
// -----
module {
// Check with nested operation.
// CHECK: %[[CONSTNAME:.+]] = arith.constant
%0 = arith.constant 42 : i32
%test = "test.outerop"(%0) ({
// CHECK: "test.innerop"(%[[CONSTNAME]]) : (i32) -> () // id: %
"test.innerop"(%0) : (i32) -> ()
// CHECK: (i32) -> i32 // users: %r, %s, %p, %p_0, %q
}): (i32) -> i32
// Check named results.
// CHECK-NEXT: // users: (%u, %v), (unused), (%u, %v, %r, %s)
%p:2, %q = "test.custom_result_name"(%test) {names = ["p", "p", "q"]} : (i32) -> (i32, i32, i32)
// CHECK-NEXT: // users: (unused), (%u, %v)
%r, %s = "test.custom_result_name"(%q#0, %q#0, %test) {names = ["r", "s"]} : (i32, i32, i32) -> (i32, i32)
// CHECK-NEXT: // unused
%u, %v = "test.custom_result_name"(%s, %q#0, %p) {names = ["u", "v"]} : (i32, i32, i32) -> (i32, i32)
}

View File

@ -1168,6 +1168,15 @@ void StringAttrPrettyNameOp::getAsmResultNames(
setNameFn(getResult(i), str.getValue());
}
void CustomResultsNameOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
ArrayAttr value = getNames();
for (size_t i = 0, e = value.size(); i != e; ++i)
if (auto str = value[i].dyn_cast<StringAttr>())
if (!str.getValue().empty())
setNameFn(getResult(i), str.getValue());
}
//===----------------------------------------------------------------------===//
// ResultTypeWithTraitOp
//===----------------------------------------------------------------------===//

View File

@ -732,6 +732,19 @@ def StringAttrPrettyNameOp
let hasCustomAssemblyFormat = 1;
}
// This is used to test encoding of a string attribute into an SSA name of a
// pretty printed value name.
def CustomResultsNameOp
: TEST_Op<"custom_result_name",
[DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let arguments = (ins
Variadic<AnyInteger>:$optional,
StrArrayAttr:$names
);
let results = (outs Variadic<AnyInteger>:$r);
}
// This is used to test the OpAsmOpInterface::getDefaultDialect() feature:
// operations nested in a region under this op will drop the "test." dialect
// prefix.