[PDLL] Add support for tablegen includes and importing ODS information

This commit adds support for processing tablegen include files, and importing
various information from ODS. This includes operations, attribute+type constraints,
attribute/operation/type interfaces, etc. This will allow for much more robust tooling,
and also allows for referencing ODS constructs directly within PDLL (imported interfaces
can be used as constraints, operation result names can be used for member access, etc).

Differential Revision: https://reviews.llvm.org/D119900
This commit is contained in:
River Riddle 2022-02-15 14:32:37 -08:00
parent e865fa7530
commit 81f2f4dfb2
30 changed files with 1312 additions and 39 deletions

View File

@ -100,6 +100,9 @@ public:
SourceMgr &operator=(SourceMgr &&) = default; SourceMgr &operator=(SourceMgr &&) = default;
~SourceMgr() = default; ~SourceMgr() = default;
/// Return the include directories of this source manager.
ArrayRef<std::string> getIncludeDirs() const { return IncludeDirectories; }
void setIncludeDirs(const std::vector<std::string> &Dirs) { void setIncludeDirs(const std::vector<std::string> &Dirs) {
IncludeDirectories = Dirs; IncludeDirectories = Dirs;
} }
@ -147,6 +150,14 @@ public:
return Buffers.size(); return Buffers.size();
} }
/// Takes the source buffers from the given source manager and append them to
/// the current manager.
void takeSourceBuffersFrom(SourceMgr &SrcMgr) {
std::move(SrcMgr.Buffers.begin(), SrcMgr.Buffers.end(),
std::back_inserter(Buffers));
SrcMgr.Buffers.clear();
}
/// Search for a file with the specified name in the current directory or in /// Search for a file with the specified name in the current directory or in
/// one of the IncludeDirs. /// one of the IncludeDirs.
/// ///
@ -156,6 +167,17 @@ public:
unsigned AddIncludeFile(const std::string &Filename, SMLoc IncludeLoc, unsigned AddIncludeFile(const std::string &Filename, SMLoc IncludeLoc,
std::string &IncludedFile); std::string &IncludedFile);
/// Search for a file with the specified name in the current directory or in
/// one of the IncludeDirs, and try to open it **without** adding to the
/// SourceMgr. If the opened file is intended to be added to the source
/// manager, prefer `AddIncludeFile` instead.
///
/// If no file is found, this returns an Error, otherwise it returns the
/// buffer of the stacked file. The full path to the included file can be
/// found in \p IncludedFile.
ErrorOr<std::unique_ptr<MemoryBuffer>>
OpenIncludeFile(const std::string &Filename, std::string &IncludedFile);
/// Return the ID of the buffer containing the specified location. /// Return the ID of the buffer containing the specified location.
/// ///
/// 0 is returned if the buffer is not found. /// 0 is returned if the buffer is not found.

View File

@ -40,6 +40,17 @@ static const size_t TabStop = 8;
unsigned SourceMgr::AddIncludeFile(const std::string &Filename, unsigned SourceMgr::AddIncludeFile(const std::string &Filename,
SMLoc IncludeLoc, SMLoc IncludeLoc,
std::string &IncludedFile) { std::string &IncludedFile) {
ErrorOr<std::unique_ptr<MemoryBuffer>> NewBufOrErr =
OpenIncludeFile(Filename, IncludedFile);
if (!NewBufOrErr)
return 0;
return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc);
}
ErrorOr<std::unique_ptr<MemoryBuffer>>
SourceMgr::OpenIncludeFile(const std::string &Filename,
std::string &IncludedFile) {
IncludedFile = Filename; IncludedFile = Filename;
ErrorOr<std::unique_ptr<MemoryBuffer>> NewBufOrErr = ErrorOr<std::unique_ptr<MemoryBuffer>> NewBufOrErr =
MemoryBuffer::getFile(IncludedFile); MemoryBuffer::getFile(IncludedFile);
@ -52,10 +63,7 @@ unsigned SourceMgr::AddIncludeFile(const std::string &Filename,
NewBufOrErr = MemoryBuffer::getFile(IncludedFile); NewBufOrErr = MemoryBuffer::getFile(IncludedFile);
} }
if (!NewBufOrErr) return NewBufOrErr;
return 0;
return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc);
} }
unsigned SourceMgr::FindBufferContainingLoc(SMLoc Loc) const { unsigned SourceMgr::FindBufferContainingLoc(SMLoc Loc) const {

View File

@ -363,7 +363,8 @@ class DialectType<Dialect d, Pred condition, string descr = "",
// A variadic type constraint. It expands to zero or more of the base type. This // A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results. // class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate, type.summary> { class Variadic<Type type> : TypeConstraint<type.predicate, type.summary,
type.cppClassName> {
Type baseType = type; Type baseType = type;
} }
@ -379,7 +380,8 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
// An optional type constraint. It expands to either zero or one of the base // An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results. // type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary> { class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
type.cppClassName> {
Type baseType = type; Type baseType = type;
} }

View File

@ -54,6 +54,11 @@ public:
// description is not provided, returns the TableGen def name. // description is not provided, returns the TableGen def name.
StringRef getSummary() const; StringRef getSummary() const;
/// Returns the name of the TablGen def of this constraint. In some cases
/// where the current def is anonymous, the name of the base def is used (e.g.
/// `Optional<>`/`Variadic<>` type constraints).
StringRef getDefName() const;
Kind getKind() const { return kind; } Kind getKind() const { return kind; }
protected: protected:

View File

@ -14,13 +14,17 @@
namespace mlir { namespace mlir {
namespace pdll { namespace pdll {
namespace ods {
class Context;
} // namespace ods
namespace ast { namespace ast {
/// This class represents the main context of the PDLL AST. It handles /// This class represents the main context of the PDLL AST. It handles
/// allocating all of the AST constructs, and manages all state necessary for /// allocating all of the AST constructs, and manages all state necessary for
/// the AST. /// the AST.
class Context { class Context {
public: public:
Context(); explicit Context(ods::Context &odsContext);
Context(const Context &) = delete; Context(const Context &) = delete;
Context &operator=(const Context &) = delete; Context &operator=(const Context &) = delete;
@ -30,6 +34,10 @@ public:
/// Return the storage uniquer used for AST types. /// Return the storage uniquer used for AST types.
StorageUniquer &getTypeUniquer() { return typeUniquer; } StorageUniquer &getTypeUniquer() { return typeUniquer; }
/// Return the ODS context used by the AST.
ods::Context &getODSContext() { return odsContext; }
const ods::Context &getODSContext() const { return odsContext; }
/// Return the diagnostic engine of this context. /// Return the diagnostic engine of this context.
DiagnosticEngine &getDiagEngine() { return diagEngine; } DiagnosticEngine &getDiagEngine() { return diagEngine; }
@ -37,6 +45,9 @@ private:
/// The diagnostic engine of this AST context. /// The diagnostic engine of this AST context.
DiagnosticEngine diagEngine; DiagnosticEngine diagEngine;
/// The ODS context used by the AST.
ods::Context &odsContext;
/// The allocator used for AST nodes, and other entities allocated within the /// The allocator used for AST nodes, and other entities allocated within the
/// context. /// context.
llvm::BumpPtrAllocator allocator; llvm::BumpPtrAllocator allocator;

View File

@ -0,0 +1,98 @@
//===- Constraint.h - MLIR PDLL ODS Constraints -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a PDLL description of ODS constraints. These are used to
// support the import of constraints defined outside of PDLL.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_
#define MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_
#include <string>
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
namespace pdll {
namespace ods {
//===----------------------------------------------------------------------===//
// Constraint
//===----------------------------------------------------------------------===//
/// This class represents a generic ODS constraint.
class Constraint {
public:
/// Return the name of this constraint.
StringRef getName() const { return name; }
/// Return the summary of this constraint.
StringRef getSummary() const { return summary; }
protected:
Constraint(StringRef name, StringRef summary)
: name(name.str()), summary(summary.str()) {}
Constraint(const Constraint &) = delete;
private:
/// The name of the constraint.
std::string name;
/// A summary of the constraint.
std::string summary;
};
//===----------------------------------------------------------------------===//
// AttributeConstraint
//===----------------------------------------------------------------------===//
/// This class represents a generic ODS Attribute constraint.
class AttributeConstraint : public Constraint {
public:
/// Return the name of the underlying c++ class of this constraint.
StringRef getCppClass() const { return cppClassName; }
private:
AttributeConstraint(StringRef name, StringRef summary, StringRef cppClassName)
: Constraint(name, summary), cppClassName(cppClassName.str()) {}
/// The c++ class of the constraint.
std::string cppClassName;
/// Allow access to the constructor.
friend class Context;
};
//===----------------------------------------------------------------------===//
// TypeConstraint
//===----------------------------------------------------------------------===//
/// This class represents a generic ODS Type constraint.
class TypeConstraint : public Constraint {
public:
/// Return the name of the underlying c++ class of this constraint.
StringRef getCppClass() const { return cppClassName; }
private:
TypeConstraint(StringRef name, StringRef summary, StringRef cppClassName)
: Constraint(name, summary), cppClassName(cppClassName.str()) {}
/// The c++ class of the constraint.
std::string cppClassName;
/// Allow access to the constructor.
friend class Context;
};
} // namespace ods
} // namespace pdll
} // namespace mlir
#endif // MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_

View File

@ -0,0 +1,78 @@
//===- Context.h - MLIR PDLL ODS Context ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_ODS_CONTEXT_H_
#define MLIR_TOOLS_PDLL_ODS_CONTEXT_H_
#include <string>
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
namespace llvm {
class SMLoc;
} // namespace llvm
namespace mlir {
namespace pdll {
namespace ods {
class AttributeConstraint;
class Dialect;
class Operation;
class TypeConstraint;
/// This class contains all of the registered ODS operation classes.
class Context {
public:
Context();
~Context();
/// Insert a new attribute constraint with the context. Returns the inserted
/// constraint, or a previously inserted constraint with the same name.
const AttributeConstraint &insertAttributeConstraint(StringRef name,
StringRef summary,
StringRef cppClass);
/// Insert a new type constraint with the context. Returns the inserted
/// constraint, or a previously inserted constraint with the same name.
const TypeConstraint &insertTypeConstraint(StringRef name, StringRef summary,
StringRef cppClass);
/// Insert a new dialect with the context. Returns the inserted dialect, or a
/// previously inserted dialect with the same name.
Dialect &insertDialect(StringRef name);
/// Lookup a dialect registered with the given name, or null if no dialect
/// with that name was inserted.
const Dialect *lookupDialect(StringRef name) const;
/// Insert a new operation with the context. Returns the inserted operation,
/// and a boolean indicating if the operation newly inserted (false if the
/// operation already existed).
std::pair<Operation *, bool>
insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
/// Lookup an operation registered with the given name, or null if no
/// operation with that name is registered.
const Operation *lookupOperation(StringRef name) const;
/// Print the contents of this context to the provided stream.
void print(raw_ostream &os) const;
private:
llvm::StringMap<std::unique_ptr<AttributeConstraint>> attributeConstraints;
llvm::StringMap<std::unique_ptr<Dialect>> dialects;
llvm::StringMap<std::unique_ptr<TypeConstraint>> typeConstraints;
};
} // namespace ods
} // namespace pdll
} // namespace mlir
#endif // MLIR_PDL_pdll_ODS_CONTEXT_H_

View File

@ -0,0 +1,64 @@
//===- Dialect.h - PDLL ODS Dialect -----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_ODS_DIALECT_H_
#define MLIR_TOOLS_PDLL_ODS_DIALECT_H_
#include <string>
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
namespace pdll {
namespace ods {
class Operation;
/// This class represents an ODS dialect, and contains information on the
/// constructs held within the dialect.
class Dialect {
public:
~Dialect();
/// Return the name of this dialect.
StringRef getName() const { return name; }
/// Insert a new operation with the dialect. Returns the inserted operation,
/// and a boolean indicating if the operation newly inserted (false if the
/// operation already existed).
std::pair<Operation *, bool>
insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
/// Lookup an operation registered with the given name, or null if no
/// operation with that name is registered.
Operation *lookupOperation(StringRef name) const;
/// Return a map of all of the operations registered to this dialect.
const llvm::StringMap<std::unique_ptr<Operation>> &getOperations() const {
return operations;
}
private:
explicit Dialect(StringRef name);
/// The name of the dialect.
std::string name;
/// The operations defined by the dialect.
llvm::StringMap<std::unique_ptr<Operation>> operations;
/// Allow access to the constructor.
friend class Context;
};
} // namespace ods
} // namespace pdll
} // namespace mlir
#endif // MLIR_TOOLS_PDLL_ODS_DIALECT_H_

View File

@ -0,0 +1,189 @@
//===- Operation.h - MLIR PDLL ODS Operation --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_PDLL_ODS_OPERATION_H_
#define MLIR_TOOLS_PDLL_ODS_OPERATION_H_
#include <string>
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"
namespace mlir {
namespace pdll {
namespace ods {
class AttributeConstraint;
class TypeConstraint;
//===----------------------------------------------------------------------===//
// VariableLengthKind
//===----------------------------------------------------------------------===//
enum VariableLengthKind { Single, Optional, Variadic };
//===----------------------------------------------------------------------===//
// Attribute
//===----------------------------------------------------------------------===//
/// This class provides an ODS representation of a specific operation attribute.
/// This includes the name, optionality, and more.
class Attribute {
public:
/// Return the name of this operand.
StringRef getName() const { return name; }
/// Return true if this attribute is optional.
bool isOptional() const { return optional; }
/// Return the constraint of this attribute.
const AttributeConstraint &getConstraint() const { return constraint; }
private:
Attribute(StringRef name, bool optional,
const AttributeConstraint &constraint)
: name(name.str()), optional(optional), constraint(constraint) {}
/// The ODS name of the attribute.
std::string name;
/// A flag indicating if the attribute is optional.
bool optional;
/// The ODS constraint of this attribute.
const AttributeConstraint &constraint;
/// Allow access to the private constructor.
friend class Operation;
};
//===----------------------------------------------------------------------===//
// OperandOrResult
//===----------------------------------------------------------------------===//
/// This class provides an ODS representation of a specific operation operand or
/// result. This includes the name, variable length flags, and more.
class OperandOrResult {
public:
/// Return the name of this value.
StringRef getName() const { return name; }
/// Returns true if this value is variadic (Note this is false if the value is
/// Optional).
bool isVariadic() const {
return variableLengthKind == VariableLengthKind::Variadic;
}
/// Returns the variable length kind of this value.
VariableLengthKind getVariableLengthKind() const {
return variableLengthKind;
}
/// Return the constraint of this value.
const TypeConstraint &getConstraint() const { return constraint; }
private:
OperandOrResult(StringRef name, VariableLengthKind variableLengthKind,
const TypeConstraint &constraint)
: name(name.str()), variableLengthKind(variableLengthKind),
constraint(constraint) {}
/// The ODS name of this value.
std::string name;
/// The variable length kind of this value.
VariableLengthKind variableLengthKind;
/// The ODS constraint of this value.
const TypeConstraint &constraint;
/// Allow access to the private constructor.
friend class Operation;
};
//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//
/// This class provides an ODS representation of a specific operation. This
/// includes all of the information necessary for use by the PDL frontend for
/// generating code for a pattern rewrite.
class Operation {
public:
/// Return the source location of this operation.
SMRange getLoc() const { return location; }
/// Append an attribute to this operation.
void appendAttribute(StringRef name, bool optional,
const AttributeConstraint &constraint) {
attributes.emplace_back(Attribute(name, optional, constraint));
}
/// Append an operand to this operation.
void appendOperand(StringRef name, VariableLengthKind variableLengthKind,
const TypeConstraint &constraint) {
operands.emplace_back(
OperandOrResult(name, variableLengthKind, constraint));
}
/// Append a result to this operation.
void appendResult(StringRef name, VariableLengthKind variableLengthKind,
const TypeConstraint &constraint) {
results.emplace_back(OperandOrResult(name, variableLengthKind, constraint));
}
/// Returns the name of the operation.
StringRef getName() const { return name; }
/// Returns the summary of the operation.
StringRef getSummary() const { return summary; }
/// Returns the description of the operation.
StringRef getDescription() const { return description; }
/// Returns the attributes of this operation.
ArrayRef<Attribute> getAttributes() const { return attributes; }
/// Returns the operands of this operation.
ArrayRef<OperandOrResult> getOperands() const { return operands; }
/// Returns the results of this operation.
ArrayRef<OperandOrResult> getResults() const { return results; }
private:
Operation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
/// The name of the operation.
std::string name;
/// The documentation of the operation.
std::string summary;
std::string description;
/// The source location of this operation.
SMRange location;
/// The operands of the operation.
SmallVector<OperandOrResult> operands;
/// The results of the operation.
SmallVector<OperandOrResult> results;
/// The attributes of the operation.
SmallVector<Attribute> attributes;
/// Allow access to the private constructor.
friend class Dialect;
};
} // namespace ods
} // namespace pdll
} // namespace mlir
#endif // MLIR_TOOLS_PDLL_ODS_OPERATION_H_

View File

@ -57,6 +57,29 @@ StringRef Constraint::getSummary() const {
return def->getName(); return def->getName();
} }
StringRef Constraint::getDefName() const {
// Functor used to check a base def in the case where the current def is
// anonymous.
auto checkBaseDefFn = [&](StringRef baseName) {
if (const auto *init = dyn_cast<llvm::DefInit>(def->getValueInit(baseName)))
return Constraint(init->getDef(), kind).getDefName();
return def->getName();
};
switch (kind) {
case CK_Attr:
if (def->isAnonymous())
return checkBaseDefFn("baseAttr");
return def->getName();
case CK_Type:
if (def->isAnonymous())
return checkBaseDefFn("baseType");
return def->getName();
default:
return def->getName();
}
}
AppliedConstraint::AppliedConstraint(Constraint &&constraint, AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self, llvm::StringRef self,
std::vector<std::string> &&entities) std::vector<std::string> &&entities)

View File

@ -6,5 +6,6 @@ add_mlir_library(MLIRPDLLAST
Types.cpp Types.cpp
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRPDLLODS
MLIRSupport MLIRSupport
) )

View File

@ -12,7 +12,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::pdll::ast; using namespace mlir::pdll::ast;
Context::Context() { Context::Context(ods::Context &odsContext) : odsContext(odsContext) {
typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>(); typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>(); typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::RewriteTypeStorage>(); typeUniquer.registerSingletonStorageType<detail::RewriteTypeStorage>();

View File

@ -1,3 +1,4 @@
add_subdirectory(AST) add_subdirectory(AST)
add_subdirectory(CodeGen) add_subdirectory(CodeGen)
add_subdirectory(ODS)
add_subdirectory(Parser) add_subdirectory(Parser)

View File

@ -17,6 +17,8 @@
#include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h" #include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
@ -33,7 +35,8 @@ class CodeGen {
public: public:
CodeGen(MLIRContext *mlirContext, const ast::Context &context, CodeGen(MLIRContext *mlirContext, const ast::Context &context,
const llvm::SourceMgr &sourceMgr) const llvm::SourceMgr &sourceMgr)
: builder(mlirContext), sourceMgr(sourceMgr) { : builder(mlirContext), odsContext(context.getODSContext()),
sourceMgr(sourceMgr) {
// Make sure that the PDL dialect is loaded. // Make sure that the PDL dialect is loaded.
mlirContext->loadDialect<pdl::PDLDialect>(); mlirContext->loadDialect<pdl::PDLDialect>();
} }
@ -117,6 +120,9 @@ private:
llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>; llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
VariableMapTy variables; VariableMapTy variables;
/// A reference to the ODS context.
const ods::Context &odsContext;
/// The source manager of the PDLL ast. /// The source manager of the PDLL ast.
const llvm::SourceMgr &sourceMgr; const llvm::SourceMgr &sourceMgr;
}; };
@ -435,7 +441,28 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
builder.getI32IntegerAttr(0)); builder.getI32IntegerAttr(0));
return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]); return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
} }
llvm_unreachable("unhandled operation member access expression");
assert(opType.getName() && "expected valid operation name");
const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName());
assert(odsOp && "expected valid ODS operation information");
// Find the result with the member name or by index.
ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
unsigned resultIndex = results.size();
if (llvm::isDigit(name[0])) {
name.getAsInteger(/*Radix=*/10, resultIndex);
} else {
auto findFn = [&](const ods::OperandOrResult &result) {
return result.getName() == name;
};
resultIndex = llvm::find_if(results, findFn) - results.begin();
}
assert(resultIndex < results.size() && "invalid result index");
// Generate the result access.
IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
parentExprs[0], index);
} }
// Handle tuple based member access. // Handle tuple based member access.

View File

@ -0,0 +1,8 @@
add_mlir_library(MLIRPDLLODS
Context.cpp
Dialect.cpp
Operation.cpp
LINK_LIBS PUBLIC
MLIRSupport
)

View File

@ -0,0 +1,174 @@
//===- Context.cpp --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Constraint.h"
#include "mlir/Tools/PDLL/ODS/Dialect.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::pdll::ods;
//===----------------------------------------------------------------------===//
// Context
//===----------------------------------------------------------------------===//
Context::Context() = default;
Context::~Context() = default;
const AttributeConstraint &
Context::insertAttributeConstraint(StringRef name, StringRef summary,
StringRef cppClass) {
std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
if (!constraint) {
constraint.reset(new AttributeConstraint(name, summary, cppClass));
} else {
assert(constraint->getCppClass() == cppClass &&
constraint->getSummary() == summary &&
"constraint with the same name was already registered with a "
"different class");
}
return *constraint;
}
const TypeConstraint &Context::insertTypeConstraint(StringRef name,
StringRef summary,
StringRef cppClass) {
std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
if (!constraint)
constraint.reset(new TypeConstraint(name, summary, cppClass));
return *constraint;
}
Dialect &Context::insertDialect(StringRef name) {
std::unique_ptr<Dialect> &dialect = dialects[name];
if (!dialect)
dialect.reset(new Dialect(name));
return *dialect;
}
const Dialect *Context::lookupDialect(StringRef name) const {
auto it = dialects.find(name);
return it == dialects.end() ? nullptr : &*it->second;
}
std::pair<Operation *, bool> Context::insertOperation(StringRef name,
StringRef summary,
StringRef desc,
SMLoc loc) {
std::pair<StringRef, StringRef> dialectAndName = name.split('.');
return insertDialect(dialectAndName.first)
.insertOperation(name, summary, desc, loc);
}
const Operation *Context::lookupOperation(StringRef name) const {
std::pair<StringRef, StringRef> dialectAndName = name.split('.');
if (const Dialect *dialect = lookupDialect(dialectAndName.first))
return dialect->lookupOperation(name);
return nullptr;
}
template <typename T>
SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
SmallVector<T *> storage;
for (auto &entry : map)
storage.push_back(entry.second.get());
llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
return lhs->getName() < rhs->getName();
});
return storage;
}
void Context::print(raw_ostream &os) const {
auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
switch (kind) {
case VariableLengthKind::Optional:
os << "Optional<" << cst << ">";
break;
case VariableLengthKind::Single:
os << cst;
break;
case VariableLengthKind::Variadic:
os << "Variadic<" << cst << ">";
break;
}
};
llvm::ScopedPrinter printer(os);
llvm::DictScope odsScope(printer, "ODSContext");
for (const Dialect *dialect : sortMapByName(dialects)) {
printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
printer.indent();
for (const Operation *op : sortMapByName(dialect->getOperations())) {
printer.startLine() << "Operation `" << op->getName() << "` {\n";
printer.indent();
// Attributes.
ArrayRef<Attribute> attributes = op->getAttributes();
if (!attributes.empty()) {
printer.startLine() << "Attributes { ";
llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
os << attr.getName() << " : ";
auto kind = attr.isOptional() ? VariableLengthKind::Optional
: VariableLengthKind::Single;
printVariableLengthCst(attr.getConstraint().getName(), kind);
});
os << " }\n";
}
// Operands.
ArrayRef<OperandOrResult> operands = op->getOperands();
if (!operands.empty()) {
printer.startLine() << "Operands { ";
llvm::interleaveComma(
operands, os, [&](const OperandOrResult &operand) {
os << operand.getName() << " : ";
printVariableLengthCst(operand.getConstraint().getName(),
operand.getVariableLengthKind());
});
os << " }\n";
}
// Results.
ArrayRef<OperandOrResult> results = op->getResults();
if (!results.empty()) {
printer.startLine() << "Results { ";
llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
os << result.getName() << " : ";
printVariableLengthCst(result.getConstraint().getName(),
result.getVariableLengthKind());
});
os << " }\n";
}
printer.objectEnd();
}
printer.objectEnd();
}
for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n";
printer.indent();
printer.startLine() << "Summary: " << cst->getSummary() << "\n";
printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
printer.objectEnd();
}
for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n";
printer.indent();
printer.startLine() << "Summary: " << cst->getSummary() << "\n";
printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
printer.objectEnd();
}
printer.objectEnd();
}

View File

@ -0,0 +1,39 @@
//===- Dialect.cpp --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/ODS/Dialect.h"
#include "mlir/Tools/PDLL/ODS/Constraint.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::pdll::ods;
//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//
Dialect::Dialect(StringRef name) : name(name.str()) {}
Dialect::~Dialect() = default;
std::pair<Operation *, bool> Dialect::insertOperation(StringRef name,
StringRef summary,
StringRef desc,
llvm::SMLoc loc) {
std::unique_ptr<Operation> &operation = operations[name];
if (operation)
return std::make_pair(&*operation, /*wasInserted*/ false);
operation.reset(new Operation(name, summary, desc, loc));
return std::make_pair(&*operation, /*wasInserted*/ true);
}
Operation *Dialect::lookupOperation(StringRef name) const {
auto it = operations.find(name);
return it != operations.end() ? it->second.get() : nullptr;
}

View File

@ -0,0 +1,26 @@
//===- Operation.cpp ------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "mlir/Support/IndentedOstream.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::pdll::ods;
//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//
Operation::Operation(StringRef name, StringRef summary, StringRef desc,
llvm::SMLoc loc)
: name(name.str()), summary(summary.str()),
location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) {
llvm::raw_string_ostream descOS(description);
raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t"));
}

View File

@ -1,3 +1,8 @@
set(LLVM_LINK_COMPONENTS
Support
TableGen
)
add_mlir_library(MLIRPDLLParser add_mlir_library(MLIRPDLLParser
Lexer.cpp Lexer.cpp
Parser.cpp Parser.cpp
@ -5,4 +10,5 @@ add_mlir_library(MLIRPDLLParser
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRPDLLAST MLIRPDLLAST
MLIRSupport MLIRSupport
MLIRTableGen
) )

View File

@ -9,15 +9,26 @@
#include "mlir/Tools/PDLL/Parser/Parser.h" #include "mlir/Tools/PDLL/Parser/Parser.h"
#include "Lexer.h" #include "Lexer.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Constraint.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h" #include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h" #include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/Tools/PDLL/ODS/Constraint.h"
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/ScopedPrinter.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Parser.h"
#include <string> #include <string>
using namespace mlir; using namespace mlir;
@ -36,7 +47,8 @@ public:
valueTy(ast::ValueType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
valueRangeTy(ast::ValueRangeType::get(ctx)), valueRangeTy(ast::ValueRangeType::get(ctx)),
typeTy(ast::TypeType::get(ctx)), typeTy(ast::TypeType::get(ctx)),
typeRangeTy(ast::TypeRangeType::get(ctx)) {} typeRangeTy(ast::TypeRangeType::get(ctx)),
attrTy(ast::AttributeType::get(ctx)) {}
/// Try to parse a new module. Returns nullptr in the case of failure. /// Try to parse a new module. Returns nullptr in the case of failure.
FailureOr<ast::Module *> parseModule(); FailureOr<ast::Module *> parseModule();
@ -78,7 +90,7 @@ private:
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
/// Parse the body of an AST module. /// Parse the body of an AST module.
LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls); LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
/// Try to convert the given expression to `type`. Returns failure and emits /// Try to convert the given expression to `type`. Returns failure and emits
/// an error if a conversion is not viable. On failure, `noteAttachFn` is /// an error if a conversion is not viable. On failure, `noteAttachFn` is
@ -92,11 +104,34 @@ private:
/// typed expression. /// typed expression.
ast::Expr *convertOpToValue(const ast::Expr *opExpr); ast::Expr *convertOpToValue(const ast::Expr *opExpr);
/// Lookup ODS information for the given operation, returns nullptr if no
/// information is found.
const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
}
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Directives // Directives
LogicalResult parseDirective(SmallVector<ast::Decl *> &decls); LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
LogicalResult parseInclude(SmallVector<ast::Decl *> &decls); LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
SmallVectorImpl<ast::Decl *> &decls);
/// Process the records of a parsed tablegen include file.
void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
SmallVectorImpl<ast::Decl *> &decls);
/// Create a user defined native constraint for a constraint imported from
/// ODS.
template <typename ConstraintT>
ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
StringRef codeBlock, SMRange loc,
ast::Type type);
template <typename ConstraintT>
ast::Decl *
createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
SMRange loc, ast::Type type);
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Decls // Decls
@ -340,13 +375,16 @@ private:
MutableArrayRef<ast::Expr *> results); MutableArrayRef<ast::Expr *> results);
LogicalResult LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name, validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> operands); MutableArrayRef<ast::Expr *> operands);
LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name, LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> results); MutableArrayRef<ast::Expr *> results);
LogicalResult LogicalResult validateOperationOperandsOrResults(
validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name, StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
MutableArrayRef<ast::Expr *> values, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
ast::Type singleTy, ast::Type rangeTy); ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
ast::Type rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc, FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
ArrayRef<ast::Expr *> elements, ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames); ArrayRef<StringRef> elementNames);
@ -440,6 +478,7 @@ private:
/// Cached types to simplify verification and expression creation. /// Cached types to simplify verification and expression creation.
ast::Type valueTy, valueRangeTy; ast::Type valueTy, valueRangeTy;
ast::Type typeTy, typeRangeTy; ast::Type typeTy, typeRangeTy;
ast::Type attrTy;
/// A counter used when naming anonymous constraints and rewrites. /// A counter used when naming anonymous constraints and rewrites.
unsigned anonymousDeclNameCounter = 0; unsigned anonymousDeclNameCounter = 0;
@ -459,7 +498,7 @@ FailureOr<ast::Module *> Parser::parseModule() {
return ast::Module::create(ctx, moduleLoc, decls); return ast::Module::create(ctx, moduleLoc, decls);
} }
LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) { LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
while (curToken.isNot(Token::eof)) { while (curToken.isNot(Token::eof)) {
if (curToken.is(Token::directive)) { if (curToken.is(Token::directive)) {
if (failed(parseDirective(decls))) if (failed(parseDirective(decls)))
@ -516,6 +555,32 @@ LogicalResult Parser::convertExpressionTo(
// Allow conversion to a single value by constraining the result range. // Allow conversion to a single value by constraining the result range.
if (type == valueTy) { if (type == valueTy) {
// If the operation is registered, we can verify if it can ever have a
// single result.
Optional<StringRef> opName = exprOpType.getName();
if (const ods::Operation *odsOp = lookupODSOperation(opName)) {
if (odsOp->getResults().empty()) {
return emitConvertError()->attachNote(
llvm::formatv("see the definition of `{0}`, which was defined "
"with zero results",
odsOp->getName()),
odsOp->getLoc());
}
unsigned numSingleResults = llvm::count_if(
odsOp->getResults(), [](const ods::OperandOrResult &result) {
return result.getVariableLengthKind() ==
ods::VariableLengthKind::Single;
});
if (numSingleResults > 1) {
return emitConvertError()->attachNote(
llvm::formatv("see the definition of `{0}`, which was defined "
"with at least {1} results",
odsOp->getName(), numSingleResults),
odsOp->getLoc());
}
}
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
valueTy); valueTy);
return success(); return success();
@ -569,7 +634,7 @@ LogicalResult Parser::convertExpressionTo(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Directives // Directives
LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) { LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
StringRef directive = curToken.getSpelling(); StringRef directive = curToken.getSpelling();
if (directive == "#include") if (directive == "#include")
return parseInclude(decls); return parseInclude(decls);
@ -577,7 +642,7 @@ LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
return emitError("unknown directive `" + directive + "`"); return emitError("unknown directive `" + directive + "`");
} }
LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) { LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
SMRange loc = curToken.getLoc(); SMRange loc = curToken.getLoc();
consumeToken(Token::directive); consumeToken(Token::directive);
@ -607,7 +672,193 @@ LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
return result; return result;
} }
return emitError(fileLoc, "expected include filename to end with `.pdll`"); // Otherwise, this must be a `.td` include.
if (filename.endswith(".td"))
return parseTdInclude(filename, fileLoc, decls);
return emitError(fileLoc,
"expected include filename to end with `.pdll` or `.td`");
}
LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
SmallVectorImpl<ast::Decl *> &decls) {
llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
// This class provides a context argument for the llvm::SourceMgr diagnostic
// handler.
struct DiagHandlerContext {
Parser &parser;
StringRef filename;
llvm::SMRange loc;
} handlerContext{*this, filename, fileLoc};
// Set the diagnostic handler for the tablegen source manager.
llvm::SrcMgr.setDiagHandler(
[](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
(void)ctx->parser.emitError(
ctx->loc,
llvm::formatv("error while processing include file `{0}`: {1}",
ctx->filename, diag.getMessage()));
},
&handlerContext);
// Use the source manager to open the file, but don't yet add it.
std::string includedFile;
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
if (!includeBuffer)
return emitError(fileLoc, "unable to open include file `" + filename + "`");
auto processFn = [&](llvm::RecordKeeper &records) {
processTdIncludeRecords(records, decls);
// After we are done processing, move all of the tablegen source buffers to
// the main parser source mgr. This allows for directly using source
// locations from the .td files without needing to remap them.
parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr);
return false;
};
if (llvm::TableGenParseFile(std::move(*includeBuffer),
parserSrcMgr.getIncludeDirs(), processFn))
return failure();
return success();
}
void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
SmallVectorImpl<ast::Decl *> &decls) {
// Return the length kind of the given value.
auto getLengthKind = [](const auto &value) {
if (value.isOptional())
return ods::VariableLengthKind::Optional;
return value.isVariadic() ? ods::VariableLengthKind::Variadic
: ods::VariableLengthKind::Single;
};
// Insert a type constraint into the ODS context.
ods::Context &odsContext = ctx.getODSContext();
auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
-> const ods::TypeConstraint & {
return odsContext.insertTypeConstraint(cst.constraint.getDefName(),
cst.constraint.getSummary(),
cst.constraint.getCPPClassName());
};
auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
};
// Process the parsed tablegen records to build ODS information.
/// Operations.
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
tblgen::Operator op(def);
bool inserted = false;
ods::Operation *odsOp = nullptr;
std::tie(odsOp, inserted) =
odsContext.insertOperation(op.getOperationName(), op.getSummary(),
op.getDescription(), op.getLoc().front());
// Ignore operations that have already been added.
if (!inserted)
continue;
for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
odsOp->appendAttribute(
attr.name, attr.attr.isOptional(),
odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(),
attr.attr.getSummary(),
attr.attr.getStorageType()));
}
for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
odsOp->appendOperand(operand.name, getLengthKind(operand),
addTypeConstraint(operand));
}
for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
odsOp->appendResult(result.name, getLengthKind(result),
addTypeConstraint(result));
}
}
/// Attr constraints.
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
decls.push_back(
createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
tblgen::AttrConstraint(def),
convertLocToRange(def->getLoc().front()), attrTy));
}
}
/// Type constraints.
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
decls.push_back(
createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
tblgen::TypeConstraint(def),
convertLocToRange(def->getLoc().front()), typeTy));
}
}
/// Interfaces.
ast::Type opTy = ast::OperationType::get(ctx);
for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
StringRef name = def->getName();
if (def->isAnonymous() || curDeclScope->lookup(name) ||
def->isSubClassOf("DeclareInterfaceMethods"))
continue;
SMRange loc = convertLocToRange(def->getLoc().front());
StringRef className = def->getValueAsString("cppClassName");
StringRef cppNamespace = def->getValueAsString("cppNamespace");
std::string codeBlock =
llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className)
.str();
if (def->isSubClassOf("OpInterface")) {
decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
name, codeBlock, loc, opTy));
} else if (def->isSubClassOf("AttrInterface")) {
decls.push_back(
createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
name, codeBlock, loc, attrTy));
} else if (def->isSubClassOf("TypeInterface")) {
decls.push_back(
createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
name, codeBlock, loc, typeTy));
}
}
}
template <typename ConstraintT>
ast::Decl *
Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
SMRange loc, ast::Type type) {
// Build the single input parameter.
ast::DeclScope *argScope = pushDeclScope();
auto *paramVar = ast::VariableDecl::create(
ctx, ast::Name::create(ctx, "self", loc), type,
/*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
argScope->add(paramVar);
popDeclScope();
// Build the native constraint.
auto *constraintDecl = ast::UserConstraintDecl::createNative(
ctx, ast::Name::create(ctx, name, loc), paramVar,
/*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx));
curDeclScope->add(constraintDecl);
return constraintDecl;
}
template <typename ConstraintT>
ast::Decl *
Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
SMRange loc, ast::Type type) {
// Format the condition template.
tblgen::FmtContext fmtContext;
fmtContext.withSelf("self");
std::string codeBlock =
tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext);
return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(),
codeBlock, loc, type);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2302,9 +2553,29 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr, FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
StringRef name, SMRange loc) { StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType(); ast::Type parentType = parentExpr->getType();
if (parentType.isa<ast::OperationType>()) { if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName()) if (name == ast::AllResultsMemberAccessExpr::getMemberName())
return valueRangeTy; return valueRangeTy;
// Verify member access based on the operation type.
if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) {
auto results = odsOp->getResults();
// Handle indexed results.
unsigned index = 0;
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
index < results.size()) {
return results[index].isVariadic() ? valueRangeTy : valueTy;
}
// Handle named results.
const auto *it = llvm::find_if(results, [&](const auto &result) {
return result.getName() == name;
});
if (it != results.end())
return it->isVariadic() ? valueRangeTy : valueTy;
}
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { } else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
// Handle indexed results. // Handle indexed results.
unsigned index = 0; unsigned index = 0;
@ -2331,9 +2602,10 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
MutableArrayRef<ast::NamedAttributeDecl *> attributes, MutableArrayRef<ast::NamedAttributeDecl *> attributes,
MutableArrayRef<ast::Expr *> results) { MutableArrayRef<ast::Expr *> results) {
Optional<StringRef> opNameRef = name->getName(); Optional<StringRef> opNameRef = name->getName();
const ods::Operation *odsOp = lookupODSOperation(opNameRef);
// Verify the inputs operands. // Verify the inputs operands.
if (failed(validateOperationOperands(loc, opNameRef, operands))) if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
return failure(); return failure();
// Verify the attribute list. // Verify the attribute list.
@ -2348,7 +2620,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
} }
// Verify the result types. // Verify the result types.
if (failed(validateOperationResults(loc, opNameRef, results))) if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
return failure(); return failure();
return ast::OperationExpr::create(ctx, loc, name, operands, results, return ast::OperationExpr::create(ctx, loc, name, operands, results,
@ -2357,21 +2629,28 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
LogicalResult LogicalResult
Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name, Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> operands) { MutableArrayRef<ast::Expr *> operands) {
return validateOperationOperandsOrResults(loc, name, operands, valueTy, return validateOperationOperandsOrResults(
valueRangeTy); "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
valueRangeTy);
} }
LogicalResult LogicalResult
Parser::validateOperationResults(SMRange loc, Optional<StringRef> name, Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> results) { MutableArrayRef<ast::Expr *> results) {
return validateOperationOperandsOrResults(loc, name, results, typeTy, return validateOperationOperandsOrResults(
typeRangeTy); "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
} }
LogicalResult Parser::validateOperationOperandsOrResults( LogicalResult Parser::validateOperationOperandsOrResults(
SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values, StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
ast::Type singleTy, ast::Type rangeTy) { Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
ast::Type rangeTy) {
// All operation types accept a single range parameter. // All operation types accept a single range parameter.
if (values.size() == 1) { if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy))) if (failed(convertExpressionTo(values[0], rangeTy)))
@ -2379,6 +2658,29 @@ LogicalResult Parser::validateOperationOperandsOrResults(
return success(); return success();
} }
/// If the operation has ODS information, we can more accurately verify the
/// values.
if (odsOpLoc) {
if (odsValues.size() != values.size()) {
return emitErrorAndNote(
loc,
llvm::formatv("invalid number of {0} groups for `{1}`; expected "
"{2}, but got {3}",
groupName, *name, odsValues.size(), values.size()),
*odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
}
auto diagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
*odsOpLoc);
};
for (unsigned i = 0, e = values.size(); i < e; ++i) {
ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
return failure();
}
return success();
}
// Otherwise, accept the value groups as they have been defined and just // Otherwise, accept the value groups as they have been defined and just
// ensure they are one of the expected types. // ensure they are one of the expected types.
for (ast::Expr *&valueExpr : values) { for (ast::Expr *&valueExpr : values) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s // RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x mlir | FileCheck %s
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AttributeExpr // AttributeExpr
@ -55,6 +55,24 @@ Pattern OpAllResultMemberAccess {
// ----- // -----
// Handle implicit "named" operation results access.
#include "include/ops.td"
// CHECK: pdl.pattern @OpResultMemberAccess
// CHECK: %[[OP0:.*]] = operation
// CHECK: %[[RES:.*]] = results 0 of %[[OP0]] -> !pdl.value
// CHECK: %[[RES1:.*]] = results 0 of %[[OP0]] -> !pdl.value
// CHECK: %[[RES2:.*]] = results 1 of %[[OP0]] -> !pdl.range<value>
// CHECK: %[[RES3:.*]] = results 1 of %[[OP0]] -> !pdl.range<value>
// CHECK: operation(%[[RES]], %[[RES1]], %[[RES2]], %[[RES3]] : !pdl.value, !pdl.value, !pdl.range<value>, !pdl.range<value>)
Pattern OpResultMemberAccess {
let op: Op<test.with_results>;
erase op<>(op.0, op.result, op.1, op.var_result);
}
// -----
// CHECK: pdl.pattern @TupleMemberAccessNumber // CHECK: pdl.pattern @TupleMemberAccessNumber
// CHECK: %[[FIRST:.*]] = operation "test.first" // CHECK: %[[FIRST:.*]] = operation "test.first"
// CHECK: %[[SECOND:.*]] = operation "test.second" // CHECK: %[[SECOND:.*]] = operation "test.second"

View File

@ -0,0 +1,9 @@
include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect {
let name = "test";
}
def OpWithResults : Op<Test_Dialect, "with_results"> {
let results = (outs I64:$result, Variadic<I64>:$var_result);
}

View File

@ -19,5 +19,5 @@
// ----- // -----
// CHECK: expected include filename to end with `.pdll` // CHECK: expected include filename to end with `.pdll` or `.td`
#include "unknown_file.foo" #include "unknown_file.foo"

View File

@ -1,4 +1,4 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s // RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Reference Expr // Reference Expr
@ -276,6 +276,26 @@ Pattern {
// ----- // -----
#include "include/ops.td"
Pattern {
// CHECK: invalid number of operand groups for `test.all_empty`; expected 0, but got 2
// CHECK: see the definition of `test.all_empty` here
let foo = op<test.all_empty>(operand1: Value, operand2: Value);
}
// -----
#include "include/ops.td"
Pattern {
// CHECK: invalid number of result groups for `test.all_empty`; expected 0, but got 2
// CHECK: see the definition of `test.all_empty` here
let foo = op<test.all_empty> -> (result1: Type, result2: Type);
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// `type` Expr // `type` Expr
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,4 +1,4 @@
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s // RUN: mlir-pdll %s -I %S -I %S/../../../include -split-input-file | FileCheck %s
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AttrExpr // AttrExpr
@ -71,6 +71,25 @@ Pattern {
// ----- // -----
#include "include/ops.td"
// CHECK: Module
// CHECK: `-VariableDecl {{.*}} Name<firstEltIndex> Type<Value>
// CHECK: `-MemberAccessExpr {{.*}} Member<0> Type<Value>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<test.all_single>>
// CHECK: `-VariableDecl {{.*}} Name<firstEltName> Type<Value>
// CHECK: `-MemberAccessExpr {{.*}} Member<result> Type<Value>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<test.all_single>>
Pattern {
let op: Op<test.all_single>;
let firstEltIndex = op.0;
let firstEltName = op.result;
erase op;
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// OperationExpr // OperationExpr
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,5 @@
include "mlir/IR/OpBase.td"
def TestAttrInterface : AttrInterface<"TestAttrInterface">;
def TestOpInterface : OpInterface<"TestOpInterface">;
def TestTypeInterface : TypeInterface<"TestTypeInterface">;

View File

@ -0,0 +1,26 @@
include "include/interfaces.td"
def Test_Dialect : Dialect {
let name = "test";
}
def OpAllEmpty : Op<Test_Dialect, "all_empty">;
def OpAllSingle : Op<Test_Dialect, "all_single"> {
let arguments = (ins I64:$operand, I64Attr:$attr);
let results = (outs I64:$result);
}
def OpAllOptional : Op<Test_Dialect, "all_optional"> {
let arguments = (ins Optional<I64>:$operand, OptionalAttr<I64Attr>:$attr);
let results = (outs Optional<I64>:$result);
}
def OpAllVariadic : Op<Test_Dialect, "all_variadic"> {
let arguments = (ins Variadic<I64>:$operands);
let results = (outs Variadic<I64>:$results);
}
def OpMultipleSingleResult : Op<Test_Dialect, "multiple_single_result"> {
let results = (outs I64:$result, I64:$result2);
}

View File

@ -0,0 +1,52 @@
// RUN: mlir-pdll %s -I %S -I %S/../../../include -dump-ods 2>&1 | FileCheck %s
#include "include/ops.td"
// CHECK: Operation `test.all_empty` {
// CHECK-NEXT: }
// CHECK: Operation `test.all_optional` {
// CHECK-NEXT: Attributes { attr : Optional<I64Attr> }
// CHECK-NEXT: Operands { operand : Optional<I64> }
// CHECK-NEXT: Results { result : Optional<I64> }
// CHECK-NEXT: }
// CHECK: Operation `test.all_single` {
// CHECK-NEXT: Attributes { attr : I64Attr }
// CHECK-NEXT: Operands { operand : I64 }
// CHECK-NEXT: Results { result : I64 }
// CHECK-NEXT: }
// CHECK: Operation `test.all_variadic` {
// CHECK-NEXT: Operands { operands : Variadic<I64> }
// CHECK-NEXT: Results { results : Variadic<I64> }
// CHECK-NEXT: }
// CHECK: AttributeConstraint `I64Attr` {
// CHECK-NEXT: Summary: 64-bit signless integer attribute
// CHECK-NEXT: CppClass: ::mlir::IntegerAttr
// CHECK-NEXT: }
// CHECK: TypeConstraint `I64` {
// CHECK-NEXT: Summary: 64-bit signless integer
// CHECK-NEXT: CppClass: ::mlir::IntegerType
// CHECK-NEXT: }
// CHECK: UserConstraintDecl {{.*}} Name<TestAttrInterface> ResultType<Tuple<>> Code<llvm::isa<::TestAttrInterface>(self)>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<self> Type<Attr>
// CHECK: `Constraints`
// CHECK: `-AttrConstraintDecl
// CHECK: UserConstraintDecl {{.*}} Name<TestOpInterface> ResultType<Tuple<>> Code<llvm::isa<::TestOpInterface>(self)>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<self> Type<Op>
// CHECK: `Constraints`
// CHECK: `-OpConstraintDecl
// CHECK: `-OpNameDecl
// CHECK: UserConstraintDecl {{.*}} Name<TestTypeInterface> ResultType<Tuple<>> Code<llvm::isa<::TestTypeInterface>(self)>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<self> Type<Type>
// CHECK: `Constraints`
// CHECK: `-TypeConstraintDecl {{.*}}

View File

@ -1,4 +1,4 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s // RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s
// CHECK: expected top-level declaration, such as a `Pattern` // CHECK: expected top-level declaration, such as a `Pattern`
10 10
@ -250,6 +250,28 @@ Pattern {
// ----- // -----
#include "include/ops.td"
Pattern {
// CHECK: unable to convert expression of type `Op<test.all_empty>` to the expected type of `Value`
// CHECK: see the definition of `test.all_empty`, which was defined with zero results
let value: Value = op<test.all_empty>;
erase _: Op;
}
// -----
#include "include/ops.td"
Pattern {
// CHECK: unable to convert expression of type `Op<test.multiple_single_result>` to the expected type of `Value`
// CHECK: see the definition of `test.multiple_single_result`, which was defined with at least 2 results
let value: Value = op<test.multiple_single_result>;
erase _: Op;
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// `replace` // `replace`
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -13,6 +13,7 @@
#include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/CodeGen/CPPGen.h" #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/Parser/Parser.h" #include "mlir/Tools/PDLL/Parser/Parser.h"
#include "llvm/Support/CommandLine.h" #include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h" #include "llvm/Support/InitLLVM.h"
@ -35,16 +36,23 @@ enum class OutputType {
static LogicalResult static LogicalResult
processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
OutputType outputType, std::vector<std::string> &includeDirs) { OutputType outputType, std::vector<std::string> &includeDirs,
bool dumpODS) {
llvm::SourceMgr sourceMgr; llvm::SourceMgr sourceMgr;
sourceMgr.setIncludeDirs(includeDirs); sourceMgr.setIncludeDirs(includeDirs);
sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc()); sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc());
ast::Context astContext; ods::Context odsContext;
ast::Context astContext(odsContext);
FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr); FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr);
if (failed(module)) if (failed(module))
return failure(); return failure();
// Print out the ODS information if requested.
if (dumpODS)
odsContext.print(llvm::errs());
// Generate the output.
if (outputType == OutputType::AST) { if (outputType == OutputType::AST) {
(*module)->print(os); (*module)->print(os);
return success(); return success();
@ -66,6 +74,10 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
// FIXME: This is necessary because we link in TableGen, which defines its
// options as static variables.. some of which overlap with our options.
llvm::cl::ResetCommandLineParser();
llvm::cl::opt<std::string> inputFilename( llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"), llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::value_desc("filename")); llvm::cl::value_desc("filename"));
@ -78,6 +90,11 @@ int main(int argc, char **argv) {
"I", llvm::cl::desc("Directory of include files"), "I", llvm::cl::desc("Directory of include files"),
llvm::cl::value_desc("directory"), llvm::cl::Prefix); llvm::cl::value_desc("directory"), llvm::cl::Prefix);
llvm::cl::opt<bool> dumpODS(
"dump-ods",
llvm::cl::desc(
"Print out the parsed ODS information from the input file"),
llvm::cl::init(false));
llvm::cl::opt<bool> splitInputFile( llvm::cl::opt<bool> splitInputFile(
"split-input-file", "split-input-file",
llvm::cl::desc("Split the input file into pieces and process each " llvm::cl::desc("Split the input file into pieces and process each "
@ -118,7 +135,8 @@ int main(int argc, char **argv) {
// up into small pieces and checks each independently. // up into small pieces and checks each independently.
auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
raw_ostream &os) { raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs); return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs,
dumpODS);
}; };
if (splitInputFile) { if (splitInputFile) {
if (failed(splitAndProcessBuffer(std::move(inputFile), processFn, if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,