[mlir][ods] Cleanup of Class Codegen helper

Depends on D113331

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D113714
This commit is contained in:
Mogball 2021-11-12 01:17:05 +00:00
parent ece17064b5
commit 2696a9529e
8 changed files with 764 additions and 854 deletions

View File

@ -0,0 +1,412 @@
//===- Class.h - Helper classes for C++ code emission -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines several classes for Op C++ code emission. They are only
// expected to be used by MLIR TableGen backends.
//
// We emit the op declaration and definition into separate files: *Ops.h.inc
// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
//
// In order to do this split, we need to track method signature and
// implementation logic separately. Signature information is used for both
// declaration and definition, while implementation logic is only for
// definition. So we have the following classes for C++ code emission.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_CLASS_H_
#define MLIR_TABLEGEN_CLASS_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include <set>
#include <string>
namespace mlir {
namespace tblgen {
class FmtObjectBase;
/// This class contains a single method parameter for a C++ function.
class MethodParameter {
public:
/// Create a method parameter with a C++ type, parameter name, and an optional
/// default value. Marking a parameter as "optional" is a cosmetic effect on
/// the generated code.
template <typename TypeT, typename NameT, typename DefaultT>
MethodParameter(TypeT &&type, NameT &&name, DefaultT &&defaultValue,
bool optional = false)
: type(stringify(std::forward<TypeT>(type))),
name(stringify(std::forward<NameT>(name))),
defaultValue(stringify(std::forward<DefaultT>(defaultValue))),
optional(optional) {}
/// Create a method parameter with a C++ type, parameter name, and no default
/// value.
template <typename TypeT, typename NameT>
MethodParameter(TypeT &&type, NameT &&name, bool optional = false)
: MethodParameter(std::forward<TypeT>(type), std::forward<NameT>(name),
/*defaultValue=*/"", optional) {}
/// Write the parameter as part of a method declaration.
void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
/// Write the parameter as part of a method definition.
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
/// Get the C++ type.
const std::string &getType() const { return type; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }
private:
void writeTo(raw_ostream &os, bool emitDefault) const;
/// The C++ type.
std::string type;
/// The variable name.
std::string name;
/// An optional default value. The default value exists if the string is not
/// empty.
std::string defaultValue;
/// Whether the parameter should be indicated as "optional".
bool optional;
};
/// This class contains a list of method parameters for constructor, class
/// methods, and method signatures.
class MethodParameters {
public:
/// Create a list of method parameters.
MethodParameters(std::initializer_list<MethodParameter> parameters)
: parameters(parameters) {}
MethodParameters(SmallVector<MethodParameter> parameters)
: parameters(std::move(parameters)) {}
/// Write the parameters as part of a method declaration.
void writeDeclTo(raw_ostream &os) const;
/// Write the parameters as part of a method definition.
void writeDefTo(raw_ostream &os) const;
/// Determine whether this list of parameters "subsumes" another, which occurs
/// when this parameter list is identical to the other and has zero or more
/// additional default-valued parameters.
bool subsumes(const MethodParameters &other) const;
/// Return the number of parameters.
unsigned getNumParameters() const { return parameters.size(); }
private:
llvm::SmallVector<MethodParameter> parameters;
};
/// This class contains the signature of a C++ method, including the return
/// type. method name, and method parameters.
class MethodSignature {
public:
MethodSignature(StringRef retType, StringRef name,
SmallVector<MethodParameter> &&parameters)
: returnType(retType), methodName(name),
parameters(std::move(parameters)) {}
template <typename... Parameters>
MethodSignature(StringRef retType, StringRef name, Parameters &&...parameters)
: returnType(retType), methodName(name),
parameters({std::forward<Parameters>(parameters)...}) {}
/// Determine whether a method with this signature makes a method with
/// `other` signature redundant. This occurs if the signatures have the same
/// name and this signature's parameteres subsume the other's.
///
/// A method that makes another method redundant with a different return type
/// can replace the other, the assumption being that the subsuming method
/// provides a more resolved return type, e.g. IntegerAttr vs. Attribute.
bool makesRedundant(const MethodSignature &other) const;
/// Get the name of the method.
StringRef getName() const { return methodName; }
/// Get the number of parameters.
unsigned getNumParameters() const { return parameters.getNumParameters(); }
/// Write the signature as part of a method declaration.
void writeDeclTo(raw_ostream &os) const;
/// Write the signature as part of a method definition. `namePrefix` is to be
/// prepended to the method name (typically namespaces for qualifying the
/// method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
/// The method's C++ return type.
std::string returnType;
/// The method name.
std::string methodName;
/// The method's parameter list.
MethodParameters parameters;
};
/// Class for holding the body of an op's method for C++ code emission
class MethodBody {
public:
explicit MethodBody(bool declOnly);
MethodBody &operator<<(Twine content);
MethodBody &operator<<(int content);
MethodBody &operator<<(const FmtObjectBase &content);
void writeTo(raw_ostream &os) const;
private:
/// Whether this class should record method body.
bool isEffective;
/// The body of the method.
std::string body;
};
/// Class for holding an op's method for C++ code emission
class Method {
public:
/// Properties (qualifiers) of class methods. Bitfield is used here to help
/// querying properties.
enum Property {
MP_None = 0x0,
MP_Static = 0x1,
MP_Constructor = 0x2,
MP_Private = 0x4,
MP_Declaration = 0x8,
MP_Inline = 0x10,
MP_Constexpr = 0x20 | MP_Inline,
MP_StaticDeclaration = MP_Static | MP_Declaration,
};
template <typename... Args>
Method(StringRef retType, StringRef name, Property property, Args &&...args)
: properties(property),
methodSignature(retType, name, std::forward<Args>(args)...),
methodBody(properties & MP_Declaration) {}
Method(Method &&) = default;
Method &operator=(Method &&) = default;
virtual ~Method() = default;
MethodBody &body() { return methodBody; }
/// Returns true if this is a static method.
bool isStatic() const { return properties & MP_Static; }
/// Returns true if this is a private method.
bool isPrivate() const { return properties & MP_Private; }
/// Returns true if this is an inline method.
bool isInline() const { return properties & MP_Inline; }
/// Returns the name of this method.
StringRef getName() const { return methodSignature.getName(); }
/// Returns if this method makes the `other` method redundant.
bool makesRedundant(const Method &other) const {
return methodSignature.makesRedundant(other.methodSignature);
}
/// Writes the method as a declaration to the given `os`.
virtual void writeDeclTo(raw_ostream &os) const;
/// Writes the method as a definition to the given `os`. `namePrefix` is the
/// prefix to be prepended to the method name (typically namespaces for
/// qualifying the method definition).
virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
protected:
/// A collection of method properties.
Property properties;
/// The signature of the method.
MethodSignature methodSignature;
/// The body of the method, if it has one.
MethodBody methodBody;
};
} // end namespace tblgen
} // end namespace mlir
/// The OR of two method properties should return method properties. Ensure that
/// this function is visible to `Class`.
inline constexpr mlir::tblgen::Method::Property
operator|(mlir::tblgen::Method::Property lhs,
mlir::tblgen::Method::Property rhs) {
return mlir::tblgen::Method::Property(static_cast<unsigned>(lhs) |
static_cast<unsigned>(rhs));
}
namespace mlir {
namespace tblgen {
/// Class for holding an op's constructor method for C++ code emission.
class Constructor : public Method {
public:
template <typename... Parameters>
Constructor(StringRef className, Property property,
Parameters &&...parameters)
: Method("", className, property,
std::forward<Parameters>(parameters)...) {}
/// Add member initializer to constructor initializing `name` with `value`.
void addMemberInitializer(StringRef name, StringRef value);
/// Writes the method as a definition to the given `os`. `namePrefix` is the
/// prefix to be prepended to the method name (typically namespaces for
/// qualifying the method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
private:
/// Member initializers.
std::string memberInitializers;
};
/// A class used to emit C++ classes from Tablegen. Contains a list of public
/// methods and a list of private fields to be emitted.
class Class {
public:
explicit Class(StringRef name);
/// Add a new constructor to this class and prune and constructors made
/// redundant by it. Returns null if the constructor was not added. Else,
/// returns a pointer to the new constructor.
template <typename... Parameters>
Constructor *addConstructorAndPrune(Parameters &&...parameters) {
return addConstructorAndPrune(
Constructor(getClassName(), Method::MP_Constructor,
std::forward<Parameters>(parameters)...));
}
/// Add a new method to this class and prune any methods made redundant by it.
/// Returns null if the method was not added (because an existing method would
/// make it redundant). Else, returns a pointer to the new method.
template <typename... Parameters>
Method *addMethod(StringRef retType, StringRef name,
Method::Property properties, Parameters &&...parameters) {
return addMethodAndPrune(Method(retType, name, properties,
std::forward<Parameters>(parameters)...));
}
/// Add a method with statically-known properties.
template <Method::Property Properties = Method::MP_None,
typename... Parameters>
Method *addMethod(StringRef retType, StringRef name,
Parameters &&...parameters) {
return addMethod(retType, name, Properties,
std::forward<Parameters>(parameters)...);
}
/// Add a static method.
template <Method::Property Properties = Method::MP_None,
typename... Parameters>
Method *addStaticMethod(StringRef retType, StringRef name,
Parameters &&...parameters) {
return addMethod<Properties | Method::MP_Static>(
retType, name, std::forward<Parameters>(parameters)...);
}
/// Add an inline static method.
template <Method::Property Properties = Method::MP_None,
typename... Parameters>
Method *addStaticInlineMethod(StringRef retType, StringRef name,
Parameters &&...parameters) {
return addMethod<Properties | Method::MP_Static | Method::MP_Inline>(
retType, name, std::forward<Parameters>(parameters)...);
}
/// Add an inline method.
template <Method::Property Properties = Method::MP_None,
typename... Parameters>
Method *addInlineMethod(StringRef retType, StringRef name,
Parameters &&...parameters) {
return addMethod<Properties | Method::MP_Inline>(
retType, name, std::forward<Parameters>(parameters)...);
}
/// Add a declaration for a method.
template <Method::Property Properties = Method::MP_None,
typename... Parameters>
Method *declareMethod(StringRef retType, StringRef name,
Parameters &&...parameters) {
return addMethod<Properties | Method::MP_Declaration>(
retType, name, std::forward<Parameters>(parameters)...);
}
/// Add a declaration for a static method.
template <Method::Property Properties = Method::MP_None,
typename... Parameters>
Method *declareStaticMethod(StringRef retType, StringRef name,
Parameters &&...parameters) {
return addMethod<Properties | Method::MP_StaticDeclaration>(
retType, name, std::forward<Parameters>(parameters)...);
}
/// Creates a new field in this class.
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
/// Writes this op's class as a declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
/// Writes the method definitions in this op's class to the given `os`.
void writeDefTo(raw_ostream &os) const;
/// Returns the C++ class name of the op.
StringRef getClassName() const { return className; }
protected:
/// Get a list of all the methods to emit, filtering out hidden ones.
void forAllMethods(llvm::function_ref<void(const Method &)> func) const {
llvm::for_each(constructors, [&](auto &ctor) { func(ctor); });
llvm::for_each(methods, [&](auto &method) { func(method); });
}
/// Add a new constructor if it is not made redundant by any existing
/// constructors and prune and existing constructors made redundant.
Constructor *addConstructorAndPrune(Constructor &&newCtor);
/// Add a new method if it is not made redundant by any existing methods and
/// prune and existing methods made redundant.
Method *addMethodAndPrune(Method &&newMethod);
/// The C++ class name.
std::string className;
/// The list of constructors.
std::vector<Constructor> constructors;
/// The list of class methods.
std::vector<Method> methods;
/// The list of class members.
SmallVector<std::string, 4> fields;
};
// Class for holding an op for C++ code emission
class OpClass : public Class {
public:
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
/// Adds an op trait.
void addTrait(Twine trait);
/// Writes this op's class as a declaration to the given `os`. Redefines
/// Class::writeDeclTo to also emit traits and extra class declarations.
void writeDeclTo(raw_ostream &os) const;
private:
StringRef extraClassDeclaration;
llvm::SetVector<std::string, SmallVector<std::string>, StringSet<>> traits;
};
} // namespace tblgen
} // namespace mlir
#endif // MLIR_TABLEGEN_CLASS_H_

View File

@ -216,9 +216,28 @@ private:
ConstraintMap regionConstraints;
};
// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
/// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
std::string escapeString(StringRef value);
namespace detail {
template <typename> struct stringifier {
template <typename T> static std::string apply(T &&t) {
return std::string(std::forward<T>(t));
}
};
template <> struct stringifier<Twine> {
static std::string apply(const Twine &twine) {
return twine.str();
}
};
} // end namespace detail
/// Generically convert a value to a std::string.
template <typename T> std::string stringify(T &&t) {
return detail::stringifier<std::remove_reference_t<std::remove_const_t<T>>>::
apply(std::forward<T>(t));
}
} // namespace tblgen
} // namespace mlir

View File

@ -1,442 +0,0 @@
//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines several classes for Op C++ code emission. They are only
// expected to be used by MLIR TableGen backends.
//
// We emit the op declaration and definition into separate files: *Ops.h.inc
// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
//
// In order to do this split, we need to track method signature and
// implementation logic separately. Signature information is used for both
// declaration and definition, while implementation logic is only for
// definition. So we have the following classes for C++ code emission.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_OPCLASS_H_
#define MLIR_TABLEGEN_OPCLASS_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/raw_ostream.h"
#include <set>
#include <string>
namespace mlir {
namespace tblgen {
class FmtObjectBase;
// Class for holding a single parameter of an op's method for C++ code emission.
class OpMethodParameter {
public:
// Properties (qualifiers) for the parameter.
enum Property {
PP_None = 0x0,
PP_Optional = 0x1,
};
OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "",
Property properties = PP_None)
: type(type), name(name), defaultValue(defaultValue),
properties(properties) {}
OpMethodParameter(StringRef type, StringRef name, Property property)
: OpMethodParameter(type, name, "", property) {}
// Writes the parameter as a part of a method declaration to `os`.
void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
// Writes the parameter as a part of a method definition to `os`
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
const std::string &getType() const { return type; }
bool hasDefaultValue() const { return !defaultValue.empty(); }
private:
void writeTo(raw_ostream &os, bool emitDefault) const;
std::string type;
std::string name;
std::string defaultValue;
Property properties;
};
// Base class for holding parameters of an op's method for C++ code emission.
class OpMethodParameters {
public:
// Discriminator for LLVM-style RTTI.
enum ParamsKind {
// Separate type and name for each parameter is not known.
PK_Unresolved,
// Each parameter is resolved to a type and name.
PK_Resolved,
};
OpMethodParameters(ParamsKind kind) : kind(kind) {}
virtual ~OpMethodParameters() {}
// LLVM-style RTTI support.
ParamsKind getKind() const { return kind; }
// Writes the parameters as a part of a method declaration to `os`.
virtual void writeDeclTo(raw_ostream &os) const = 0;
// Writes the parameters as a part of a method definition to `os`
virtual void writeDefTo(raw_ostream &os) const = 0;
// Factory methods to create the correct type of `OpMethodParameters`
// object based on the arguments.
static std::unique_ptr<OpMethodParameters> create();
static std::unique_ptr<OpMethodParameters> create(StringRef params);
static std::unique_ptr<OpMethodParameters>
create(llvm::SmallVectorImpl<OpMethodParameter> &&params);
static std::unique_ptr<OpMethodParameters>
create(StringRef type, StringRef name, StringRef defaultValue = "");
private:
const ParamsKind kind;
};
// Class for holding unresolved parameters.
class OpMethodUnresolvedParameters : public OpMethodParameters {
public:
OpMethodUnresolvedParameters(StringRef params)
: OpMethodParameters(PK_Unresolved), parameters(params) {}
// write the parameters as a part of a method declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const override;
// write the parameters as a part of a method definition to the given `os`
void writeDefTo(raw_ostream &os) const override;
// LLVM-style RTTI support.
static bool classof(const OpMethodParameters *params) {
return params->getKind() == PK_Unresolved;
}
private:
std::string parameters;
};
// Class for holding resolved parameters.
class OpMethodResolvedParameters : public OpMethodParameters {
public:
OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {}
OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &&params)
: OpMethodParameters(PK_Resolved) {
for (OpMethodParameter &param : params)
parameters.emplace_back(std::move(param));
}
OpMethodResolvedParameters(StringRef type, StringRef name,
StringRef defaultValue)
: OpMethodParameters(PK_Resolved) {
parameters.emplace_back(type, name, defaultValue);
}
// Returns the number of parameters.
size_t getNumParameters() const { return parameters.size(); }
// Returns if this method makes the `other` method redundant. Note that this
// is more than just finding conflicting methods. This method determines if
// the 2 set of parameters are conflicting and if so, returns true if this
// method has a more general set of parameters that can replace all possible
// calls to the `other` method.
bool makesRedundant(const OpMethodResolvedParameters &other) const;
// write the parameters as a part of a method declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const override;
// write the parameters as a part of a method definition to the given `os`
void writeDefTo(raw_ostream &os) const override;
// LLVM-style RTTI support.
static bool classof(const OpMethodParameters *params) {
return params->getKind() == PK_Resolved;
}
private:
llvm::SmallVector<OpMethodParameter, 4> parameters;
};
// Class for holding the signature of an op's method for C++ code emission
class OpMethodSignature {
public:
template <typename... Args>
OpMethodSignature(StringRef retType, StringRef name, Args &&...args)
: returnType(retType), methodName(name),
parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {}
OpMethodSignature(OpMethodSignature &&) = default;
// Returns if a method with this signature makes a method with `other`
// signature redundant. Only supports resolved parameters.
bool makesRedundant(const OpMethodSignature &other) const;
// Returns the number of parameters (for resolved parameters).
size_t getNumParameters() const {
return cast<OpMethodResolvedParameters>(parameters.get())
->getNumParameters();
}
// Returns the name of the method.
StringRef getName() const { return methodName; }
// Writes the signature as a method declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the signature as the start of a method definition to the given `os`.
// `namePrefix` is the prefix to be prepended to the method name (typically
// namespaces for qualifying the method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
private:
std::string returnType;
std::string methodName;
std::unique_ptr<OpMethodParameters> parameters;
};
// Class for holding the body of an op's method for C++ code emission
class OpMethodBody {
public:
explicit OpMethodBody(bool declOnly);
OpMethodBody &operator<<(Twine content);
OpMethodBody &operator<<(int content);
OpMethodBody &operator<<(const FmtObjectBase &content);
void writeTo(raw_ostream &os) const;
private:
// Whether this class should record method body.
bool isEffective;
std::string body;
};
// Class for holding an op's method for C++ code emission
class OpMethod {
public:
// Properties (qualifiers) of class methods. Bitfield is used here to help
// querying properties.
enum Property {
MP_None = 0x0,
MP_Static = 0x1,
MP_Constructor = 0x2,
MP_Private = 0x4,
MP_Declaration = 0x8,
MP_Inline = 0x10,
MP_Constexpr = 0x20 | MP_Inline,
MP_StaticDeclaration = MP_Static | MP_Declaration,
};
template <typename... Args>
OpMethod(StringRef retType, StringRef name, Property property, unsigned id,
Args &&...args)
: properties(property),
methodSignature(retType, name, std::forward<Args>(args)...),
methodBody(properties & MP_Declaration), id(id) {}
OpMethod(OpMethod &&) = default;
virtual ~OpMethod() = default;
OpMethodBody &body() { return methodBody; }
// Returns true if this is a static method.
bool isStatic() const { return properties & MP_Static; }
// Returns true if this is a private method.
bool isPrivate() const { return properties & MP_Private; }
// Returns true if this is an inline method.
bool isInline() const { return properties & MP_Inline; }
// Returns the name of this method.
StringRef getName() const { return methodSignature.getName(); }
// Returns the ID for this method
unsigned getID() const { return id; }
// Returns if this method makes the `other` method redundant.
bool makesRedundant(const OpMethod &other) const {
return methodSignature.makesRedundant(other.methodSignature);
}
// Writes the method as a declaration to the given `os`.
virtual void writeDeclTo(raw_ostream &os) const;
// Writes the method as a definition to the given `os`. `namePrefix` is the
// prefix to be prepended to the method name (typically namespaces for
// qualifying the method definition).
virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
protected:
Property properties;
OpMethodSignature methodSignature;
OpMethodBody methodBody;
const unsigned id;
};
// Class for holding an op's constructor method for C++ code emission.
class OpConstructor : public OpMethod {
public:
template <typename... Args>
OpConstructor(StringRef className, Property property, unsigned id,
Args &&...args)
: OpMethod("", className, property, id, std::forward<Args>(args)...) {}
// Add member initializer to constructor initializing `name` with `value`.
void addMemberInitializer(StringRef name, StringRef value);
// Writes the method as a definition to the given `os`. `namePrefix` is the
// prefix to be prepended to the method name (typically namespaces for
// qualifying the method definition).
void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
private:
// Member initializers.
std::string memberInitializers;
};
// A class used to emit C++ classes from Tablegen. Contains a list of public
// methods and a list of private fields to be emitted.
class Class {
public:
explicit Class(StringRef name);
// Adds a new method to this class and prune redundant methods. Returns null
// if the method was not added (because an existing method would make it
// redundant), else returns a pointer to the added method. Note that this call
// may also delete existing methods that are made redundant by a method to the
// class.
template <typename... Args>
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
OpMethod::Property properties, Args &&...args) {
auto newMethod = std::make_unique<OpMethod>(
retType, name, properties, nextMethodID++, std::forward<Args>(args)...);
return addMethodAndPrune(methods, std::move(newMethod));
}
template <typename... Args>
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
Args &&...args) {
return addMethodAndPrune(retType, name, OpMethod::MP_None,
std::forward<Args>(args)...);
}
template <typename... Args>
OpConstructor *addConstructorAndPrune(Args &&...args) {
auto newConstructor = std::make_unique<OpConstructor>(
getClassName(), OpMethod::MP_Constructor, nextMethodID++,
std::forward<Args>(args)...);
return addMethodAndPrune(constructors, std::move(newConstructor));
}
// Creates a new field in this class.
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
// Writes this op's class as a declaration to the given `os`.
void writeDeclTo(raw_ostream &os) const;
// Writes the method definitions in this op's class to the given `os`.
void writeDefTo(raw_ostream &os) const;
// Returns the C++ class name of the op.
StringRef getClassName() const { return className; }
protected:
// Get a list of all the methods to emit, filtering out hidden ones.
void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const {
using ConsRef = const std::unique_ptr<OpConstructor> &;
using MethodRef = const std::unique_ptr<OpMethod> &;
llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); });
llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); });
}
// For deterministic code generation, keep methods sorted in the order in
// which they were generated.
template <typename MethodTy>
struct MethodCompare {
bool operator()(const std::unique_ptr<MethodTy> &x,
const std::unique_ptr<MethodTy> &y) const {
return x->getID() < y->getID();
}
};
template <typename MethodTy>
using MethodSet =
std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>;
template <typename MethodTy>
MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set,
std::unique_ptr<MethodTy> &&newMethod) {
// Check if the new method will be made redundant by existing methods.
for (auto &method : set)
if (method->makesRedundant(*newMethod))
return nullptr;
// We can add this a method to the set. Prune any existing methods that will
// be made redundant by adding this new method. Note that the redundant
// check between two methods is more than a conflict check. makesRedundant()
// below will check if the new method conflicts with an existing method and
// if so, returns true if the new method makes the existing method redundant
// because all calls to the existing method can be subsumed by the new
// method. So makesRedundant() does a combined job of finding conflicts and
// deciding which of the 2 conflicting methods survive.
//
// Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing
// it manually here.
for (auto it = set.begin(), end = set.end(); it != end;) {
if (newMethod->makesRedundant(*(it->get())))
it = set.erase(it);
else
++it;
}
MethodTy *ret = newMethod.get();
set.insert(std::move(newMethod));
return ret;
}
std::string className;
MethodSet<OpConstructor> constructors;
MethodSet<OpMethod> methods;
unsigned nextMethodID = 0;
SmallVector<std::string, 4> fields;
};
// Class for holding an op for C++ code emission
class OpClass : public Class {
public:
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
// Adds an op trait.
void addTrait(Twine trait);
// Writes this op's class as a declaration to the given `os`. Redefines
// Class::writeDeclTo to also emit traits and extra class declarations.
void writeDeclTo(raw_ostream &os) const;
private:
StringRef extraClassDeclaration;
SmallVector<std::string, 4> traitsVec;
StringSet<> traitsSet;
};
} // namespace tblgen
} // namespace mlir
#endif // MLIR_TABLEGEN_OPCLASS_H_

View File

@ -13,12 +13,12 @@ llvm_add_library(MLIRTableGen STATIC
Attribute.cpp
AttrOrTypeDef.cpp
Builder.cpp
Class.cpp
Constraint.cpp
Dialect.cpp
Format.cpp
Interfaces.cpp
Operator.cpp
OpClass.cpp
Pass.cpp
Pattern.cpp
Predicate.cpp

View File

@ -1,4 +1,4 @@
//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
//===- Class.cpp - Helper classes for Op C++ code emission --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/Sequence.h"
@ -20,173 +20,102 @@
using namespace mlir;
using namespace mlir::tblgen;
namespace {
// Returns space to be emitted after the given C++ `type`. return "" if the
// ends with '&' or '*', or is empty, else returns " ".
StringRef getSpaceAfterType(StringRef type) {
static StringRef getSpaceAfterType(StringRef type) {
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
}
} // namespace
//===----------------------------------------------------------------------===//
// OpMethodParameter definitions
// MethodParameter definitions
//===----------------------------------------------------------------------===//
void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
if (properties & PP_Optional)
void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
if (optional)
os << "/*optional*/";
os << type << getSpaceAfterType(type) << name;
if (emitDefault && !defaultValue.empty())
if (emitDefault && hasDefaultValue())
os << " = " << defaultValue;
}
//===----------------------------------------------------------------------===//
// OpMethodParameters definitions
// MethodParameters definitions
//===----------------------------------------------------------------------===//
// Factory methods to construct the correct type of `OpMethodParameters`
// object based on the arguments.
std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
return std::make_unique<OpMethodResolvedParameters>();
void MethodParameters::writeDeclTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os,
[&os](auto &param) { param.writeDeclTo(os); });
}
void MethodParameters::writeDefTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os,
[&os](auto &param) { param.writeDefTo(os); });
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(StringRef params) {
return std::make_unique<OpMethodUnresolvedParameters>(params);
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &&params) {
return std::make_unique<OpMethodResolvedParameters>(std::move(params));
}
std::unique_ptr<OpMethodParameters>
OpMethodParameters::create(StringRef type, StringRef name,
StringRef defaultValue) {
return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
}
//===----------------------------------------------------------------------===//
// OpMethodUnresolvedParameters definitions
//===----------------------------------------------------------------------===//
void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
os << parameters;
}
void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
// We need to remove the default values for parameters in method definition.
// TODO: We are using '=' and ',' as delimiters for parameter
// initializers. This is incorrect for initializer list with more than one
// element. Change to a more robust approach.
llvm::SmallVector<StringRef, 4> tokens;
StringRef params = parameters;
while (!params.empty()) {
std::pair<StringRef, StringRef> parts = params.split("=");
tokens.push_back(parts.first);
params = parts.second.split(',').second;
}
llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
}
//===----------------------------------------------------------------------===//
// OpMethodResolvedParameters definitions
//===----------------------------------------------------------------------===//
// Returns true if a method with these parameters makes a method with parameters
// `other` redundant. This should return true only if all possible calls to the
// other method can be replaced by calls to this method.
bool OpMethodResolvedParameters::makesRedundant(
const OpMethodResolvedParameters &other) const {
const size_t otherNumParams = other.getNumParameters();
const size_t thisNumParams = getNumParameters();
// All calls to the other method can be replaced this method only if this
// method has the same or more arguments number of arguments as the other, and
// the common arguments have the same type.
if (thisNumParams < otherNumParams)
bool MethodParameters::subsumes(const MethodParameters &other) const {
// These parameters do not subsume the others if there are fewer parameters
// or their types do not match.
if (parameters.size() < other.parameters.size())
return false;
for (int idx : llvm::seq<int>(0, otherNumParams))
if (parameters[idx].getType() != other.parameters[idx].getType())
return false;
// If all the common arguments have the same type, we can elide the other
// method if this method has the same number of arguments as other or the
// first argument after the common ones has a default value (and by C++
// requirement, all the later ones will also have a default value).
return thisNumParams == otherNumParams ||
parameters[otherNumParams].hasDefaultValue();
}
void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
param.writeDeclTo(os);
});
}
void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
param.writeDefTo(os);
});
}
//===----------------------------------------------------------------------===//
// OpMethodSignature definitions
//===----------------------------------------------------------------------===//
// Returns if a method with this signature makes a method with `other` signature
// redundant. Only supports resolved parameters.
bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
if (methodName != other.methodName)
if (!std::equal(
other.parameters.begin(), other.parameters.end(), parameters.begin(),
[](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); }))
return false;
auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
auto *resolvedOther =
dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
if (resolvedThis && resolvedOther)
return resolvedThis->makesRedundant(*resolvedOther);
return false;
// If all the common parameters have the same type, we can elide the other
// method if this method has the same number of parameters as other or if the
// first paramater after the common parameters has a default value (and, as
// required by C++, subsequent parameters will have default values too).
return parameters.size() == other.parameters.size() ||
parameters[other.parameters.size()].hasDefaultValue();
}
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
//===----------------------------------------------------------------------===//
// MethodSignature definitions
//===----------------------------------------------------------------------===//
bool MethodSignature::makesRedundant(const MethodSignature &other) const {
return methodName == other.methodName &&
parameters.subsumes(other.parameters);
}
void MethodSignature::writeDeclTo(raw_ostream &os) const {
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
parameters->writeDeclTo(os);
parameters.writeDeclTo(os);
os << ")";
}
void OpMethodSignature::writeDefTo(raw_ostream &os,
StringRef namePrefix) const {
void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
os << returnType << getSpaceAfterType(returnType) << namePrefix
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
parameters->writeDefTo(os);
parameters.writeDefTo(os);
os << ")";
}
//===----------------------------------------------------------------------===//
// OpMethodBody definitions
// MethodBody definitions
//===----------------------------------------------------------------------===//
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {}
OpMethodBody &OpMethodBody::operator<<(Twine content) {
MethodBody &MethodBody::operator<<(Twine content) {
if (isEffective)
body.append(content.str());
return *this;
}
OpMethodBody &OpMethodBody::operator<<(int content) {
MethodBody &MethodBody::operator<<(int content) {
if (isEffective)
body.append(std::to_string(content));
return *this;
}
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
MethodBody &MethodBody::operator<<(const FmtObjectBase &content) {
if (isEffective)
body.append(content.str());
return *this;
}
void OpMethodBody::writeTo(raw_ostream &os) const {
void MethodBody::writeTo(raw_ostream &os) const {
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef;
if (bodyRef.empty() || bodyRef.back() != '\n')
@ -194,10 +123,10 @@ void OpMethodBody::writeTo(raw_ostream &os) const {
}
//===----------------------------------------------------------------------===//
// OpMethod definitions
// Method definitions
//===----------------------------------------------------------------------===//
void OpMethod::writeDeclTo(raw_ostream &os) const {
void Method::writeDeclTo(raw_ostream &os) const {
os.indent(2);
if (isStatic())
os << "static ";
@ -213,7 +142,7 @@ void OpMethod::writeDeclTo(raw_ostream &os) const {
}
}
void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
@ -227,15 +156,15 @@ void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
}
//===----------------------------------------------------------------------===//
// OpConstructor definitions
// Constructor definitions
//===----------------------------------------------------------------------===//
void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
void Constructor::addMemberInitializer(StringRef name, StringRef value) {
memberInitializers.append(std::string(llvm::formatv(
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
}
void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
// Do not write definition if the method is decl only.
if (properties & MP_Declaration)
return;
@ -243,7 +172,7 @@ void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
methodSignature.writeDefTo(os, namePrefix);
os << " " << memberInitializers << " {\n";
methodBody.writeTo(os);
os << "}";
os << "}\n";
}
//===----------------------------------------------------------------------===//
@ -259,12 +188,13 @@ void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
: formatv("{0} = {1}", varName, defaultValue).str();
fields.push_back(std::move(field));
}
void Class::writeDeclTo(raw_ostream &os) const {
bool hasPrivateMethod = false;
os << "class " << className << " {\n";
os << "public:\n";
forAllMethods([&](const OpMethod &method) {
forAllMethods([&](const Method &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
@ -276,7 +206,7 @@ void Class::writeDeclTo(raw_ostream &os) const {
os << '\n';
os << "private:\n";
if (hasPrivateMethod) {
forAllMethods([&](const OpMethod &method) {
forAllMethods([&](const Method &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << '\n';
@ -291,12 +221,35 @@ void Class::writeDeclTo(raw_ostream &os) const {
}
void Class::writeDefTo(raw_ostream &os) const {
forAllMethods([&](const OpMethod &method) {
forAllMethods([&](const Method &method) {
method.writeDefTo(os, className);
os << "\n\n";
os << "\n";
});
}
// Insert a new method into a list of methods, if it would not be pruned, and
// prune and existing methods.
template <typename ContainerT, typename MethodT>
MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) {
if (llvm::any_of(methods, [&](auto &method) {
return method.makesRedundant(newMethod);
}))
return nullptr;
llvm::erase_if(
methods, [&](auto &method) { return newMethod.makesRedundant(method); });
methods.push_back(std::move(newMethod));
return &methods.back();
}
Method *Class::addMethodAndPrune(Method &&newMethod) {
return insertAndPrune(methods, std::move(newMethod));
}
Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
return insertAndPrune(constructors, std::move(newCtor));
}
//===----------------------------------------------------------------------===//
// OpClass definitions
//===----------------------------------------------------------------------===//
@ -304,15 +257,11 @@ void Class::writeDefTo(raw_ostream &os) const {
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
void OpClass::addTrait(Twine trait) {
auto traitStr = trait.str();
if (traitsSet.insert(traitStr).second)
traitsVec.push_back(std::move(traitStr));
}
void OpClass::addTrait(Twine trait) { traits.insert(trait.str()); }
void OpClass::writeDeclTo(raw_ostream &os) const {
os << "class " << className << " : public ::mlir::Op<" << className;
for (const auto &trait : traitsVec)
for (const auto &trait : traits)
os << ", " << trait;
os << "> {\npublic:\n"
<< " using Op::Op;\n"
@ -320,7 +269,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
<< " using Adaptor = " << className << "Adaptor;\n";
bool hasPrivateMethod = false;
forAllMethods([&](const OpMethod &method) {
forAllMethods([&](const Method &method) {
if (!method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";
@ -335,7 +284,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
if (hasPrivateMethod) {
os << "\nprivate:\n";
forAllMethods([&](const OpMethod &method) {
forAllMethods([&](const Method &method) {
if (method.isPrivate()) {
method.writeDeclTo(os);
os << "\n";

View File

@ -10,11 +10,11 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/Optional.h"

View File

@ -13,11 +13,11 @@
#include "OpFormatGen.h"
#include "OpGenHelpers.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
#include "mlir/TableGen/Trait.h"
@ -361,7 +361,7 @@ private:
// types. `inferredAttributes` is populated with any attributes that are
// elided from the build list. The given `typeParamKind` and `attrParamKind`
// controls how result types and attributes are placed in the parameter list.
void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> &paramList,
void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
@ -369,7 +369,7 @@ private:
// Adds op arguments and regions into operation state for build() methods.
void
genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
llvm::StringSet<> &inferredAttributes,
bool isRawValueAttr = false);
@ -390,17 +390,16 @@ private:
// Generates verify statements for operands and results in the operation.
// The generated code will be attached to `body`.
void genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
void genOperandResultVerifier(MethodBody &body, Operator::value_range values,
StringRef valueKind);
// Generates verify statements for regions in the operation.
// The generated code will be attached to `body`.
void genRegionVerifier(OpMethodBody &body);
void genRegionVerifier(MethodBody &body);
// Generates verify statements for successors in the operation.
// The generated code will be attached to `body`.
void genSuccessorVerifier(OpMethodBody &body);
void genSuccessorVerifier(MethodBody &body);
// Generates the traits used by the object.
void genTraits();
@ -413,8 +412,8 @@ private:
// Generate op interface method for the given interface method. If
// 'declaration' is true, generates a declaration, else a definition.
OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
bool declaration = true);
Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
bool declaration = true);
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
@ -470,7 +469,7 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
// Generate attribute verification. If an op instance is not available, then
// attribute checks that require one will not be emitted.
static void genAttributeVerifier(
const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body,
const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
// Check that a required attribute exists.
//
@ -602,7 +601,7 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName,
static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
const Operator &op) {
if (m)
return;
@ -627,18 +626,15 @@ void OpEmitter::genAttrNameGetters() {
for (const NamedAttribute &namedAttr : op.getAttributes())
addAttrName(namedAttr.name);
// Include key attributes from several traits as implicitly registered.
std::string operandSizes = "operand_segment_sizes";
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
addAttrName(operandSizes);
std::string attrSizes = "result_segment_sizes";
addAttrName("operand_segment_sizes");
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
addAttrName(attrSizes);
addAttrName("result_segment_sizes");
// Emit the getAttributeNames method.
{
auto *method = opClass.addMethodAndPrune(
"::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames",
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline));
auto *method = opClass.addStaticInlineMethod(
"::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
ERROR_IF_PRUNED(method, "getAttributeNames", op);
auto &body = method->body();
if (attributeNames.empty()) {
@ -658,20 +654,18 @@ void OpEmitter::genAttrNameGetters() {
// Emit the getAttributeNameForIndex methods.
{
auto *method = opClass.addMethodAndPrune(
auto *method = opClass.addInlineMethod<Method::MP_Private>(
"::mlir::Identifier", "getAttributeNameForIndex",
OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline),
"unsigned", "index");
MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
method->body()
<< " return getAttributeNameForIndex((*this)->getName(), index);";
}
{
auto *method = opClass.addMethodAndPrune(
auto *method = opClass.addStaticInlineMethod<Method::MP_Private>(
"::mlir::Identifier", "getAttributeNameForIndex",
OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline |
OpMethod::MP_Static),
"::mlir::OperationName name, unsigned index");
MethodParameter("::mlir::OperationName", "name"),
MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
method->body() << "assert(index < " << attributeNames.size()
<< " && \"invalid attribute index\");\n"
@ -689,8 +683,7 @@ void OpEmitter::genAttrNameGetters() {
// Generate the non-static variant.
{
auto *method =
opClass.addMethodAndPrune("::mlir::Identifier", methodName,
OpMethod::Property(OpMethod::MP_Inline));
opClass.addInlineMethod("::mlir::Identifier", methodName);
ERROR_IF_PRUNED(method, methodName, op);
method->body()
<< llvm::formatv(attrNameMethodBody, attrIt.second).str();
@ -698,10 +691,9 @@ void OpEmitter::genAttrNameGetters() {
// Generate the static variant.
{
auto *method = opClass.addMethodAndPrune(
auto *method = opClass.addStaticInlineMethod(
"::mlir::Identifier", methodName,
OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
"::mlir::OperationName", "name");
MethodParameter("::mlir::OperationName", "name"));
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody,
"name, " + Twine(attrIt.second))
@ -717,13 +709,13 @@ void OpEmitter::genAttrGetters() {
// Emit the derived attribute body.
auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name))
if (auto *method = opClass.addMethod(attr.getReturnType(), name))
method->body() << " " << attr.getDerivedCodeBody() << "\n";
};
// Emit with return type specified.
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
auto *method = opClass.addMethod(attr.getReturnType(), name);
ERROR_IF_PRUNED(method, name, op);
auto &body = method->body();
body << " auto attr = " << name << "Attr();\n";
@ -748,7 +740,7 @@ void OpEmitter::genAttrGetters() {
// use the string interface for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
auto *method =
opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
opClass.addMethod(attr.getStorageType(), (name + "Attr").str());
if (!method)
return;
method->body() << formatv(
@ -773,68 +765,69 @@ void OpEmitter::genAttrGetters() {
[](const NamedAttribute &namedAttr) {
return namedAttr.attr.isDerivedAttr();
});
if (!derivedAttrs.empty()) {
opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
// Generate helper method to query whether a named attribute is a derived
// attribute. This enables, for example, avoiding adding an attribute that
// overlaps with a derived attribute.
{
auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
OpMethod::MP_Static,
"::llvm::StringRef", "name");
ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
auto &body = method->body();
for (auto namedAttr : derivedAttrs)
body << " if (name == \"" << namedAttr.name << "\") return true;\n";
body << " return false;";
}
// Generate method to materialize derived attributes as a DictionaryAttr.
{
auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
"materializeDerivedAttributes");
ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
auto &body = method->body();
if (derivedAttrs.empty())
return;
auto nonMaterializable =
make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
return namedAttr.attr.getConvertFromStorageCall().empty();
});
if (!nonMaterializable.empty()) {
std::string attrs;
llvm::raw_string_ostream os(attrs);
interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
os << op.getGetterName(attr.name);
opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
// Generate helper method to query whether a named attribute is a derived
// attribute. This enables, for example, avoiding adding an attribute that
// overlaps with a derived attribute.
{
auto *method =
opClass.addStaticMethod("bool", "isDerivedAttribute",
MethodParameter("::llvm::StringRef", "name"));
ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
auto &body = method->body();
for (auto namedAttr : derivedAttrs)
body << " if (name == \"" << namedAttr.name << "\") return true;\n";
body << " return false;";
}
// Generate method to materialize derived attributes as a DictionaryAttr.
{
auto *method = opClass.addMethod("::mlir::DictionaryAttr",
"materializeDerivedAttributes");
ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
auto &body = method->body();
auto nonMaterializable =
make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
return namedAttr.attr.getConvertFromStorageCall().empty();
});
PrintWarning(
op.getLoc(),
formatv(
"op has non-materializable derived attributes '{0}', skipping",
os.str()));
body << formatv(" emitOpError(\"op has non-materializable derived "
"attributes '{0}'\");\n",
attrs);
body << " return nullptr;";
return;
}
body << " ::mlir::MLIRContext* ctx = getContext();\n";
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
body << " return ::mlir::DictionaryAttr::get(";
body << " ctx, {\n";
interleave(
derivedAttrs, body,
[&](const NamedAttribute &namedAttr) {
auto tmpl = namedAttr.attr.getConvertFromStorageCall();
std::string name = op.getGetterName(namedAttr.name);
body << " {" << name << "AttrName(),\n"
<< tgfmt(tmpl, &fctx.withSelf(name + "()")
.withBuilder("odsBuilder")
.addSubst("_ctx", "ctx"))
<< "}";
},
",\n");
body << "});";
if (!nonMaterializable.empty()) {
std::string attrs;
llvm::raw_string_ostream os(attrs);
interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
os << op.getGetterName(attr.name);
});
PrintWarning(
op.getLoc(),
formatv(
"op has non-materializable derived attributes '{0}', skipping",
os.str()));
body << formatv(" emitOpError(\"op has non-materializable derived "
"attributes '{0}'\");\n",
attrs);
body << " return nullptr;";
return;
}
body << " ::mlir::MLIRContext* ctx = getContext();\n";
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
body << " return ::mlir::DictionaryAttr::get(";
body << " ctx, {\n";
interleave(
derivedAttrs, body,
[&](const NamedAttribute &namedAttr) {
auto tmpl = namedAttr.attr.getConvertFromStorageCall();
std::string name = op.getGetterName(namedAttr.name);
body << " {" << name << "AttrName(),\n"
<< tgfmt(tmpl, &fctx.withSelf(name + "()")
.withBuilder("odsBuilder")
.addSubst("_ctx", "ctx"))
<< "}";
},
",\n");
body << "});";
}
}
@ -844,19 +837,21 @@ void OpEmitter::genAttrSetters() {
// for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
Attribute attr) {
auto *method = opClass.addMethodAndPrune(
"void", (setterName + "Attr").str(), attr.getStorageType(), "attr");
auto *method =
opClass.addMethod("void", (setterName + "Attr").str(),
MethodParameter(attr.getStorageType(), "attr"));
if (method)
method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);",
getterName);
};
for (const NamedAttribute &namedAttr : op.getAttributes()) {
if (!namedAttr.attr.isDerivedAttr())
for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
op.getGetterNames(namedAttr.name)))
emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
namedAttr.attr);
if (namedAttr.attr.isDerivedAttr())
continue;
for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
op.getGetterNames(namedAttr.name)))
emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
namedAttr.attr);
}
}
@ -866,7 +861,7 @@ void OpEmitter::genOptionalAttrRemovers() {
auto emitRemoveAttr = [&](StringRef name) {
auto upperInitial = name.take_front().upper();
auto suffix = name.drop_front();
auto *method = opClass.addMethodAndPrune(
auto *method = opClass.addMethod(
"::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
if (!method)
return;
@ -887,8 +882,8 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
int numVariadic, int numNonVariadic,
StringRef rangeSizeCall, bool hasAttrSegmentSize,
StringRef sizeAttrInit, RangeT &&odsValues) {
auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
methodName, "unsigned", "index");
auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
MethodParameter("unsigned", "index"));
if (!method)
return;
auto &body = method->body();
@ -900,7 +895,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
// Because the op can have arbitrarily interleaved variadic and non-variadic
// operands, we need to embed a list in the "sink" getter method for
// calculation at run-time.
llvm::SmallVector<StringRef, 4> isVariadic;
SmallVector<StringRef, 4> isVariadic;
isVariadic.reserve(llvm::size(odsValues));
for (auto &it : odsValues)
isVariadic.push_back(it.isVariableLength() ? "true" : "false");
@ -959,8 +954,8 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
rangeSizeCall, attrSizedOperands, sizeAttrInit,
const_cast<Operator &>(op).getOperands());
auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
"index");
auto *m = opClass.addMethod(rangeType, "getODSOperands",
MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(m, "getODSOperands", op);
auto &body = m->body();
body << formatv(valueRangeReturnCode, rangeBeginCall,
@ -974,7 +969,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
continue;
for (StringRef name : op.getGetterNames(operand.name)) {
if (operand.isOptional()) {
m = opClass.addMethodAndPrune("::mlir::Value", name);
m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : "
@ -983,24 +978,24 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
std::string segmentAttr = op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
if (isAdaptor) {
m = opClass.addMethodAndPrune(
"::llvm::SmallVector<::mlir::ValueRange>", name);
m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
name);
ERROR_IF_PRUNED(m, name, op);
m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
segmentAttr, i);
continue;
}
m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", name);
m = opClass.addMethod("::mlir::OperandRangeRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ").split("
<< segmentAttr << "Attr());";
} else if (operand.isVariadic()) {
m = opClass.addMethodAndPrune(rangeType, name);
m = opClass.addMethod(rangeType, name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ");";
} else {
m = opClass.addMethodAndPrune("::mlir::Value", name);
m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSOperands(" << i << ").begin();";
}
@ -1035,10 +1030,10 @@ void OpEmitter::genNamedOperandSetters() {
if (operand.name.empty())
continue;
for (StringRef name : op.getGetterNames(operand.name)) {
auto *m = opClass.addMethodAndPrune(
operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange"
: "::mlir::MutableOperandRange",
(name + "Mutable").str());
auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
? "::mlir::MutableOperandRangeRange"
: "::mlir::MutableOperandRange",
(name + "Mutable").str());
ERROR_IF_PRUNED(m, name, op);
auto &body = m->body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
@ -1110,8 +1105,9 @@ void OpEmitter::genNamedResultGetters() {
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
attrSizeInitCode, op.getResults());
auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
"getODSResults", "unsigned", "index");
auto *m =
opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
MethodParameter("unsigned", "index"));
ERROR_IF_PRUNED(m, "getODSResults", op);
m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
"getODSResultIndexAndLength(index)");
@ -1122,17 +1118,17 @@ void OpEmitter::genNamedResultGetters() {
continue;
for (StringRef name : op.getGetterNames(result.name)) {
if (result.isOptional()) {
m = opClass.addMethodAndPrune("::mlir::Value", name);
m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body()
<< " auto results = getODSResults(" << i << ");\n"
<< " return results.empty() ? ::mlir::Value() : *results.begin();";
} else if (result.isVariadic()) {
m = opClass.addMethodAndPrune("::mlir::Operation::result_range", name);
m = opClass.addMethod("::mlir::Operation::result_range", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSResults(" << i << ");";
} else {
m = opClass.addMethodAndPrune("::mlir::Value", name);
m = opClass.addMethod("::mlir::Value", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSResults(" << i << ").begin();";
}
@ -1150,15 +1146,15 @@ void OpEmitter::genNamedRegionGetters() {
for (StringRef name : op.getGetterNames(region.name)) {
// Generate the accessors for a variadic region.
if (region.isVariadic()) {
auto *m = opClass.addMethodAndPrune(
"::mlir::MutableArrayRef<::mlir::Region>", name);
auto *m =
opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
i);
continue;
}
auto *m = opClass.addMethodAndPrune("::mlir::Region &", name);
auto *m = opClass.addMethod("::mlir::Region &", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getRegion({0});", i);
}
@ -1175,7 +1171,7 @@ void OpEmitter::genNamedSuccessorGetters() {
for (StringRef name : op.getGetterNames(successor.name)) {
// Generate the accessors for a variadic successor list.
if (successor.isVariadic()) {
auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name);
auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(
" return {std::next((*this)->successor_begin(), {0}), "
@ -1184,7 +1180,7 @@ void OpEmitter::genNamedSuccessorGetters() {
continue;
}
auto *m = opClass.addMethodAndPrune("::mlir::Block *", name);
auto *m = opClass.addMethod("::mlir::Block *", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getSuccessor({0});", i);
}
@ -1227,14 +1223,13 @@ void OpEmitter::genSeparateArgParamBuilder() {
// inferring result type.
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
bool inferType) {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
SmallVector<MethodParameter> paramList;
SmallVector<std::string, 4> resultNames;
llvm::StringSet<> inferredAttributes;
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
attrType);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method.
if (!m)
return;
@ -1308,7 +1303,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
int numResults = op.getNumResults();
// Signature
llvm::SmallVector<OpMethodParameter, 4> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
@ -1319,8 +1314,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@ -1348,14 +1342,13 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
SmallVector<OpMethodParameter, 4> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@ -1407,14 +1400,13 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
}
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
SmallVector<MethodParameter> paramList;
SmallVector<std::string, 4> resultNames;
llvm::StringSet<> inferredAttributes;
buildParamList(paramList, inferredAttributes, resultNames,
TypeParamKind::None);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@ -1436,14 +1428,13 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
}
void OpEmitter::genUseAttrAsResultTypeBuilder() {
SmallVector<OpMethodParameter, 4> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@ -1480,16 +1471,15 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
/// Returns a signature of the builder. Updates the context `fctx` to enable
/// replacement of $_builder and $_state in the body.
static std::string getBuilderSignature(const Builder &builder) {
static SmallVector<MethodParameter>
getBuilderSignature(const Builder &builder) {
ArrayRef<Builder::Parameter> params(builder.getParameters());
// Inject builder and state arguments.
llvm::SmallVector<std::string, 8> arguments;
SmallVector<MethodParameter> arguments;
arguments.reserve(params.size() + 2);
arguments.push_back(
llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
arguments.push_back(
llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
arguments.emplace_back("::mlir::OperationState &", builderOpState);
for (unsigned i = 0, e = params.size(); i < e; ++i) {
// If no name is provided, generate one.
@ -1497,27 +1487,27 @@ static std::string getBuilderSignature(const Builder &builder) {
std::string name =
paramName ? paramName->str() : "odsArg" + std::to_string(i);
std::string defaultValue;
StringRef defaultValue;
if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
arguments.push_back(
llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
.str());
defaultValue = *defaultParamValue;
arguments.emplace_back(params[i].getCppType(), std::move(name),
defaultValue);
}
return llvm::join(arguments, ", ");
return arguments;
}
void OpEmitter::genBuilder() {
// Handle custom builders if provided.
for (const Builder &builder : op.getBuilders()) {
std::string paramStr = getBuilderSignature(builder);
SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
Optional<StringRef> body = builder.getBody();
OpMethod::Property properties =
body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
Method::Property properties =
body ? Method::MP_Static : Method::MP_StaticDeclaration;
auto *method =
opClass.addMethodAndPrune("void", "build", properties, paramStr);
opClass.addMethod("void", "build", properties, std::move(arguments));
if (body)
ERROR_IF_PRUNED(method, "build", op);
@ -1561,7 +1551,7 @@ void OpEmitter::genCollectiveParamBuilder() {
int numVariadicOperands = op.getNumVariableLengthOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
SmallVector<OpMethodParameter, 4> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
@ -1573,8 +1563,7 @@ void OpEmitter::genCollectiveParamBuilder() {
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)
return;
@ -1612,7 +1601,7 @@ void OpEmitter::genCollectiveParamBuilder() {
genInferredTypeCollectiveParamBuilder();
}
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
@ -1637,11 +1626,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
StringRef type =
result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (result.isOptional())
properties = OpMethodParameter::PP_Optional;
paramList.emplace_back(type, resultName, properties);
paramList.emplace_back(type, resultName, result.isOptional());
resultTypeNames.emplace_back(std::move(resultName));
}
} break;
@ -1699,11 +1685,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
else
type = "::mlir::Value";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand->isOptional())
properties = OpMethodParameter::PP_Optional;
paramList.emplace_back(type, getArgumentName(op, numOperands++),
properties);
operand->isOptional());
continue;
}
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
@ -1713,10 +1696,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
if (inferredAttributes.contains(namedAttr.name))
continue;
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (attr.isOptional())
properties = OpMethodParameter::PP_Optional;
StringRef type;
switch (attrParamKind) {
case AttrParamKind::WrappedAttr:
@ -1736,7 +1715,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
i >= defaultValuedAttrStartIndex) {
defaultValue += attr.getDefaultValue();
}
paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
paramList.emplace_back(type, namedAttr.name, defaultValue,
attr.isOptional());
}
/// Insert parameters for each successor.
@ -1754,7 +1734,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
OpMethodBody &body, llvm::StringSet<> &inferredAttributes,
MethodBody &body, llvm::StringSet<> &inferredAttributes,
bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
@ -1871,12 +1851,11 @@ void OpEmitter::genCanonicalizerDecls() {
if (hasCanonicalizeMethod) {
// static LogicResult FooOp::
// canonicalize(FooOp op, PatternRewriter &rewriter);
SmallVector<OpMethodParameter, 2> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back(op.getCppClassName(), "op");
paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
OpMethod::MP_StaticDeclaration,
std::move(paramList));
auto *m = opClass.declareStaticMethod("::mlir::LogicalResult",
"canonicalize", std::move(paramList));
ERROR_IF_PRUNED(m, "canonicalize", op);
}
@ -1892,12 +1871,12 @@ void OpEmitter::genCanonicalizerDecls() {
// Add a signature for getCanonicalizationPatterns if implemented by the
// dialect or if synthesized to call 'canonicalize'.
SmallVector<OpMethodParameter, 2> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::RewritePatternSet &", "results");
paramList.emplace_back("::mlir::MLIRContext *", "context");
auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
auto *method = opClass.addMethodAndPrune(
"void", "getCanonicalizationPatterns", kind, std::move(paramList));
auto kind = hasBody ? Method::MP_Static : Method::MP_StaticDeclaration;
auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
std::move(paramList));
// If synthesizing the method, fill it it.
if (hasBody) {
@ -1912,18 +1891,17 @@ void OpEmitter::genFolderDecls() {
if (def.getValueAsBit("hasFolder")) {
if (hasSingleResult) {
auto *m = opClass.addMethodAndPrune(
"::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
"::llvm::ArrayRef<::mlir::Attribute>", "operands");
auto *m = opClass.declareMethod(
"::mlir::OpFoldResult", "fold",
MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
ERROR_IF_PRUNED(m, "operands", op);
} else {
SmallVector<OpMethodParameter, 2> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
"results");
auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
OpMethod::MP_Declaration,
std::move(paramList));
auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
std::move(paramList));
ERROR_IF_PRUNED(m, "fold", op);
}
}
@ -1953,18 +1931,18 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
}
}
OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
bool declaration) {
SmallVector<OpMethodParameter, 4> paramList;
Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
bool declaration) {
SmallVector<MethodParameter> paramList;
for (const InterfaceMethod::Argument &arg : method.getArguments())
paramList.emplace_back(arg.type, arg.name);
auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
auto properties = method.isStatic() ? Method::MP_Static : Method::MP_None;
if (declaration)
properties =
static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
properties, std::move(paramList));
static_cast<Method::Property>(properties | Method::MP_Declaration);
return opClass.addMethod(method.getReturnType(), method.getName(), properties,
std::move(paramList));
}
void OpEmitter::genOpInterfaceMethods() {
@ -2039,8 +2017,8 @@ void OpEmitter::genSideEffectInterfaceMethods() {
"SideEffects::EffectInstance<{0}>> &",
it.first())
.str();
auto *getEffects =
opClass.addMethodAndPrune("void", "getEffects", type, "effects");
auto *getEffects = opClass.addMethod("void", "getEffects",
MethodParameter(type, "effects"));
ERROR_IF_PRUNED(getEffects, "getEffects", op);
auto &body = getEffects->body();
@ -2082,7 +2060,7 @@ void OpEmitter::genTypeInterfaceMethods() {
const auto *trait = dyn_cast<InterfaceTrait>(
op.getTrait("::mlir::InferTypeOpInterface::Trait"));
Interface interface = trait->getInterface();
OpMethod *method = [&]() -> OpMethod * {
Method *method = [&]() -> Method * {
for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
if (interfaceMethod.getName() == "inferReturnTypes") {
return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
@ -2099,8 +2077,7 @@ void OpEmitter::genTypeInterfaceMethods() {
fctx.withBuilder("odsBuilder");
body << " ::mlir::Builder odsBuilder(context);\n";
auto emitType =
[&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
if (!type.isArg())
return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
auto argIndex = type.getArg();
@ -2129,12 +2106,11 @@ void OpEmitter::genParser() {
hasStringAttribute(def, "assemblyFormat"))
return;
SmallVector<OpMethodParameter, 2> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
paramList.emplace_back("::mlir::OperationState &", "result");
auto *method =
opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
OpMethod::MP_Static, std::move(paramList));
auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
std::move(paramList));
ERROR_IF_PRUNED(method, "parse", op);
FmtContext fctx;
@ -2152,8 +2128,8 @@ void OpEmitter::genPrinter() {
if (!stringInit)
return;
auto *method =
opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
auto *method = opClass.addMethod(
"void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
ERROR_IF_PRUNED(method, "print", op);
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
@ -2162,7 +2138,7 @@ void OpEmitter::genPrinter() {
}
/// Generate verification on native traits requiring attributes.
static void genNativeTraitAttrVerifier(OpMethodBody &body,
static void genNativeTraitAttrVerifier(MethodBody &body,
const OpOrAdaptorHelper &emitHelper) {
// Check that the variadic segment sizes attribute exists and contains the
// expected number of elements.
@ -2209,7 +2185,7 @@ static void genNativeTraitAttrVerifier(OpMethodBody &body,
}
void OpEmitter::genVerifier() {
auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();
@ -2247,7 +2223,7 @@ void OpEmitter::genVerifier() {
}
}
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
void OpEmitter::genOperandResultVerifier(MethodBody &body,
Operator::value_range values,
StringRef valueKind) {
// Check that an optional value is at most 1 element.
@ -2321,7 +2297,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
body << " }\n";
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
void OpEmitter::genRegionVerifier(MethodBody &body) {
/// Code to verify a region.
///
/// {0}: Getter for the regions.
@ -2363,7 +2339,7 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
body << " }\n";
}
void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
void OpEmitter::genSuccessorVerifier(MethodBody &body) {
const char *const verifySuccessor = R"(
for (auto *successor : {0})
if (::mlir::failed({1}(*this, successor, "{2}", index++)))
@ -2485,9 +2461,8 @@ void OpEmitter::genTraits() {
}
void OpEmitter::genOpNameGetter() {
auto *method = opClass.addMethodAndPrune(
"::llvm::StringLiteral", "getOperationName",
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
auto *method = opClass.addStaticMethod<Method::MP_Constexpr>(
"::llvm::StringLiteral", "getOperationName");
ERROR_IF_PRUNED(method, "getOperationName", op);
method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
<< "\");";
@ -2514,8 +2489,9 @@ void OpEmitter::genOpAsmInterface() {
opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
// Generate the right accessor for the number of results.
auto *method = opClass.addMethodAndPrune(
"void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
auto *method = opClass.addMethod(
"void", "getAsmResultNames",
MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
ERROR_IF_PRUNED(method, "getAsmResultNames", op);
auto &body = method->body();
for (int i = 0; i != numResults; ++i) {
@ -2567,7 +2543,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
const auto *attrSizedOperands =
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
{
SmallVector<OpMethodParameter, 2> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::ValueRange", "values");
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
attrSizedOperands ? "" : "nullptr");
@ -2581,14 +2557,14 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
{
auto *constructor = adaptor.addConstructorAndPrune(
llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
MethodParameter(op.getCppClassName() + " &", "op"));
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
constructor->addMemberInitializer("odsRegions", "op->getRegions()");
}
{
auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
ERROR_IF_PRUNED(m, "getOperands", op);
m->body() << " return odsOperands;";
}
@ -2605,7 +2581,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
auto emitAttr = [&](StringRef name, StringRef emitName, Attribute attr) {
auto *method = adaptor.addMethodAndPrune(attr.getStorageType(), emitName);
auto *method = adaptor.addMethod(attr.getStorageType(), emitName);
ERROR_IF_PRUNED(method, "Adaptor::" + emitName, op);
auto &body = method->body();
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
@ -2629,8 +2605,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
};
{
auto *m =
adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
m->body() << " return odsAttrs;";
}
@ -2645,7 +2620,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
unsigned numRegions = op.getNumRegions();
if (numRegions > 0) {
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
m->body() << " return odsRegions;";
}
@ -2657,13 +2632,13 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
// Generate the accessors for a variadic region.
for (StringRef name : op.getGetterNames(region.name)) {
if (region.isVariadic()) {
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", name);
auto *m = adaptor.addMethod("::mlir::RegionRange", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return odsRegions.drop_front({0});", i);
continue;
}
auto *m = adaptor.addMethodAndPrune("::mlir::Region &", name);
auto *m = adaptor.addMethod("::mlir::Region &", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return *odsRegions[{0}];", i);
}
@ -2674,8 +2649,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
}
void OpOperandAdaptorEmitter::addVerification() {
auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
"::mlir::Location", "loc");
auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify",
MethodParameter("::mlir::Location", "loc"));
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();

View File

@ -9,10 +9,10 @@
#include "OpFormatGen.h"
#include "FormatGen.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Trait.h"
#include "llvm/ADT/MapVector.h"
@ -140,8 +140,7 @@ using SuccessorVariable =
namespace {
/// This class implements single kind directives.
template <Element::Kind type>
class DirectiveElement : public Element {
template <Element::Kind type> class DirectiveElement : public Element {
public:
DirectiveElement() : Element(type){};
static bool classof(const Element *ele) { return ele->getKind() == type; }
@ -422,23 +421,23 @@ struct OperationFormat {
/// Generate the operation parser from this format.
void genParser(Operator &op, OpClass &opClass);
/// Generate the parser code for a specific format element.
void genElementParser(Element *element, OpMethodBody &body,
void genElementParser(Element *element, MethodBody &body,
FmtContext &attrTypeCtx);
/// Generate the c++ to resolve the types of operands and results during
/// parsing.
void genParserTypeResolution(Operator &op, OpMethodBody &body);
void genParserTypeResolution(Operator &op, MethodBody &body);
/// Generate the c++ to resolve regions during parsing.
void genParserRegionResolution(Operator &op, OpMethodBody &body);
void genParserRegionResolution(Operator &op, MethodBody &body);
/// Generate the c++ to resolve successors during parsing.
void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
void genParserSuccessorResolution(Operator &op, MethodBody &body);
/// Generate the c++ to handling variadic segment size traits.
void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body);
void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
/// Generate the operation printer from this format.
void genPrinter(Operator &op, OpClass &opClass);
/// Generate the printer code for a specific format element.
void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
void genElementPrinter(Element *element, MethodBody &body, Operator &op,
bool &shouldEmitSpace, bool &lastWasPunctuation);
/// The various elements in this format.
@ -813,7 +812,7 @@ static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
}
/// Generate the parser for a literal value.
static void genLiteralParser(StringRef value, OpMethodBody &body) {
static void genLiteralParser(StringRef value, MethodBody &body) {
// Handle the case of a keyword/identifier.
if (value.front() == '_' || isalpha(value.front())) {
body << "Keyword(\"" << value << "\")";
@ -839,7 +838,7 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, const Operator &op,
OpMethodBody &body) {
MethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getThenElements();
@ -937,7 +936,7 @@ static void genElementParserStorage(Element *element, const Operator &op,
}
/// Generate the parser for a parameter to a custom directive.
static void genCustomParameterParser(Element &param, OpMethodBody &body) {
static void genCustomParameterParser(Element &param, MethodBody &body) {
if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
body << attr->getVar()->name << "Attr";
} else if (isa<AttrDictDirective>(&param)) {
@ -988,7 +987,7 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
}
/// Generate the parser for a custom directive.
static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
body << " {\n";
// Preprocess the directive variables.
@ -1098,7 +1097,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
}
/// Generate the parser for a enum attribute.
static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
FmtContext &attrTypeCtx) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
@ -1141,13 +1140,12 @@ static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
}
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
llvm::SmallVector<OpMethodParameter, 4> paramList;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
paramList.emplace_back("::mlir::OperationState &", "result");
auto *method =
opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
OpMethod::MP_Static, std::move(paramList));
auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
std::move(paramList));
auto &body = method->body();
// Generate variables to store the operands and type within the format. This
@ -1174,7 +1172,7 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
body << " return ::mlir::success();\n";
}
void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
void OperationFormat::genElementParser(Element *element, MethodBody &body,
FmtContext &attrTypeCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
@ -1353,8 +1351,7 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
}
}
void OperationFormat::genParserTypeResolution(Operator &op,
OpMethodBody &body) {
void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
// If any of type resolutions use transformed variables, make sure that the
// types of those variables are resolved.
SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
@ -1528,7 +1525,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
}
void OperationFormat::genParserRegionResolution(Operator &op,
OpMethodBody &body) {
MethodBody &body) {
// Check for the case where all regions were parsed.
bool hasAllRegions = llvm::any_of(
elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
@ -1547,7 +1544,7 @@ void OperationFormat::genParserRegionResolution(Operator &op,
}
void OperationFormat::genParserSuccessorResolution(Operator &op,
OpMethodBody &body) {
MethodBody &body) {
// Check for the case where all successors were parsed.
bool hasAllSuccessors = llvm::any_of(
elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
@ -1566,7 +1563,7 @@ void OperationFormat::genParserSuccessorResolution(Operator &op,
}
void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
OpMethodBody &body) {
MethodBody &body) {
if (!allOperands) {
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
body << " result.addAttribute(\"operand_segment_sizes\", "
@ -1641,7 +1638,7 @@ const char *enumAttrBeginPrinterCode = R"(
/// Generate the printer for the 'attr-dict' directive.
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
OpMethodBody &body, bool withKeyword) {
MethodBody &body, bool withKeyword) {
body << " _odsPrinter.printOptionalAttrDict"
<< (withKeyword ? "WithKeyword" : "")
<< "((*this)->getAttrs(), /*elidedAttrs=*/{";
@ -1665,7 +1662,7 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
/// space should be emitted before this element. `lastWasPunctuation` is true if
/// the previous element was a punctuation literal.
static void genLiteralPrinter(StringRef value, OpMethodBody &body,
static void genLiteralPrinter(StringRef value, MethodBody &body,
bool &shouldEmitSpace, bool &lastWasPunctuation) {
body << " _odsPrinter";
@ -1682,8 +1679,8 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
/// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
/// are set to false.
static void genSpacePrinter(bool value, OpMethodBody &body,
bool &shouldEmitSpace, bool &lastWasPunctuation) {
static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
bool &lastWasPunctuation) {
if (value) {
body << " _odsPrinter << ' ';\n";
lastWasPunctuation = false;
@ -1696,7 +1693,7 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
/// Generate the printer for a custom directive parameter.
static void genCustomDirectiveParameterPrinter(Element *element,
const Operator &op,
OpMethodBody &body) {
MethodBody &body) {
if (auto *attr = dyn_cast<AttributeVariable>(element)) {
body << op.getGetterName(attr->getVar()->name) << "Attr()";
@ -1734,7 +1731,7 @@ static void genCustomDirectiveParameterPrinter(Element *element,
/// Generate the printer for a custom directive.
static void genCustomDirectivePrinter(CustomDirective *customDir,
const Operator &op, OpMethodBody &body) {
const Operator &op, MethodBody &body) {
body << " print" << customDir->getName() << "(_odsPrinter, *this";
for (Element &param : customDir->getArguments()) {
body << ", ";
@ -1744,7 +1741,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
}
/// Generate the printer for a region with the given variable name.
static void genRegionPrinter(const Twine &regionName, OpMethodBody &body,
static void genRegionPrinter(const Twine &regionName, MethodBody &body,
bool hasImplicitTermTrait) {
if (hasImplicitTermTrait)
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
@ -1753,7 +1750,7 @@ static void genRegionPrinter(const Twine &regionName, OpMethodBody &body,
body << " _odsPrinter.printRegion(" << regionName << ");\n";
}
static void genVariadicRegionPrinter(const Twine &regionListName,
OpMethodBody &body,
MethodBody &body,
bool hasImplicitTermTrait) {
body << " llvm::interleaveComma(" << regionListName
<< ", _odsPrinter, [&](::mlir::Region &region) {\n ";
@ -1762,8 +1759,8 @@ static void genVariadicRegionPrinter(const Twine &regionListName,
}
/// Generate the C++ for an operand to a (*-)type directive.
static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
OpMethodBody &body) {
static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
MethodBody &body) {
if (isa<OperandsDirective>(arg))
return body << "getOperation()->getOperandTypes()";
if (isa<ResultsDirective>(arg))
@ -1786,7 +1783,7 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
/// Generate the printer for an enum attribute.
static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
OpMethodBody &body) {
MethodBody &body) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
@ -1864,7 +1861,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
/// Generate the check for the anchor of an optional group.
static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
OpMethodBody &body) {
MethodBody &body) {
TypeSwitch<Element *>(anchor)
.Case<OperandVariable, ResultVariable>([&](auto *element) {
const NamedTypeConstraint *var = element->getVar();
@ -1892,7 +1889,7 @@ static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
});
}
void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
Operator &op, bool &shouldEmitSpace,
bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
@ -2047,8 +2044,9 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
}
void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
auto *method = opClass.addMethodAndPrune("void", "print",
"::mlir::OpAsmPrinter &_odsPrinter");
auto *method = opClass.addMethod(
"void", "print",
MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
auto &body = method->body();
// Flags for if we should emit a space, and if the last element was
@ -2065,8 +2063,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
/// Function to find an element within the given range that has the same name as
/// 'name'.
template <typename RangeT>
static auto findArg(RangeT &&range, StringRef name) {
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}