[mlir] Add fallback support for parsing/printing unknown external resources

This is necessary/useful for building generic tooling that can roundtrip external
resources without needing to explicitly handle them. For example, this allows
for viewing the resources encoded within a bytecode file without having to
explicitly know how to process them (e.g. making it easier to interact with a
reproducer encoded in bytecode).

Differential Revision: https://reviews.llvm.org/D133460
This commit is contained in:
River Riddle 2022-09-13 01:07:29 -07:00
parent 6ab2bcffe4
commit 34300ee369
8 changed files with 210 additions and 40 deletions

View File

@ -27,6 +27,10 @@ public:
/// of the bytecode when reading. It has no functional effect on the bytecode
/// serialization.
BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING);
/// `map` is a fallback resource map, which when provided will attach resource
/// printers for the fallback resources within the map.
BytecodeWriterConfig(FallbackAsmResourceMap &map,
StringRef producer = "MLIR" LLVM_VERSION_STRING);
~BytecodeWriterConfig();
/// An internal implementation class that contains the state of the
@ -53,6 +57,13 @@ public:
name, std::forward<CallableT>(printFn)));
}
/// Attach resource printers to the AsmState for the fallback resources
/// in the given map.
void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) {
for (auto &printer : map.getPrinters())
attachResourcePrinter(std::move(printer));
}
private:
/// A pointer to allocated storage for the impl state.
std::unique_ptr<Impl> impl;

View File

@ -16,6 +16,8 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
#include <memory>
@ -401,6 +403,50 @@ private:
std::string name;
};
/// A fallback map containing external resources not explicitly handled by
/// another parser/printer.
class FallbackAsmResourceMap {
public:
/// This class represents an opaque resource.
struct OpaqueAsmResource {
OpaqueAsmResource(StringRef key,
std::variant<AsmResourceBlob, bool, std::string> value)
: key(key.str()), value(std::move(value)) {}
/// The key identifying the resource.
std::string key;
/// An opaque value for the resource, whose variant values align 1-1 with
/// the kinds defined in AsmResourceEntryKind.
std::variant<AsmResourceBlob, bool, std::string> value;
};
/// Return a parser than can be used for parsing entries for the given
/// identifier key.
AsmResourceParser &getParserFor(StringRef key);
/// Build a set of resource printers to print the resources within this map.
std::vector<std::unique_ptr<AsmResourcePrinter>> getPrinters();
private:
struct ResourceCollection : public AsmResourceParser {
ResourceCollection(StringRef name) : AsmResourceParser(name) {}
/// Parse a resource into this collection.
LogicalResult parseResource(AsmParsedResourceEntry &entry) final;
/// Build the resources held by this collection.
void buildResources(Operation *op, AsmResourceBuilder &builder) const;
/// The set of resources parsed into this collection.
SmallVector<OpaqueAsmResource> resources;
};
/// The set of opaque resources.
llvm::MapVector<std::string, std::unique_ptr<ResourceCollection>,
llvm::StringMap<unsigned>>
keyToResources;
};
//===----------------------------------------------------------------------===//
// ParserConfig
//===----------------------------------------------------------------------===//
@ -409,7 +455,12 @@ private:
/// contains all of the necessary state to parse a MLIR source file.
class ParserConfig {
public:
ParserConfig(MLIRContext *context) : context(context) {
/// Construct a parser configuration with the given context.
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context,
FallbackAsmResourceMap *fallbackResourceMap = nullptr)
: context(context), fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
@ -420,7 +471,11 @@ public:
/// parser with `name` is registered.
AsmResourceParser *getResourceParser(StringRef name) const {
auto it = resourceParsers.find(name);
return it == resourceParsers.end() ? nullptr : it->second.get();
if (it != resourceParsers.end())
return it->second.get();
if (fallbackResourceMap)
return &fallbackResourceMap->getParserFor(name);
return nullptr;
}
/// Attach the given resource parser.
@ -444,6 +499,7 @@ public:
private:
MLIRContext *context;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
};
//===----------------------------------------------------------------------===//
@ -466,13 +522,17 @@ public:
using LocationMap = DenseMap<Operation *, std::pair<unsigned, unsigned>>;
/// Initialize the asm state at the level of the given operation. A location
/// map may optionally be provided to be populated when printing.
/// map may optionally be provided to be populated when printing. `map` is an
/// optional fallback resource map, which when provided will attach resource
/// printers for the fallback resources within the map.
AsmState(Operation *op,
const OpPrintingFlags &printerFlags = OpPrintingFlags(),
LocationMap *locationMap = nullptr);
LocationMap *locationMap = nullptr,
FallbackAsmResourceMap *map = nullptr);
AsmState(MLIRContext *ctx,
const OpPrintingFlags &printerFlags = OpPrintingFlags(),
LocationMap *locationMap = nullptr);
LocationMap *locationMap = nullptr,
FallbackAsmResourceMap *map = nullptr);
~AsmState();
/// Get the printer flags.
@ -498,6 +558,13 @@ public:
name, std::forward<CallableT>(printFn)));
}
/// Attach resource printers to the AsmState for the fallback resources
/// in the given map.
void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) {
for (auto &printer : map.getPrinters())
attachResourcePrinter(std::move(printer));
}
/// Returns a map of dialect resources that were referenced when using this
/// state to print IR.
DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &

View File

@ -39,6 +39,11 @@ struct BytecodeWriterConfig::Impl {
BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer)
: impl(std::make_unique<Impl>(producer)) {}
BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
StringRef producer)
: BytecodeWriterConfig(producer) {
attachFallbackResourcePrinter(map);
}
BytecodeWriterConfig::~BytecodeWriterConfig() = default;
void BytecodeWriterConfig::attachResourcePrinter(

View File

@ -1283,6 +1283,69 @@ StringRef mlir::toString(AsmResourceEntryKind kind) {
llvm_unreachable("unknown AsmResourceEntryKind");
}
AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
if (!collection)
collection = std::make_unique<ResourceCollection>(key);
return *collection;
}
std::vector<std::unique_ptr<AsmResourcePrinter>>
FallbackAsmResourceMap::getPrinters() {
std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
for (auto &it : keyToResources) {
ResourceCollection *collection = it.second.get();
auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
return collection->buildResources(op, builder);
};
printers.emplace_back(
AsmResourcePrinter::fromCallable(collection->getName(), buildValues));
}
return printers;
}
LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
AsmParsedResourceEntry &entry) {
switch (entry.getKind()) {
case AsmResourceEntryKind::Blob: {
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
if (failed(blob))
return failure();
resources.emplace_back(entry.getKey(), std::move(*blob));
return success();
}
case AsmResourceEntryKind::Bool: {
FailureOr<bool> value = entry.parseAsBool();
if (failed(value))
return failure();
resources.emplace_back(entry.getKey(), *value);
break;
}
case AsmResourceEntryKind::String: {
FailureOr<std::string> str = entry.parseAsString();
if (failed(str))
return failure();
resources.emplace_back(entry.getKey(), std::move(*str));
break;
}
}
return success();
}
void FallbackAsmResourceMap::ResourceCollection::buildResources(
Operation *op, AsmResourceBuilder &builder) const {
for (const auto &entry : resources) {
if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value))
builder.buildBlob(entry.key, *value);
else if (const auto *value = std::get_if<bool>(&entry.value))
builder.buildBool(entry.key, *value);
else if (const auto *value = std::get_if<std::string>(&entry.value))
builder.buildString(entry.key, *value);
else
llvm_unreachable("unknown AsmResourceEntryKind");
}
}
//===----------------------------------------------------------------------===//
// AsmState
//===----------------------------------------------------------------------===//
@ -1401,12 +1464,18 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
}
AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
LocationMap *locationMap)
LocationMap *locationMap, FallbackAsmResourceMap *map)
: impl(std::make_unique<AsmStateImpl>(
op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
if (map)
attachFallbackResourcePrinter(*map);
}
AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
LocationMap *locationMap)
: impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {}
LocationMap *locationMap, FallbackAsmResourceMap *map)
: impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {
if (map)
attachFallbackResourcePrinter(*map);
}
AsmState::~AsmState() = default;
const OpPrintingFlags &AsmState::getPrinterFlags() const {
@ -3308,14 +3377,6 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) {
}
void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
// If this is a top level operation, we also print aliases.
if (!getParent() && !printerFlags.shouldUseLocalScope()) {
AsmState state(this, printerFlags);
state.getImpl().initializeAliases(this);
print(os, state);
return;
}
// Find the operation to number from based upon the provided flags.
Operation *op = this;
bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
@ -3337,10 +3398,12 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
}
void Operation::print(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, state.getImpl());
if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
state.getImpl().initializeAliases(this);
printer.printTopLevelOperation(this);
else
} else {
printer.print(this);
}
}
void Operation::dump() {

View File

@ -319,6 +319,10 @@ struct MLIRDocument {
/// The container for the IR parsed from the input file.
Block parsedIR;
/// A collection of external resources, which we want to propagate up to the
/// user.
FallbackAsmResourceMap fallbackResourceMap;
/// The source manager containing the contents of the input file.
llvm::SourceMgr sourceMgr;
};
@ -338,11 +342,13 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
return;
}
ParserConfig config(&context, &fallbackResourceMap);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, &context, &asmState))) {
if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
// If parsing failed, clear out any of the current state.
parsedIR.clear();
asmState = AsmParserState();
fallbackResourceMap = FallbackAsmResourceMap();
return;
}
}
@ -875,9 +881,11 @@ MLIRDocument::convertToBytecode() {
lsp::MLIRConvertBytecodeResult result;
{
BytecodeWriterConfig writerConfig(fallbackResourceMap);
std::string rawBytecodeBuffer;
llvm::raw_string_ostream os(rawBytecodeBuffer);
writeBytecodeToFile(&parsedIR.front(), os);
writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
result.output = llvm::encodeBase64(rawBytecodeBuffer);
}
return result;
@ -1284,11 +1292,15 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
&tempContext,
[&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
// Handling for external resources, which we want to propagate up to the user.
FallbackAsmResourceMap fallbackResourceMap;
// Setup the parser config.
ParserConfig parserConfig(&tempContext, &fallbackResourceMap);
// Try to parse the given source file.
// TODO: This won't preserve external resources or the producer, we should try
// to fix this.
Block parsedBlock;
if (failed(parseSourceFile(uri.file(), &parsedBlock, &tempContext))) {
if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
return llvm::make_error<lsp::LSPError>(
"failed to parse bytecode source file: " + errorMsg,
lsp::ErrorCode::RequestFailed);
@ -1310,8 +1322,11 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
OwningOpRef<Operation *> topOp = &parsedBlock.front();
(*topOp)->remove();
AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
/*locationMap=*/nullptr, &fallbackResourceMap);
llvm::raw_string_ostream os(result.output);
(*topOp)->print(os, OpPrintingFlags().enableDebugInfo().assumeVerified());
(*topOp)->print(os, state);
}
return std::move(result);
}

View File

@ -66,8 +66,11 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
pm.enableTiming(timing);
// Prepare the parser config, and attach any useful/necessary resource
// handlers.
ParserConfig config(context);
// handlers. Unhandled external resources are treated as passthrough, i.e.
// they are not processed and will be emitted directly to the output
// untouched.
FallbackAsmResourceMap fallbackResourceMap;
ParserConfig config(context, &fallbackResourceMap);
attachPassReproducerAsmResource(config, pm, wasThreadingEnabled);
// Parse the input file and reset the context threading state.
@ -89,9 +92,12 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
// Print the output.
TimingScope outputTiming = timing.nest("Output");
if (emitBytecode) {
writeBytecodeToFile(module->getOperation(), os);
BytecodeWriterConfig writerConfig(fallbackResourceMap);
writeBytecodeToFile(module->getOperation(), os, writerConfig);
} else {
module->print(os);
AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr,
&fallbackResourceMap);
module->print(os, asmState);
os << '\n';
}
return success();

View File

@ -5,6 +5,13 @@
// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000"
// CHECK-NEXT: }
// Check that we properly preserve unknown external resources.
// CHECK: external: {
// CHECK-NEXT: blob: "0x08000000010000000000000002000000000000000300000000000000"
// CHECK-NEXT: bool: true
// CHECK-NEXT: string: "string"
// CHECK-NEXT: }
module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>} {}
{-#
@ -13,5 +20,12 @@ module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>}
blob1: "0x08000000010000000000000002000000000000000300000000000000",
blob2: "0x08000000040000000000000005000000000000000600000000000000"
}
},
external_resources: {
external: {
blob: "0x08000000010000000000000002000000000000000300000000000000",
bool: true,
string: "string"
}
}
#-}

View File

@ -129,14 +129,3 @@
entry "value"
}
#-}
// -----
// expected-warning@+3 {{ignoring unknown external resources for 'foobar'}}
{-#
external_resources: {
foobar: {
entry: "foo"
}
}
#-}