diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index 924ffd6f248b..e58013052fa4 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -37,6 +37,7 @@ class DiagnosticEngine; class Identifier; struct LogicalResult; class MLIRContext; +class Operation; class Type; namespace detail { @@ -64,6 +65,7 @@ public: Attribute, Double, Integer, + Operation, String, Type, Unsigned, @@ -90,6 +92,12 @@ public: return static_cast(opaqueVal); } + /// Returns this argument as an operation. + Operation &getAsOperation() const { + assert(getKind() == DiagnosticArgumentKind::Operation); + return *reinterpret_cast(opaqueVal); + } + /// Returns this argument as a string. StringRef getAsString() const { assert(getKind() == DiagnosticArgumentKind::String); @@ -132,6 +140,14 @@ private: sizeof(T) <= sizeof(uint64_t)>::type * = 0) : kind(DiagnosticArgumentKind::Unsigned), opaqueVal(uint64_t(val)) {} + // Construct from an operation reference. + explicit DiagnosticArgument(Operation &val) : DiagnosticArgument(&val) {} + explicit DiagnosticArgument(Operation *val) + : kind(DiagnosticArgumentKind::Operation), + opaqueVal(reinterpret_cast(val)) { + assert(val && "expected valid operation"); + } + // Construct from a string reference. explicit DiagnosticArgument(StringRef val) : kind(DiagnosticArgumentKind::String), stringVal(val) {} diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 5283cf086da3..194f427f4f42 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Identifier.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" @@ -70,6 +71,9 @@ void DiagnosticArgument::print(raw_ostream &os) const { case DiagnosticArgumentKind::Integer: os << getAsInteger(); break; + case DiagnosticArgumentKind::Operation: + os << getAsOperation(); + break; case DiagnosticArgumentKind::String: os << getAsString(); break;