From 81f2f4dfb2922e4f7af8bbfd8b653eda7c1f1339 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 15 Feb 2022 14:32:37 -0800 Subject: [PATCH] [PDLL] Add support for tablegen includes and importing ODS information This commit adds support for processing tablegen include files, and importing various information from ODS. This includes operations, attribute+type constraints, attribute/operation/type interfaces, etc. This will allow for much more robust tooling, and also allows for referencing ODS constructs directly within PDLL (imported interfaces can be used as constraints, operation result names can be used for member access, etc). Differential Revision: https://reviews.llvm.org/D119900 --- llvm/include/llvm/Support/SourceMgr.h | 22 ++ llvm/lib/Support/SourceMgr.cpp | 16 +- mlir/include/mlir/IR/OpBase.td | 6 +- mlir/include/mlir/TableGen/Constraint.h | 5 + mlir/include/mlir/Tools/PDLL/AST/Context.h | 13 +- mlir/include/mlir/Tools/PDLL/ODS/Constraint.h | 98 +++++ mlir/include/mlir/Tools/PDLL/ODS/Context.h | 78 ++++ mlir/include/mlir/Tools/PDLL/ODS/Dialect.h | 64 ++++ mlir/include/mlir/Tools/PDLL/ODS/Operation.h | 189 ++++++++++ mlir/lib/TableGen/Constraint.cpp | 23 ++ mlir/lib/Tools/PDLL/AST/CMakeLists.txt | 1 + mlir/lib/Tools/PDLL/AST/Context.cpp | 2 +- mlir/lib/Tools/PDLL/CMakeLists.txt | 1 + mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp | 31 +- mlir/lib/Tools/PDLL/ODS/CMakeLists.txt | 8 + mlir/lib/Tools/PDLL/ODS/Context.cpp | 174 +++++++++ mlir/lib/Tools/PDLL/ODS/Dialect.cpp | 39 ++ mlir/lib/Tools/PDLL/ODS/Operation.cpp | 26 ++ mlir/lib/Tools/PDLL/Parser/CMakeLists.txt | 6 + mlir/lib/Tools/PDLL/Parser/Parser.cpp | 344 ++++++++++++++++-- mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 20 +- .../mlir-pdll/CodeGen/MLIR/include/ops.td | 9 + .../mlir-pdll/Parser/directive-failure.pdll | 2 +- mlir/test/mlir-pdll/Parser/expr-failure.pdll | 22 +- mlir/test/mlir-pdll/Parser/expr.pdll | 21 +- .../mlir-pdll/Parser/include/interfaces.td | 5 + mlir/test/mlir-pdll/Parser/include/ops.td | 26 ++ mlir/test/mlir-pdll/Parser/include_td.pdll | 52 +++ mlir/test/mlir-pdll/Parser/stmt-failure.pdll | 24 +- mlir/tools/mlir-pdll/mlir-pdll.cpp | 24 +- 30 files changed, 1312 insertions(+), 39 deletions(-) create mode 100644 mlir/include/mlir/Tools/PDLL/ODS/Constraint.h create mode 100644 mlir/include/mlir/Tools/PDLL/ODS/Context.h create mode 100644 mlir/include/mlir/Tools/PDLL/ODS/Dialect.h create mode 100644 mlir/include/mlir/Tools/PDLL/ODS/Operation.h create mode 100644 mlir/lib/Tools/PDLL/ODS/CMakeLists.txt create mode 100644 mlir/lib/Tools/PDLL/ODS/Context.cpp create mode 100644 mlir/lib/Tools/PDLL/ODS/Dialect.cpp create mode 100644 mlir/lib/Tools/PDLL/ODS/Operation.cpp create mode 100644 mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td create mode 100644 mlir/test/mlir-pdll/Parser/include/interfaces.td create mode 100644 mlir/test/mlir-pdll/Parser/include/ops.td create mode 100644 mlir/test/mlir-pdll/Parser/include_td.pdll diff --git a/llvm/include/llvm/Support/SourceMgr.h b/llvm/include/llvm/Support/SourceMgr.h index 28716b42f4ab..fc6d651a37b0 100644 --- a/llvm/include/llvm/Support/SourceMgr.h +++ b/llvm/include/llvm/Support/SourceMgr.h @@ -100,6 +100,9 @@ public: SourceMgr &operator=(SourceMgr &&) = default; ~SourceMgr() = default; + /// Return the include directories of this source manager. + ArrayRef getIncludeDirs() const { return IncludeDirectories; } + void setIncludeDirs(const std::vector &Dirs) { IncludeDirectories = Dirs; } @@ -147,6 +150,14 @@ public: return Buffers.size(); } + /// Takes the source buffers from the given source manager and append them to + /// the current manager. + void takeSourceBuffersFrom(SourceMgr &SrcMgr) { + std::move(SrcMgr.Buffers.begin(), SrcMgr.Buffers.end(), + std::back_inserter(Buffers)); + SrcMgr.Buffers.clear(); + } + /// Search for a file with the specified name in the current directory or in /// one of the IncludeDirs. /// @@ -156,6 +167,17 @@ public: unsigned AddIncludeFile(const std::string &Filename, SMLoc IncludeLoc, std::string &IncludedFile); + /// Search for a file with the specified name in the current directory or in + /// one of the IncludeDirs, and try to open it **without** adding to the + /// SourceMgr. If the opened file is intended to be added to the source + /// manager, prefer `AddIncludeFile` instead. + /// + /// If no file is found, this returns an Error, otherwise it returns the + /// buffer of the stacked file. The full path to the included file can be + /// found in \p IncludedFile. + ErrorOr> + OpenIncludeFile(const std::string &Filename, std::string &IncludedFile); + /// Return the ID of the buffer containing the specified location. /// /// 0 is returned if the buffer is not found. diff --git a/llvm/lib/Support/SourceMgr.cpp b/llvm/lib/Support/SourceMgr.cpp index 2eb2989b200b..42982b4c8e6c 100644 --- a/llvm/lib/Support/SourceMgr.cpp +++ b/llvm/lib/Support/SourceMgr.cpp @@ -40,6 +40,17 @@ static const size_t TabStop = 8; unsigned SourceMgr::AddIncludeFile(const std::string &Filename, SMLoc IncludeLoc, std::string &IncludedFile) { + ErrorOr> NewBufOrErr = + OpenIncludeFile(Filename, IncludedFile); + if (!NewBufOrErr) + return 0; + + return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc); +} + +ErrorOr> +SourceMgr::OpenIncludeFile(const std::string &Filename, + std::string &IncludedFile) { IncludedFile = Filename; ErrorOr> NewBufOrErr = MemoryBuffer::getFile(IncludedFile); @@ -52,10 +63,7 @@ unsigned SourceMgr::AddIncludeFile(const std::string &Filename, NewBufOrErr = MemoryBuffer::getFile(IncludedFile); } - if (!NewBufOrErr) - return 0; - - return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc); + return NewBufOrErr; } unsigned SourceMgr::FindBufferContainingLoc(SMLoc Loc) const { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 11bc1a639794..fa40ca7819c0 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -363,7 +363,8 @@ class DialectType : TypeConstraint { +class Variadic : TypeConstraint { Type baseType = type; } @@ -379,7 +380,8 @@ class VariadicOfVariadic // An optional type constraint. It expands to either zero or one of the base // type. This class is used for supporting optional operands/results. -class Optional : TypeConstraint { +class Optional : TypeConstraint { Type baseType = type; } diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h index 4e099aa33416..b24b9b7459ee 100644 --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -54,6 +54,11 @@ public: // description is not provided, returns the TableGen def name. StringRef getSummary() const; + /// Returns the name of the TablGen def of this constraint. In some cases + /// where the current def is anonymous, the name of the base def is used (e.g. + /// `Optional<>`/`Variadic<>` type constraints). + StringRef getDefName() const; + Kind getKind() const { return kind; } protected: diff --git a/mlir/include/mlir/Tools/PDLL/AST/Context.h b/mlir/include/mlir/Tools/PDLL/AST/Context.h index 978158951cff..f9a9424e125e 100644 --- a/mlir/include/mlir/Tools/PDLL/AST/Context.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Context.h @@ -14,13 +14,17 @@ namespace mlir { namespace pdll { +namespace ods { +class Context; +} // namespace ods + namespace ast { /// This class represents the main context of the PDLL AST. It handles /// allocating all of the AST constructs, and manages all state necessary for /// the AST. class Context { public: - Context(); + explicit Context(ods::Context &odsContext); Context(const Context &) = delete; Context &operator=(const Context &) = delete; @@ -30,6 +34,10 @@ public: /// Return the storage uniquer used for AST types. StorageUniquer &getTypeUniquer() { return typeUniquer; } + /// Return the ODS context used by the AST. + ods::Context &getODSContext() { return odsContext; } + const ods::Context &getODSContext() const { return odsContext; } + /// Return the diagnostic engine of this context. DiagnosticEngine &getDiagEngine() { return diagEngine; } @@ -37,6 +45,9 @@ private: /// The diagnostic engine of this AST context. DiagnosticEngine diagEngine; + /// The ODS context used by the AST. + ods::Context &odsContext; + /// The allocator used for AST nodes, and other entities allocated within the /// context. llvm::BumpPtrAllocator allocator; diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h new file mode 100644 index 000000000000..270330966de4 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h @@ -0,0 +1,98 @@ +//===- Constraint.h - MLIR PDLL ODS Constraints -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a PDLL description of ODS constraints. These are used to +// support the import of constraints defined outside of PDLL. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_ +#define MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace pdll { +namespace ods { + +//===----------------------------------------------------------------------===// +// Constraint +//===----------------------------------------------------------------------===// + +/// This class represents a generic ODS constraint. +class Constraint { +public: + /// Return the name of this constraint. + StringRef getName() const { return name; } + + /// Return the summary of this constraint. + StringRef getSummary() const { return summary; } + +protected: + Constraint(StringRef name, StringRef summary) + : name(name.str()), summary(summary.str()) {} + Constraint(const Constraint &) = delete; + +private: + /// The name of the constraint. + std::string name; + /// A summary of the constraint. + std::string summary; +}; + +//===----------------------------------------------------------------------===// +// AttributeConstraint +//===----------------------------------------------------------------------===// + +/// This class represents a generic ODS Attribute constraint. +class AttributeConstraint : public Constraint { +public: + /// Return the name of the underlying c++ class of this constraint. + StringRef getCppClass() const { return cppClassName; } + +private: + AttributeConstraint(StringRef name, StringRef summary, StringRef cppClassName) + : Constraint(name, summary), cppClassName(cppClassName.str()) {} + + /// The c++ class of the constraint. + std::string cppClassName; + + /// Allow access to the constructor. + friend class Context; +}; + +//===----------------------------------------------------------------------===// +// TypeConstraint +//===----------------------------------------------------------------------===// + +/// This class represents a generic ODS Type constraint. +class TypeConstraint : public Constraint { +public: + /// Return the name of the underlying c++ class of this constraint. + StringRef getCppClass() const { return cppClassName; } + +private: + TypeConstraint(StringRef name, StringRef summary, StringRef cppClassName) + : Constraint(name, summary), cppClassName(cppClassName.str()) {} + + /// The c++ class of the constraint. + std::string cppClassName; + + /// Allow access to the constructor. + friend class Context; +}; + +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_ diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h new file mode 100644 index 000000000000..d0955ab62d8e --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h @@ -0,0 +1,78 @@ +//===- Context.h - MLIR PDLL ODS Context ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_CONTEXT_H_ +#define MLIR_TOOLS_PDLL_ODS_CONTEXT_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +namespace llvm { +class SMLoc; +} // namespace llvm + +namespace mlir { +namespace pdll { +namespace ods { +class AttributeConstraint; +class Dialect; +class Operation; +class TypeConstraint; + +/// This class contains all of the registered ODS operation classes. +class Context { +public: + Context(); + ~Context(); + + /// Insert a new attribute constraint with the context. Returns the inserted + /// constraint, or a previously inserted constraint with the same name. + const AttributeConstraint &insertAttributeConstraint(StringRef name, + StringRef summary, + StringRef cppClass); + + /// Insert a new type constraint with the context. Returns the inserted + /// constraint, or a previously inserted constraint with the same name. + const TypeConstraint &insertTypeConstraint(StringRef name, StringRef summary, + StringRef cppClass); + + /// Insert a new dialect with the context. Returns the inserted dialect, or a + /// previously inserted dialect with the same name. + Dialect &insertDialect(StringRef name); + + /// Lookup a dialect registered with the given name, or null if no dialect + /// with that name was inserted. + const Dialect *lookupDialect(StringRef name) const; + + /// Insert a new operation with the context. Returns the inserted operation, + /// and a boolean indicating if the operation newly inserted (false if the + /// operation already existed). + std::pair + insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + + /// Lookup an operation registered with the given name, or null if no + /// operation with that name is registered. + const Operation *lookupOperation(StringRef name) const; + + /// Print the contents of this context to the provided stream. + void print(raw_ostream &os) const; + +private: + llvm::StringMap> attributeConstraints; + llvm::StringMap> dialects; + llvm::StringMap> typeConstraints; +}; +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_PDL_pdll_ODS_CONTEXT_H_ diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h new file mode 100644 index 000000000000..f75d497867b8 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h @@ -0,0 +1,64 @@ +//===- Dialect.h - PDLL ODS Dialect -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_DIALECT_H_ +#define MLIR_TOOLS_PDLL_ODS_DIALECT_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace pdll { +namespace ods { +class Operation; + +/// This class represents an ODS dialect, and contains information on the +/// constructs held within the dialect. +class Dialect { +public: + ~Dialect(); + + /// Return the name of this dialect. + StringRef getName() const { return name; } + + /// Insert a new operation with the dialect. Returns the inserted operation, + /// and a boolean indicating if the operation newly inserted (false if the + /// operation already existed). + std::pair + insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + + /// Lookup an operation registered with the given name, or null if no + /// operation with that name is registered. + Operation *lookupOperation(StringRef name) const; + + /// Return a map of all of the operations registered to this dialect. + const llvm::StringMap> &getOperations() const { + return operations; + } + +private: + explicit Dialect(StringRef name); + + /// The name of the dialect. + std::string name; + + /// The operations defined by the dialect. + llvm::StringMap> operations; + + /// Allow access to the constructor. + friend class Context; +}; +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_ODS_DIALECT_H_ diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h new file mode 100644 index 000000000000..c5b86e1733d0 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h @@ -0,0 +1,189 @@ +//===- Operation.h - MLIR PDLL ODS Operation --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_OPERATION_H_ +#define MLIR_TOOLS_PDLL_ODS_OPERATION_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" + +namespace mlir { +namespace pdll { +namespace ods { +class AttributeConstraint; +class TypeConstraint; + +//===----------------------------------------------------------------------===// +// VariableLengthKind +//===----------------------------------------------------------------------===// + +enum VariableLengthKind { Single, Optional, Variadic }; + +//===----------------------------------------------------------------------===// +// Attribute +//===----------------------------------------------------------------------===// + +/// This class provides an ODS representation of a specific operation attribute. +/// This includes the name, optionality, and more. +class Attribute { +public: + /// Return the name of this operand. + StringRef getName() const { return name; } + + /// Return true if this attribute is optional. + bool isOptional() const { return optional; } + + /// Return the constraint of this attribute. + const AttributeConstraint &getConstraint() const { return constraint; } + +private: + Attribute(StringRef name, bool optional, + const AttributeConstraint &constraint) + : name(name.str()), optional(optional), constraint(constraint) {} + + /// The ODS name of the attribute. + std::string name; + + /// A flag indicating if the attribute is optional. + bool optional; + + /// The ODS constraint of this attribute. + const AttributeConstraint &constraint; + + /// Allow access to the private constructor. + friend class Operation; +}; + +//===----------------------------------------------------------------------===// +// OperandOrResult +//===----------------------------------------------------------------------===// + +/// This class provides an ODS representation of a specific operation operand or +/// result. This includes the name, variable length flags, and more. +class OperandOrResult { +public: + /// Return the name of this value. + StringRef getName() const { return name; } + + /// Returns true if this value is variadic (Note this is false if the value is + /// Optional). + bool isVariadic() const { + return variableLengthKind == VariableLengthKind::Variadic; + } + + /// Returns the variable length kind of this value. + VariableLengthKind getVariableLengthKind() const { + return variableLengthKind; + } + + /// Return the constraint of this value. + const TypeConstraint &getConstraint() const { return constraint; } + +private: + OperandOrResult(StringRef name, VariableLengthKind variableLengthKind, + const TypeConstraint &constraint) + : name(name.str()), variableLengthKind(variableLengthKind), + constraint(constraint) {} + + /// The ODS name of this value. + std::string name; + + /// The variable length kind of this value. + VariableLengthKind variableLengthKind; + + /// The ODS constraint of this value. + const TypeConstraint &constraint; + + /// Allow access to the private constructor. + friend class Operation; +}; + +//===----------------------------------------------------------------------===// +// Operation +//===----------------------------------------------------------------------===// + +/// This class provides an ODS representation of a specific operation. This +/// includes all of the information necessary for use by the PDL frontend for +/// generating code for a pattern rewrite. +class Operation { +public: + /// Return the source location of this operation. + SMRange getLoc() const { return location; } + + /// Append an attribute to this operation. + void appendAttribute(StringRef name, bool optional, + const AttributeConstraint &constraint) { + attributes.emplace_back(Attribute(name, optional, constraint)); + } + + /// Append an operand to this operation. + void appendOperand(StringRef name, VariableLengthKind variableLengthKind, + const TypeConstraint &constraint) { + operands.emplace_back( + OperandOrResult(name, variableLengthKind, constraint)); + } + + /// Append a result to this operation. + void appendResult(StringRef name, VariableLengthKind variableLengthKind, + const TypeConstraint &constraint) { + results.emplace_back(OperandOrResult(name, variableLengthKind, constraint)); + } + + /// Returns the name of the operation. + StringRef getName() const { return name; } + + /// Returns the summary of the operation. + StringRef getSummary() const { return summary; } + + /// Returns the description of the operation. + StringRef getDescription() const { return description; } + + /// Returns the attributes of this operation. + ArrayRef getAttributes() const { return attributes; } + + /// Returns the operands of this operation. + ArrayRef getOperands() const { return operands; } + + /// Returns the results of this operation. + ArrayRef getResults() const { return results; } + +private: + Operation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + + /// The name of the operation. + std::string name; + + /// The documentation of the operation. + std::string summary; + std::string description; + + /// The source location of this operation. + SMRange location; + + /// The operands of the operation. + SmallVector operands; + + /// The results of the operation. + SmallVector results; + + /// The attributes of the operation. + SmallVector attributes; + + /// Allow access to the private constructor. + friend class Dialect; +}; +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_ODS_OPERATION_H_ diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp index 759e28fbc903..249c22eebbfb 100644 --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -57,6 +57,29 @@ StringRef Constraint::getSummary() const { return def->getName(); } +StringRef Constraint::getDefName() const { + // Functor used to check a base def in the case where the current def is + // anonymous. + auto checkBaseDefFn = [&](StringRef baseName) { + if (const auto *init = dyn_cast(def->getValueInit(baseName))) + return Constraint(init->getDef(), kind).getDefName(); + return def->getName(); + }; + + switch (kind) { + case CK_Attr: + if (def->isAnonymous()) + return checkBaseDefFn("baseAttr"); + return def->getName(); + case CK_Type: + if (def->isAnonymous()) + return checkBaseDefFn("baseType"); + return def->getName(); + default: + return def->getName(); + } +} + AppliedConstraint::AppliedConstraint(Constraint &&constraint, llvm::StringRef self, std::vector &&entities) diff --git a/mlir/lib/Tools/PDLL/AST/CMakeLists.txt b/mlir/lib/Tools/PDLL/AST/CMakeLists.txt index 3eb9c62a37e1..5e67ee02b9c7 100644 --- a/mlir/lib/Tools/PDLL/AST/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/AST/CMakeLists.txt @@ -6,5 +6,6 @@ add_mlir_library(MLIRPDLLAST Types.cpp LINK_LIBS PUBLIC + MLIRPDLLODS MLIRSupport ) diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp index 09ae0e6ad6e0..6f2e4cd58082 100644 --- a/mlir/lib/Tools/PDLL/AST/Context.cpp +++ b/mlir/lib/Tools/PDLL/AST/Context.cpp @@ -12,7 +12,7 @@ using namespace mlir; using namespace mlir::pdll::ast; -Context::Context() { +Context::Context(ods::Context &odsContext) : odsContext(odsContext) { typeUniquer.registerSingletonStorageType(); typeUniquer.registerSingletonStorageType(); typeUniquer.registerSingletonStorageType(); diff --git a/mlir/lib/Tools/PDLL/CMakeLists.txt b/mlir/lib/Tools/PDLL/CMakeLists.txt index ac83f5e5fae7..522429b282e3 100644 --- a/mlir/lib/Tools/PDLL/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(AST) add_subdirectory(CodeGen) +add_subdirectory(ODS) add_subdirectory(Parser) diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp index 81b719c63365..1f8466f5c121 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -17,6 +17,8 @@ #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Types.h" +#include "mlir/Tools/PDLL/ODS/Context.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -33,7 +35,8 @@ class CodeGen { public: CodeGen(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr) - : builder(mlirContext), sourceMgr(sourceMgr) { + : builder(mlirContext), odsContext(context.getODSContext()), + sourceMgr(sourceMgr) { // Make sure that the PDL dialect is loaded. mlirContext->loadDialect(); } @@ -117,6 +120,9 @@ private: llvm::ScopedHashTable>; VariableMapTy variables; + /// A reference to the ODS context. + const ods::Context &odsContext; + /// The source manager of the PDLL ast. const llvm::SourceMgr &sourceMgr; }; @@ -435,7 +441,28 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { builder.getI32IntegerAttr(0)); return builder.create(loc, mlirType, parentExprs[0]); } - llvm_unreachable("unhandled operation member access expression"); + + assert(opType.getName() && "expected valid operation name"); + const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName()); + assert(odsOp && "expected valid ODS operation information"); + + // Find the result with the member name or by index. + ArrayRef results = odsOp->getResults(); + unsigned resultIndex = results.size(); + if (llvm::isDigit(name[0])) { + name.getAsInteger(/*Radix=*/10, resultIndex); + } else { + auto findFn = [&](const ods::OperandOrResult &result) { + return result.getName() == name; + }; + resultIndex = llvm::find_if(results, findFn) - results.begin(); + } + assert(resultIndex < results.size() && "invalid result index"); + + // Generate the result access. + IntegerAttr index = builder.getI32IntegerAttr(resultIndex); + return builder.create(loc, genType(expr->getType()), + parentExprs[0], index); } // Handle tuple based member access. diff --git a/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt new file mode 100644 index 000000000000..3abbaab33ab3 --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_library(MLIRPDLLODS + Context.cpp + Dialect.cpp + Operation.cpp + + LINK_LIBS PUBLIC + MLIRSupport + ) diff --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp new file mode 100644 index 000000000000..7684da5e05fa --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp @@ -0,0 +1,174 @@ +//===- Context.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/ODS/Context.h" +#include "mlir/Tools/PDLL/ODS/Constraint.h" +#include "mlir/Tools/PDLL/ODS/Dialect.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" +#include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::pdll::ods; + +//===----------------------------------------------------------------------===// +// Context +//===----------------------------------------------------------------------===// + +Context::Context() = default; +Context::~Context() = default; + +const AttributeConstraint & +Context::insertAttributeConstraint(StringRef name, StringRef summary, + StringRef cppClass) { + std::unique_ptr &constraint = attributeConstraints[name]; + if (!constraint) { + constraint.reset(new AttributeConstraint(name, summary, cppClass)); + } else { + assert(constraint->getCppClass() == cppClass && + constraint->getSummary() == summary && + "constraint with the same name was already registered with a " + "different class"); + } + return *constraint; +} + +const TypeConstraint &Context::insertTypeConstraint(StringRef name, + StringRef summary, + StringRef cppClass) { + std::unique_ptr &constraint = typeConstraints[name]; + if (!constraint) + constraint.reset(new TypeConstraint(name, summary, cppClass)); + return *constraint; +} + +Dialect &Context::insertDialect(StringRef name) { + std::unique_ptr &dialect = dialects[name]; + if (!dialect) + dialect.reset(new Dialect(name)); + return *dialect; +} + +const Dialect *Context::lookupDialect(StringRef name) const { + auto it = dialects.find(name); + return it == dialects.end() ? nullptr : &*it->second; +} + +std::pair Context::insertOperation(StringRef name, + StringRef summary, + StringRef desc, + SMLoc loc) { + std::pair dialectAndName = name.split('.'); + return insertDialect(dialectAndName.first) + .insertOperation(name, summary, desc, loc); +} + +const Operation *Context::lookupOperation(StringRef name) const { + std::pair dialectAndName = name.split('.'); + if (const Dialect *dialect = lookupDialect(dialectAndName.first)) + return dialect->lookupOperation(name); + return nullptr; +} + +template +SmallVector sortMapByName(const llvm::StringMap> &map) { + SmallVector storage; + for (auto &entry : map) + storage.push_back(entry.second.get()); + llvm::sort(storage, [](const auto &lhs, const auto &rhs) { + return lhs->getName() < rhs->getName(); + }); + return storage; +} + +void Context::print(raw_ostream &os) const { + auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) { + switch (kind) { + case VariableLengthKind::Optional: + os << "Optional<" << cst << ">"; + break; + case VariableLengthKind::Single: + os << cst; + break; + case VariableLengthKind::Variadic: + os << "Variadic<" << cst << ">"; + break; + } + }; + + llvm::ScopedPrinter printer(os); + llvm::DictScope odsScope(printer, "ODSContext"); + for (const Dialect *dialect : sortMapByName(dialects)) { + printer.startLine() << "Dialect `" << dialect->getName() << "` {\n"; + printer.indent(); + + for (const Operation *op : sortMapByName(dialect->getOperations())) { + printer.startLine() << "Operation `" << op->getName() << "` {\n"; + printer.indent(); + + // Attributes. + ArrayRef attributes = op->getAttributes(); + if (!attributes.empty()) { + printer.startLine() << "Attributes { "; + llvm::interleaveComma(attributes, os, [&](const Attribute &attr) { + os << attr.getName() << " : "; + + auto kind = attr.isOptional() ? VariableLengthKind::Optional + : VariableLengthKind::Single; + printVariableLengthCst(attr.getConstraint().getName(), kind); + }); + os << " }\n"; + } + + // Operands. + ArrayRef operands = op->getOperands(); + if (!operands.empty()) { + printer.startLine() << "Operands { "; + llvm::interleaveComma( + operands, os, [&](const OperandOrResult &operand) { + os << operand.getName() << " : "; + printVariableLengthCst(operand.getConstraint().getName(), + operand.getVariableLengthKind()); + }); + os << " }\n"; + } + + // Results. + ArrayRef results = op->getResults(); + if (!results.empty()) { + printer.startLine() << "Results { "; + llvm::interleaveComma(results, os, [&](const OperandOrResult &result) { + os << result.getName() << " : "; + printVariableLengthCst(result.getConstraint().getName(), + result.getVariableLengthKind()); + }); + os << " }\n"; + } + + printer.objectEnd(); + } + printer.objectEnd(); + } + for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) { + printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n"; + printer.indent(); + + printer.startLine() << "Summary: " << cst->getSummary() << "\n"; + printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; + printer.objectEnd(); + } + for (const TypeConstraint *cst : sortMapByName(typeConstraints)) { + printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n"; + printer.indent(); + + printer.startLine() << "Summary: " << cst->getSummary() << "\n"; + printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; + printer.objectEnd(); + } + printer.objectEnd(); +} diff --git a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp new file mode 100644 index 000000000000..ce9c23421c0e --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp @@ -0,0 +1,39 @@ +//===- Dialect.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/ODS/Dialect.h" +#include "mlir/Tools/PDLL/ODS/Constraint.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::pdll::ods; + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +Dialect::Dialect(StringRef name) : name(name.str()) {} +Dialect::~Dialect() = default; + +std::pair Dialect::insertOperation(StringRef name, + StringRef summary, + StringRef desc, + llvm::SMLoc loc) { + std::unique_ptr &operation = operations[name]; + if (operation) + return std::make_pair(&*operation, /*wasInserted*/ false); + + operation.reset(new Operation(name, summary, desc, loc)); + return std::make_pair(&*operation, /*wasInserted*/ true); +} + +Operation *Dialect::lookupOperation(StringRef name) const { + auto it = operations.find(name); + return it != operations.end() ? it->second.get() : nullptr; +} diff --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp new file mode 100644 index 000000000000..121c6c8c4c88 --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp @@ -0,0 +1,26 @@ +//===- Operation.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/ODS/Operation.h" +#include "mlir/Support/IndentedOstream.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::pdll::ods; + +//===----------------------------------------------------------------------===// +// Operation +//===----------------------------------------------------------------------===// + +Operation::Operation(StringRef name, StringRef summary, StringRef desc, + llvm::SMLoc loc) + : name(name.str()), summary(summary.str()), + location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) { + llvm::raw_string_ostream descOS(description); + raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t")); +} diff --git a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt index fb933cdfc9f7..5d466cf15885 100644 --- a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_LINK_COMPONENTS + Support + TableGen +) + add_mlir_library(MLIRPDLLParser Lexer.cpp Parser.cpp @@ -5,4 +10,5 @@ add_mlir_library(MLIRPDLLParser LINK_LIBS PUBLIC MLIRPDLLAST MLIRSupport + MLIRTableGen ) diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index da7283799d95..3da77839c9cc 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -9,15 +9,26 @@ #include "mlir/Tools/PDLL/Parser/Parser.h" #include "Lexer.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/Argument.h" +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Constraint.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/Operator.h" #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Diagnostic.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Types.h" +#include "mlir/Tools/PDLL/ODS/Constraint.h" +#include "mlir/Tools/PDLL/ODS/Context.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Parser.h" #include using namespace mlir; @@ -36,7 +47,8 @@ public: valueTy(ast::ValueType::get(ctx)), valueRangeTy(ast::ValueRangeType::get(ctx)), typeTy(ast::TypeType::get(ctx)), - typeRangeTy(ast::TypeRangeType::get(ctx)) {} + typeRangeTy(ast::TypeRangeType::get(ctx)), + attrTy(ast::AttributeType::get(ctx)) {} /// Try to parse a new module. Returns nullptr in the case of failure. FailureOr parseModule(); @@ -78,7 +90,7 @@ private: void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } /// Parse the body of an AST module. - LogicalResult parseModuleBody(SmallVector &decls); + LogicalResult parseModuleBody(SmallVectorImpl &decls); /// Try to convert the given expression to `type`. Returns failure and emits /// an error if a conversion is not viable. On failure, `noteAttachFn` is @@ -92,11 +104,34 @@ private: /// typed expression. ast::Expr *convertOpToValue(const ast::Expr *opExpr); + /// Lookup ODS information for the given operation, returns nullptr if no + /// information is found. + const ods::Operation *lookupODSOperation(Optional opName) { + return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; + } + //===--------------------------------------------------------------------===// // Directives - LogicalResult parseDirective(SmallVector &decls); - LogicalResult parseInclude(SmallVector &decls); + LogicalResult parseDirective(SmallVectorImpl &decls); + LogicalResult parseInclude(SmallVectorImpl &decls); + LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, + SmallVectorImpl &decls); + + /// Process the records of a parsed tablegen include file. + void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, + SmallVectorImpl &decls); + + /// Create a user defined native constraint for a constraint imported from + /// ODS. + template + ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, + StringRef codeBlock, SMRange loc, + ast::Type type); + template + ast::Decl * + createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, + SMRange loc, ast::Type type); //===--------------------------------------------------------------------===// // Decls @@ -340,13 +375,16 @@ private: MutableArrayRef results); LogicalResult validateOperationOperands(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef operands); LogicalResult validateOperationResults(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef results); - LogicalResult - validateOperationOperandsOrResults(SMRange loc, Optional name, - MutableArrayRef values, - ast::Type singleTy, ast::Type rangeTy); + LogicalResult validateOperationOperandsOrResults( + StringRef groupName, SMRange loc, Optional odsOpLoc, + Optional name, MutableArrayRef values, + ArrayRef odsValues, ast::Type singleTy, + ast::Type rangeTy); FailureOr createTupleExpr(SMRange loc, ArrayRef elements, ArrayRef elementNames); @@ -440,6 +478,7 @@ private: /// Cached types to simplify verification and expression creation. ast::Type valueTy, valueRangeTy; ast::Type typeTy, typeRangeTy; + ast::Type attrTy; /// A counter used when naming anonymous constraints and rewrites. unsigned anonymousDeclNameCounter = 0; @@ -459,7 +498,7 @@ FailureOr Parser::parseModule() { return ast::Module::create(ctx, moduleLoc, decls); } -LogicalResult Parser::parseModuleBody(SmallVector &decls) { +LogicalResult Parser::parseModuleBody(SmallVectorImpl &decls) { while (curToken.isNot(Token::eof)) { if (curToken.is(Token::directive)) { if (failed(parseDirective(decls))) @@ -516,6 +555,32 @@ LogicalResult Parser::convertExpressionTo( // Allow conversion to a single value by constraining the result range. if (type == valueTy) { + // If the operation is registered, we can verify if it can ever have a + // single result. + Optional opName = exprOpType.getName(); + if (const ods::Operation *odsOp = lookupODSOperation(opName)) { + if (odsOp->getResults().empty()) { + return emitConvertError()->attachNote( + llvm::formatv("see the definition of `{0}`, which was defined " + "with zero results", + odsOp->getName()), + odsOp->getLoc()); + } + + unsigned numSingleResults = llvm::count_if( + odsOp->getResults(), [](const ods::OperandOrResult &result) { + return result.getVariableLengthKind() == + ods::VariableLengthKind::Single; + }); + if (numSingleResults > 1) { + return emitConvertError()->attachNote( + llvm::formatv("see the definition of `{0}`, which was defined " + "with at least {1} results", + odsOp->getName(), numSingleResults), + odsOp->getLoc()); + } + } + expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, valueTy); return success(); @@ -569,7 +634,7 @@ LogicalResult Parser::convertExpressionTo( //===----------------------------------------------------------------------===// // Directives -LogicalResult Parser::parseDirective(SmallVector &decls) { +LogicalResult Parser::parseDirective(SmallVectorImpl &decls) { StringRef directive = curToken.getSpelling(); if (directive == "#include") return parseInclude(decls); @@ -577,7 +642,7 @@ LogicalResult Parser::parseDirective(SmallVector &decls) { return emitError("unknown directive `" + directive + "`"); } -LogicalResult Parser::parseInclude(SmallVector &decls) { +LogicalResult Parser::parseInclude(SmallVectorImpl &decls) { SMRange loc = curToken.getLoc(); consumeToken(Token::directive); @@ -607,7 +672,193 @@ LogicalResult Parser::parseInclude(SmallVector &decls) { return result; } - return emitError(fileLoc, "expected include filename to end with `.pdll`"); + // Otherwise, this must be a `.td` include. + if (filename.endswith(".td")) + return parseTdInclude(filename, fileLoc, decls); + + return emitError(fileLoc, + "expected include filename to end with `.pdll` or `.td`"); +} + +LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, + SmallVectorImpl &decls) { + llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); + + // This class provides a context argument for the llvm::SourceMgr diagnostic + // handler. + struct DiagHandlerContext { + Parser &parser; + StringRef filename; + llvm::SMRange loc; + } handlerContext{*this, filename, fileLoc}; + + // Set the diagnostic handler for the tablegen source manager. + llvm::SrcMgr.setDiagHandler( + [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { + auto *ctx = reinterpret_cast(rawHandlerContext); + (void)ctx->parser.emitError( + ctx->loc, + llvm::formatv("error while processing include file `{0}`: {1}", + ctx->filename, diag.getMessage())); + }, + &handlerContext); + + // Use the source manager to open the file, but don't yet add it. + std::string includedFile; + llvm::ErrorOr> includeBuffer = + parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); + if (!includeBuffer) + return emitError(fileLoc, "unable to open include file `" + filename + "`"); + + auto processFn = [&](llvm::RecordKeeper &records) { + processTdIncludeRecords(records, decls); + + // After we are done processing, move all of the tablegen source buffers to + // the main parser source mgr. This allows for directly using source + // locations from the .td files without needing to remap them. + parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr); + return false; + }; + if (llvm::TableGenParseFile(std::move(*includeBuffer), + parserSrcMgr.getIncludeDirs(), processFn)) + return failure(); + + return success(); +} + +void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, + SmallVectorImpl &decls) { + // Return the length kind of the given value. + auto getLengthKind = [](const auto &value) { + if (value.isOptional()) + return ods::VariableLengthKind::Optional; + return value.isVariadic() ? ods::VariableLengthKind::Variadic + : ods::VariableLengthKind::Single; + }; + + // Insert a type constraint into the ODS context. + ods::Context &odsContext = ctx.getODSContext(); + auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) + -> const ods::TypeConstraint & { + return odsContext.insertTypeConstraint(cst.constraint.getDefName(), + cst.constraint.getSummary(), + cst.constraint.getCPPClassName()); + }; + auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { + return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; + }; + + // Process the parsed tablegen records to build ODS information. + /// Operations. + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { + tblgen::Operator op(def); + + bool inserted = false; + ods::Operation *odsOp = nullptr; + std::tie(odsOp, inserted) = + odsContext.insertOperation(op.getOperationName(), op.getSummary(), + op.getDescription(), op.getLoc().front()); + + // Ignore operations that have already been added. + if (!inserted) + continue; + + for (const tblgen::NamedAttribute &attr : op.getAttributes()) { + odsOp->appendAttribute( + attr.name, attr.attr.isOptional(), + odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(), + attr.attr.getSummary(), + attr.attr.getStorageType())); + } + for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { + odsOp->appendOperand(operand.name, getLengthKind(operand), + addTypeConstraint(operand)); + } + for (const tblgen::NamedTypeConstraint &result : op.getResults()) { + odsOp->appendResult(result.name, getLengthKind(result), + addTypeConstraint(result)); + } + } + /// Attr constraints. + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { + if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { + decls.push_back( + createODSNativePDLLConstraintDecl( + tblgen::AttrConstraint(def), + convertLocToRange(def->getLoc().front()), attrTy)); + } + } + /// Type constraints. + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { + if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { + decls.push_back( + createODSNativePDLLConstraintDecl( + tblgen::TypeConstraint(def), + convertLocToRange(def->getLoc().front()), typeTy)); + } + } + /// Interfaces. + ast::Type opTy = ast::OperationType::get(ctx); + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { + StringRef name = def->getName(); + if (def->isAnonymous() || curDeclScope->lookup(name) || + def->isSubClassOf("DeclareInterfaceMethods")) + continue; + SMRange loc = convertLocToRange(def->getLoc().front()); + + StringRef className = def->getValueAsString("cppClassName"); + StringRef cppNamespace = def->getValueAsString("cppNamespace"); + std::string codeBlock = + llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) + .str(); + + if (def->isSubClassOf("OpInterface")) { + decls.push_back(createODSNativePDLLConstraintDecl( + name, codeBlock, loc, opTy)); + } else if (def->isSubClassOf("AttrInterface")) { + decls.push_back( + createODSNativePDLLConstraintDecl( + name, codeBlock, loc, attrTy)); + } else if (def->isSubClassOf("TypeInterface")) { + decls.push_back( + createODSNativePDLLConstraintDecl( + name, codeBlock, loc, typeTy)); + } + } +} + +template +ast::Decl * +Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, + SMRange loc, ast::Type type) { + // Build the single input parameter. + ast::DeclScope *argScope = pushDeclScope(); + auto *paramVar = ast::VariableDecl::create( + ctx, ast::Name::create(ctx, "self", loc), type, + /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); + argScope->add(paramVar); + popDeclScope(); + + // Build the native constraint. + auto *constraintDecl = ast::UserConstraintDecl::createNative( + ctx, ast::Name::create(ctx, name, loc), paramVar, + /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); + curDeclScope->add(constraintDecl); + return constraintDecl; +} + +template +ast::Decl * +Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, + SMRange loc, ast::Type type) { + // Format the condition template. + tblgen::FmtContext fmtContext; + fmtContext.withSelf("self"); + std::string codeBlock = + tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); + + return createODSNativePDLLConstraintDecl(constraint.getDefName(), + codeBlock, loc, type); } //===----------------------------------------------------------------------===// @@ -2302,9 +2553,29 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, StringRef name, SMRange loc) { ast::Type parentType = parentExpr->getType(); - if (parentType.isa()) { + if (ast::OperationType opType = parentType.dyn_cast()) { if (name == ast::AllResultsMemberAccessExpr::getMemberName()) return valueRangeTy; + + // Verify member access based on the operation type. + if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { + auto results = odsOp->getResults(); + + // Handle indexed results. + unsigned index = 0; + if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && + index < results.size()) { + return results[index].isVariadic() ? valueRangeTy : valueTy; + } + + // Handle named results. + const auto *it = llvm::find_if(results, [&](const auto &result) { + return result.getName() == name; + }); + if (it != results.end()) + return it->isVariadic() ? valueRangeTy : valueTy; + } + } else if (auto tupleType = parentType.dyn_cast()) { // Handle indexed results. unsigned index = 0; @@ -2331,9 +2602,10 @@ FailureOr Parser::createOperationExpr( MutableArrayRef attributes, MutableArrayRef results) { Optional opNameRef = name->getName(); + const ods::Operation *odsOp = lookupODSOperation(opNameRef); // Verify the inputs operands. - if (failed(validateOperationOperands(loc, opNameRef, operands))) + if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) return failure(); // Verify the attribute list. @@ -2348,7 +2620,7 @@ FailureOr Parser::createOperationExpr( } // Verify the result types. - if (failed(validateOperationResults(loc, opNameRef, results))) + if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) return failure(); return ast::OperationExpr::create(ctx, loc, name, operands, results, @@ -2357,21 +2629,28 @@ FailureOr Parser::createOperationExpr( LogicalResult Parser::validateOperationOperands(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef operands) { - return validateOperationOperandsOrResults(loc, name, operands, valueTy, - valueRangeTy); + return validateOperationOperandsOrResults( + "operand", loc, odsOp ? odsOp->getLoc() : Optional(), name, + operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, + valueRangeTy); } LogicalResult Parser::validateOperationResults(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef results) { - return validateOperationOperandsOrResults(loc, name, results, typeTy, - typeRangeTy); + return validateOperationOperandsOrResults( + "result", loc, odsOp ? odsOp->getLoc() : Optional(), name, + results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); } LogicalResult Parser::validateOperationOperandsOrResults( - SMRange loc, Optional name, MutableArrayRef values, - ast::Type singleTy, ast::Type rangeTy) { + StringRef groupName, SMRange loc, Optional odsOpLoc, + Optional name, MutableArrayRef values, + ArrayRef odsValues, ast::Type singleTy, + ast::Type rangeTy) { // All operation types accept a single range parameter. if (values.size() == 1) { if (failed(convertExpressionTo(values[0], rangeTy))) @@ -2379,6 +2658,29 @@ LogicalResult Parser::validateOperationOperandsOrResults( return success(); } + /// If the operation has ODS information, we can more accurately verify the + /// values. + if (odsOpLoc) { + if (odsValues.size() != values.size()) { + return emitErrorAndNote( + loc, + llvm::formatv("invalid number of {0} groups for `{1}`; expected " + "{2}, but got {3}", + groupName, *name, odsValues.size(), values.size()), + *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); + } + auto diagFn = [&](ast::Diagnostic &diag) { + diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), + *odsOpLoc); + }; + for (unsigned i = 0, e = values.size(); i < e; ++i) { + ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; + if (failed(convertExpressionTo(values[i], expectedType, diagFn))) + return failure(); + } + return success(); + } + // Otherwise, accept the value groups as they have been defined and just // ensure they are one of the expected types. for (ast::Expr *&valueExpr : values) { diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 3e652ad8b49e..e8db46c1dd3d 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -1,4 +1,4 @@ -// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s +// RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x mlir | FileCheck %s //===----------------------------------------------------------------------===// // AttributeExpr @@ -55,6 +55,24 @@ Pattern OpAllResultMemberAccess { // ----- +// Handle implicit "named" operation results access. + +#include "include/ops.td" + +// CHECK: pdl.pattern @OpResultMemberAccess +// CHECK: %[[OP0:.*]] = operation +// CHECK: %[[RES:.*]] = results 0 of %[[OP0]] -> !pdl.value +// CHECK: %[[RES1:.*]] = results 0 of %[[OP0]] -> !pdl.value +// CHECK: %[[RES2:.*]] = results 1 of %[[OP0]] -> !pdl.range +// CHECK: %[[RES3:.*]] = results 1 of %[[OP0]] -> !pdl.range +// CHECK: operation(%[[RES]], %[[RES1]], %[[RES2]], %[[RES3]] : !pdl.value, !pdl.value, !pdl.range, !pdl.range) +Pattern OpResultMemberAccess { + let op: Op; + erase op<>(op.0, op.result, op.1, op.var_result); +} + +// ----- + // CHECK: pdl.pattern @TupleMemberAccessNumber // CHECK: %[[FIRST:.*]] = operation "test.first" // CHECK: %[[SECOND:.*]] = operation "test.second" diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td b/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td new file mode 100644 index 000000000000..588b290c4578 --- /dev/null +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td @@ -0,0 +1,9 @@ +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +def OpWithResults : Op { + let results = (outs I64:$result, Variadic:$var_result); +} diff --git a/mlir/test/mlir-pdll/Parser/directive-failure.pdll b/mlir/test/mlir-pdll/Parser/directive-failure.pdll index 14924fe61104..14f8db8aa5e2 100644 --- a/mlir/test/mlir-pdll/Parser/directive-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/directive-failure.pdll @@ -19,5 +19,5 @@ // ----- -// CHECK: expected include filename to end with `.pdll` +// CHECK: expected include filename to end with `.pdll` or `.td` #include "unknown_file.foo" diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 7ed3ba8057bd..08174de7cf16 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -1,4 +1,4 @@ -// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s +// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s //===----------------------------------------------------------------------===// // Reference Expr @@ -276,6 +276,26 @@ Pattern { // ----- +#include "include/ops.td" + +Pattern { + // CHECK: invalid number of operand groups for `test.all_empty`; expected 0, but got 2 + // CHECK: see the definition of `test.all_empty` here + let foo = op(operand1: Value, operand2: Value); +} + +// ----- + +#include "include/ops.td" + +Pattern { + // CHECK: invalid number of result groups for `test.all_empty`; expected 0, but got 2 + // CHECK: see the definition of `test.all_empty` here + let foo = op -> (result1: Type, result2: Type); +} + +// ----- + //===----------------------------------------------------------------------===// // `type` Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 9919fe5c0d07..c7d96035e2f7 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -1,4 +1,4 @@ -// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s +// RUN: mlir-pdll %s -I %S -I %S/../../../include -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// // AttrExpr @@ -71,6 +71,25 @@ Pattern { // ----- +#include "include/ops.td" + +// CHECK: Module +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member<0> Type +// CHECK: `-DeclRefExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member Type +// CHECK: `-DeclRefExpr {{.*}} Type> +Pattern { + let op: Op; + let firstEltIndex = op.0; + let firstEltName = op.result; + + erase op; +} + +// ----- + //===----------------------------------------------------------------------===// // OperationExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/include/interfaces.td b/mlir/test/mlir-pdll/Parser/include/interfaces.td new file mode 100644 index 000000000000..eea8783545f0 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include/interfaces.td @@ -0,0 +1,5 @@ +include "mlir/IR/OpBase.td" + +def TestAttrInterface : AttrInterface<"TestAttrInterface">; +def TestOpInterface : OpInterface<"TestOpInterface">; +def TestTypeInterface : TypeInterface<"TestTypeInterface">; diff --git a/mlir/test/mlir-pdll/Parser/include/ops.td b/mlir/test/mlir-pdll/Parser/include/ops.td new file mode 100644 index 000000000000..1727d1a5444b --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include/ops.td @@ -0,0 +1,26 @@ +include "include/interfaces.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +def OpAllEmpty : Op; + +def OpAllSingle : Op { + let arguments = (ins I64:$operand, I64Attr:$attr); + let results = (outs I64:$result); +} + +def OpAllOptional : Op { + let arguments = (ins Optional:$operand, OptionalAttr:$attr); + let results = (outs Optional:$result); +} + +def OpAllVariadic : Op { + let arguments = (ins Variadic:$operands); + let results = (outs Variadic:$results); +} + +def OpMultipleSingleResult : Op { + let results = (outs I64:$result, I64:$result2); +} diff --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll new file mode 100644 index 000000000000..c55ed1d0f154 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include_td.pdll @@ -0,0 +1,52 @@ +// RUN: mlir-pdll %s -I %S -I %S/../../../include -dump-ods 2>&1 | FileCheck %s + +#include "include/ops.td" + +// CHECK: Operation `test.all_empty` { +// CHECK-NEXT: } + +// CHECK: Operation `test.all_optional` { +// CHECK-NEXT: Attributes { attr : Optional } +// CHECK-NEXT: Operands { operand : Optional } +// CHECK-NEXT: Results { result : Optional } +// CHECK-NEXT: } + +// CHECK: Operation `test.all_single` { +// CHECK-NEXT: Attributes { attr : I64Attr } +// CHECK-NEXT: Operands { operand : I64 } +// CHECK-NEXT: Results { result : I64 } +// CHECK-NEXT: } + +// CHECK: Operation `test.all_variadic` { +// CHECK-NEXT: Operands { operands : Variadic } +// CHECK-NEXT: Results { results : Variadic } +// CHECK-NEXT: } + +// CHECK: AttributeConstraint `I64Attr` { +// CHECK-NEXT: Summary: 64-bit signless integer attribute +// CHECK-NEXT: CppClass: ::mlir::IntegerAttr +// CHECK-NEXT: } + +// CHECK: TypeConstraint `I64` { +// CHECK-NEXT: Summary: 64-bit signless integer +// CHECK-NEXT: CppClass: ::mlir::IntegerType +// CHECK-NEXT: } + +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-AttrConstraintDecl + +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-OpConstraintDecl +// CHECK: `-OpNameDecl + +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-TypeConstraintDecl {{.*}} diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll index d52a2d28b6d0..4220259e60ee 100644 --- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll @@ -1,4 +1,4 @@ -// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s +// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s // CHECK: expected top-level declaration, such as a `Pattern` 10 @@ -250,6 +250,28 @@ Pattern { // ----- +#include "include/ops.td" + +Pattern { + // CHECK: unable to convert expression of type `Op` to the expected type of `Value` + // CHECK: see the definition of `test.all_empty`, which was defined with zero results + let value: Value = op; + erase _: Op; +} + +// ----- + +#include "include/ops.td" + +Pattern { + // CHECK: unable to convert expression of type `Op` to the expected type of `Value` + // CHECK: see the definition of `test.multiple_single_result`, which was defined with at least 2 results + let value: Value = op; + erase _: Op; +} + +// ----- + //===----------------------------------------------------------------------===// // `replace` //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp index e133d9e45c54..904fb77a51b1 100644 --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -13,6 +13,7 @@ #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/CodeGen/CPPGen.h" #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" +#include "mlir/Tools/PDLL/ODS/Context.h" #include "mlir/Tools/PDLL/Parser/Parser.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" @@ -35,16 +36,23 @@ enum class OutputType { static LogicalResult processBuffer(raw_ostream &os, std::unique_ptr chunkBuffer, - OutputType outputType, std::vector &includeDirs) { + OutputType outputType, std::vector &includeDirs, + bool dumpODS) { llvm::SourceMgr sourceMgr; sourceMgr.setIncludeDirs(includeDirs); sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc()); - ast::Context astContext; + ods::Context odsContext; + ast::Context astContext(odsContext); FailureOr module = parsePDLAST(astContext, sourceMgr); if (failed(module)) return failure(); + // Print out the ODS information if requested. + if (dumpODS) + odsContext.print(llvm::errs()); + + // Generate the output. if (outputType == OutputType::AST) { (*module)->print(os); return success(); @@ -66,6 +74,10 @@ processBuffer(raw_ostream &os, std::unique_ptr chunkBuffer, } int main(int argc, char **argv) { + // FIXME: This is necessary because we link in TableGen, which defines its + // options as static variables.. some of which overlap with our options. + llvm::cl::ResetCommandLineParser(); + llvm::cl::opt inputFilename( llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::value_desc("filename")); @@ -78,6 +90,11 @@ int main(int argc, char **argv) { "I", llvm::cl::desc("Directory of include files"), llvm::cl::value_desc("directory"), llvm::cl::Prefix); + llvm::cl::opt dumpODS( + "dump-ods", + llvm::cl::desc( + "Print out the parsed ODS information from the input file"), + llvm::cl::init(false)); llvm::cl::opt splitInputFile( "split-input-file", llvm::cl::desc("Split the input file into pieces and process each " @@ -118,7 +135,8 @@ int main(int argc, char **argv) { // up into small pieces and checks each independently. auto processFn = [&](std::unique_ptr chunkBuffer, raw_ostream &os) { - return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs); + return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs, + dumpODS); }; if (splitInputFile) { if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,