forked from OSchip/llvm-project
1428 lines
56 KiB
C++
1428 lines
56 KiB
C++
//===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// SPIRVSerializationGen generates common utility functions for SPIR-V
|
|
// serialization.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/TableGen/Attribute.h"
|
|
#include "mlir/TableGen/CodeGenHelpers.h"
|
|
#include "mlir/TableGen/Format.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "mlir/TableGen/Operator.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringMap.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/TableGen/Error.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
#include "llvm/TableGen/TableGenBackend.h"
|
|
|
|
#include <list>
|
|
|
|
using llvm::ArrayRef;
|
|
using llvm::formatv;
|
|
using llvm::raw_ostream;
|
|
using llvm::raw_string_ostream;
|
|
using llvm::Record;
|
|
using llvm::RecordKeeper;
|
|
using llvm::SmallVector;
|
|
using llvm::SMLoc;
|
|
using llvm::StringMap;
|
|
using llvm::StringRef;
|
|
using llvm::Twine;
|
|
using mlir::tblgen::Attribute;
|
|
using mlir::tblgen::EnumAttr;
|
|
using mlir::tblgen::EnumAttrCase;
|
|
using mlir::tblgen::NamedAttribute;
|
|
using mlir::tblgen::NamedTypeConstraint;
|
|
using mlir::tblgen::NamespaceEmitter;
|
|
using mlir::tblgen::Operator;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Availability Wrapper Class
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
// Wrapper class with helper methods for accessing availability defined in
|
|
// TableGen.
|
|
class Availability {
|
|
public:
|
|
explicit Availability(const Record *def);
|
|
|
|
// Returns the name of the direct TableGen class for this availability
|
|
// instance.
|
|
StringRef getClass() const;
|
|
|
|
// Returns the generated C++ interface's class namespace.
|
|
StringRef getInterfaceClassNamespace() const;
|
|
|
|
// Returns the generated C++ interface's class name.
|
|
StringRef getInterfaceClassName() const;
|
|
|
|
// Returns the generated C++ interface's description.
|
|
StringRef getInterfaceDescription() const;
|
|
|
|
// Returns the name of the query function insided the generated C++ interface.
|
|
StringRef getQueryFnName() const;
|
|
|
|
// Returns the return type of the query function insided the generated C++
|
|
// interface.
|
|
StringRef getQueryFnRetType() const;
|
|
|
|
// Returns the code for merging availability requirements.
|
|
StringRef getMergeActionCode() const;
|
|
|
|
// Returns the initializer expression for initializing the final availability
|
|
// requirements.
|
|
StringRef getMergeInitializer() const;
|
|
|
|
// Returns the C++ type for an availability instance.
|
|
StringRef getMergeInstanceType() const;
|
|
|
|
// Returns the C++ statements for preparing availability instance.
|
|
StringRef getMergeInstancePreparation() const;
|
|
|
|
// Returns the concrete availability instance carried in this case.
|
|
StringRef getMergeInstance() const;
|
|
|
|
// Returns the underlying LLVM TableGen Record.
|
|
const llvm::Record *getDef() const { return def; }
|
|
|
|
private:
|
|
// The TableGen definition of this availability.
|
|
const llvm::Record *def;
|
|
};
|
|
} // namespace
|
|
|
|
Availability::Availability(const llvm::Record *def) : def(def) {
|
|
assert(def->isSubClassOf("Availability") &&
|
|
"must be subclass of TableGen 'Availability' class");
|
|
}
|
|
|
|
StringRef Availability::getClass() const {
|
|
SmallVector<Record *, 1> parentClass;
|
|
def->getDirectSuperClasses(parentClass);
|
|
if (parentClass.size() != 1) {
|
|
PrintFatalError(def->getLoc(),
|
|
"expected to only have one direct superclass");
|
|
}
|
|
return parentClass.front()->getName();
|
|
}
|
|
|
|
StringRef Availability::getInterfaceClassNamespace() const {
|
|
return def->getValueAsString("cppNamespace");
|
|
}
|
|
|
|
StringRef Availability::getInterfaceClassName() const {
|
|
return def->getValueAsString("interfaceName");
|
|
}
|
|
|
|
StringRef Availability::getInterfaceDescription() const {
|
|
return def->getValueAsString("interfaceDescription");
|
|
}
|
|
|
|
StringRef Availability::getQueryFnRetType() const {
|
|
return def->getValueAsString("queryFnRetType");
|
|
}
|
|
|
|
StringRef Availability::getQueryFnName() const {
|
|
return def->getValueAsString("queryFnName");
|
|
}
|
|
|
|
StringRef Availability::getMergeActionCode() const {
|
|
return def->getValueAsString("mergeAction");
|
|
}
|
|
|
|
StringRef Availability::getMergeInitializer() const {
|
|
return def->getValueAsString("initializer");
|
|
}
|
|
|
|
StringRef Availability::getMergeInstanceType() const {
|
|
return def->getValueAsString("instanceType");
|
|
}
|
|
|
|
StringRef Availability::getMergeInstancePreparation() const {
|
|
return def->getValueAsString("instancePreparation");
|
|
}
|
|
|
|
StringRef Availability::getMergeInstance() const {
|
|
return def->getValueAsString("instance");
|
|
}
|
|
|
|
// Returns the availability spec of the given `def`.
|
|
std::vector<Availability> getAvailabilities(const Record &def) {
|
|
std::vector<Availability> availabilities;
|
|
|
|
if (def.getValue("availability")) {
|
|
std::vector<Record *> availDefs = def.getValueAsListOfDefs("availability");
|
|
availabilities.reserve(availDefs.size());
|
|
for (const Record *avail : availDefs)
|
|
availabilities.emplace_back(avail);
|
|
}
|
|
|
|
return availabilities;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Availability Interface Definitions AutoGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void emitInterfaceDef(const Availability &availability,
|
|
raw_ostream &os) {
|
|
|
|
os << availability.getQueryFnRetType() << " ";
|
|
|
|
StringRef cppNamespace = availability.getInterfaceClassNamespace();
|
|
cppNamespace.consume_front("::");
|
|
if (!cppNamespace.empty())
|
|
os << cppNamespace << "::";
|
|
|
|
StringRef methodName = availability.getQueryFnName();
|
|
os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
|
|
<< " return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
|
|
<< "}\n";
|
|
}
|
|
|
|
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Availability Interface Definitions", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
|
|
SmallVector<const Record *, 1> handledClasses;
|
|
for (const Record *def : defs) {
|
|
SmallVector<Record *, 1> parent;
|
|
def->getDirectSuperClasses(parent);
|
|
if (parent.size() != 1) {
|
|
PrintFatalError(def->getLoc(),
|
|
"expected to only have one direct superclass");
|
|
}
|
|
if (llvm::is_contained(handledClasses, parent.front()))
|
|
continue;
|
|
|
|
Availability availability(def);
|
|
emitInterfaceDef(availability, os);
|
|
handledClasses.push_back(parent.front());
|
|
}
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Availability Interface Declarations AutoGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
|
|
os << " class Concept {\n"
|
|
<< " public:\n"
|
|
<< " virtual ~Concept() = default;\n"
|
|
<< " virtual " << availability.getQueryFnRetType() << " "
|
|
<< availability.getQueryFnName()
|
|
<< "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
|
|
<< " };\n";
|
|
}
|
|
|
|
static void emitModelDecl(const Availability &availability, raw_ostream &os) {
|
|
for (const char *modelClass : {"Model", "FallbackModel"}) {
|
|
os << " template<typename ConcreteOp>\n";
|
|
os << " class " << modelClass << " : public Concept {\n"
|
|
<< " public:\n"
|
|
<< " " << availability.getQueryFnRetType() << " "
|
|
<< availability.getQueryFnName()
|
|
<< "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
|
|
<< " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
|
|
<< " (void)op;\n"
|
|
// Forward to the method on the concrete operation type.
|
|
<< " return op." << availability.getQueryFnName() << "();\n"
|
|
<< " }\n"
|
|
<< " };\n";
|
|
}
|
|
os << " template<typename ConcreteModel, typename ConcreteOp>\n";
|
|
os << " class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
|
|
}
|
|
|
|
static void emitInterfaceDecl(const Availability &availability,
|
|
raw_ostream &os) {
|
|
StringRef interfaceName = availability.getInterfaceClassName();
|
|
std::string interfaceTraitsName =
|
|
std::string(formatv("{0}Traits", interfaceName));
|
|
|
|
StringRef cppNamespace = availability.getInterfaceClassNamespace();
|
|
NamespaceEmitter nsEmitter(os, cppNamespace);
|
|
|
|
// Emit the traits struct containing the concept and model declarations.
|
|
os << "namespace detail {\n"
|
|
<< "struct " << interfaceTraitsName << " {\n";
|
|
emitConceptDecl(availability, os);
|
|
os << '\n';
|
|
emitModelDecl(availability, os);
|
|
os << "};\n} // namespace detail\n\n";
|
|
|
|
// Emit the main interface class declaration.
|
|
os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
|
|
os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
|
|
"public:\n"
|
|
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
|
|
interfaceName, interfaceName, interfaceTraitsName);
|
|
|
|
// Emit query function declaration.
|
|
os << " " << availability.getQueryFnRetType() << " "
|
|
<< availability.getQueryFnName() << "();\n";
|
|
os << "};\n\n";
|
|
}
|
|
|
|
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("Availability Interface Declarations", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
|
|
SmallVector<const Record *, 4> handledClasses;
|
|
for (const Record *def : defs) {
|
|
SmallVector<Record *, 1> parent;
|
|
def->getDirectSuperClasses(parent);
|
|
if (parent.size() != 1) {
|
|
PrintFatalError(def->getLoc(),
|
|
"expected to only have one direct superclass");
|
|
}
|
|
if (llvm::is_contained(handledClasses, parent.front()))
|
|
continue;
|
|
|
|
Availability avail(def);
|
|
emitInterfaceDecl(avail, os);
|
|
handledClasses.push_back(parent.front());
|
|
}
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Availability Interface Hook Registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Registers the operation interface generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genInterfaceDecls("gen-avail-interface-decls",
|
|
"Generate availability interface declarations",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitInterfaceDecls(records, os);
|
|
});
|
|
|
|
// Registers the operation interface generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genInterfaceDefs("gen-avail-interface-defs",
|
|
"Generate op interface definitions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitInterfaceDefs(records, os);
|
|
});
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Enum Availability Query AutoGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
|
|
raw_ostream &os) {
|
|
EnumAttr enumAttr(enumDef);
|
|
StringRef enumName = enumAttr.getEnumClassName();
|
|
std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
|
|
|
|
// Mapping from availability class name to (enumerant, availability
|
|
// specification) pairs.
|
|
llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
|
|
classCaseMap;
|
|
|
|
// Place all availability specifications to their corresponding
|
|
// availability classes.
|
|
for (const EnumAttrCase &enumerant : enumerants)
|
|
for (const Availability &avail : getAvailabilities(enumerant.getDef()))
|
|
classCaseMap[avail.getClass()].push_back({enumerant, avail});
|
|
|
|
for (const auto &classCasePair : classCaseMap) {
|
|
Availability avail = classCasePair.getValue().front().second;
|
|
|
|
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n",
|
|
avail.getMergeInstanceType(), avail.getQueryFnName(),
|
|
enumName);
|
|
|
|
os << " switch (value) {\n";
|
|
for (const auto &caseSpecPair : classCasePair.getValue()) {
|
|
EnumAttrCase enumerant = caseSpecPair.first;
|
|
Availability avail = caseSpecPair.second;
|
|
os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
|
|
enumerant.getSymbol(), avail.getMergeInstancePreparation(),
|
|
avail.getMergeInstanceType(), avail.getMergeInstance());
|
|
}
|
|
// Only emit default if uncovered cases.
|
|
if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
|
|
os << " default: break;\n";
|
|
os << " }\n"
|
|
<< " return llvm::None;\n"
|
|
<< "}\n";
|
|
}
|
|
}
|
|
|
|
static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
|
|
raw_ostream &os) {
|
|
EnumAttr enumAttr(enumDef);
|
|
StringRef enumName = enumAttr.getEnumClassName();
|
|
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
|
|
std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
|
|
|
|
// Mapping from availability class name to (enumerant, availability
|
|
// specification) pairs.
|
|
llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
|
|
classCaseMap;
|
|
|
|
// Place all availability specifications to their corresponding
|
|
// availability classes.
|
|
for (const EnumAttrCase &enumerant : enumerants)
|
|
for (const Availability &avail : getAvailabilities(enumerant.getDef()))
|
|
classCaseMap[avail.getClass()].push_back({enumerant, avail});
|
|
|
|
for (const auto &classCasePair : classCaseMap) {
|
|
Availability avail = classCasePair.getValue().front().second;
|
|
|
|
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n",
|
|
avail.getMergeInstanceType(), avail.getQueryFnName(),
|
|
enumName);
|
|
|
|
os << formatv(
|
|
" assert(::llvm::countPopulation(static_cast<{0}>(value)) <= 1"
|
|
" && \"cannot have more than one bit set\");\n",
|
|
underlyingType);
|
|
|
|
os << " switch (value) {\n";
|
|
for (const auto &caseSpecPair : classCasePair.getValue()) {
|
|
EnumAttrCase enumerant = caseSpecPair.first;
|
|
Availability avail = caseSpecPair.second;
|
|
os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
|
|
enumerant.getSymbol(), avail.getMergeInstancePreparation(),
|
|
avail.getMergeInstanceType(), avail.getMergeInstance());
|
|
}
|
|
os << " default: break;\n";
|
|
os << " }\n"
|
|
<< " return llvm::None;\n"
|
|
<< "}\n";
|
|
}
|
|
}
|
|
|
|
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
|
|
EnumAttr enumAttr(enumDef);
|
|
StringRef enumName = enumAttr.getEnumClassName();
|
|
StringRef cppNamespace = enumAttr.getCppNamespace();
|
|
auto enumerants = enumAttr.getAllCases();
|
|
|
|
llvm::SmallVector<StringRef, 2> namespaces;
|
|
llvm::SplitString(cppNamespace, namespaces, "::");
|
|
|
|
for (auto ns : namespaces)
|
|
os << "namespace " << ns << " {\n";
|
|
|
|
llvm::StringSet<> handledClasses;
|
|
|
|
// Place all availability specifications to their corresponding
|
|
// availability classes.
|
|
for (const EnumAttrCase &enumerant : enumerants)
|
|
for (const Availability &avail : getAvailabilities(enumerant.getDef())) {
|
|
StringRef className = avail.getClass();
|
|
if (handledClasses.count(className))
|
|
continue;
|
|
os << formatv("llvm::Optional<{0}> {1}({2} value);\n",
|
|
avail.getMergeInstanceType(), avail.getQueryFnName(),
|
|
enumName);
|
|
handledClasses.insert(className);
|
|
}
|
|
|
|
for (auto ns : llvm::reverse(namespaces))
|
|
os << "} // namespace " << ns << "\n";
|
|
}
|
|
|
|
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
|
|
for (const auto *def : defs)
|
|
emitEnumDecl(*def, os);
|
|
|
|
return false;
|
|
}
|
|
|
|
static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
|
|
EnumAttr enumAttr(enumDef);
|
|
StringRef cppNamespace = enumAttr.getCppNamespace();
|
|
|
|
llvm::SmallVector<StringRef, 2> namespaces;
|
|
llvm::SplitString(cppNamespace, namespaces, "::");
|
|
|
|
for (auto ns : namespaces)
|
|
os << "namespace " << ns << " {\n";
|
|
|
|
if (enumAttr.isBitEnum()) {
|
|
emitAvailabilityQueryForBitEnum(enumDef, os);
|
|
} else {
|
|
emitAvailabilityQueryForIntEnum(enumDef, os);
|
|
}
|
|
|
|
for (auto ns : llvm::reverse(namespaces))
|
|
os << "} // namespace " << ns << "\n";
|
|
os << "\n";
|
|
}
|
|
|
|
static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
|
|
for (const auto *def : defs)
|
|
emitEnumDef(*def, os);
|
|
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Enum Availability Query Hook Registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Registers the enum utility generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genEnumDecls("gen-spirv-enum-avail-decls",
|
|
"Generate SPIR-V enum availability declarations",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitEnumDecls(records, os);
|
|
});
|
|
|
|
// Registers the enum utility generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genEnumDefs("gen-spirv-enum-avail-defs",
|
|
"Generate SPIR-V enum availability definitions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitEnumDefs(records, os);
|
|
});
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Serialization AutoGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Generates code to serialize attributes of a SPV_Op `op` into `os`. The
|
|
/// generates code extracts the attribute with name `attrName` from
|
|
/// `operandList` of `op`.
|
|
static void emitAttributeSerialization(const Attribute &attr,
|
|
ArrayRef<SMLoc> loc, StringRef tabs,
|
|
StringRef opVar, StringRef operandList,
|
|
StringRef attrName, raw_ostream &os) {
|
|
os << tabs
|
|
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
|
|
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
|
|
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
|
|
os << tabs
|
|
<< formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
|
|
"attr.cast<IntegerAttr>()));\n",
|
|
operandList, opVar);
|
|
} else if (attr.getAttrDefName() == "I32ArrayAttr") {
|
|
// Serialize all the elements of the array
|
|
os << tabs << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
|
|
os << tabs
|
|
<< formatv(" {0}.push_back(static_cast<uint32_t>("
|
|
"attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
|
|
operandList);
|
|
os << tabs << " }\n";
|
|
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
|
|
os << tabs
|
|
<< formatv(" {0}.push_back(static_cast<uint32_t>("
|
|
"attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
|
|
operandList);
|
|
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
|
|
os << tabs
|
|
<< formatv(" {0}.push_back(static_cast<uint32_t>("
|
|
"getTypeID(attr.cast<TypeAttr>().getValue())));\n",
|
|
operandList);
|
|
} else {
|
|
PrintFatalError(
|
|
loc,
|
|
llvm::Twine(
|
|
"unhandled attribute type in SPIR-V serialization generation : '") +
|
|
attr.getAttrDefName() + llvm::Twine("'"));
|
|
}
|
|
os << tabs << "}\n";
|
|
}
|
|
|
|
/// Generates code to serialize the operands of a SPV_Op `op` into `os`. The
|
|
/// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
|
|
/// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
|
|
/// updated as well to include the serialized attributes.
|
|
static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
|
|
StringRef tabs, StringRef opVar,
|
|
StringRef operands, StringRef elidedAttrs,
|
|
raw_ostream &os) {
|
|
using mlir::tblgen::Argument;
|
|
|
|
// SPIR-V ops can mix operands and attributes in the definition. These
|
|
// operands and attributes are serialized in the exact order of the definition
|
|
// to match SPIR-V binary format requirements. It can cause excessive
|
|
// generated code bloat because we are emitting code to handle each
|
|
// operand/attribute separately. So here we probe first to check whether all
|
|
// the operands are ahead of attributes. Then we can serialize all operands
|
|
// together.
|
|
|
|
// Whether all operands are ahead of all attributes in the op's spec.
|
|
bool areOperandsAheadOfAttrs = true;
|
|
// Find the first attribute.
|
|
const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) {
|
|
return arg.is<NamedAttribute *>();
|
|
});
|
|
// Check whether all following arguments are attributes.
|
|
for (const Argument *ie = op.arg_end(); it != ie; ++it) {
|
|
if (!it->is<NamedAttribute *>()) {
|
|
areOperandsAheadOfAttrs = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Serialize all operands together.
|
|
if (areOperandsAheadOfAttrs) {
|
|
if (op.getNumOperands() != 0) {
|
|
os << tabs
|
|
<< formatv("for (Value operand : {0}->getOperands()) {{\n", opVar);
|
|
os << tabs << " auto id = getValueID(operand);\n";
|
|
os << tabs << " assert(id && \"use before def!\");\n";
|
|
os << tabs << formatv(" {0}.push_back(id);\n", operands);
|
|
os << tabs << "}\n";
|
|
}
|
|
for (const NamedAttribute &attr : op.getAttributes()) {
|
|
emitAttributeSerialization(
|
|
(attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc,
|
|
tabs, opVar, operands, attr.name, os);
|
|
os << tabs
|
|
<< formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr.name);
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Serialize operands separately.
|
|
auto operandNum = 0;
|
|
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
|
|
auto argument = op.getArg(i);
|
|
os << tabs << "{\n";
|
|
if (argument.is<NamedTypeConstraint *>()) {
|
|
os << tabs
|
|
<< formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
|
|
operandNum);
|
|
os << tabs << " auto argID = getValueID(arg);\n";
|
|
os << tabs << " if (!argID) {\n";
|
|
os << tabs
|
|
<< formatv(" return emitError({0}.getLoc(), "
|
|
"\"operand #{1} has a use before def\");\n",
|
|
opVar, operandNum);
|
|
os << tabs << " }\n";
|
|
os << tabs << formatv(" {0}.push_back(argID);\n", operands);
|
|
os << " }\n";
|
|
operandNum++;
|
|
} else {
|
|
NamedAttribute *attr = argument.get<NamedAttribute *>();
|
|
auto newtabs = tabs.str() + " ";
|
|
emitAttributeSerialization(
|
|
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
|
|
loc, newtabs, opVar, operands, attr->name, os);
|
|
os << newtabs
|
|
<< formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name);
|
|
}
|
|
os << tabs << "}\n";
|
|
}
|
|
}
|
|
|
|
/// Generates code to serializes the result of SPV_Op `op` into `os`. The
|
|
/// generated gets the ID for the type of the result (if any), the SSA-ID of
|
|
/// the result and updates `resultID` with the SSA-ID.
|
|
static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
|
|
StringRef tabs, StringRef opVar,
|
|
StringRef operands, StringRef resultID,
|
|
raw_ostream &os) {
|
|
if (op.getNumResults() == 1) {
|
|
StringRef resultTypeID("resultTypeID");
|
|
os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID);
|
|
os << tabs
|
|
<< formatv(
|
|
"if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
|
|
opVar, resultTypeID);
|
|
os << tabs << " return failure();\n";
|
|
os << tabs << "}\n";
|
|
os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);
|
|
// Create an SSA result <id> for the op
|
|
os << tabs << formatv("{0} = getNextID();\n", resultID);
|
|
os << tabs
|
|
<< formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID);
|
|
os << tabs << formatv("{0}.push_back({1});\n", operands, resultID);
|
|
} else if (op.getNumResults() != 0) {
|
|
PrintFatalError(loc, "SPIR-V ops can only have zero or one result");
|
|
}
|
|
}
|
|
|
|
/// Generates code to serialize attributes of SPV_Op `op` that become
|
|
/// decorations on the `resultID` of the serialized operation `opVar` in the
|
|
/// SPIR-V binary.
|
|
static void emitDecorationSerialization(const Operator &op, StringRef tabs,
|
|
StringRef opVar, StringRef elidedAttrs,
|
|
StringRef resultID, raw_ostream &os) {
|
|
if (op.getNumResults() == 1) {
|
|
// All non-argument attributes translated into OpDecorate instruction
|
|
os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar);
|
|
os << tabs
|
|
<< formatv(" if (llvm::is_contained({0}, attr.first)) {{", elidedAttrs);
|
|
os << tabs << " continue;\n";
|
|
os << tabs << " }\n";
|
|
os << tabs
|
|
<< formatv(
|
|
" if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
|
|
opVar, resultID);
|
|
os << tabs << " return failure();\n";
|
|
os << tabs << " }\n";
|
|
os << tabs << "}\n";
|
|
}
|
|
}
|
|
|
|
/// Generates code to serialize an SPV_Op `op` into `os`.
|
|
static void emitSerializationFunction(const Record *attrClass,
|
|
const Record *record, const Operator &op,
|
|
raw_ostream &os) {
|
|
// If the record has 'autogenSerialization' set to 0, nothing to do
|
|
if (!record->getValueAsBit("autogenSerialization"))
|
|
return;
|
|
|
|
StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
|
|
resultID("resultID");
|
|
|
|
os << formatv(
|
|
"template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
|
|
op.getQualCppClassName(), opVar);
|
|
|
|
// Special case for ops without attributes in TableGen definitions
|
|
if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
|
|
std::string extInstSet;
|
|
std::string opcode;
|
|
if (record->isSubClassOf("SPV_ExtInstOp")) {
|
|
extInstSet =
|
|
formatv("\"{0}\"", record->getValueAsString("extendedInstSetName"));
|
|
opcode = std::to_string(record->getValueAsInt("extendedInstOpcode"));
|
|
} else {
|
|
extInstSet = "\"\"";
|
|
opcode = formatv("static_cast<uint32_t>(spirv::Opcode::{0})",
|
|
record->getValueAsString("spirvOpName"));
|
|
}
|
|
|
|
os << formatv(" return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
|
|
opVar, extInstSet, opcode);
|
|
return;
|
|
}
|
|
|
|
os << formatv(" SmallVector<uint32_t, 4> {0};\n", operands);
|
|
os << formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs);
|
|
|
|
// Serialize result information.
|
|
if (op.getNumResults() == 1) {
|
|
os << formatv(" uint32_t {0} = 0;\n", resultID);
|
|
emitResultSerialization(op, record->getLoc(), " ", opVar, operands,
|
|
resultID, os);
|
|
}
|
|
|
|
// Process arguments.
|
|
emitArgumentSerialization(op, record->getLoc(), " ", opVar, operands,
|
|
elidedAttrs, os);
|
|
|
|
if (record->isSubClassOf("SPV_ExtInstOp")) {
|
|
os << formatv(
|
|
" (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", opVar,
|
|
record->getValueAsString("extendedInstSetName"),
|
|
record->getValueAsInt("extendedInstOpcode"), operands);
|
|
} else {
|
|
// Emit debug info.
|
|
os << formatv(" (void)emitDebugLine(functionBody, {0}.getLoc());\n",
|
|
opVar);
|
|
os << formatv(" (void)encodeInstructionInto("
|
|
"functionBody, spirv::Opcode::{1}, {2});\n",
|
|
op.getQualCppClassName(),
|
|
record->getValueAsString("spirvOpName"), operands);
|
|
}
|
|
|
|
// Process decorations.
|
|
emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os);
|
|
|
|
os << " return success();\n";
|
|
os << "}\n\n";
|
|
}
|
|
|
|
/// Generates the prologue for the function that dispatches the serialization of
|
|
/// the operation `opVar` based on its opcode.
|
|
static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
|
|
os << formatv(
|
|
"LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
|
|
"*{0}) {{\n",
|
|
opVar);
|
|
}
|
|
|
|
/// Generates the body of the dispatch function. This function generates the
|
|
/// check that if satisfied, will call the serialization function generated for
|
|
/// the `op`.
|
|
static void emitSerializationDispatch(const Operator &op, StringRef tabs,
|
|
StringRef opVar, raw_ostream &os) {
|
|
os << tabs
|
|
<< formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar);
|
|
os << tabs
|
|
<< formatv(" return processOp(cast<{0}>({1}));\n",
|
|
op.getQualCppClassName(), opVar);
|
|
os << tabs << "}\n";
|
|
}
|
|
|
|
/// Generates the epilogue for the function that dispatches the serialization of
|
|
/// the operation.
|
|
static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
|
|
os << formatv(
|
|
" return {0}->emitError(\"unhandled operation serialization\");\n",
|
|
opVar);
|
|
os << "}\n\n";
|
|
}
|
|
|
|
/// Generates code to deserialize the attribute of a SPV_Op into `os`. The
|
|
/// generated code reads the `words` of the serialized instruction at
|
|
/// position `wordIndex` and adds the deserialized attribute into `attrList`.
|
|
static void emitAttributeDeserialization(const Attribute &attr,
|
|
ArrayRef<SMLoc> loc, StringRef tabs,
|
|
StringRef attrList, StringRef attrName,
|
|
StringRef words, StringRef wordIndex,
|
|
raw_ostream &os) {
|
|
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
|
|
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
|
|
os << tabs
|
|
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
|
|
"getConstantInt({2}[{3}++])));\n",
|
|
attrList, attrName, words, wordIndex);
|
|
} else if (attr.getAttrDefName() == "I32ArrayAttr") {
|
|
os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
|
|
os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
|
|
os << tabs
|
|
<< formatv(
|
|
" "
|
|
"attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
|
|
";\n",
|
|
words, wordIndex);
|
|
os << tabs << "}\n";
|
|
os << tabs
|
|
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
|
|
"opBuilder.getArrayAttr(attrListElems)));\n",
|
|
attrList, attrName);
|
|
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
|
|
os << tabs
|
|
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
|
|
"opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
|
|
attrList, attrName, words, wordIndex);
|
|
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
|
|
os << tabs
|
|
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
|
|
"TypeAttr::get(getType({2}[{3}++]))));\n",
|
|
attrList, attrName, words, wordIndex);
|
|
} else {
|
|
PrintFatalError(
|
|
loc, llvm::Twine(
|
|
"unhandled attribute type in deserialization generation : '") +
|
|
attr.getAttrDefName() + llvm::Twine("'"));
|
|
}
|
|
}
|
|
|
|
/// Generates the code to deserialize the result of an SPV_Op `op` into
|
|
/// `os`. The generated code gets the type of the result specified at
|
|
/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
|
|
/// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
|
|
/// respectively.
|
|
static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
|
|
StringRef tabs, StringRef words,
|
|
StringRef wordIndex,
|
|
StringRef resultTypes, StringRef valueID,
|
|
raw_ostream &os) {
|
|
// Deserialize result information if it exists
|
|
if (op.getNumResults() == 1) {
|
|
os << tabs << "{\n";
|
|
os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
|
|
os << tabs
|
|
<< formatv(
|
|
" return emitError(unknownLoc, \"expected result type <id> "
|
|
"while deserializing {0}\");\n",
|
|
op.getQualCppClassName());
|
|
os << tabs << " }\n";
|
|
os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex);
|
|
os << tabs << " if (!ty) {\n";
|
|
os << tabs
|
|
<< formatv(
|
|
" return emitError(unknownLoc, \"unknown type result <id> : "
|
|
"\") << {0}[{1}];\n",
|
|
words, wordIndex);
|
|
os << tabs << " }\n";
|
|
os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes);
|
|
os << tabs << formatv(" {0}++;\n", wordIndex);
|
|
os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
|
|
os << tabs
|
|
<< formatv(
|
|
" return emitError(unknownLoc, \"expected result <id> while "
|
|
"deserializing {0}\");\n",
|
|
op.getQualCppClassName());
|
|
os << tabs << " }\n";
|
|
os << tabs << "}\n";
|
|
os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex);
|
|
} else if (op.getNumResults() != 0) {
|
|
PrintFatalError(loc, "SPIR-V ops can have only zero or one result");
|
|
}
|
|
}
|
|
|
|
/// Generates the code to deserialize the operands of an SPV_Op `op` into
|
|
/// `os`. The generated code reads the `words` of the binary instruction, from
|
|
/// position `wordIndex` to the end, and either gets the Value corresponding to
|
|
/// the ID encoded, or deserializes the attributes encoded. The parsed
|
|
/// operand(attribute) is added to the `operands` list or `attributes` list.
|
|
static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
|
|
StringRef tabs, StringRef words,
|
|
StringRef wordIndex, StringRef operands,
|
|
StringRef attributes, raw_ostream &os) {
|
|
// Process operands/attributes
|
|
unsigned operandNum = 0;
|
|
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
|
|
auto argument = op.getArg(i);
|
|
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
|
|
if (valueArg->isVariableLength()) {
|
|
if (i != e - 1) {
|
|
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
|
|
"Optional<...> arguments only if "
|
|
"it's the last argument");
|
|
}
|
|
os << tabs
|
|
<< formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
|
|
} else {
|
|
os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words);
|
|
}
|
|
os << " {\n";
|
|
os << tabs
|
|
<< formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex);
|
|
os << tabs << " if (!arg) {\n";
|
|
os << tabs
|
|
<< formatv(
|
|
" return emitError(unknownLoc, \"unknown result <id> : \") "
|
|
"<< {0}[{1}];\n",
|
|
words, wordIndex);
|
|
os << tabs << " }\n";
|
|
os << tabs << formatv(" {0}.push_back(arg);\n", operands);
|
|
if (!valueArg->isVariableLength()) {
|
|
os << tabs << formatv(" {0}++;\n", wordIndex);
|
|
}
|
|
operandNum++;
|
|
os << tabs << "}\n";
|
|
} else {
|
|
os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
|
|
auto attr = argument.get<NamedAttribute *>();
|
|
auto newtabs = tabs.str() + " ";
|
|
emitAttributeDeserialization(
|
|
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
|
|
loc, newtabs, attributes, attr->name, words, wordIndex, os);
|
|
os << " }\n";
|
|
}
|
|
}
|
|
|
|
os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words);
|
|
os << tabs
|
|
<< formatv(
|
|
" return emitError(unknownLoc, \"found more operands than "
|
|
"expected when deserializing {0}, only \") << {1} << \" of \" << "
|
|
"{2}.size() << \" processed\";\n",
|
|
op.getQualCppClassName(), wordIndex, words);
|
|
os << tabs << "}\n\n";
|
|
}
|
|
|
|
/// Generates code to update the `attributes` vector with the attributes
|
|
/// obtained from parsing the decorations in the SPIR-V binary associated with
|
|
/// an <id> `valueID`
|
|
static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
|
|
StringRef valueID,
|
|
StringRef attributes,
|
|
raw_ostream &os) {
|
|
// Import decorations parsed
|
|
if (op.getNumResults() == 1) {
|
|
os << tabs << formatv("if (decorations.count({0})) {{\n", valueID);
|
|
os << tabs
|
|
<< formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID);
|
|
os << tabs
|
|
<< formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes);
|
|
os << tabs << "}\n";
|
|
}
|
|
}
|
|
|
|
/// Generates code to deserialize an SPV_Op `op` into `os`.
|
|
static void emitDeserializationFunction(const Record *attrClass,
|
|
const Record *record,
|
|
const Operator &op, raw_ostream &os) {
|
|
// If the record has 'autogenSerialization' set to 0, nothing to do
|
|
if (!record->getValueAsBit("autogenSerialization"))
|
|
return;
|
|
|
|
StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
|
|
wordIndex("wordIndex"), opVar("op"), operands("operands"),
|
|
attributes("attributes");
|
|
|
|
// Method declaration
|
|
os << formatv("template <> "
|
|
"LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
|
|
"uint32_t> {1}) {{\n",
|
|
op.getQualCppClassName(), words);
|
|
|
|
// Special case for ops without attributes in TableGen definitions
|
|
if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
|
|
os << formatv(" return processOpWithoutGrammarAttr("
|
|
"{0}, \"{1}\", {2}, {3});\n}\n\n",
|
|
words, op.getOperationName(),
|
|
op.getNumResults() ? "true" : "false", op.getNumOperands());
|
|
return;
|
|
}
|
|
|
|
os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes);
|
|
os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex);
|
|
os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID);
|
|
|
|
// Deserialize result information
|
|
emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex,
|
|
resultTypes, valueID, os);
|
|
|
|
os << formatv(" SmallVector<Value, 4> {0};\n", operands);
|
|
os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes);
|
|
// Operand deserialization
|
|
emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex,
|
|
operands, attributes, os);
|
|
|
|
// Decorations
|
|
emitDecorationDeserialization(op, " ", valueID, attributes, os);
|
|
|
|
os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n");
|
|
os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
|
|
"(void){1};\n",
|
|
op.getQualCppClassName(), opVar, resultTypes, operands,
|
|
attributes);
|
|
if (op.getNumResults() == 1) {
|
|
os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar);
|
|
}
|
|
|
|
// According to SPIR-V spec:
|
|
// This location information applies to the instructions physically following
|
|
// this instruction, up to the first occurrence of any of the following: the
|
|
// next end of block.
|
|
os << formatv(" if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar);
|
|
os << formatv(" (void)clearDebugLine();\n");
|
|
os << " return success();\n";
|
|
os << "}\n\n";
|
|
}
|
|
|
|
/// Generates the prologue for the function that dispatches the deserialization
|
|
/// based on the `opcode`.
|
|
static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
|
|
raw_ostream &os) {
|
|
os << formatv("LogicalResult spirv::Deserializer::"
|
|
"dispatchToAutogenDeserialization(spirv::Opcode {0},"
|
|
" ArrayRef<uint32_t> {1}) {{\n",
|
|
opcode, words);
|
|
os << formatv(" switch ({0}) {{\n", opcode);
|
|
}
|
|
|
|
/// Generates the body of the dispatch function, by generating the case label
|
|
/// for an opcode and the call to the method to perform the deserialization.
|
|
static void emitDeserializationDispatch(const Operator &op, const Record *def,
|
|
StringRef tabs, StringRef words,
|
|
raw_ostream &os) {
|
|
os << tabs
|
|
<< formatv("case spirv::Opcode::{0}:\n",
|
|
def->getValueAsString("spirvOpName"));
|
|
os << tabs
|
|
<< formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(),
|
|
words);
|
|
}
|
|
|
|
/// Generates the epilogue for the function that dispatches the deserialization
|
|
/// of the operation.
|
|
static void finalizeDispatchDeserializationFn(StringRef opcode,
|
|
raw_ostream &os) {
|
|
os << " default:\n";
|
|
os << " ;\n";
|
|
os << " }\n";
|
|
StringRef opcodeVar("opcodeString");
|
|
os << formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar,
|
|
opcode);
|
|
os << formatv(" if (!{0}.empty()) {{\n", opcodeVar);
|
|
os << formatv(" return emitError(unknownLoc, \"unhandled deserialization "
|
|
"of \") << {0};\n",
|
|
opcodeVar);
|
|
os << " } else {\n";
|
|
os << formatv(" return emitError(unknownLoc, \"unhandled opcode \") << "
|
|
"static_cast<uint32_t>({0});\n",
|
|
opcode);
|
|
os << " }\n";
|
|
os << "}\n";
|
|
}
|
|
|
|
static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
|
|
StringRef instructionID,
|
|
StringRef words,
|
|
raw_ostream &os) {
|
|
os << formatv("LogicalResult spirv::Deserializer::"
|
|
"dispatchToExtensionSetAutogenDeserialization("
|
|
"StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
|
|
extensionSetName, instructionID, words);
|
|
}
|
|
|
|
static void
|
|
emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
StringRef extensionSetName("extensionSetName"),
|
|
instructionID("instructionID"), words("words");
|
|
|
|
// First iterate over all ops derived from SPV_ExtensionSetOps to get all
|
|
// extensionSets.
|
|
|
|
// For each of the extensions a separate raw_string_ostream is used to
|
|
// generate code into. These are then concatenated at the end. Since
|
|
// raw_string_ostream needs a string&, use a vector to store all the string
|
|
// that are captured by reference within raw_string_ostream.
|
|
StringMap<raw_string_ostream> extensionSets;
|
|
std::list<std::string> extensionSetNames;
|
|
|
|
initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
|
|
os);
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_ExtInstOp");
|
|
for (const auto *def : defs) {
|
|
if (!def->getValueAsBit("autogenSerialization")) {
|
|
continue;
|
|
}
|
|
Operator op(def);
|
|
auto setName = def->getValueAsString("extendedInstSetName");
|
|
if (!extensionSets.count(setName)) {
|
|
extensionSetNames.push_back("");
|
|
extensionSets.try_emplace(setName, extensionSetNames.back());
|
|
auto &setos = extensionSets.find(setName)->second;
|
|
setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName);
|
|
setos << formatv(" switch ({0}) {{\n", instructionID);
|
|
}
|
|
auto &setos = extensionSets.find(setName)->second;
|
|
setos << formatv(" case {0}:\n",
|
|
def->getValueAsInt("extendedInstOpcode"));
|
|
setos << formatv(" return processOp<{0}>({1});\n",
|
|
op.getQualCppClassName(), words);
|
|
}
|
|
|
|
// Append the dispatch code for all the extended sets.
|
|
for (auto &extensionSet : extensionSets) {
|
|
os << extensionSet.second.str();
|
|
os << " default:\n";
|
|
os << formatv(
|
|
" return emitError(unknownLoc, \"unhandled deserializations of "
|
|
"\") << {0} << \" from extension set \" << {1};\n",
|
|
instructionID, extensionSetName);
|
|
os << " }\n";
|
|
os << " }\n";
|
|
}
|
|
|
|
os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of "
|
|
"extended instruction set {0}\");\n",
|
|
extensionSetName);
|
|
os << "}\n";
|
|
}
|
|
|
|
/// Emits all the autogenerated serialization/deserializations functions for the
|
|
/// SPV_Ops.
|
|
static bool emitSerializationFns(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
|
|
|
|
std::string dSerFnString, dDesFnString, serFnString, deserFnString,
|
|
utilsString;
|
|
raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
|
|
serFn(serFnString), deserFn(deserFnString);
|
|
Record *attrClass = recordKeeper.getClass("Attr");
|
|
|
|
// Emit the serialization and deserialization functions simultaneously.
|
|
StringRef opVar("op");
|
|
StringRef opcode("opcode"), words("words");
|
|
|
|
// Handle the SPIR-V ops.
|
|
initDispatchSerializationFn(opVar, dSerFn);
|
|
initDispatchDeserializationFn(opcode, words, dDesFn);
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
|
|
for (const auto *def : defs) {
|
|
Operator op(def);
|
|
emitSerializationFunction(attrClass, def, op, serFn);
|
|
emitDeserializationFunction(attrClass, def, op, deserFn);
|
|
if (def->getValueAsBit("hasOpcode") || def->isSubClassOf("SPV_ExtInstOp")) {
|
|
emitSerializationDispatch(op, " ", opVar, dSerFn);
|
|
}
|
|
if (def->getValueAsBit("hasOpcode")) {
|
|
emitDeserializationDispatch(op, def, " ", words, dDesFn);
|
|
}
|
|
}
|
|
finalizeDispatchSerializationFn(opVar, dSerFn);
|
|
finalizeDispatchDeserializationFn(opcode, dDesFn);
|
|
|
|
emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn);
|
|
|
|
os << "#ifdef GET_SERIALIZATION_FNS\n\n";
|
|
os << serFn.str();
|
|
os << dSerFn.str();
|
|
os << "#endif // GET_SERIALIZATION_FNS\n\n";
|
|
|
|
os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
|
|
os << deserFn.str();
|
|
os << dDesFn.str();
|
|
os << "#endif // GET_DESERIALIZATION_FNS\n\n";
|
|
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Serialization Hook Registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static mlir::GenRegistration genSerialization(
|
|
"gen-spirv-serialization",
|
|
"Generate SPIR-V (de)serialization utilities and functions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitSerializationFns(records, os);
|
|
});
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Op Utils AutoGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
|
|
os << formatv("template <typename EnumClass> inline constexpr StringRef "
|
|
"attributeName();\n");
|
|
}
|
|
|
|
static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
|
|
raw_ostream &os) {
|
|
auto enumName = enumAttr.getEnumClassName();
|
|
os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
|
|
enumName);
|
|
os << " "
|
|
<< formatv("static constexpr const char attrName[] = \"{0}\";\n",
|
|
llvm::convertToSnakeFromCamelCase(enumName));
|
|
os << " return attrName;\n";
|
|
os << "}\n";
|
|
}
|
|
|
|
static bool emitAttrUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
|
|
os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
|
|
os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
|
|
emitEnumGetAttrNameFnDecl(os);
|
|
for (const auto *def : defs) {
|
|
EnumAttr enumAttr(*def);
|
|
emitEnumGetAttrNameFnDefn(enumAttr, os);
|
|
}
|
|
os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Op Utils Hook Registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static mlir::GenRegistration
|
|
genOpUtils("gen-spirv-attr-utils",
|
|
"Generate SPIR-V attribute utility definitions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitAttrUtils(records, os);
|
|
});
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SPIR-V Availability Impl AutoGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
|
|
mlir::tblgen::FmtContext fctx;
|
|
fctx.addSubst("overall", "tblgen_overall");
|
|
|
|
std::vector<Availability> opAvailabilities =
|
|
getAvailabilities(srcOp.getDef());
|
|
|
|
// First collect all availability classes this op should implement.
|
|
// All availability instances keep information for the generated interface and
|
|
// the instance's specific requirement. Here we remember a random instance so
|
|
// we can get the information regarding the generated interface.
|
|
llvm::StringMap<Availability> availClasses;
|
|
for (const Availability &avail : opAvailabilities)
|
|
availClasses.try_emplace(avail.getClass(), avail);
|
|
for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
|
|
const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
|
|
if (!enumAttr)
|
|
continue;
|
|
|
|
for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
|
|
for (const Availability &caseAvail :
|
|
getAvailabilities(enumerant.getDef()))
|
|
availClasses.try_emplace(caseAvail.getClass(), caseAvail);
|
|
}
|
|
|
|
// Then generate implementation for each availability class.
|
|
for (const auto &availClass : availClasses) {
|
|
StringRef availClassName = availClass.getKey();
|
|
Availability avail = availClass.getValue();
|
|
|
|
// Generate the implementation method signature.
|
|
os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
|
|
srcOp.getCppClassName(), avail.getQueryFnName());
|
|
|
|
// Create the variable for the final requirement and initialize it.
|
|
os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(),
|
|
avail.getMergeInitializer());
|
|
|
|
// Update with the op's specific availability spec.
|
|
for (const Availability &avail : opAvailabilities)
|
|
if (avail.getClass() == availClassName &&
|
|
(!avail.getMergeInstancePreparation().empty() ||
|
|
!avail.getMergeActionCode().empty())) {
|
|
os << " {\n "
|
|
// Prepare this instance.
|
|
<< avail.getMergeInstancePreparation()
|
|
<< "\n "
|
|
// Merge this instance.
|
|
<< std::string(
|
|
tgfmt(avail.getMergeActionCode(),
|
|
&fctx.addSubst("instance", avail.getMergeInstance())))
|
|
<< ";\n }\n";
|
|
}
|
|
|
|
// Update with enum attributes' specific availability spec.
|
|
for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
|
|
const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
|
|
if (!enumAttr)
|
|
continue;
|
|
|
|
// (enumerant, availability specification) pairs for this availability
|
|
// class.
|
|
SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
|
|
|
|
// Collect all cases' availability specs.
|
|
for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
|
|
for (const Availability &caseAvail :
|
|
getAvailabilities(enumerant.getDef()))
|
|
if (availClassName == caseAvail.getClass())
|
|
caseSpecs.push_back({enumerant, caseAvail});
|
|
|
|
// If this attribute kind does not have any availability spec from any of
|
|
// its cases, no more work to do.
|
|
if (caseSpecs.empty())
|
|
continue;
|
|
|
|
if (enumAttr->isBitEnum()) {
|
|
// For BitEnumAttr, we need to iterate over each bit to query its
|
|
// availability spec.
|
|
os << formatv(" for (unsigned i = 0; "
|
|
"i < std::numeric_limits<{0}>::digits; ++i) {{\n",
|
|
enumAttr->getUnderlyingType());
|
|
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
|
|
"static_cast<{0}::{1}>(1 << i);\n",
|
|
enumAttr->getCppNamespace(), enumAttr->getEnumClassName(),
|
|
namedAttr.name);
|
|
os << formatv(
|
|
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
|
|
enumAttr->getUnderlyingType());
|
|
} else {
|
|
// For IntEnumAttr, we just need to query the value as a whole.
|
|
os << " {\n";
|
|
os << formatv(" auto tblgen_attrVal = this->{0}();\n",
|
|
namedAttr.name);
|
|
}
|
|
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
|
|
enumAttr->getCppNamespace(), avail.getQueryFnName());
|
|
os << " if (tblgen_instance) "
|
|
// TODO` here once ODS supports
|
|
// dialect-specific contents so that we can use not implementing the
|
|
// availability interface as indication of no requirements.
|
|
<< std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(),
|
|
&fctx.addSubst("instance", "*tblgen_instance")))
|
|
<< ";\n";
|
|
os << " }\n";
|
|
}
|
|
|
|
os << " return tblgen_overall;\n";
|
|
os << "}\n";
|
|
}
|
|
}
|
|
|
|
static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os);
|
|
|
|
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
|
|
for (const auto *def : defs) {
|
|
Operator op(def);
|
|
emitAvailabilityImpl(op, os);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Op Availability Implementation Hook Registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static mlir::GenRegistration
|
|
genOpAvailabilityImpl("gen-spirv-avail-impls",
|
|
"Generate SPIR-V operation utility definitions",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitAvailabilityImpl(records, os);
|
|
});
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SPIR-V Capability Implication AutoGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
|
|
raw_ostream &os) {
|
|
llvm::emitSourceFileHeader("SPIR-V Capability Implication", os);
|
|
|
|
EnumAttr enumAttr(recordKeeper.getDef("SPV_CapabilityAttr"));
|
|
|
|
os << "ArrayRef<spirv::Capability> "
|
|
"spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
|
|
<< " switch (cap) {\n"
|
|
<< " default: return {};\n";
|
|
for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
|
|
const Record &def = enumerant.getDef();
|
|
if (!def.getValue("implies"))
|
|
continue;
|
|
|
|
std::vector<Record *> impliedCapsDefs = def.getValueAsListOfDefs("implies");
|
|
os << " case spirv::Capability::" << enumerant.getSymbol()
|
|
<< ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
|
|
<< "] = {";
|
|
llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
|
|
os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol();
|
|
});
|
|
os << "}; return ArrayRef<spirv::Capability>(implies, "
|
|
<< impliedCapsDefs.size() << "); }\n";
|
|
}
|
|
os << " }\n";
|
|
os << "}\n";
|
|
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SPIR-V Capability Implication Hook Registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static mlir::GenRegistration
|
|
genCapabilityImplication("gen-spirv-capability-implication",
|
|
"Generate utility function to return implied "
|
|
"capabilities for a given capability",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitCapabilityImplication(records, os);
|
|
});
|