Added a TableGen generator for structured data

Similar to enum, added a generator for structured data. This provide Dictionary that stores a fixed set of values and guarantees the values are valid. It is intended to store a fixed number of values by a given name.

PiperOrigin-RevId: 266437460
This commit is contained in:
Rob Suderman 2019-08-30 12:51:31 -07:00 committed by A. Unique TensorFlower
parent 037742cdf2
commit 8f90a442c3
8 changed files with 548 additions and 1 deletions

View File

@ -838,7 +838,7 @@ class I64EnumAttr<string name, string description,
//===----------------------------------------------------------------------===//
// Composite attribute kinds
def DictionaryAttr : Attr<CPred<"$_self.isa<DictionaryAttr>()">,
class DictionaryAttr : Attr<CPred<"$_self.isa<DictionaryAttr>()">,
"dictionary of named attribute values"> {
let storageType = [{ DictionaryAttr }];
let returnType = [{ DictionaryAttr }];
@ -914,6 +914,29 @@ def I32ElementsAttr : Attr<
"{$_builder.getI32IntegerAttr($0)})";
let convertFromStorage = "$_self";
}
// Attribute information for an Attribute field within a StructAttr.
class StructFieldAttr<string thisName, Attr thisType> {
// Name of this field in the StructAttr.
string name = thisName;
// Attribute type wrapped by the struct attr.
Attr type = thisType;
}
// Structured attribute that wraps a DictionaryAttr and provides both a
// validation method and set of accessors for a fixed set of fields. This is
// useful when representing data that would normally be in a structure.
class StructAttr<string name, Dialect dialect,
list<StructFieldAttr> attributes> : DictionaryAttr {
// Name for this StructAttr.
string className = name;
// The dialect this StructAttr belongs to.
Dialect structDialect = dialect;
// List of fields that the StructAttr contains.
list<StructFieldAttr> fields = attributes;
}
// Attributes containing symbol references.
def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,

View File

@ -180,6 +180,36 @@ public:
std::vector<EnumAttrCase> getAllCases() const;
};
class StructFieldAttr {
public:
explicit StructFieldAttr(const llvm::Record *record);
explicit StructFieldAttr(const llvm::Record &record);
explicit StructFieldAttr(const llvm::DefInit *init);
StringRef getName() const;
Attribute getType() const;
private:
const llvm::Record *def;
};
// Wrapper class providing helper methods for accessing struct attributes
// defined in TableGen.
class StructAttr : public Attribute {
public:
explicit StructAttr(const llvm::Record *record);
explicit StructAttr(const llvm::Record &record) : StructAttr(&record){};
explicit StructAttr(const llvm::DefInit *init);
// Returns the struct class name.
StringRef getStructClassName() const;
// Returns the C++ namespaces this struct class should be placed in.
StringRef getCppNamespace() const;
std::vector<StructFieldAttr> getAllFields() const;
};
} // end namespace tblgen
} // end namespace mlir

View File

@ -210,3 +210,55 @@ std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
return cases;
}
tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record *record)
: def(record) {
assert(def->isSubClassOf("StructFieldAttr") &&
"must be subclass of TableGen 'StructFieldAttr' class");
}
tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record &record)
: StructFieldAttr(&record) {}
tblgen::StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
: StructFieldAttr(init->getDef()) {}
StringRef tblgen::StructFieldAttr::getName() const {
return def->getValueAsString("name");
}
tblgen::Attribute tblgen::StructFieldAttr::getType() const {
auto init = def->getValueInit("type");
return tblgen::Attribute(cast<llvm::DefInit>(init));
}
tblgen::StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
assert(def->isSubClassOf("StructAttr") &&
"must be subclass of TableGen 'StructAttr' class");
}
tblgen::StructAttr::StructAttr(const llvm::DefInit *init)
: StructAttr(init->getDef()) {}
StringRef tblgen::StructAttr::getStructClassName() const {
return def->getValueAsString("className");
}
StringRef tblgen::StructAttr::getCppNamespace() const {
Dialect dialect(def->getValueAsDef("structDialect"));
return dialect.getCppNamespace();
}
std::vector<mlir::tblgen::StructFieldAttr>
tblgen::StructAttr::getAllFields() const {
std::vector<mlir::tblgen::StructFieldAttr> attributes;
const auto *inits = def->getValueAsListInit("fields");
attributes.reserve(inits->size());
for (const llvm::Init *init : *inits) {
attributes.emplace_back(cast<llvm::DefInit>(init));
}
return attributes;
}

View File

@ -13,5 +13,6 @@ add_tablegen(mlir-tblgen MLIR
ReferenceImplGen.cpp
RewriterGen.cpp
SPIRVUtilsGen.cpp
StructsGen.cpp
)
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")

View File

@ -0,0 +1,259 @@
//===- StructsGen.cpp - MLIR struct utility generator ---------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// StructsGen generates common utility functions for grouping attributes into a
// set of structured data.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
using mlir::tblgen::StructAttr;
static void
emitStructClass(const Record &structDef, StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
StringRef description, raw_ostream &os) {
const char *structInfo = R"(
// {0}
class {1} : public mlir::DictionaryAttr)";
const char *structInfoEnd = R"( {
public:
using DictionaryAttr::DictionaryAttr;
static bool classof(mlir::Attribute attr);
)";
os << formatv(structInfo, description, structName) << structInfoEnd;
// Declares a constructor function for the tablegen structure.
// TblgenStruct::get(MLIRContext context, Type1 Field1, Type2 Field2, ...);
const char *getInfoDecl = " static {0} get(\n";
const char *getInfoDeclArg = " {0} {1},\n";
const char *getInfoDeclEnd = " mlir::MLIRContext* context);\n\n";
os << llvm::formatv(getInfoDecl, structName);
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << llvm::formatv(getInfoDeclArg, storage, name);
}
os << getInfoDeclEnd;
// Declares an accessor for the fields owned by the tablegen structure.
// namespace::storage TblgenStruct::field1() const;
const char *fieldInfo = R"( {0} {1}() const;
)";
for (const auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << formatv(fieldInfo, storage, name);
}
os << "};\n\n";
}
static void emitStructDecl(const Record &structDef, raw_ostream &os) {
StructAttr structAttr(&structDef);
StringRef structName = structAttr.getStructClassName();
StringRef cppNamespace = structAttr.getCppNamespace();
StringRef description = structAttr.getDescription();
auto fields = structAttr.getAllFields();
// Wrap in the appropriate namespace.
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
// Emit the struct class definition
emitStructClass(structDef, structName, fields, description, os);
// Close the declared namespace.
for (auto ns : namespaces)
os << "} // namespace " << ns << "\n";
}
static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Struct Utility Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
for (const auto *def : defs) {
emitStructDecl(*def, os);
}
return false;
}
static void emitFactoryDef(llvm::StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
raw_ostream &os) {
const char *getInfoDecl = "{0} {0}::get(\n";
const char *getInfoDeclArg = " {0} {1},\n";
const char *getInfoDeclEnd = " mlir::MLIRContext* context) {";
os << llvm::formatv(getInfoDecl, structName);
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << llvm::formatv(getInfoDeclArg, storage, name);
}
os << getInfoDeclEnd;
const char *fieldStart = R"(
llvm::SmallVector<mlir::NamedAttribute, {0}> fields;
)";
os << llvm::formatv(fieldStart, fields.size());
const char *getFieldInfo = R"(
assert({0});
auto {0}_id = mlir::Identifier::get("{0}", context);
fields.emplace_back({0}_id, {0});
)";
for (auto field : fields) {
os << llvm::formatv(getFieldInfo, field.getName());
}
const char *getEndInfo = R"(
Attribute dict = mlir::DictionaryAttr::get(fields, context);
return dict.dyn_cast<{0}>();
}
)";
os << llvm::formatv(getEndInfo, structName);
}
static void emitClassofDef(llvm::StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
raw_ostream &os) {
const char *classofInfo = R"(
bool {0}::classof(mlir::Attribute attr))";
const char *classofInfoHeader = R"(
auto derived = attr.dyn_cast<mlir::DictionaryAttr>();
if (!derived)
return false;
if (derived.size() != {0})
return false;
)";
os << llvm::formatv(classofInfo, structName) << " {";
os << llvm::formatv(classofInfoHeader, fields.size());
const char *classofArgInfo = R"(
auto {0} = derived.get("{0}");
if (!{0} || !{0}.isa<{1}>())
return false;
)";
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << llvm::formatv(classofArgInfo, name, storage);
}
const char *classofEndInfo = R"(
return true;
}
)";
os << classofEndInfo;
}
static void
emitAccessorDef(llvm::StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
raw_ostream &os) {
const char *fieldInfo = R"(
{0} {2}::{1}() const {
auto derived = this->cast<mlir::DictionaryAttr>();
auto {1} = derived.get("{1}");
assert({1} && "attribute not found.");
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
return {1}.cast<{0}>();
}
)";
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << llvm::formatv(fieldInfo, storage, name, structName);
}
}
static void emitStructDef(const Record &structDef, raw_ostream &os) {
StructAttr structAttr(&structDef);
StringRef cppNamespace = structAttr.getCppNamespace();
StringRef structName = structAttr.getStructClassName();
mlir::tblgen::FmtContext ctx;
auto fields = structAttr.getAllFields();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
emitFactoryDef(structName, fields, os);
emitClassofDef(structName, fields, os);
emitAccessorDef(structName, fields, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}
static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Struct Utility Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
for (const auto *def : defs)
emitStructDef(*def, os);
return false;
}
// Registers the struct utility generator to mlir-tblgen.
static mlir::GenRegistration
genStructDecls("gen-struct-attr-decls",
"Generate struct utility declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitStructDecls(records, os);
});
// Registers the struct utility generator to mlir-tblgen.
static mlir::GenRegistration
genStructDefs("gen-struct-attr-defs", "Generate struct utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitStructDefs(records, os);
});

View File

@ -3,12 +3,19 @@ mlir_tablegen(EnumsGenTest.h.inc -gen-enum-decls)
mlir_tablegen(EnumsGenTest.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTableGenEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS structs.td)
mlir_tablegen(StructAttrGenTest.h.inc -gen-struct-attr-decls)
mlir_tablegen(StructAttrGenTest.cpp.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRTableGenStructAttrIncGen)
add_mlir_unittest(MLIRTableGenTests
EnumsGenTest.cpp
StructAttrGenTest.cpp
FormatTest.cpp
)
add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen)
add_dependencies(MLIRTableGenTests MLIRTableGenStructAttrIncGen)
target_link_libraries(MLIRTableGenTests
PRIVATE LLVMMLIRTableGen)

View File

@ -0,0 +1,146 @@
//===- StructsGenTest.cpp - TableGen StructsGen Tests ---------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/StringSwitch.h"
#include "gmock/gmock.h"
#include <type_traits>
namespace mlir {
// Pull in generated enum utility declarations
#include "StructAttrGenTest.h.inc"
// And definitions
#include "StructAttrGenTest.cpp.inc"
// Helper that returns an example test::TestStruct for testing its
// implementation.
static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
auto integerType = mlir::IntegerType::get(32, context);
auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
auto floatType = mlir::FloatType::getF16(context);
auto floatAttr = mlir::FloatAttr::get(floatType, 0.25);
auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType);
auto elementsAttr =
mlir::DenseElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context);
}
// Validates that test::TestStruct::classof correctly identifies a valid
// test::TestStruct.
TEST(StructsGenTest, ClassofTrue) {
mlir::MLIRContext context;
auto structAttr = getTestStruct(&context);
ASSERT_TRUE(test::TestStruct::classof(structAttr));
}
// Validates that test::TestStruct::classof fails when an extra attribute is in
// the class.
TEST(StructsGenTest, ClassofExtraFalse) {
mlir::MLIRContext context;
mlir::DictionaryAttr structAttr = getTestStruct(&context);
auto expectedValues = structAttr.getValue();
ASSERT_EQ(expectedValues.size(), 3);
// Copy the set of named attributes.
llvm::SmallVector<mlir::NamedAttribute, 5> newValues(expectedValues.begin(),
expectedValues.end());
// Add an extra NamedAttribute.
auto wrongId = mlir::Identifier::get("wrong", &context);
auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
newValues.push_back(wrongAttr);
// Make a new DictionaryAttr and validate.
auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
// Validates that test::TestStruct::classof fails when a NamedAttribute has an
// incorrect name.
TEST(StructsGenTest, ClassofBadNameFalse) {
mlir::MLIRContext context;
mlir::DictionaryAttr structAttr = getTestStruct(&context);
auto expectedValues = structAttr.getValue();
ASSERT_EQ(expectedValues.size(), 3);
// Create a copy of all but the first NamedAttributes.
llvm::SmallVector<mlir::NamedAttribute, 4> newValues(
expectedValues.begin() + 1, expectedValues.end());
// Add a copy of the first attribute with the wrong Identifier.
auto wrongId = mlir::Identifier::get("wrong", &context);
auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
newValues.push_back(wrongAttr);
auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
// Validates that test::TestStruct::classof fails when a NamedAttribute is
// missing.
TEST(StructsGenTest, ClassofMissingFalse) {
mlir::MLIRContext context;
mlir::DictionaryAttr structAttr = getTestStruct(&context);
auto expectedValues = structAttr.getValue();
ASSERT_EQ(expectedValues.size(), 3);
// Copy a subset of the structures Named Attributes.
llvm::SmallVector<mlir::NamedAttribute, 3> newValues(
expectedValues.begin() + 1, expectedValues.end());
// Make a new DictionaryAttr and validate it is not a validte TestStruct.
auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
// Validate the accessor for the FloatAttr value.
TEST(StructsGenTest, GetFloat) {
mlir::MLIRContext context;
auto structAttr = getTestStruct(&context);
auto returnedAttr = structAttr.sample_float();
EXPECT_EQ(returnedAttr.getValueAsDouble(), 0.25);
}
// Validate the accessor for the IntegerAttr value.
TEST(StructsGenTest, GetInteger) {
mlir::MLIRContext context;
auto structAttr = getTestStruct(&context);
auto returnedAttr = structAttr.sample_integer();
EXPECT_EQ(returnedAttr.getInt(), 127);
}
// Validate the accessor for the ElementsAttr value.
TEST(StructsGenTest, GetElements) {
mlir::MLIRContext context;
auto structAttr = getTestStruct(&context);
auto returnedAttr = structAttr.sample_elements();
auto denseAttr = returnedAttr.dyn_cast<mlir::DenseElementsAttr>();
ASSERT_TRUE(denseAttr);
for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) {
EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1);
}
}
} // namespace mlir

View File

@ -0,0 +1,29 @@
//===-- structss.td - StructsGen test definition file ------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect {
let name = "test";
}
def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [
StructFieldAttr<"sample_integer", I32Attr>,
StructFieldAttr<"sample_float", F32Attr>,
StructFieldAttr<"sample_elements", ElementsAttr>] > {
let description = "Structure for test data";
}