diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index dd9d4e29e4bc..b4c921e6f52d 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -838,7 +838,7 @@ class I64EnumAttr()">, +class DictionaryAttr : Attr()">, "dictionary of named attribute values"> { let storageType = [{ DictionaryAttr }]; let returnType = [{ DictionaryAttr }]; @@ -914,6 +914,29 @@ def I32ElementsAttr : Attr< "{$_builder.getI32IntegerAttr($0)})"; let convertFromStorage = "$_self"; } +// Attribute information for an Attribute field within a StructAttr. +class StructFieldAttr { + // Name of this field in the StructAttr. + string name = thisName; + + // Attribute type wrapped by the struct attr. + Attr type = thisType; +} + +// Structured attribute that wraps a DictionaryAttr and provides both a +// validation method and set of accessors for a fixed set of fields. This is +// useful when representing data that would normally be in a structure. +class StructAttr attributes> : DictionaryAttr { + // Name for this StructAttr. + string className = name; + + // The dialect this StructAttr belongs to. + Dialect structDialect = dialect; + + // List of fields that the StructAttr contains. + list fields = attributes; +} // Attributes containing symbol references. def SymbolRefAttr : Attr()">, diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 2f137a2aca45..1cff9fdfa8b2 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -180,6 +180,36 @@ public: std::vector getAllCases() const; }; +class StructFieldAttr { +public: + explicit StructFieldAttr(const llvm::Record *record); + explicit StructFieldAttr(const llvm::Record &record); + explicit StructFieldAttr(const llvm::DefInit *init); + + StringRef getName() const; + Attribute getType() const; + +private: + const llvm::Record *def; +}; + +// Wrapper class providing helper methods for accessing struct attributes +// defined in TableGen. +class StructAttr : public Attribute { +public: + explicit StructAttr(const llvm::Record *record); + explicit StructAttr(const llvm::Record &record) : StructAttr(&record){}; + explicit StructAttr(const llvm::DefInit *init); + + // Returns the struct class name. + StringRef getStructClassName() const; + + // Returns the C++ namespaces this struct class should be placed in. + StringRef getCppNamespace() const; + + std::vector getAllFields() const; +}; + } // end namespace tblgen } // end namespace mlir diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index b42bb94e3fc4..3d19de244298 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -210,3 +210,55 @@ std::vector tblgen::EnumAttr::getAllCases() const { return cases; } + +tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record *record) + : def(record) { + assert(def->isSubClassOf("StructFieldAttr") && + "must be subclass of TableGen 'StructFieldAttr' class"); +} + +tblgen::StructFieldAttr::StructFieldAttr(const llvm::Record &record) + : StructFieldAttr(&record) {} + +tblgen::StructFieldAttr::StructFieldAttr(const llvm::DefInit *init) + : StructFieldAttr(init->getDef()) {} + +StringRef tblgen::StructFieldAttr::getName() const { + return def->getValueAsString("name"); +} + +tblgen::Attribute tblgen::StructFieldAttr::getType() const { + auto init = def->getValueInit("type"); + return tblgen::Attribute(cast(init)); +} + +tblgen::StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) { + assert(def->isSubClassOf("StructAttr") && + "must be subclass of TableGen 'StructAttr' class"); +} + +tblgen::StructAttr::StructAttr(const llvm::DefInit *init) + : StructAttr(init->getDef()) {} + +StringRef tblgen::StructAttr::getStructClassName() const { + return def->getValueAsString("className"); +} + +StringRef tblgen::StructAttr::getCppNamespace() const { + Dialect dialect(def->getValueAsDef("structDialect")); + return dialect.getCppNamespace(); +} + +std::vector +tblgen::StructAttr::getAllFields() const { + std::vector attributes; + + const auto *inits = def->getValueAsListInit("fields"); + attributes.reserve(inits->size()); + + for (const llvm::Init *init : *inits) { + attributes.emplace_back(cast(init)); + } + + return attributes; +} diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index 067e1725e240..31c23b8bd387 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -13,5 +13,6 @@ add_tablegen(mlir-tblgen MLIR ReferenceImplGen.cpp RewriterGen.cpp SPIRVUtilsGen.cpp + StructsGen.cpp ) set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning") diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp new file mode 100644 index 000000000000..d8844957ece6 --- /dev/null +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -0,0 +1,259 @@ +//===- StructsGen.cpp - MLIR struct utility generator ---------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// StructsGen generates common utility functions for grouping attributes into a +// set of structured data. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using llvm::raw_ostream; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::StringRef; +using mlir::tblgen::StructAttr; + +static void +emitStructClass(const Record &structDef, StringRef structName, + llvm::ArrayRef fields, + StringRef description, raw_ostream &os) { + const char *structInfo = R"( +// {0} +class {1} : public mlir::DictionaryAttr)"; + const char *structInfoEnd = R"( { +public: + using DictionaryAttr::DictionaryAttr; + static bool classof(mlir::Attribute attr); +)"; + os << formatv(structInfo, description, structName) << structInfoEnd; + + // Declares a constructor function for the tablegen structure. + // TblgenStruct::get(MLIRContext context, Type1 Field1, Type2 Field2, ...); + const char *getInfoDecl = " static {0} get(\n"; + const char *getInfoDeclArg = " {0} {1},\n"; + const char *getInfoDeclEnd = " mlir::MLIRContext* context);\n\n"; + + os << llvm::formatv(getInfoDecl, structName); + + for (auto field : fields) { + auto name = field.getName(); + auto type = field.getType(); + auto storage = type.getStorageType(); + os << llvm::formatv(getInfoDeclArg, storage, name); + } + os << getInfoDeclEnd; + + // Declares an accessor for the fields owned by the tablegen structure. + // namespace::storage TblgenStruct::field1() const; + const char *fieldInfo = R"( {0} {1}() const; +)"; + for (const auto field : fields) { + auto name = field.getName(); + auto type = field.getType(); + auto storage = type.getStorageType(); + os << formatv(fieldInfo, storage, name); + } + + os << "};\n\n"; +} + +static void emitStructDecl(const Record &structDef, raw_ostream &os) { + StructAttr structAttr(&structDef); + StringRef structName = structAttr.getStructClassName(); + StringRef cppNamespace = structAttr.getCppNamespace(); + StringRef description = structAttr.getDescription(); + auto fields = structAttr.getAllFields(); + + // Wrap in the appropriate namespace. + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + // Emit the struct class definition + emitStructClass(structDef, structName, fields, description, os); + + // Close the declared namespace. + for (auto ns : namespaces) + os << "} // namespace " << ns << "\n"; +} + +static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("Struct Utility Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr"); + for (const auto *def : defs) { + emitStructDecl(*def, os); + } + + return false; +} + +static void emitFactoryDef(llvm::StringRef structName, + llvm::ArrayRef fields, + raw_ostream &os) { + const char *getInfoDecl = "{0} {0}::get(\n"; + const char *getInfoDeclArg = " {0} {1},\n"; + const char *getInfoDeclEnd = " mlir::MLIRContext* context) {"; + + os << llvm::formatv(getInfoDecl, structName); + + for (auto field : fields) { + auto name = field.getName(); + auto type = field.getType(); + auto storage = type.getStorageType(); + os << llvm::formatv(getInfoDeclArg, storage, name); + } + os << getInfoDeclEnd; + + const char *fieldStart = R"( + llvm::SmallVector fields; +)"; + os << llvm::formatv(fieldStart, fields.size()); + + const char *getFieldInfo = R"( + assert({0}); + auto {0}_id = mlir::Identifier::get("{0}", context); + fields.emplace_back({0}_id, {0}); +)"; + + for (auto field : fields) { + os << llvm::formatv(getFieldInfo, field.getName()); + } + + const char *getEndInfo = R"( + Attribute dict = mlir::DictionaryAttr::get(fields, context); + return dict.dyn_cast<{0}>(); +} +)"; + os << llvm::formatv(getEndInfo, structName); +} + +static void emitClassofDef(llvm::StringRef structName, + llvm::ArrayRef fields, + raw_ostream &os) { + const char *classofInfo = R"( +bool {0}::classof(mlir::Attribute attr))"; + + const char *classofInfoHeader = R"( + auto derived = attr.dyn_cast(); + if (!derived) + return false; + if (derived.size() != {0}) + return false; +)"; + + os << llvm::formatv(classofInfo, structName) << " {"; + os << llvm::formatv(classofInfoHeader, fields.size()); + + const char *classofArgInfo = R"( + auto {0} = derived.get("{0}"); + if (!{0} || !{0}.isa<{1}>()) + return false; +)"; + for (auto field : fields) { + auto name = field.getName(); + auto type = field.getType(); + auto storage = type.getStorageType(); + os << llvm::formatv(classofArgInfo, name, storage); + } + + const char *classofEndInfo = R"( + return true; +} +)"; + os << classofEndInfo; +} + +static void +emitAccessorDef(llvm::StringRef structName, + llvm::ArrayRef fields, + raw_ostream &os) { + const char *fieldInfo = R"( +{0} {2}::{1}() const { + auto derived = this->cast(); + auto {1} = derived.get("{1}"); + assert({1} && "attribute not found."); + assert({1}.isa<{0}>() && "incorrect Attribute type found."); + return {1}.cast<{0}>(); +} +)"; + for (auto field : fields) { + auto name = field.getName(); + auto type = field.getType(); + auto storage = type.getStorageType(); + os << llvm::formatv(fieldInfo, storage, name, structName); + } +} + +static void emitStructDef(const Record &structDef, raw_ostream &os) { + StructAttr structAttr(&structDef); + StringRef cppNamespace = structAttr.getCppNamespace(); + StringRef structName = structAttr.getStructClassName(); + mlir::tblgen::FmtContext ctx; + auto fields = structAttr.getAllFields(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + emitFactoryDef(structName, fields, os); + emitClassofDef(structName, fields, os); + emitAccessorDef(structName, fields, os); + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; +} + +static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("Struct Utility Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr"); + for (const auto *def : defs) + emitStructDef(*def, os); + + return false; +} + +// Registers the struct utility generator to mlir-tblgen. +static mlir::GenRegistration + genStructDecls("gen-struct-attr-decls", + "Generate struct utility declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitStructDecls(records, os); + }); + +// Registers the struct utility generator to mlir-tblgen. +static mlir::GenRegistration + genStructDefs("gen-struct-attr-defs", "Generate struct utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitStructDefs(records, os); + }); diff --git a/mlir/unittests/TableGen/CMakeLists.txt b/mlir/unittests/TableGen/CMakeLists.txt index aa55adbdae82..0c1227462442 100644 --- a/mlir/unittests/TableGen/CMakeLists.txt +++ b/mlir/unittests/TableGen/CMakeLists.txt @@ -3,12 +3,19 @@ mlir_tablegen(EnumsGenTest.h.inc -gen-enum-decls) mlir_tablegen(EnumsGenTest.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRTableGenEnumsIncGen) +set(LLVM_TARGET_DEFINITIONS structs.td) +mlir_tablegen(StructAttrGenTest.h.inc -gen-struct-attr-decls) +mlir_tablegen(StructAttrGenTest.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRTableGenStructAttrIncGen) + add_mlir_unittest(MLIRTableGenTests EnumsGenTest.cpp + StructAttrGenTest.cpp FormatTest.cpp ) add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen) +add_dependencies(MLIRTableGenTests MLIRTableGenStructAttrIncGen) target_link_libraries(MLIRTableGenTests PRIVATE LLVMMLIRTableGen) diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp new file mode 100644 index 000000000000..4457e6c495c2 --- /dev/null +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -0,0 +1,146 @@ +//===- StructsGenTest.cpp - TableGen StructsGen Tests ---------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringSwitch.h" +#include "gmock/gmock.h" +#include + +namespace mlir { + +// Pull in generated enum utility declarations +#include "StructAttrGenTest.h.inc" +// And definitions +#include "StructAttrGenTest.cpp.inc" +// Helper that returns an example test::TestStruct for testing its +// implementation. +static test::TestStruct getTestStruct(mlir::MLIRContext *context) { + auto integerType = mlir::IntegerType::get(32, context); + auto integerAttr = mlir::IntegerAttr::get(integerType, 127); + + auto floatType = mlir::FloatType::getF16(context); + auto floatAttr = mlir::FloatAttr::get(floatType, 0.25); + + auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType); + auto elementsAttr = + mlir::DenseElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6}); + + return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context); +} + +// Validates that test::TestStruct::classof correctly identifies a valid +// test::TestStruct. +TEST(StructsGenTest, ClassofTrue) { + mlir::MLIRContext context; + auto structAttr = getTestStruct(&context); + ASSERT_TRUE(test::TestStruct::classof(structAttr)); +} + +// Validates that test::TestStruct::classof fails when an extra attribute is in +// the class. +TEST(StructsGenTest, ClassofExtraFalse) { + mlir::MLIRContext context; + mlir::DictionaryAttr structAttr = getTestStruct(&context); + auto expectedValues = structAttr.getValue(); + ASSERT_EQ(expectedValues.size(), 3); + + // Copy the set of named attributes. + llvm::SmallVector newValues(expectedValues.begin(), + expectedValues.end()); + + // Add an extra NamedAttribute. + auto wrongId = mlir::Identifier::get("wrong", &context); + auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); + newValues.push_back(wrongAttr); + + // Make a new DictionaryAttr and validate. + auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + ASSERT_FALSE(test::TestStruct::classof(badDictionary)); +} + +// Validates that test::TestStruct::classof fails when a NamedAttribute has an +// incorrect name. +TEST(StructsGenTest, ClassofBadNameFalse) { + mlir::MLIRContext context; + mlir::DictionaryAttr structAttr = getTestStruct(&context); + auto expectedValues = structAttr.getValue(); + ASSERT_EQ(expectedValues.size(), 3); + + // Create a copy of all but the first NamedAttributes. + llvm::SmallVector newValues( + expectedValues.begin() + 1, expectedValues.end()); + + // Add a copy of the first attribute with the wrong Identifier. + auto wrongId = mlir::Identifier::get("wrong", &context); + auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); + newValues.push_back(wrongAttr); + + auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + ASSERT_FALSE(test::TestStruct::classof(badDictionary)); +} + +// Validates that test::TestStruct::classof fails when a NamedAttribute is +// missing. +TEST(StructsGenTest, ClassofMissingFalse) { + mlir::MLIRContext context; + mlir::DictionaryAttr structAttr = getTestStruct(&context); + auto expectedValues = structAttr.getValue(); + ASSERT_EQ(expectedValues.size(), 3); + + // Copy a subset of the structures Named Attributes. + llvm::SmallVector newValues( + expectedValues.begin() + 1, expectedValues.end()); + + // Make a new DictionaryAttr and validate it is not a validte TestStruct. + auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + ASSERT_FALSE(test::TestStruct::classof(badDictionary)); +} + +// Validate the accessor for the FloatAttr value. +TEST(StructsGenTest, GetFloat) { + mlir::MLIRContext context; + auto structAttr = getTestStruct(&context); + auto returnedAttr = structAttr.sample_float(); + EXPECT_EQ(returnedAttr.getValueAsDouble(), 0.25); +} + +// Validate the accessor for the IntegerAttr value. +TEST(StructsGenTest, GetInteger) { + mlir::MLIRContext context; + auto structAttr = getTestStruct(&context); + auto returnedAttr = structAttr.sample_integer(); + EXPECT_EQ(returnedAttr.getInt(), 127); +} + +// Validate the accessor for the ElementsAttr value. +TEST(StructsGenTest, GetElements) { + mlir::MLIRContext context; + auto structAttr = getTestStruct(&context); + auto returnedAttr = structAttr.sample_elements(); + auto denseAttr = returnedAttr.dyn_cast(); + ASSERT_TRUE(denseAttr); + + for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) { + EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1); + } +} + +} // namespace mlir diff --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td new file mode 100644 index 000000000000..be847aeafd17 --- /dev/null +++ b/mlir/unittests/TableGen/structs.td @@ -0,0 +1,29 @@ +//===-- structss.td - StructsGen test definition file ------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [ + StructFieldAttr<"sample_integer", I32Attr>, + StructFieldAttr<"sample_float", F32Attr>, + StructFieldAttr<"sample_elements", ElementsAttr>] > { + let description = "Structure for test data"; +}