diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h index cd1d8c71a80d..fb4329e4d66f 100644 --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -27,6 +27,10 @@ public: /// of the bytecode when reading. It has no functional effect on the bytecode /// serialization. BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING); + /// `map` is a fallback resource map, which when provided will attach resource + /// printers for the fallback resources within the map. + BytecodeWriterConfig(FallbackAsmResourceMap &map, + StringRef producer = "MLIR" LLVM_VERSION_STRING); ~BytecodeWriterConfig(); /// An internal implementation class that contains the state of the @@ -53,6 +57,13 @@ public: name, std::forward(printFn))); } + /// Attach resource printers to the AsmState for the fallback resources + /// in the given map. + void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) { + for (auto &printer : map.getPrinters()) + attachResourcePrinter(std::move(printer)); + } + private: /// A pointer to allocated storage for the impl state. std::unique_ptr impl; diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h index 3ff4cfdfad2f..87f3a37b637d 100644 --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -16,6 +16,8 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/StringMap.h" #include @@ -401,6 +403,50 @@ private: std::string name; }; +/// A fallback map containing external resources not explicitly handled by +/// another parser/printer. +class FallbackAsmResourceMap { +public: + /// This class represents an opaque resource. + struct OpaqueAsmResource { + OpaqueAsmResource(StringRef key, + std::variant value) + : key(key.str()), value(std::move(value)) {} + + /// The key identifying the resource. + std::string key; + /// An opaque value for the resource, whose variant values align 1-1 with + /// the kinds defined in AsmResourceEntryKind. + std::variant value; + }; + + /// Return a parser than can be used for parsing entries for the given + /// identifier key. + AsmResourceParser &getParserFor(StringRef key); + + /// Build a set of resource printers to print the resources within this map. + std::vector> getPrinters(); + +private: + struct ResourceCollection : public AsmResourceParser { + ResourceCollection(StringRef name) : AsmResourceParser(name) {} + + /// Parse a resource into this collection. + LogicalResult parseResource(AsmParsedResourceEntry &entry) final; + + /// Build the resources held by this collection. + void buildResources(Operation *op, AsmResourceBuilder &builder) const; + + /// The set of resources parsed into this collection. + SmallVector resources; + }; + + /// The set of opaque resources. + llvm::MapVector, + llvm::StringMap> + keyToResources; +}; + //===----------------------------------------------------------------------===// // ParserConfig //===----------------------------------------------------------------------===// @@ -409,7 +455,12 @@ private: /// contains all of the necessary state to parse a MLIR source file. class ParserConfig { public: - ParserConfig(MLIRContext *context) : context(context) { + /// Construct a parser configuration with the given context. + /// `fallbackResourceMap` is an optional fallback handler that can be used to + /// parse external resources not explicitly handled by another parser. + ParserConfig(MLIRContext *context, + FallbackAsmResourceMap *fallbackResourceMap = nullptr) + : context(context), fallbackResourceMap(fallbackResourceMap) { assert(context && "expected valid MLIR context"); } @@ -420,7 +471,11 @@ public: /// parser with `name` is registered. AsmResourceParser *getResourceParser(StringRef name) const { auto it = resourceParsers.find(name); - return it == resourceParsers.end() ? nullptr : it->second.get(); + if (it != resourceParsers.end()) + return it->second.get(); + if (fallbackResourceMap) + return &fallbackResourceMap->getParserFor(name); + return nullptr; } /// Attach the given resource parser. @@ -444,6 +499,7 @@ public: private: MLIRContext *context; DenseMap> resourceParsers; + FallbackAsmResourceMap *fallbackResourceMap; }; //===----------------------------------------------------------------------===// @@ -466,13 +522,17 @@ public: using LocationMap = DenseMap>; /// Initialize the asm state at the level of the given operation. A location - /// map may optionally be provided to be populated when printing. + /// map may optionally be provided to be populated when printing. `map` is an + /// optional fallback resource map, which when provided will attach resource + /// printers for the fallback resources within the map. AsmState(Operation *op, const OpPrintingFlags &printerFlags = OpPrintingFlags(), - LocationMap *locationMap = nullptr); + LocationMap *locationMap = nullptr, + FallbackAsmResourceMap *map = nullptr); AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags = OpPrintingFlags(), - LocationMap *locationMap = nullptr); + LocationMap *locationMap = nullptr, + FallbackAsmResourceMap *map = nullptr); ~AsmState(); /// Get the printer flags. @@ -498,6 +558,13 @@ public: name, std::forward(printFn))); } + /// Attach resource printers to the AsmState for the fallback resources + /// in the given map. + void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) { + for (auto &printer : map.getPrinters()) + attachResourcePrinter(std::move(printer)); + } + /// Returns a map of dialect resources that were referenced when using this /// state to print IR. DenseMap> & diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index ff53cec15d77..7bcc1a841c24 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -39,6 +39,11 @@ struct BytecodeWriterConfig::Impl { BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer) : impl(std::make_unique(producer)) {} +BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map, + StringRef producer) + : BytecodeWriterConfig(producer) { + attachFallbackResourcePrinter(map); +} BytecodeWriterConfig::~BytecodeWriterConfig() = default; void BytecodeWriterConfig::attachResourcePrinter( diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 395bd03bb579..310e5efbd8f8 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1283,6 +1283,69 @@ StringRef mlir::toString(AsmResourceEntryKind kind) { llvm_unreachable("unknown AsmResourceEntryKind"); } +AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) { + std::unique_ptr &collection = keyToResources[key.str()]; + if (!collection) + collection = std::make_unique(key); + return *collection; +} + +std::vector> +FallbackAsmResourceMap::getPrinters() { + std::vector> printers; + for (auto &it : keyToResources) { + ResourceCollection *collection = it.second.get(); + auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) { + return collection->buildResources(op, builder); + }; + printers.emplace_back( + AsmResourcePrinter::fromCallable(collection->getName(), buildValues)); + } + return printers; +} + +LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource( + AsmParsedResourceEntry &entry) { + switch (entry.getKind()) { + case AsmResourceEntryKind::Blob: { + FailureOr blob = entry.parseAsBlob(); + if (failed(blob)) + return failure(); + resources.emplace_back(entry.getKey(), std::move(*blob)); + return success(); + } + case AsmResourceEntryKind::Bool: { + FailureOr value = entry.parseAsBool(); + if (failed(value)) + return failure(); + resources.emplace_back(entry.getKey(), *value); + break; + } + case AsmResourceEntryKind::String: { + FailureOr str = entry.parseAsString(); + if (failed(str)) + return failure(); + resources.emplace_back(entry.getKey(), std::move(*str)); + break; + } + } + return success(); +} + +void FallbackAsmResourceMap::ResourceCollection::buildResources( + Operation *op, AsmResourceBuilder &builder) const { + for (const auto &entry : resources) { + if (const auto *value = std::get_if(&entry.value)) + builder.buildBlob(entry.key, *value); + else if (const auto *value = std::get_if(&entry.value)) + builder.buildBool(entry.key, *value); + else if (const auto *value = std::get_if(&entry.value)) + builder.buildString(entry.key, *value); + else + llvm_unreachable("unknown AsmResourceEntryKind"); + } +} + //===----------------------------------------------------------------------===// // AsmState //===----------------------------------------------------------------------===// @@ -1401,12 +1464,18 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, } AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, - LocationMap *locationMap) + LocationMap *locationMap, FallbackAsmResourceMap *map) : impl(std::make_unique( - op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {} + op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) { + if (map) + attachFallbackResourcePrinter(*map); +} AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, - LocationMap *locationMap) - : impl(std::make_unique(ctx, printerFlags, locationMap)) {} + LocationMap *locationMap, FallbackAsmResourceMap *map) + : impl(std::make_unique(ctx, printerFlags, locationMap)) { + if (map) + attachFallbackResourcePrinter(*map); +} AsmState::~AsmState() = default; const OpPrintingFlags &AsmState::getPrinterFlags() const { @@ -3308,14 +3377,6 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) { } void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { - // If this is a top level operation, we also print aliases. - if (!getParent() && !printerFlags.shouldUseLocalScope()) { - AsmState state(this, printerFlags); - state.getImpl().initializeAliases(this); - print(os, state); - return; - } - // Find the operation to number from based upon the provided flags. Operation *op = this; bool shouldUseLocalScope = printerFlags.shouldUseLocalScope(); @@ -3337,10 +3398,12 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { } void Operation::print(raw_ostream &os, AsmState &state) { OperationPrinter printer(os, state.getImpl()); - if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) + if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { + state.getImpl().initializeAliases(this); printer.printTopLevelOperation(this); - else + } else { printer.print(this); + } } void Operation::dump() { diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp index 22f01bf719d5..93cc4dc33538 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -319,6 +319,10 @@ struct MLIRDocument { /// The container for the IR parsed from the input file. Block parsedIR; + /// A collection of external resources, which we want to propagate up to the + /// user. + FallbackAsmResourceMap fallbackResourceMap; + /// The source manager containing the contents of the input file. llvm::SourceMgr sourceMgr; }; @@ -338,11 +342,13 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, return; } + ParserConfig config(&context, &fallbackResourceMap); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, &context, &asmState))) { + if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { // If parsing failed, clear out any of the current state. parsedIR.clear(); asmState = AsmParserState(); + fallbackResourceMap = FallbackAsmResourceMap(); return; } } @@ -875,9 +881,11 @@ MLIRDocument::convertToBytecode() { lsp::MLIRConvertBytecodeResult result; { + BytecodeWriterConfig writerConfig(fallbackResourceMap); + std::string rawBytecodeBuffer; llvm::raw_string_ostream os(rawBytecodeBuffer); - writeBytecodeToFile(&parsedIR.front(), os); + writeBytecodeToFile(&parsedIR.front(), os, writerConfig); result.output = llvm::encodeBase64(rawBytecodeBuffer); } return result; @@ -1284,11 +1292,15 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { &tempContext, [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; }); + // Handling for external resources, which we want to propagate up to the user. + FallbackAsmResourceMap fallbackResourceMap; + + // Setup the parser config. + ParserConfig parserConfig(&tempContext, &fallbackResourceMap); + // Try to parse the given source file. - // TODO: This won't preserve external resources or the producer, we should try - // to fix this. Block parsedBlock; - if (failed(parseSourceFile(uri.file(), &parsedBlock, &tempContext))) { + if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) { return llvm::make_error( "failed to parse bytecode source file: " + errorMsg, lsp::ErrorCode::RequestFailed); @@ -1310,8 +1322,11 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { OwningOpRef topOp = &parsedBlock.front(); (*topOp)->remove(); + AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(), + /*locationMap=*/nullptr, &fallbackResourceMap); + llvm::raw_string_ostream os(result.output); - (*topOp)->print(os, OpPrintingFlags().enableDebugInfo().assumeVerified()); + (*topOp)->print(os, state); } return std::move(result); } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 519da97af0c7..ccd095dcf0bb 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -66,8 +66,11 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, pm.enableTiming(timing); // Prepare the parser config, and attach any useful/necessary resource - // handlers. - ParserConfig config(context); + // handlers. Unhandled external resources are treated as passthrough, i.e. + // they are not processed and will be emitted directly to the output + // untouched. + FallbackAsmResourceMap fallbackResourceMap; + ParserConfig config(context, &fallbackResourceMap); attachPassReproducerAsmResource(config, pm, wasThreadingEnabled); // Parse the input file and reset the context threading state. @@ -89,9 +92,12 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, // Print the output. TimingScope outputTiming = timing.nest("Output"); if (emitBytecode) { - writeBytecodeToFile(module->getOperation(), os); + BytecodeWriterConfig writerConfig(fallbackResourceMap); + writeBytecodeToFile(module->getOperation(), os, writerConfig); } else { - module->print(os); + AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr, + &fallbackResourceMap); + module->print(os, asmState); os << '\n'; } return success(); diff --git a/mlir/test/IR/file-metadata-resources.mlir b/mlir/test/IR/file-metadata-resources.mlir index 57562555c964..a531c7ce9756 100644 --- a/mlir/test/IR/file-metadata-resources.mlir +++ b/mlir/test/IR/file-metadata-resources.mlir @@ -5,6 +5,13 @@ // CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000" // CHECK-NEXT: } +// Check that we properly preserve unknown external resources. +// CHECK: external: { +// CHECK-NEXT: blob: "0x08000000010000000000000002000000000000000300000000000000" +// CHECK-NEXT: bool: true +// CHECK-NEXT: string: "string" +// CHECK-NEXT: } + module attributes { test.blob_ref = #test.e1di64_elements : tensor<*xi1>} {} {-# @@ -13,5 +20,12 @@ module attributes { test.blob_ref = #test.e1di64_elements : tensor<*xi1>} blob1: "0x08000000010000000000000002000000000000000300000000000000", blob2: "0x08000000040000000000000005000000000000000600000000000000" } + }, + external_resources: { + external: { + blob: "0x08000000010000000000000002000000000000000300000000000000", + bool: true, + string: "string" + } } #-} diff --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir index 352cf19f11be..553bd43c6aee 100644 --- a/mlir/test/IR/invalid-file-metadata.mlir +++ b/mlir/test/IR/invalid-file-metadata.mlir @@ -129,14 +129,3 @@ entry "value" } #-} - -// ----- - -// expected-warning@+3 {{ignoring unknown external resources for 'foobar'}} -{-# - external_resources: { - foobar: { - entry: "foo" - } - } -#-}