forked from OSchip/llvm-project
[mlir][ods] Cleanup of Class Codegen helper
Depends on D113331 Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D113714
This commit is contained in:
parent
ece17064b5
commit
2696a9529e
|
@ -0,0 +1,412 @@
|
|||
//===- Class.h - Helper classes for C++ code emission -----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines several classes for Op C++ code emission. They are only
|
||||
// expected to be used by MLIR TableGen backends.
|
||||
//
|
||||
// We emit the op declaration and definition into separate files: *Ops.h.inc
|
||||
// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
|
||||
// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
|
||||
//
|
||||
// In order to do this split, we need to track method signature and
|
||||
// implementation logic separately. Signature information is used for both
|
||||
// declaration and definition, while implementation logic is only for
|
||||
// definition. So we have the following classes for C++ code emission.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_CLASS_H_
|
||||
#define MLIR_TABLEGEN_CLASS_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/CodeGenHelpers.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
class FmtObjectBase;
|
||||
|
||||
/// This class contains a single method parameter for a C++ function.
|
||||
class MethodParameter {
|
||||
public:
|
||||
/// Create a method parameter with a C++ type, parameter name, and an optional
|
||||
/// default value. Marking a parameter as "optional" is a cosmetic effect on
|
||||
/// the generated code.
|
||||
template <typename TypeT, typename NameT, typename DefaultT>
|
||||
MethodParameter(TypeT &&type, NameT &&name, DefaultT &&defaultValue,
|
||||
bool optional = false)
|
||||
: type(stringify(std::forward<TypeT>(type))),
|
||||
name(stringify(std::forward<NameT>(name))),
|
||||
defaultValue(stringify(std::forward<DefaultT>(defaultValue))),
|
||||
optional(optional) {}
|
||||
|
||||
/// Create a method parameter with a C++ type, parameter name, and no default
|
||||
/// value.
|
||||
template <typename TypeT, typename NameT>
|
||||
MethodParameter(TypeT &&type, NameT &&name, bool optional = false)
|
||||
: MethodParameter(std::forward<TypeT>(type), std::forward<NameT>(name),
|
||||
/*defaultValue=*/"", optional) {}
|
||||
|
||||
/// Write the parameter as part of a method declaration.
|
||||
void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
|
||||
/// Write the parameter as part of a method definition.
|
||||
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
|
||||
|
||||
/// Get the C++ type.
|
||||
const std::string &getType() const { return type; }
|
||||
/// Returns true if the parameter has a default value.
|
||||
bool hasDefaultValue() const { return !defaultValue.empty(); }
|
||||
|
||||
private:
|
||||
void writeTo(raw_ostream &os, bool emitDefault) const;
|
||||
|
||||
/// The C++ type.
|
||||
std::string type;
|
||||
/// The variable name.
|
||||
std::string name;
|
||||
/// An optional default value. The default value exists if the string is not
|
||||
/// empty.
|
||||
std::string defaultValue;
|
||||
/// Whether the parameter should be indicated as "optional".
|
||||
bool optional;
|
||||
};
|
||||
|
||||
/// This class contains a list of method parameters for constructor, class
|
||||
/// methods, and method signatures.
|
||||
class MethodParameters {
|
||||
public:
|
||||
/// Create a list of method parameters.
|
||||
MethodParameters(std::initializer_list<MethodParameter> parameters)
|
||||
: parameters(parameters) {}
|
||||
MethodParameters(SmallVector<MethodParameter> parameters)
|
||||
: parameters(std::move(parameters)) {}
|
||||
|
||||
/// Write the parameters as part of a method declaration.
|
||||
void writeDeclTo(raw_ostream &os) const;
|
||||
/// Write the parameters as part of a method definition.
|
||||
void writeDefTo(raw_ostream &os) const;
|
||||
|
||||
/// Determine whether this list of parameters "subsumes" another, which occurs
|
||||
/// when this parameter list is identical to the other and has zero or more
|
||||
/// additional default-valued parameters.
|
||||
bool subsumes(const MethodParameters &other) const;
|
||||
|
||||
/// Return the number of parameters.
|
||||
unsigned getNumParameters() const { return parameters.size(); }
|
||||
|
||||
private:
|
||||
llvm::SmallVector<MethodParameter> parameters;
|
||||
};
|
||||
|
||||
/// This class contains the signature of a C++ method, including the return
|
||||
/// type. method name, and method parameters.
|
||||
class MethodSignature {
|
||||
public:
|
||||
MethodSignature(StringRef retType, StringRef name,
|
||||
SmallVector<MethodParameter> &¶meters)
|
||||
: returnType(retType), methodName(name),
|
||||
parameters(std::move(parameters)) {}
|
||||
template <typename... Parameters>
|
||||
MethodSignature(StringRef retType, StringRef name, Parameters &&...parameters)
|
||||
: returnType(retType), methodName(name),
|
||||
parameters({std::forward<Parameters>(parameters)...}) {}
|
||||
|
||||
/// Determine whether a method with this signature makes a method with
|
||||
/// `other` signature redundant. This occurs if the signatures have the same
|
||||
/// name and this signature's parameteres subsume the other's.
|
||||
///
|
||||
/// A method that makes another method redundant with a different return type
|
||||
/// can replace the other, the assumption being that the subsuming method
|
||||
/// provides a more resolved return type, e.g. IntegerAttr vs. Attribute.
|
||||
bool makesRedundant(const MethodSignature &other) const;
|
||||
|
||||
/// Get the name of the method.
|
||||
StringRef getName() const { return methodName; }
|
||||
|
||||
/// Get the number of parameters.
|
||||
unsigned getNumParameters() const { return parameters.getNumParameters(); }
|
||||
|
||||
/// Write the signature as part of a method declaration.
|
||||
void writeDeclTo(raw_ostream &os) const;
|
||||
|
||||
/// Write the signature as part of a method definition. `namePrefix` is to be
|
||||
/// prepended to the method name (typically namespaces for qualifying the
|
||||
/// method definition).
|
||||
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
|
||||
|
||||
private:
|
||||
/// The method's C++ return type.
|
||||
std::string returnType;
|
||||
/// The method name.
|
||||
std::string methodName;
|
||||
/// The method's parameter list.
|
||||
MethodParameters parameters;
|
||||
};
|
||||
|
||||
/// Class for holding the body of an op's method for C++ code emission
|
||||
class MethodBody {
|
||||
public:
|
||||
explicit MethodBody(bool declOnly);
|
||||
|
||||
MethodBody &operator<<(Twine content);
|
||||
MethodBody &operator<<(int content);
|
||||
MethodBody &operator<<(const FmtObjectBase &content);
|
||||
|
||||
void writeTo(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
/// Whether this class should record method body.
|
||||
bool isEffective;
|
||||
/// The body of the method.
|
||||
std::string body;
|
||||
};
|
||||
|
||||
/// Class for holding an op's method for C++ code emission
|
||||
class Method {
|
||||
public:
|
||||
/// Properties (qualifiers) of class methods. Bitfield is used here to help
|
||||
/// querying properties.
|
||||
enum Property {
|
||||
MP_None = 0x0,
|
||||
MP_Static = 0x1,
|
||||
MP_Constructor = 0x2,
|
||||
MP_Private = 0x4,
|
||||
MP_Declaration = 0x8,
|
||||
MP_Inline = 0x10,
|
||||
MP_Constexpr = 0x20 | MP_Inline,
|
||||
MP_StaticDeclaration = MP_Static | MP_Declaration,
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
Method(StringRef retType, StringRef name, Property property, Args &&...args)
|
||||
: properties(property),
|
||||
methodSignature(retType, name, std::forward<Args>(args)...),
|
||||
methodBody(properties & MP_Declaration) {}
|
||||
|
||||
Method(Method &&) = default;
|
||||
Method &operator=(Method &&) = default;
|
||||
|
||||
virtual ~Method() = default;
|
||||
|
||||
MethodBody &body() { return methodBody; }
|
||||
|
||||
/// Returns true if this is a static method.
|
||||
bool isStatic() const { return properties & MP_Static; }
|
||||
|
||||
/// Returns true if this is a private method.
|
||||
bool isPrivate() const { return properties & MP_Private; }
|
||||
|
||||
/// Returns true if this is an inline method.
|
||||
bool isInline() const { return properties & MP_Inline; }
|
||||
|
||||
/// Returns the name of this method.
|
||||
StringRef getName() const { return methodSignature.getName(); }
|
||||
|
||||
/// Returns if this method makes the `other` method redundant.
|
||||
bool makesRedundant(const Method &other) const {
|
||||
return methodSignature.makesRedundant(other.methodSignature);
|
||||
}
|
||||
|
||||
/// Writes the method as a declaration to the given `os`.
|
||||
virtual void writeDeclTo(raw_ostream &os) const;
|
||||
|
||||
/// Writes the method as a definition to the given `os`. `namePrefix` is the
|
||||
/// prefix to be prepended to the method name (typically namespaces for
|
||||
/// qualifying the method definition).
|
||||
virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
|
||||
|
||||
protected:
|
||||
/// A collection of method properties.
|
||||
Property properties;
|
||||
/// The signature of the method.
|
||||
MethodSignature methodSignature;
|
||||
/// The body of the method, if it has one.
|
||||
MethodBody methodBody;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
/// The OR of two method properties should return method properties. Ensure that
|
||||
/// this function is visible to `Class`.
|
||||
inline constexpr mlir::tblgen::Method::Property
|
||||
operator|(mlir::tblgen::Method::Property lhs,
|
||||
mlir::tblgen::Method::Property rhs) {
|
||||
return mlir::tblgen::Method::Property(static_cast<unsigned>(lhs) |
|
||||
static_cast<unsigned>(rhs));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
/// Class for holding an op's constructor method for C++ code emission.
|
||||
class Constructor : public Method {
|
||||
public:
|
||||
template <typename... Parameters>
|
||||
Constructor(StringRef className, Property property,
|
||||
Parameters &&...parameters)
|
||||
: Method("", className, property,
|
||||
std::forward<Parameters>(parameters)...) {}
|
||||
|
||||
/// Add member initializer to constructor initializing `name` with `value`.
|
||||
void addMemberInitializer(StringRef name, StringRef value);
|
||||
|
||||
/// Writes the method as a definition to the given `os`. `namePrefix` is the
|
||||
/// prefix to be prepended to the method name (typically namespaces for
|
||||
/// qualifying the method definition).
|
||||
void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
|
||||
|
||||
private:
|
||||
/// Member initializers.
|
||||
std::string memberInitializers;
|
||||
};
|
||||
|
||||
/// A class used to emit C++ classes from Tablegen. Contains a list of public
|
||||
/// methods and a list of private fields to be emitted.
|
||||
class Class {
|
||||
public:
|
||||
explicit Class(StringRef name);
|
||||
|
||||
/// Add a new constructor to this class and prune and constructors made
|
||||
/// redundant by it. Returns null if the constructor was not added. Else,
|
||||
/// returns a pointer to the new constructor.
|
||||
template <typename... Parameters>
|
||||
Constructor *addConstructorAndPrune(Parameters &&...parameters) {
|
||||
return addConstructorAndPrune(
|
||||
Constructor(getClassName(), Method::MP_Constructor,
|
||||
std::forward<Parameters>(parameters)...));
|
||||
}
|
||||
|
||||
/// Add a new method to this class and prune any methods made redundant by it.
|
||||
/// Returns null if the method was not added (because an existing method would
|
||||
/// make it redundant). Else, returns a pointer to the new method.
|
||||
template <typename... Parameters>
|
||||
Method *addMethod(StringRef retType, StringRef name,
|
||||
Method::Property properties, Parameters &&...parameters) {
|
||||
return addMethodAndPrune(Method(retType, name, properties,
|
||||
std::forward<Parameters>(parameters)...));
|
||||
}
|
||||
|
||||
/// Add a method with statically-known properties.
|
||||
template <Method::Property Properties = Method::MP_None,
|
||||
typename... Parameters>
|
||||
Method *addMethod(StringRef retType, StringRef name,
|
||||
Parameters &&...parameters) {
|
||||
return addMethod(retType, name, Properties,
|
||||
std::forward<Parameters>(parameters)...);
|
||||
}
|
||||
|
||||
/// Add a static method.
|
||||
template <Method::Property Properties = Method::MP_None,
|
||||
typename... Parameters>
|
||||
Method *addStaticMethod(StringRef retType, StringRef name,
|
||||
Parameters &&...parameters) {
|
||||
return addMethod<Properties | Method::MP_Static>(
|
||||
retType, name, std::forward<Parameters>(parameters)...);
|
||||
}
|
||||
|
||||
/// Add an inline static method.
|
||||
template <Method::Property Properties = Method::MP_None,
|
||||
typename... Parameters>
|
||||
Method *addStaticInlineMethod(StringRef retType, StringRef name,
|
||||
Parameters &&...parameters) {
|
||||
return addMethod<Properties | Method::MP_Static | Method::MP_Inline>(
|
||||
retType, name, std::forward<Parameters>(parameters)...);
|
||||
}
|
||||
|
||||
/// Add an inline method.
|
||||
template <Method::Property Properties = Method::MP_None,
|
||||
typename... Parameters>
|
||||
Method *addInlineMethod(StringRef retType, StringRef name,
|
||||
Parameters &&...parameters) {
|
||||
return addMethod<Properties | Method::MP_Inline>(
|
||||
retType, name, std::forward<Parameters>(parameters)...);
|
||||
}
|
||||
|
||||
/// Add a declaration for a method.
|
||||
template <Method::Property Properties = Method::MP_None,
|
||||
typename... Parameters>
|
||||
Method *declareMethod(StringRef retType, StringRef name,
|
||||
Parameters &&...parameters) {
|
||||
return addMethod<Properties | Method::MP_Declaration>(
|
||||
retType, name, std::forward<Parameters>(parameters)...);
|
||||
}
|
||||
|
||||
/// Add a declaration for a static method.
|
||||
template <Method::Property Properties = Method::MP_None,
|
||||
typename... Parameters>
|
||||
Method *declareStaticMethod(StringRef retType, StringRef name,
|
||||
Parameters &&...parameters) {
|
||||
return addMethod<Properties | Method::MP_StaticDeclaration>(
|
||||
retType, name, std::forward<Parameters>(parameters)...);
|
||||
}
|
||||
|
||||
/// Creates a new field in this class.
|
||||
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
|
||||
|
||||
/// Writes this op's class as a declaration to the given `os`.
|
||||
void writeDeclTo(raw_ostream &os) const;
|
||||
/// Writes the method definitions in this op's class to the given `os`.
|
||||
void writeDefTo(raw_ostream &os) const;
|
||||
|
||||
/// Returns the C++ class name of the op.
|
||||
StringRef getClassName() const { return className; }
|
||||
|
||||
protected:
|
||||
/// Get a list of all the methods to emit, filtering out hidden ones.
|
||||
void forAllMethods(llvm::function_ref<void(const Method &)> func) const {
|
||||
llvm::for_each(constructors, [&](auto &ctor) { func(ctor); });
|
||||
llvm::for_each(methods, [&](auto &method) { func(method); });
|
||||
}
|
||||
|
||||
/// Add a new constructor if it is not made redundant by any existing
|
||||
/// constructors and prune and existing constructors made redundant.
|
||||
Constructor *addConstructorAndPrune(Constructor &&newCtor);
|
||||
/// Add a new method if it is not made redundant by any existing methods and
|
||||
/// prune and existing methods made redundant.
|
||||
Method *addMethodAndPrune(Method &&newMethod);
|
||||
|
||||
/// The C++ class name.
|
||||
std::string className;
|
||||
/// The list of constructors.
|
||||
std::vector<Constructor> constructors;
|
||||
/// The list of class methods.
|
||||
std::vector<Method> methods;
|
||||
/// The list of class members.
|
||||
SmallVector<std::string, 4> fields;
|
||||
};
|
||||
|
||||
// Class for holding an op for C++ code emission
|
||||
class OpClass : public Class {
|
||||
public:
|
||||
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
|
||||
|
||||
/// Adds an op trait.
|
||||
void addTrait(Twine trait);
|
||||
|
||||
/// Writes this op's class as a declaration to the given `os`. Redefines
|
||||
/// Class::writeDeclTo to also emit traits and extra class declarations.
|
||||
void writeDeclTo(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
StringRef extraClassDeclaration;
|
||||
llvm::SetVector<std::string, SmallVector<std::string>, StringSet<>> traits;
|
||||
};
|
||||
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_CLASS_H_
|
|
@ -216,9 +216,28 @@ private:
|
|||
ConstraintMap regionConstraints;
|
||||
};
|
||||
|
||||
// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
|
||||
/// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
|
||||
std::string escapeString(StringRef value);
|
||||
|
||||
namespace detail {
|
||||
template <typename> struct stringifier {
|
||||
template <typename T> static std::string apply(T &&t) {
|
||||
return std::string(std::forward<T>(t));
|
||||
}
|
||||
};
|
||||
template <> struct stringifier<Twine> {
|
||||
static std::string apply(const Twine &twine) {
|
||||
return twine.str();
|
||||
}
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
/// Generically convert a value to a std::string.
|
||||
template <typename T> std::string stringify(T &&t) {
|
||||
return detail::stringifier<std::remove_reference_t<std::remove_const_t<T>>>::
|
||||
apply(std::forward<T>(t));
|
||||
}
|
||||
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1,442 +0,0 @@
|
|||
//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines several classes for Op C++ code emission. They are only
|
||||
// expected to be used by MLIR TableGen backends.
|
||||
//
|
||||
// We emit the op declaration and definition into separate files: *Ops.h.inc
|
||||
// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
|
||||
// the latter for dialect *Ops.cpp. This way provides a cleaner interface.
|
||||
//
|
||||
// In order to do this split, we need to track method signature and
|
||||
// implementation logic separately. Signature information is used for both
|
||||
// declaration and definition, while implementation logic is only for
|
||||
// definition. So we have the following classes for C++ code emission.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_OPCLASS_H_
|
||||
#define MLIR_TABLEGEN_OPCLASS_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
class FmtObjectBase;
|
||||
|
||||
// Class for holding a single parameter of an op's method for C++ code emission.
|
||||
class OpMethodParameter {
|
||||
public:
|
||||
// Properties (qualifiers) for the parameter.
|
||||
enum Property {
|
||||
PP_None = 0x0,
|
||||
PP_Optional = 0x1,
|
||||
};
|
||||
|
||||
OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "",
|
||||
Property properties = PP_None)
|
||||
: type(type), name(name), defaultValue(defaultValue),
|
||||
properties(properties) {}
|
||||
|
||||
OpMethodParameter(StringRef type, StringRef name, Property property)
|
||||
: OpMethodParameter(type, name, "", property) {}
|
||||
|
||||
// Writes the parameter as a part of a method declaration to `os`.
|
||||
void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); }
|
||||
|
||||
// Writes the parameter as a part of a method definition to `os`
|
||||
void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); }
|
||||
|
||||
const std::string &getType() const { return type; }
|
||||
bool hasDefaultValue() const { return !defaultValue.empty(); }
|
||||
|
||||
private:
|
||||
void writeTo(raw_ostream &os, bool emitDefault) const;
|
||||
|
||||
std::string type;
|
||||
std::string name;
|
||||
std::string defaultValue;
|
||||
Property properties;
|
||||
};
|
||||
|
||||
// Base class for holding parameters of an op's method for C++ code emission.
|
||||
class OpMethodParameters {
|
||||
public:
|
||||
// Discriminator for LLVM-style RTTI.
|
||||
enum ParamsKind {
|
||||
// Separate type and name for each parameter is not known.
|
||||
PK_Unresolved,
|
||||
// Each parameter is resolved to a type and name.
|
||||
PK_Resolved,
|
||||
};
|
||||
|
||||
OpMethodParameters(ParamsKind kind) : kind(kind) {}
|
||||
virtual ~OpMethodParameters() {}
|
||||
|
||||
// LLVM-style RTTI support.
|
||||
ParamsKind getKind() const { return kind; }
|
||||
|
||||
// Writes the parameters as a part of a method declaration to `os`.
|
||||
virtual void writeDeclTo(raw_ostream &os) const = 0;
|
||||
|
||||
// Writes the parameters as a part of a method definition to `os`
|
||||
virtual void writeDefTo(raw_ostream &os) const = 0;
|
||||
|
||||
// Factory methods to create the correct type of `OpMethodParameters`
|
||||
// object based on the arguments.
|
||||
static std::unique_ptr<OpMethodParameters> create();
|
||||
|
||||
static std::unique_ptr<OpMethodParameters> create(StringRef params);
|
||||
|
||||
static std::unique_ptr<OpMethodParameters>
|
||||
create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms);
|
||||
|
||||
static std::unique_ptr<OpMethodParameters>
|
||||
create(StringRef type, StringRef name, StringRef defaultValue = "");
|
||||
|
||||
private:
|
||||
const ParamsKind kind;
|
||||
};
|
||||
|
||||
// Class for holding unresolved parameters.
|
||||
class OpMethodUnresolvedParameters : public OpMethodParameters {
|
||||
public:
|
||||
OpMethodUnresolvedParameters(StringRef params)
|
||||
: OpMethodParameters(PK_Unresolved), parameters(params) {}
|
||||
|
||||
// write the parameters as a part of a method declaration to the given `os`.
|
||||
void writeDeclTo(raw_ostream &os) const override;
|
||||
|
||||
// write the parameters as a part of a method definition to the given `os`
|
||||
void writeDefTo(raw_ostream &os) const override;
|
||||
|
||||
// LLVM-style RTTI support.
|
||||
static bool classof(const OpMethodParameters *params) {
|
||||
return params->getKind() == PK_Unresolved;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string parameters;
|
||||
};
|
||||
|
||||
// Class for holding resolved parameters.
|
||||
class OpMethodResolvedParameters : public OpMethodParameters {
|
||||
public:
|
||||
OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {}
|
||||
|
||||
OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &¶ms)
|
||||
: OpMethodParameters(PK_Resolved) {
|
||||
for (OpMethodParameter ¶m : params)
|
||||
parameters.emplace_back(std::move(param));
|
||||
}
|
||||
|
||||
OpMethodResolvedParameters(StringRef type, StringRef name,
|
||||
StringRef defaultValue)
|
||||
: OpMethodParameters(PK_Resolved) {
|
||||
parameters.emplace_back(type, name, defaultValue);
|
||||
}
|
||||
|
||||
// Returns the number of parameters.
|
||||
size_t getNumParameters() const { return parameters.size(); }
|
||||
|
||||
// Returns if this method makes the `other` method redundant. Note that this
|
||||
// is more than just finding conflicting methods. This method determines if
|
||||
// the 2 set of parameters are conflicting and if so, returns true if this
|
||||
// method has a more general set of parameters that can replace all possible
|
||||
// calls to the `other` method.
|
||||
bool makesRedundant(const OpMethodResolvedParameters &other) const;
|
||||
|
||||
// write the parameters as a part of a method declaration to the given `os`.
|
||||
void writeDeclTo(raw_ostream &os) const override;
|
||||
|
||||
// write the parameters as a part of a method definition to the given `os`
|
||||
void writeDefTo(raw_ostream &os) const override;
|
||||
|
||||
// LLVM-style RTTI support.
|
||||
static bool classof(const OpMethodParameters *params) {
|
||||
return params->getKind() == PK_Resolved;
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::SmallVector<OpMethodParameter, 4> parameters;
|
||||
};
|
||||
|
||||
// Class for holding the signature of an op's method for C++ code emission
|
||||
class OpMethodSignature {
|
||||
public:
|
||||
template <typename... Args>
|
||||
OpMethodSignature(StringRef retType, StringRef name, Args &&...args)
|
||||
: returnType(retType), methodName(name),
|
||||
parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {}
|
||||
OpMethodSignature(OpMethodSignature &&) = default;
|
||||
|
||||
// Returns if a method with this signature makes a method with `other`
|
||||
// signature redundant. Only supports resolved parameters.
|
||||
bool makesRedundant(const OpMethodSignature &other) const;
|
||||
|
||||
// Returns the number of parameters (for resolved parameters).
|
||||
size_t getNumParameters() const {
|
||||
return cast<OpMethodResolvedParameters>(parameters.get())
|
||||
->getNumParameters();
|
||||
}
|
||||
|
||||
// Returns the name of the method.
|
||||
StringRef getName() const { return methodName; }
|
||||
|
||||
// Writes the signature as a method declaration to the given `os`.
|
||||
void writeDeclTo(raw_ostream &os) const;
|
||||
|
||||
// Writes the signature as the start of a method definition to the given `os`.
|
||||
// `namePrefix` is the prefix to be prepended to the method name (typically
|
||||
// namespaces for qualifying the method definition).
|
||||
void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
|
||||
|
||||
private:
|
||||
std::string returnType;
|
||||
std::string methodName;
|
||||
std::unique_ptr<OpMethodParameters> parameters;
|
||||
};
|
||||
|
||||
// Class for holding the body of an op's method for C++ code emission
|
||||
class OpMethodBody {
|
||||
public:
|
||||
explicit OpMethodBody(bool declOnly);
|
||||
|
||||
OpMethodBody &operator<<(Twine content);
|
||||
OpMethodBody &operator<<(int content);
|
||||
OpMethodBody &operator<<(const FmtObjectBase &content);
|
||||
|
||||
void writeTo(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
// Whether this class should record method body.
|
||||
bool isEffective;
|
||||
std::string body;
|
||||
};
|
||||
|
||||
// Class for holding an op's method for C++ code emission
|
||||
class OpMethod {
|
||||
public:
|
||||
// Properties (qualifiers) of class methods. Bitfield is used here to help
|
||||
// querying properties.
|
||||
enum Property {
|
||||
MP_None = 0x0,
|
||||
MP_Static = 0x1,
|
||||
MP_Constructor = 0x2,
|
||||
MP_Private = 0x4,
|
||||
MP_Declaration = 0x8,
|
||||
MP_Inline = 0x10,
|
||||
MP_Constexpr = 0x20 | MP_Inline,
|
||||
MP_StaticDeclaration = MP_Static | MP_Declaration,
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
OpMethod(StringRef retType, StringRef name, Property property, unsigned id,
|
||||
Args &&...args)
|
||||
: properties(property),
|
||||
methodSignature(retType, name, std::forward<Args>(args)...),
|
||||
methodBody(properties & MP_Declaration), id(id) {}
|
||||
|
||||
OpMethod(OpMethod &&) = default;
|
||||
|
||||
virtual ~OpMethod() = default;
|
||||
|
||||
OpMethodBody &body() { return methodBody; }
|
||||
|
||||
// Returns true if this is a static method.
|
||||
bool isStatic() const { return properties & MP_Static; }
|
||||
|
||||
// Returns true if this is a private method.
|
||||
bool isPrivate() const { return properties & MP_Private; }
|
||||
|
||||
// Returns true if this is an inline method.
|
||||
bool isInline() const { return properties & MP_Inline; }
|
||||
|
||||
// Returns the name of this method.
|
||||
StringRef getName() const { return methodSignature.getName(); }
|
||||
|
||||
// Returns the ID for this method
|
||||
unsigned getID() const { return id; }
|
||||
|
||||
// Returns if this method makes the `other` method redundant.
|
||||
bool makesRedundant(const OpMethod &other) const {
|
||||
return methodSignature.makesRedundant(other.methodSignature);
|
||||
}
|
||||
|
||||
// Writes the method as a declaration to the given `os`.
|
||||
virtual void writeDeclTo(raw_ostream &os) const;
|
||||
|
||||
// Writes the method as a definition to the given `os`. `namePrefix` is the
|
||||
// prefix to be prepended to the method name (typically namespaces for
|
||||
// qualifying the method definition).
|
||||
virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
|
||||
|
||||
protected:
|
||||
Property properties;
|
||||
OpMethodSignature methodSignature;
|
||||
OpMethodBody methodBody;
|
||||
const unsigned id;
|
||||
};
|
||||
|
||||
// Class for holding an op's constructor method for C++ code emission.
|
||||
class OpConstructor : public OpMethod {
|
||||
public:
|
||||
template <typename... Args>
|
||||
OpConstructor(StringRef className, Property property, unsigned id,
|
||||
Args &&...args)
|
||||
: OpMethod("", className, property, id, std::forward<Args>(args)...) {}
|
||||
|
||||
// Add member initializer to constructor initializing `name` with `value`.
|
||||
void addMemberInitializer(StringRef name, StringRef value);
|
||||
|
||||
// Writes the method as a definition to the given `os`. `namePrefix` is the
|
||||
// prefix to be prepended to the method name (typically namespaces for
|
||||
// qualifying the method definition).
|
||||
void writeDefTo(raw_ostream &os, StringRef namePrefix) const override;
|
||||
|
||||
private:
|
||||
// Member initializers.
|
||||
std::string memberInitializers;
|
||||
};
|
||||
|
||||
// A class used to emit C++ classes from Tablegen. Contains a list of public
|
||||
// methods and a list of private fields to be emitted.
|
||||
class Class {
|
||||
public:
|
||||
explicit Class(StringRef name);
|
||||
|
||||
// Adds a new method to this class and prune redundant methods. Returns null
|
||||
// if the method was not added (because an existing method would make it
|
||||
// redundant), else returns a pointer to the added method. Note that this call
|
||||
// may also delete existing methods that are made redundant by a method to the
|
||||
// class.
|
||||
template <typename... Args>
|
||||
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
|
||||
OpMethod::Property properties, Args &&...args) {
|
||||
auto newMethod = std::make_unique<OpMethod>(
|
||||
retType, name, properties, nextMethodID++, std::forward<Args>(args)...);
|
||||
return addMethodAndPrune(methods, std::move(newMethod));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
OpMethod *addMethodAndPrune(StringRef retType, StringRef name,
|
||||
Args &&...args) {
|
||||
return addMethodAndPrune(retType, name, OpMethod::MP_None,
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
OpConstructor *addConstructorAndPrune(Args &&...args) {
|
||||
auto newConstructor = std::make_unique<OpConstructor>(
|
||||
getClassName(), OpMethod::MP_Constructor, nextMethodID++,
|
||||
std::forward<Args>(args)...);
|
||||
return addMethodAndPrune(constructors, std::move(newConstructor));
|
||||
}
|
||||
|
||||
// Creates a new field in this class.
|
||||
void newField(StringRef type, StringRef name, StringRef defaultValue = "");
|
||||
|
||||
// Writes this op's class as a declaration to the given `os`.
|
||||
void writeDeclTo(raw_ostream &os) const;
|
||||
// Writes the method definitions in this op's class to the given `os`.
|
||||
void writeDefTo(raw_ostream &os) const;
|
||||
|
||||
// Returns the C++ class name of the op.
|
||||
StringRef getClassName() const { return className; }
|
||||
|
||||
protected:
|
||||
// Get a list of all the methods to emit, filtering out hidden ones.
|
||||
void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const {
|
||||
using ConsRef = const std::unique_ptr<OpConstructor> &;
|
||||
using MethodRef = const std::unique_ptr<OpMethod> &;
|
||||
llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); });
|
||||
llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); });
|
||||
}
|
||||
|
||||
// For deterministic code generation, keep methods sorted in the order in
|
||||
// which they were generated.
|
||||
template <typename MethodTy>
|
||||
struct MethodCompare {
|
||||
bool operator()(const std::unique_ptr<MethodTy> &x,
|
||||
const std::unique_ptr<MethodTy> &y) const {
|
||||
return x->getID() < y->getID();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename MethodTy>
|
||||
using MethodSet =
|
||||
std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>;
|
||||
|
||||
template <typename MethodTy>
|
||||
MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set,
|
||||
std::unique_ptr<MethodTy> &&newMethod) {
|
||||
// Check if the new method will be made redundant by existing methods.
|
||||
for (auto &method : set)
|
||||
if (method->makesRedundant(*newMethod))
|
||||
return nullptr;
|
||||
|
||||
// We can add this a method to the set. Prune any existing methods that will
|
||||
// be made redundant by adding this new method. Note that the redundant
|
||||
// check between two methods is more than a conflict check. makesRedundant()
|
||||
// below will check if the new method conflicts with an existing method and
|
||||
// if so, returns true if the new method makes the existing method redundant
|
||||
// because all calls to the existing method can be subsumed by the new
|
||||
// method. So makesRedundant() does a combined job of finding conflicts and
|
||||
// deciding which of the 2 conflicting methods survive.
|
||||
//
|
||||
// Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing
|
||||
// it manually here.
|
||||
for (auto it = set.begin(), end = set.end(); it != end;) {
|
||||
if (newMethod->makesRedundant(*(it->get())))
|
||||
it = set.erase(it);
|
||||
else
|
||||
++it;
|
||||
}
|
||||
|
||||
MethodTy *ret = newMethod.get();
|
||||
set.insert(std::move(newMethod));
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string className;
|
||||
MethodSet<OpConstructor> constructors;
|
||||
MethodSet<OpMethod> methods;
|
||||
unsigned nextMethodID = 0;
|
||||
SmallVector<std::string, 4> fields;
|
||||
};
|
||||
|
||||
// Class for holding an op for C++ code emission
|
||||
class OpClass : public Class {
|
||||
public:
|
||||
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
|
||||
|
||||
// Adds an op trait.
|
||||
void addTrait(Twine trait);
|
||||
|
||||
// Writes this op's class as a declaration to the given `os`. Redefines
|
||||
// Class::writeDeclTo to also emit traits and extra class declarations.
|
||||
void writeDeclTo(raw_ostream &os) const;
|
||||
|
||||
private:
|
||||
StringRef extraClassDeclaration;
|
||||
SmallVector<std::string, 4> traitsVec;
|
||||
StringSet<> traitsSet;
|
||||
};
|
||||
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_OPCLASS_H_
|
|
@ -13,12 +13,12 @@ llvm_add_library(MLIRTableGen STATIC
|
|||
Attribute.cpp
|
||||
AttrOrTypeDef.cpp
|
||||
Builder.cpp
|
||||
Class.cpp
|
||||
Constraint.cpp
|
||||
Dialect.cpp
|
||||
Format.cpp
|
||||
Interfaces.cpp
|
||||
Operator.cpp
|
||||
OpClass.cpp
|
||||
Pass.cpp
|
||||
Pattern.cpp
|
||||
Predicate.cpp
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
|
||||
//===- Class.cpp - Helper classes for Op C++ code emission --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/OpClass.h"
|
||||
#include "mlir/TableGen/Class.h"
|
||||
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
@ -20,173 +20,102 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns space to be emitted after the given C++ `type`. return "" if the
|
||||
// ends with '&' or '*', or is empty, else returns " ".
|
||||
StringRef getSpaceAfterType(StringRef type) {
|
||||
static StringRef getSpaceAfterType(StringRef type) {
|
||||
return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpMethodParameter definitions
|
||||
// MethodParameter definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
|
||||
if (properties & PP_Optional)
|
||||
void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
|
||||
if (optional)
|
||||
os << "/*optional*/";
|
||||
os << type << getSpaceAfterType(type) << name;
|
||||
if (emitDefault && !defaultValue.empty())
|
||||
if (emitDefault && hasDefaultValue())
|
||||
os << " = " << defaultValue;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpMethodParameters definitions
|
||||
// MethodParameters definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Factory methods to construct the correct type of `OpMethodParameters`
|
||||
// object based on the arguments.
|
||||
std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
|
||||
return std::make_unique<OpMethodResolvedParameters>();
|
||||
void MethodParameters::writeDeclTo(raw_ostream &os) const {
|
||||
llvm::interleaveComma(parameters, os,
|
||||
[&os](auto ¶m) { param.writeDeclTo(os); });
|
||||
}
|
||||
void MethodParameters::writeDefTo(raw_ostream &os) const {
|
||||
llvm::interleaveComma(parameters, os,
|
||||
[&os](auto ¶m) { param.writeDefTo(os); });
|
||||
}
|
||||
|
||||
std::unique_ptr<OpMethodParameters>
|
||||
OpMethodParameters::create(StringRef params) {
|
||||
return std::make_unique<OpMethodUnresolvedParameters>(params);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpMethodParameters>
|
||||
OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) {
|
||||
return std::make_unique<OpMethodResolvedParameters>(std::move(params));
|
||||
}
|
||||
|
||||
std::unique_ptr<OpMethodParameters>
|
||||
OpMethodParameters::create(StringRef type, StringRef name,
|
||||
StringRef defaultValue) {
|
||||
return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpMethodUnresolvedParameters definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
|
||||
os << parameters;
|
||||
}
|
||||
|
||||
void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
|
||||
// We need to remove the default values for parameters in method definition.
|
||||
// TODO: We are using '=' and ',' as delimiters for parameter
|
||||
// initializers. This is incorrect for initializer list with more than one
|
||||
// element. Change to a more robust approach.
|
||||
llvm::SmallVector<StringRef, 4> tokens;
|
||||
StringRef params = parameters;
|
||||
while (!params.empty()) {
|
||||
std::pair<StringRef, StringRef> parts = params.split("=");
|
||||
tokens.push_back(parts.first);
|
||||
params = parts.second.split(',').second;
|
||||
}
|
||||
llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpMethodResolvedParameters definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Returns true if a method with these parameters makes a method with parameters
|
||||
// `other` redundant. This should return true only if all possible calls to the
|
||||
// other method can be replaced by calls to this method.
|
||||
bool OpMethodResolvedParameters::makesRedundant(
|
||||
const OpMethodResolvedParameters &other) const {
|
||||
const size_t otherNumParams = other.getNumParameters();
|
||||
const size_t thisNumParams = getNumParameters();
|
||||
|
||||
// All calls to the other method can be replaced this method only if this
|
||||
// method has the same or more arguments number of arguments as the other, and
|
||||
// the common arguments have the same type.
|
||||
if (thisNumParams < otherNumParams)
|
||||
bool MethodParameters::subsumes(const MethodParameters &other) const {
|
||||
// These parameters do not subsume the others if there are fewer parameters
|
||||
// or their types do not match.
|
||||
if (parameters.size() < other.parameters.size())
|
||||
return false;
|
||||
for (int idx : llvm::seq<int>(0, otherNumParams))
|
||||
if (parameters[idx].getType() != other.parameters[idx].getType())
|
||||
return false;
|
||||
|
||||
// If all the common arguments have the same type, we can elide the other
|
||||
// method if this method has the same number of arguments as other or the
|
||||
// first argument after the common ones has a default value (and by C++
|
||||
// requirement, all the later ones will also have a default value).
|
||||
return thisNumParams == otherNumParams ||
|
||||
parameters[otherNumParams].hasDefaultValue();
|
||||
}
|
||||
|
||||
void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
|
||||
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
|
||||
param.writeDeclTo(os);
|
||||
});
|
||||
}
|
||||
|
||||
void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
|
||||
llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
|
||||
param.writeDefTo(os);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpMethodSignature definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Returns if a method with this signature makes a method with `other` signature
|
||||
// redundant. Only supports resolved parameters.
|
||||
bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
|
||||
if (methodName != other.methodName)
|
||||
if (!std::equal(
|
||||
other.parameters.begin(), other.parameters.end(), parameters.begin(),
|
||||
[](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); }))
|
||||
return false;
|
||||
auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
|
||||
auto *resolvedOther =
|
||||
dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
|
||||
if (resolvedThis && resolvedOther)
|
||||
return resolvedThis->makesRedundant(*resolvedOther);
|
||||
return false;
|
||||
|
||||
// If all the common parameters have the same type, we can elide the other
|
||||
// method if this method has the same number of parameters as other or if the
|
||||
// first paramater after the common parameters has a default value (and, as
|
||||
// required by C++, subsequent parameters will have default values too).
|
||||
return parameters.size() == other.parameters.size() ||
|
||||
parameters[other.parameters.size()].hasDefaultValue();
|
||||
}
|
||||
|
||||
void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MethodSignature definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool MethodSignature::makesRedundant(const MethodSignature &other) const {
|
||||
return methodName == other.methodName &&
|
||||
parameters.subsumes(other.parameters);
|
||||
}
|
||||
|
||||
void MethodSignature::writeDeclTo(raw_ostream &os) const {
|
||||
os << returnType << getSpaceAfterType(returnType) << methodName << "(";
|
||||
parameters->writeDeclTo(os);
|
||||
parameters.writeDeclTo(os);
|
||||
os << ")";
|
||||
}
|
||||
|
||||
void OpMethodSignature::writeDefTo(raw_ostream &os,
|
||||
StringRef namePrefix) const {
|
||||
void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
||||
os << returnType << getSpaceAfterType(returnType) << namePrefix
|
||||
<< (namePrefix.empty() ? "" : "::") << methodName << "(";
|
||||
parameters->writeDefTo(os);
|
||||
parameters.writeDefTo(os);
|
||||
os << ")";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpMethodBody definitions
|
||||
// MethodBody definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
|
||||
MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {}
|
||||
|
||||
OpMethodBody &OpMethodBody::operator<<(Twine content) {
|
||||
MethodBody &MethodBody::operator<<(Twine content) {
|
||||
if (isEffective)
|
||||
body.append(content.str());
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpMethodBody &OpMethodBody::operator<<(int content) {
|
||||
MethodBody &MethodBody::operator<<(int content) {
|
||||
if (isEffective)
|
||||
body.append(std::to_string(content));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
|
||||
MethodBody &MethodBody::operator<<(const FmtObjectBase &content) {
|
||||
if (isEffective)
|
||||
body.append(content.str());
|
||||
return *this;
|
||||
}
|
||||
|
||||
void OpMethodBody::writeTo(raw_ostream &os) const {
|
||||
void MethodBody::writeTo(raw_ostream &os) const {
|
||||
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
|
||||
os << bodyRef;
|
||||
if (bodyRef.empty() || bodyRef.back() != '\n')
|
||||
|
@ -194,10 +123,10 @@ void OpMethodBody::writeTo(raw_ostream &os) const {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpMethod definitions
|
||||
// Method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void OpMethod::writeDeclTo(raw_ostream &os) const {
|
||||
void Method::writeDeclTo(raw_ostream &os) const {
|
||||
os.indent(2);
|
||||
if (isStatic())
|
||||
os << "static ";
|
||||
|
@ -213,7 +142,7 @@ void OpMethod::writeDeclTo(raw_ostream &os) const {
|
|||
}
|
||||
}
|
||||
|
||||
void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
||||
void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
||||
// Do not write definition if the method is decl only.
|
||||
if (properties & MP_Declaration)
|
||||
return;
|
||||
|
@ -227,15 +156,15 @@ void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpConstructor definitions
|
||||
// Constructor definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
|
||||
void Constructor::addMemberInitializer(StringRef name, StringRef value) {
|
||||
memberInitializers.append(std::string(llvm::formatv(
|
||||
"{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
|
||||
}
|
||||
|
||||
void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
||||
void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
||||
// Do not write definition if the method is decl only.
|
||||
if (properties & MP_Declaration)
|
||||
return;
|
||||
|
@ -243,7 +172,7 @@ void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
|
|||
methodSignature.writeDefTo(os, namePrefix);
|
||||
os << " " << memberInitializers << " {\n";
|
||||
methodBody.writeTo(os);
|
||||
os << "}";
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -259,12 +188,13 @@ void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
|
|||
: formatv("{0} = {1}", varName, defaultValue).str();
|
||||
fields.push_back(std::move(field));
|
||||
}
|
||||
|
||||
void Class::writeDeclTo(raw_ostream &os) const {
|
||||
bool hasPrivateMethod = false;
|
||||
os << "class " << className << " {\n";
|
||||
os << "public:\n";
|
||||
|
||||
forAllMethods([&](const OpMethod &method) {
|
||||
forAllMethods([&](const Method &method) {
|
||||
if (!method.isPrivate()) {
|
||||
method.writeDeclTo(os);
|
||||
os << '\n';
|
||||
|
@ -276,7 +206,7 @@ void Class::writeDeclTo(raw_ostream &os) const {
|
|||
os << '\n';
|
||||
os << "private:\n";
|
||||
if (hasPrivateMethod) {
|
||||
forAllMethods([&](const OpMethod &method) {
|
||||
forAllMethods([&](const Method &method) {
|
||||
if (method.isPrivate()) {
|
||||
method.writeDeclTo(os);
|
||||
os << '\n';
|
||||
|
@ -291,12 +221,35 @@ void Class::writeDeclTo(raw_ostream &os) const {
|
|||
}
|
||||
|
||||
void Class::writeDefTo(raw_ostream &os) const {
|
||||
forAllMethods([&](const OpMethod &method) {
|
||||
forAllMethods([&](const Method &method) {
|
||||
method.writeDefTo(os, className);
|
||||
os << "\n\n";
|
||||
os << "\n";
|
||||
});
|
||||
}
|
||||
|
||||
// Insert a new method into a list of methods, if it would not be pruned, and
|
||||
// prune and existing methods.
|
||||
template <typename ContainerT, typename MethodT>
|
||||
MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) {
|
||||
if (llvm::any_of(methods, [&](auto &method) {
|
||||
return method.makesRedundant(newMethod);
|
||||
}))
|
||||
return nullptr;
|
||||
|
||||
llvm::erase_if(
|
||||
methods, [&](auto &method) { return newMethod.makesRedundant(method); });
|
||||
methods.push_back(std::move(newMethod));
|
||||
return &methods.back();
|
||||
}
|
||||
|
||||
Method *Class::addMethodAndPrune(Method &&newMethod) {
|
||||
return insertAndPrune(methods, std::move(newMethod));
|
||||
}
|
||||
|
||||
Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) {
|
||||
return insertAndPrune(constructors, std::move(newCtor));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpClass definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -304,15 +257,11 @@ void Class::writeDefTo(raw_ostream &os) const {
|
|||
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
|
||||
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
|
||||
|
||||
void OpClass::addTrait(Twine trait) {
|
||||
auto traitStr = trait.str();
|
||||
if (traitsSet.insert(traitStr).second)
|
||||
traitsVec.push_back(std::move(traitStr));
|
||||
}
|
||||
void OpClass::addTrait(Twine trait) { traits.insert(trait.str()); }
|
||||
|
||||
void OpClass::writeDeclTo(raw_ostream &os) const {
|
||||
os << "class " << className << " : public ::mlir::Op<" << className;
|
||||
for (const auto &trait : traitsVec)
|
||||
for (const auto &trait : traits)
|
||||
os << ", " << trait;
|
||||
os << "> {\npublic:\n"
|
||||
<< " using Op::Op;\n"
|
||||
|
@ -320,7 +269,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
|
|||
<< " using Adaptor = " << className << "Adaptor;\n";
|
||||
|
||||
bool hasPrivateMethod = false;
|
||||
forAllMethods([&](const OpMethod &method) {
|
||||
forAllMethods([&](const Method &method) {
|
||||
if (!method.isPrivate()) {
|
||||
method.writeDeclTo(os);
|
||||
os << "\n";
|
||||
|
@ -335,7 +284,7 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
|
|||
|
||||
if (hasPrivateMethod) {
|
||||
os << "\nprivate:\n";
|
||||
forAllMethods([&](const OpMethod &method) {
|
||||
forAllMethods([&](const Method &method) {
|
||||
if (method.isPrivate()) {
|
||||
method.writeDeclTo(os);
|
||||
os << "\n";
|
|
@ -10,11 +10,11 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Class.h"
|
||||
#include "mlir/TableGen/CodeGenHelpers.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Interfaces.h"
|
||||
#include "mlir/TableGen/OpClass.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/Trait.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
|
|
@ -13,11 +13,11 @@
|
|||
|
||||
#include "OpFormatGen.h"
|
||||
#include "OpGenHelpers.h"
|
||||
#include "mlir/TableGen/Class.h"
|
||||
#include "mlir/TableGen/CodeGenHelpers.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Interfaces.h"
|
||||
#include "mlir/TableGen/OpClass.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/SideEffects.h"
|
||||
#include "mlir/TableGen/Trait.h"
|
||||
|
@ -361,7 +361,7 @@ private:
|
|||
// types. `inferredAttributes` is populated with any attributes that are
|
||||
// elided from the build list. The given `typeParamKind` and `attrParamKind`
|
||||
// controls how result types and attributes are placed in the parameter list.
|
||||
void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||
void buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
|
||||
llvm::StringSet<> &inferredAttributes,
|
||||
SmallVectorImpl<std::string> &resultTypeNames,
|
||||
TypeParamKind typeParamKind,
|
||||
|
@ -369,7 +369,7 @@ private:
|
|||
|
||||
// Adds op arguments and regions into operation state for build() methods.
|
||||
void
|
||||
genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
|
||||
genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
|
||||
llvm::StringSet<> &inferredAttributes,
|
||||
bool isRawValueAttr = false);
|
||||
|
||||
|
@ -390,17 +390,16 @@ private:
|
|||
|
||||
// Generates verify statements for operands and results in the operation.
|
||||
// The generated code will be attached to `body`.
|
||||
void genOperandResultVerifier(OpMethodBody &body,
|
||||
Operator::value_range values,
|
||||
void genOperandResultVerifier(MethodBody &body, Operator::value_range values,
|
||||
StringRef valueKind);
|
||||
|
||||
// Generates verify statements for regions in the operation.
|
||||
// The generated code will be attached to `body`.
|
||||
void genRegionVerifier(OpMethodBody &body);
|
||||
void genRegionVerifier(MethodBody &body);
|
||||
|
||||
// Generates verify statements for successors in the operation.
|
||||
// The generated code will be attached to `body`.
|
||||
void genSuccessorVerifier(OpMethodBody &body);
|
||||
void genSuccessorVerifier(MethodBody &body);
|
||||
|
||||
// Generates the traits used by the object.
|
||||
void genTraits();
|
||||
|
@ -413,8 +412,8 @@ private:
|
|||
|
||||
// Generate op interface method for the given interface method. If
|
||||
// 'declaration' is true, generates a declaration, else a definition.
|
||||
OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
|
||||
bool declaration = true);
|
||||
Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
|
||||
bool declaration = true);
|
||||
|
||||
// Generate the side effect interface methods.
|
||||
void genSideEffectInterfaceMethods();
|
||||
|
@ -470,7 +469,7 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
|
|||
// Generate attribute verification. If an op instance is not available, then
|
||||
// attribute checks that require one will not be emitted.
|
||||
static void genAttributeVerifier(
|
||||
const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body,
|
||||
const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
|
||||
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
|
||||
// Check that a required attribute exists.
|
||||
//
|
||||
|
@ -602,7 +601,7 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
|
|||
|
||||
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
|
||||
|
||||
static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName,
|
||||
static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
|
||||
const Operator &op) {
|
||||
if (m)
|
||||
return;
|
||||
|
@ -627,18 +626,15 @@ void OpEmitter::genAttrNameGetters() {
|
|||
for (const NamedAttribute &namedAttr : op.getAttributes())
|
||||
addAttrName(namedAttr.name);
|
||||
// Include key attributes from several traits as implicitly registered.
|
||||
std::string operandSizes = "operand_segment_sizes";
|
||||
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
|
||||
addAttrName(operandSizes);
|
||||
std::string attrSizes = "result_segment_sizes";
|
||||
addAttrName("operand_segment_sizes");
|
||||
if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
|
||||
addAttrName(attrSizes);
|
||||
addAttrName("result_segment_sizes");
|
||||
|
||||
// Emit the getAttributeNames method.
|
||||
{
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
"::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames",
|
||||
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline));
|
||||
auto *method = opClass.addStaticInlineMethod(
|
||||
"::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
|
||||
ERROR_IF_PRUNED(method, "getAttributeNames", op);
|
||||
auto &body = method->body();
|
||||
if (attributeNames.empty()) {
|
||||
|
@ -658,20 +654,18 @@ void OpEmitter::genAttrNameGetters() {
|
|||
|
||||
// Emit the getAttributeNameForIndex methods.
|
||||
{
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
auto *method = opClass.addInlineMethod<Method::MP_Private>(
|
||||
"::mlir::Identifier", "getAttributeNameForIndex",
|
||||
OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline),
|
||||
"unsigned", "index");
|
||||
MethodParameter("unsigned", "index"));
|
||||
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
|
||||
method->body()
|
||||
<< " return getAttributeNameForIndex((*this)->getName(), index);";
|
||||
}
|
||||
{
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
auto *method = opClass.addStaticInlineMethod<Method::MP_Private>(
|
||||
"::mlir::Identifier", "getAttributeNameForIndex",
|
||||
OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline |
|
||||
OpMethod::MP_Static),
|
||||
"::mlir::OperationName name, unsigned index");
|
||||
MethodParameter("::mlir::OperationName", "name"),
|
||||
MethodParameter("unsigned", "index"));
|
||||
ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
|
||||
method->body() << "assert(index < " << attributeNames.size()
|
||||
<< " && \"invalid attribute index\");\n"
|
||||
|
@ -689,8 +683,7 @@ void OpEmitter::genAttrNameGetters() {
|
|||
// Generate the non-static variant.
|
||||
{
|
||||
auto *method =
|
||||
opClass.addMethodAndPrune("::mlir::Identifier", methodName,
|
||||
OpMethod::Property(OpMethod::MP_Inline));
|
||||
opClass.addInlineMethod("::mlir::Identifier", methodName);
|
||||
ERROR_IF_PRUNED(method, methodName, op);
|
||||
method->body()
|
||||
<< llvm::formatv(attrNameMethodBody, attrIt.second).str();
|
||||
|
@ -698,10 +691,9 @@ void OpEmitter::genAttrNameGetters() {
|
|||
|
||||
// Generate the static variant.
|
||||
{
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
auto *method = opClass.addStaticInlineMethod(
|
||||
"::mlir::Identifier", methodName,
|
||||
OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
|
||||
"::mlir::OperationName", "name");
|
||||
MethodParameter("::mlir::OperationName", "name"));
|
||||
ERROR_IF_PRUNED(method, methodName, op);
|
||||
method->body() << llvm::formatv(attrNameMethodBody,
|
||||
"name, " + Twine(attrIt.second))
|
||||
|
@ -717,13 +709,13 @@ void OpEmitter::genAttrGetters() {
|
|||
|
||||
// Emit the derived attribute body.
|
||||
auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
|
||||
if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name))
|
||||
if (auto *method = opClass.addMethod(attr.getReturnType(), name))
|
||||
method->body() << " " << attr.getDerivedCodeBody() << "\n";
|
||||
};
|
||||
|
||||
// Emit with return type specified.
|
||||
auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
|
||||
auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
|
||||
auto *method = opClass.addMethod(attr.getReturnType(), name);
|
||||
ERROR_IF_PRUNED(method, name, op);
|
||||
auto &body = method->body();
|
||||
body << " auto attr = " << name << "Attr();\n";
|
||||
|
@ -748,7 +740,7 @@ void OpEmitter::genAttrGetters() {
|
|||
// use the string interface for better compile time verification.
|
||||
auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
|
||||
auto *method =
|
||||
opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
|
||||
opClass.addMethod(attr.getStorageType(), (name + "Attr").str());
|
||||
if (!method)
|
||||
return;
|
||||
method->body() << formatv(
|
||||
|
@ -773,68 +765,69 @@ void OpEmitter::genAttrGetters() {
|
|||
[](const NamedAttribute &namedAttr) {
|
||||
return namedAttr.attr.isDerivedAttr();
|
||||
});
|
||||
if (!derivedAttrs.empty()) {
|
||||
opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
|
||||
// Generate helper method to query whether a named attribute is a derived
|
||||
// attribute. This enables, for example, avoiding adding an attribute that
|
||||
// overlaps with a derived attribute.
|
||||
{
|
||||
auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
|
||||
OpMethod::MP_Static,
|
||||
"::llvm::StringRef", "name");
|
||||
ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
|
||||
auto &body = method->body();
|
||||
for (auto namedAttr : derivedAttrs)
|
||||
body << " if (name == \"" << namedAttr.name << "\") return true;\n";
|
||||
body << " return false;";
|
||||
}
|
||||
// Generate method to materialize derived attributes as a DictionaryAttr.
|
||||
{
|
||||
auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
|
||||
"materializeDerivedAttributes");
|
||||
ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
|
||||
auto &body = method->body();
|
||||
if (derivedAttrs.empty())
|
||||
return;
|
||||
|
||||
auto nonMaterializable =
|
||||
make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
|
||||
return namedAttr.attr.getConvertFromStorageCall().empty();
|
||||
});
|
||||
if (!nonMaterializable.empty()) {
|
||||
std::string attrs;
|
||||
llvm::raw_string_ostream os(attrs);
|
||||
interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
|
||||
os << op.getGetterName(attr.name);
|
||||
opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
|
||||
// Generate helper method to query whether a named attribute is a derived
|
||||
// attribute. This enables, for example, avoiding adding an attribute that
|
||||
// overlaps with a derived attribute.
|
||||
{
|
||||
auto *method =
|
||||
opClass.addStaticMethod("bool", "isDerivedAttribute",
|
||||
MethodParameter("::llvm::StringRef", "name"));
|
||||
ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
|
||||
auto &body = method->body();
|
||||
for (auto namedAttr : derivedAttrs)
|
||||
body << " if (name == \"" << namedAttr.name << "\") return true;\n";
|
||||
body << " return false;";
|
||||
}
|
||||
// Generate method to materialize derived attributes as a DictionaryAttr.
|
||||
{
|
||||
auto *method = opClass.addMethod("::mlir::DictionaryAttr",
|
||||
"materializeDerivedAttributes");
|
||||
ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
|
||||
auto &body = method->body();
|
||||
|
||||
auto nonMaterializable =
|
||||
make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
|
||||
return namedAttr.attr.getConvertFromStorageCall().empty();
|
||||
});
|
||||
PrintWarning(
|
||||
op.getLoc(),
|
||||
formatv(
|
||||
"op has non-materializable derived attributes '{0}', skipping",
|
||||
os.str()));
|
||||
body << formatv(" emitOpError(\"op has non-materializable derived "
|
||||
"attributes '{0}'\");\n",
|
||||
attrs);
|
||||
body << " return nullptr;";
|
||||
return;
|
||||
}
|
||||
|
||||
body << " ::mlir::MLIRContext* ctx = getContext();\n";
|
||||
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
|
||||
body << " return ::mlir::DictionaryAttr::get(";
|
||||
body << " ctx, {\n";
|
||||
interleave(
|
||||
derivedAttrs, body,
|
||||
[&](const NamedAttribute &namedAttr) {
|
||||
auto tmpl = namedAttr.attr.getConvertFromStorageCall();
|
||||
std::string name = op.getGetterName(namedAttr.name);
|
||||
body << " {" << name << "AttrName(),\n"
|
||||
<< tgfmt(tmpl, &fctx.withSelf(name + "()")
|
||||
.withBuilder("odsBuilder")
|
||||
.addSubst("_ctx", "ctx"))
|
||||
<< "}";
|
||||
},
|
||||
",\n");
|
||||
body << "});";
|
||||
if (!nonMaterializable.empty()) {
|
||||
std::string attrs;
|
||||
llvm::raw_string_ostream os(attrs);
|
||||
interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
|
||||
os << op.getGetterName(attr.name);
|
||||
});
|
||||
PrintWarning(
|
||||
op.getLoc(),
|
||||
formatv(
|
||||
"op has non-materializable derived attributes '{0}', skipping",
|
||||
os.str()));
|
||||
body << formatv(" emitOpError(\"op has non-materializable derived "
|
||||
"attributes '{0}'\");\n",
|
||||
attrs);
|
||||
body << " return nullptr;";
|
||||
return;
|
||||
}
|
||||
|
||||
body << " ::mlir::MLIRContext* ctx = getContext();\n";
|
||||
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
|
||||
body << " return ::mlir::DictionaryAttr::get(";
|
||||
body << " ctx, {\n";
|
||||
interleave(
|
||||
derivedAttrs, body,
|
||||
[&](const NamedAttribute &namedAttr) {
|
||||
auto tmpl = namedAttr.attr.getConvertFromStorageCall();
|
||||
std::string name = op.getGetterName(namedAttr.name);
|
||||
body << " {" << name << "AttrName(),\n"
|
||||
<< tgfmt(tmpl, &fctx.withSelf(name + "()")
|
||||
.withBuilder("odsBuilder")
|
||||
.addSubst("_ctx", "ctx"))
|
||||
<< "}";
|
||||
},
|
||||
",\n");
|
||||
body << "});";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -844,19 +837,21 @@ void OpEmitter::genAttrSetters() {
|
|||
// for better compile time verification.
|
||||
auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
|
||||
Attribute attr) {
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
"void", (setterName + "Attr").str(), attr.getStorageType(), "attr");
|
||||
auto *method =
|
||||
opClass.addMethod("void", (setterName + "Attr").str(),
|
||||
MethodParameter(attr.getStorageType(), "attr"));
|
||||
if (method)
|
||||
method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);",
|
||||
getterName);
|
||||
};
|
||||
|
||||
for (const NamedAttribute &namedAttr : op.getAttributes()) {
|
||||
if (!namedAttr.attr.isDerivedAttr())
|
||||
for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
|
||||
op.getGetterNames(namedAttr.name)))
|
||||
emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
|
||||
namedAttr.attr);
|
||||
if (namedAttr.attr.isDerivedAttr())
|
||||
continue;
|
||||
for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
|
||||
op.getGetterNames(namedAttr.name)))
|
||||
emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
|
||||
namedAttr.attr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -866,7 +861,7 @@ void OpEmitter::genOptionalAttrRemovers() {
|
|||
auto emitRemoveAttr = [&](StringRef name) {
|
||||
auto upperInitial = name.take_front().upper();
|
||||
auto suffix = name.drop_front();
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
auto *method = opClass.addMethod(
|
||||
"::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
|
||||
if (!method)
|
||||
return;
|
||||
|
@ -887,8 +882,8 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
|
|||
int numVariadic, int numNonVariadic,
|
||||
StringRef rangeSizeCall, bool hasAttrSegmentSize,
|
||||
StringRef sizeAttrInit, RangeT &&odsValues) {
|
||||
auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
|
||||
methodName, "unsigned", "index");
|
||||
auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
|
||||
MethodParameter("unsigned", "index"));
|
||||
if (!method)
|
||||
return;
|
||||
auto &body = method->body();
|
||||
|
@ -900,7 +895,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
|
|||
// Because the op can have arbitrarily interleaved variadic and non-variadic
|
||||
// operands, we need to embed a list in the "sink" getter method for
|
||||
// calculation at run-time.
|
||||
llvm::SmallVector<StringRef, 4> isVariadic;
|
||||
SmallVector<StringRef, 4> isVariadic;
|
||||
isVariadic.reserve(llvm::size(odsValues));
|
||||
for (auto &it : odsValues)
|
||||
isVariadic.push_back(it.isVariableLength() ? "true" : "false");
|
||||
|
@ -959,8 +954,8 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
|||
rangeSizeCall, attrSizedOperands, sizeAttrInit,
|
||||
const_cast<Operator &>(op).getOperands());
|
||||
|
||||
auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
|
||||
"index");
|
||||
auto *m = opClass.addMethod(rangeType, "getODSOperands",
|
||||
MethodParameter("unsigned", "index"));
|
||||
ERROR_IF_PRUNED(m, "getODSOperands", op);
|
||||
auto &body = m->body();
|
||||
body << formatv(valueRangeReturnCode, rangeBeginCall,
|
||||
|
@ -974,7 +969,7 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
|||
continue;
|
||||
for (StringRef name : op.getGetterNames(operand.name)) {
|
||||
if (operand.isOptional()) {
|
||||
m = opClass.addMethodAndPrune("::mlir::Value", name);
|
||||
m = opClass.addMethod("::mlir::Value", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << " auto operands = getODSOperands(" << i << ");\n"
|
||||
<< " return operands.empty() ? ::mlir::Value() : "
|
||||
|
@ -983,24 +978,24 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
|||
std::string segmentAttr = op.getGetterName(
|
||||
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
|
||||
if (isAdaptor) {
|
||||
m = opClass.addMethodAndPrune(
|
||||
"::llvm::SmallVector<::mlir::ValueRange>", name);
|
||||
m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
|
||||
name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
|
||||
segmentAttr, i);
|
||||
continue;
|
||||
}
|
||||
|
||||
m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", name);
|
||||
m = opClass.addMethod("::mlir::OperandRangeRange", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << " return getODSOperands(" << i << ").split("
|
||||
<< segmentAttr << "Attr());";
|
||||
} else if (operand.isVariadic()) {
|
||||
m = opClass.addMethodAndPrune(rangeType, name);
|
||||
m = opClass.addMethod(rangeType, name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << " return getODSOperands(" << i << ");";
|
||||
} else {
|
||||
m = opClass.addMethodAndPrune("::mlir::Value", name);
|
||||
m = opClass.addMethod("::mlir::Value", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << " return *getODSOperands(" << i << ").begin();";
|
||||
}
|
||||
|
@ -1035,10 +1030,10 @@ void OpEmitter::genNamedOperandSetters() {
|
|||
if (operand.name.empty())
|
||||
continue;
|
||||
for (StringRef name : op.getGetterNames(operand.name)) {
|
||||
auto *m = opClass.addMethodAndPrune(
|
||||
operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange"
|
||||
: "::mlir::MutableOperandRange",
|
||||
(name + "Mutable").str());
|
||||
auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
|
||||
? "::mlir::MutableOperandRangeRange"
|
||||
: "::mlir::MutableOperandRange",
|
||||
(name + "Mutable").str());
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
auto &body = m->body();
|
||||
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
|
||||
|
@ -1110,8 +1105,9 @@ void OpEmitter::genNamedResultGetters() {
|
|||
numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
|
||||
attrSizeInitCode, op.getResults());
|
||||
|
||||
auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
|
||||
"getODSResults", "unsigned", "index");
|
||||
auto *m =
|
||||
opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
|
||||
MethodParameter("unsigned", "index"));
|
||||
ERROR_IF_PRUNED(m, "getODSResults", op);
|
||||
m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
|
||||
"getODSResultIndexAndLength(index)");
|
||||
|
@ -1122,17 +1118,17 @@ void OpEmitter::genNamedResultGetters() {
|
|||
continue;
|
||||
for (StringRef name : op.getGetterNames(result.name)) {
|
||||
if (result.isOptional()) {
|
||||
m = opClass.addMethodAndPrune("::mlir::Value", name);
|
||||
m = opClass.addMethod("::mlir::Value", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body()
|
||||
<< " auto results = getODSResults(" << i << ");\n"
|
||||
<< " return results.empty() ? ::mlir::Value() : *results.begin();";
|
||||
} else if (result.isVariadic()) {
|
||||
m = opClass.addMethodAndPrune("::mlir::Operation::result_range", name);
|
||||
m = opClass.addMethod("::mlir::Operation::result_range", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << " return getODSResults(" << i << ");";
|
||||
} else {
|
||||
m = opClass.addMethodAndPrune("::mlir::Value", name);
|
||||
m = opClass.addMethod("::mlir::Value", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << " return *getODSResults(" << i << ").begin();";
|
||||
}
|
||||
|
@ -1150,15 +1146,15 @@ void OpEmitter::genNamedRegionGetters() {
|
|||
for (StringRef name : op.getGetterNames(region.name)) {
|
||||
// Generate the accessors for a variadic region.
|
||||
if (region.isVariadic()) {
|
||||
auto *m = opClass.addMethodAndPrune(
|
||||
"::mlir::MutableArrayRef<::mlir::Region>", name);
|
||||
auto *m =
|
||||
opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
|
||||
i);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *m = opClass.addMethodAndPrune("::mlir::Region &", name);
|
||||
auto *m = opClass.addMethod("::mlir::Region &", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << formatv(" return (*this)->getRegion({0});", i);
|
||||
}
|
||||
|
@ -1175,7 +1171,7 @@ void OpEmitter::genNamedSuccessorGetters() {
|
|||
for (StringRef name : op.getGetterNames(successor.name)) {
|
||||
// Generate the accessors for a variadic successor list.
|
||||
if (successor.isVariadic()) {
|
||||
auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name);
|
||||
auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << formatv(
|
||||
" return {std::next((*this)->successor_begin(), {0}), "
|
||||
|
@ -1184,7 +1180,7 @@ void OpEmitter::genNamedSuccessorGetters() {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto *m = opClass.addMethodAndPrune("::mlir::Block *", name);
|
||||
auto *m = opClass.addMethod("::mlir::Block *", name);
|
||||
ERROR_IF_PRUNED(m, name, op);
|
||||
m->body() << formatv(" return (*this)->getSuccessor({0});", i);
|
||||
}
|
||||
|
@ -1227,14 +1223,13 @@ void OpEmitter::genSeparateArgParamBuilder() {
|
|||
// inferring result type.
|
||||
auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
|
||||
bool inferType) {
|
||||
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
SmallVector<std::string, 4> resultNames;
|
||||
llvm::StringSet<> inferredAttributes;
|
||||
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
|
||||
attrType);
|
||||
|
||||
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
|
||||
// If the builder is redundant, skip generating the method.
|
||||
if (!m)
|
||||
return;
|
||||
|
@ -1308,7 +1303,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
|
|||
int numResults = op.getNumResults();
|
||||
|
||||
// Signature
|
||||
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
||||
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
||||
paramList.emplace_back("::mlir::ValueRange", "operands");
|
||||
|
@ -1319,8 +1314,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
|
|||
if (op.getNumVariadicRegions())
|
||||
paramList.emplace_back("unsigned", "numRegions");
|
||||
|
||||
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
|
||||
// If the builder is redundant, skip generating the method
|
||||
if (!m)
|
||||
return;
|
||||
|
@ -1348,14 +1342,13 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
|
|||
|
||||
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
|
||||
// TODO: Expand to support regions.
|
||||
SmallVector<OpMethodParameter, 4> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
||||
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
||||
paramList.emplace_back("::mlir::ValueRange", "operands");
|
||||
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
|
||||
"attributes", "{}");
|
||||
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
|
||||
// If the builder is redundant, skip generating the method
|
||||
if (!m)
|
||||
return;
|
||||
|
@ -1407,14 +1400,13 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
|
|||
}
|
||||
|
||||
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
|
||||
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
SmallVector<std::string, 4> resultNames;
|
||||
llvm::StringSet<> inferredAttributes;
|
||||
buildParamList(paramList, inferredAttributes, resultNames,
|
||||
TypeParamKind::None);
|
||||
|
||||
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
|
||||
// If the builder is redundant, skip generating the method
|
||||
if (!m)
|
||||
return;
|
||||
|
@ -1436,14 +1428,13 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
|
|||
}
|
||||
|
||||
void OpEmitter::genUseAttrAsResultTypeBuilder() {
|
||||
SmallVector<OpMethodParameter, 4> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
|
||||
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
||||
paramList.emplace_back("::mlir::ValueRange", "operands");
|
||||
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
|
||||
"attributes", "{}");
|
||||
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
|
||||
// If the builder is redundant, skip generating the method
|
||||
if (!m)
|
||||
return;
|
||||
|
@ -1480,16 +1471,15 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
|
|||
|
||||
/// Returns a signature of the builder. Updates the context `fctx` to enable
|
||||
/// replacement of $_builder and $_state in the body.
|
||||
static std::string getBuilderSignature(const Builder &builder) {
|
||||
static SmallVector<MethodParameter>
|
||||
getBuilderSignature(const Builder &builder) {
|
||||
ArrayRef<Builder::Parameter> params(builder.getParameters());
|
||||
|
||||
// Inject builder and state arguments.
|
||||
llvm::SmallVector<std::string, 8> arguments;
|
||||
SmallVector<MethodParameter> arguments;
|
||||
arguments.reserve(params.size() + 2);
|
||||
arguments.push_back(
|
||||
llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
|
||||
arguments.push_back(
|
||||
llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
|
||||
arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
|
||||
arguments.emplace_back("::mlir::OperationState &", builderOpState);
|
||||
|
||||
for (unsigned i = 0, e = params.size(); i < e; ++i) {
|
||||
// If no name is provided, generate one.
|
||||
|
@ -1497,27 +1487,27 @@ static std::string getBuilderSignature(const Builder &builder) {
|
|||
std::string name =
|
||||
paramName ? paramName->str() : "odsArg" + std::to_string(i);
|
||||
|
||||
std::string defaultValue;
|
||||
StringRef defaultValue;
|
||||
if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
|
||||
defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
|
||||
arguments.push_back(
|
||||
llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
|
||||
.str());
|
||||
defaultValue = *defaultParamValue;
|
||||
|
||||
arguments.emplace_back(params[i].getCppType(), std::move(name),
|
||||
defaultValue);
|
||||
}
|
||||
|
||||
return llvm::join(arguments, ", ");
|
||||
return arguments;
|
||||
}
|
||||
|
||||
void OpEmitter::genBuilder() {
|
||||
// Handle custom builders if provided.
|
||||
for (const Builder &builder : op.getBuilders()) {
|
||||
std::string paramStr = getBuilderSignature(builder);
|
||||
SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
|
||||
|
||||
Optional<StringRef> body = builder.getBody();
|
||||
OpMethod::Property properties =
|
||||
body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
|
||||
Method::Property properties =
|
||||
body ? Method::MP_Static : Method::MP_StaticDeclaration;
|
||||
auto *method =
|
||||
opClass.addMethodAndPrune("void", "build", properties, paramStr);
|
||||
opClass.addMethod("void", "build", properties, std::move(arguments));
|
||||
if (body)
|
||||
ERROR_IF_PRUNED(method, "build", op);
|
||||
|
||||
|
@ -1561,7 +1551,7 @@ void OpEmitter::genCollectiveParamBuilder() {
|
|||
int numVariadicOperands = op.getNumVariableLengthOperands();
|
||||
int numNonVariadicOperands = numOperands - numVariadicOperands;
|
||||
|
||||
SmallVector<OpMethodParameter, 4> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::OpBuilder &", "");
|
||||
paramList.emplace_back("::mlir::OperationState &", builderOpState);
|
||||
paramList.emplace_back("::mlir::TypeRange", "resultTypes");
|
||||
|
@ -1573,8 +1563,7 @@ void OpEmitter::genCollectiveParamBuilder() {
|
|||
if (op.getNumVariadicRegions())
|
||||
paramList.emplace_back("unsigned", "numRegions");
|
||||
|
||||
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
|
||||
// If the builder is redundant, skip generating the method
|
||||
if (!m)
|
||||
return;
|
||||
|
@ -1612,7 +1601,7 @@ void OpEmitter::genCollectiveParamBuilder() {
|
|||
genInferredTypeCollectiveParamBuilder();
|
||||
}
|
||||
|
||||
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
||||
void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
|
||||
llvm::StringSet<> &inferredAttributes,
|
||||
SmallVectorImpl<std::string> &resultTypeNames,
|
||||
TypeParamKind typeParamKind,
|
||||
|
@ -1637,11 +1626,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
|
||||
StringRef type =
|
||||
result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
|
||||
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
||||
if (result.isOptional())
|
||||
properties = OpMethodParameter::PP_Optional;
|
||||
|
||||
paramList.emplace_back(type, resultName, properties);
|
||||
paramList.emplace_back(type, resultName, result.isOptional());
|
||||
resultTypeNames.emplace_back(std::move(resultName));
|
||||
}
|
||||
} break;
|
||||
|
@ -1699,11 +1685,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
else
|
||||
type = "::mlir::Value";
|
||||
|
||||
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
||||
if (operand->isOptional())
|
||||
properties = OpMethodParameter::PP_Optional;
|
||||
paramList.emplace_back(type, getArgumentName(op, numOperands++),
|
||||
properties);
|
||||
operand->isOptional());
|
||||
continue;
|
||||
}
|
||||
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
|
||||
|
@ -1713,10 +1696,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
if (inferredAttributes.contains(namedAttr.name))
|
||||
continue;
|
||||
|
||||
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
|
||||
if (attr.isOptional())
|
||||
properties = OpMethodParameter::PP_Optional;
|
||||
|
||||
StringRef type;
|
||||
switch (attrParamKind) {
|
||||
case AttrParamKind::WrappedAttr:
|
||||
|
@ -1736,7 +1715,8 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
i >= defaultValuedAttrStartIndex) {
|
||||
defaultValue += attr.getDefaultValue();
|
||||
}
|
||||
paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
|
||||
paramList.emplace_back(type, namedAttr.name, defaultValue,
|
||||
attr.isOptional());
|
||||
}
|
||||
|
||||
/// Insert parameters for each successor.
|
||||
|
@ -1754,7 +1734,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
|
|||
}
|
||||
|
||||
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
|
||||
OpMethodBody &body, llvm::StringSet<> &inferredAttributes,
|
||||
MethodBody &body, llvm::StringSet<> &inferredAttributes,
|
||||
bool isRawValueAttr) {
|
||||
// Push all operands to the result.
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
|
@ -1871,12 +1851,11 @@ void OpEmitter::genCanonicalizerDecls() {
|
|||
if (hasCanonicalizeMethod) {
|
||||
// static LogicResult FooOp::
|
||||
// canonicalize(FooOp op, PatternRewriter &rewriter);
|
||||
SmallVector<OpMethodParameter, 2> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back(op.getCppClassName(), "op");
|
||||
paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
|
||||
auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
|
||||
OpMethod::MP_StaticDeclaration,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.declareStaticMethod("::mlir::LogicalResult",
|
||||
"canonicalize", std::move(paramList));
|
||||
ERROR_IF_PRUNED(m, "canonicalize", op);
|
||||
}
|
||||
|
||||
|
@ -1892,12 +1871,12 @@ void OpEmitter::genCanonicalizerDecls() {
|
|||
|
||||
// Add a signature for getCanonicalizationPatterns if implemented by the
|
||||
// dialect or if synthesized to call 'canonicalize'.
|
||||
SmallVector<OpMethodParameter, 2> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::RewritePatternSet &", "results");
|
||||
paramList.emplace_back("::mlir::MLIRContext *", "context");
|
||||
auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
"void", "getCanonicalizationPatterns", kind, std::move(paramList));
|
||||
auto kind = hasBody ? Method::MP_Static : Method::MP_StaticDeclaration;
|
||||
auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
|
||||
std::move(paramList));
|
||||
|
||||
// If synthesizing the method, fill it it.
|
||||
if (hasBody) {
|
||||
|
@ -1912,18 +1891,17 @@ void OpEmitter::genFolderDecls() {
|
|||
|
||||
if (def.getValueAsBit("hasFolder")) {
|
||||
if (hasSingleResult) {
|
||||
auto *m = opClass.addMethodAndPrune(
|
||||
"::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
|
||||
"::llvm::ArrayRef<::mlir::Attribute>", "operands");
|
||||
auto *m = opClass.declareMethod(
|
||||
"::mlir::OpFoldResult", "fold",
|
||||
MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
|
||||
ERROR_IF_PRUNED(m, "operands", op);
|
||||
} else {
|
||||
SmallVector<OpMethodParameter, 2> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
|
||||
paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
|
||||
"results");
|
||||
auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
|
||||
OpMethod::MP_Declaration,
|
||||
std::move(paramList));
|
||||
auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
|
||||
std::move(paramList));
|
||||
ERROR_IF_PRUNED(m, "fold", op);
|
||||
}
|
||||
}
|
||||
|
@ -1953,18 +1931,18 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
|
|||
}
|
||||
}
|
||||
|
||||
OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
|
||||
bool declaration) {
|
||||
SmallVector<OpMethodParameter, 4> paramList;
|
||||
Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
|
||||
bool declaration) {
|
||||
SmallVector<MethodParameter> paramList;
|
||||
for (const InterfaceMethod::Argument &arg : method.getArguments())
|
||||
paramList.emplace_back(arg.type, arg.name);
|
||||
|
||||
auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
|
||||
auto properties = method.isStatic() ? Method::MP_Static : Method::MP_None;
|
||||
if (declaration)
|
||||
properties =
|
||||
static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
|
||||
return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
|
||||
properties, std::move(paramList));
|
||||
static_cast<Method::Property>(properties | Method::MP_Declaration);
|
||||
return opClass.addMethod(method.getReturnType(), method.getName(), properties,
|
||||
std::move(paramList));
|
||||
}
|
||||
|
||||
void OpEmitter::genOpInterfaceMethods() {
|
||||
|
@ -2039,8 +2017,8 @@ void OpEmitter::genSideEffectInterfaceMethods() {
|
|||
"SideEffects::EffectInstance<{0}>> &",
|
||||
it.first())
|
||||
.str();
|
||||
auto *getEffects =
|
||||
opClass.addMethodAndPrune("void", "getEffects", type, "effects");
|
||||
auto *getEffects = opClass.addMethod("void", "getEffects",
|
||||
MethodParameter(type, "effects"));
|
||||
ERROR_IF_PRUNED(getEffects, "getEffects", op);
|
||||
auto &body = getEffects->body();
|
||||
|
||||
|
@ -2082,7 +2060,7 @@ void OpEmitter::genTypeInterfaceMethods() {
|
|||
const auto *trait = dyn_cast<InterfaceTrait>(
|
||||
op.getTrait("::mlir::InferTypeOpInterface::Trait"));
|
||||
Interface interface = trait->getInterface();
|
||||
OpMethod *method = [&]() -> OpMethod * {
|
||||
Method *method = [&]() -> Method * {
|
||||
for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
|
||||
if (interfaceMethod.getName() == "inferReturnTypes") {
|
||||
return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
|
||||
|
@ -2099,8 +2077,7 @@ void OpEmitter::genTypeInterfaceMethods() {
|
|||
fctx.withBuilder("odsBuilder");
|
||||
body << " ::mlir::Builder odsBuilder(context);\n";
|
||||
|
||||
auto emitType =
|
||||
[&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
|
||||
auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
|
||||
if (!type.isArg())
|
||||
return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
|
||||
auto argIndex = type.getArg();
|
||||
|
@ -2129,12 +2106,11 @@ void OpEmitter::genParser() {
|
|||
hasStringAttribute(def, "assemblyFormat"))
|
||||
return;
|
||||
|
||||
SmallVector<OpMethodParameter, 2> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
|
||||
paramList.emplace_back("::mlir::OperationState &", "result");
|
||||
auto *method =
|
||||
opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
|
||||
OpMethod::MP_Static, std::move(paramList));
|
||||
auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
|
||||
std::move(paramList));
|
||||
ERROR_IF_PRUNED(method, "parse", op);
|
||||
|
||||
FmtContext fctx;
|
||||
|
@ -2152,8 +2128,8 @@ void OpEmitter::genPrinter() {
|
|||
if (!stringInit)
|
||||
return;
|
||||
|
||||
auto *method =
|
||||
opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
|
||||
auto *method = opClass.addMethod(
|
||||
"void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
|
||||
ERROR_IF_PRUNED(method, "print", op);
|
||||
FmtContext fctx;
|
||||
fctx.addSubst("cppClass", opClass.getClassName());
|
||||
|
@ -2162,7 +2138,7 @@ void OpEmitter::genPrinter() {
|
|||
}
|
||||
|
||||
/// Generate verification on native traits requiring attributes.
|
||||
static void genNativeTraitAttrVerifier(OpMethodBody &body,
|
||||
static void genNativeTraitAttrVerifier(MethodBody &body,
|
||||
const OpOrAdaptorHelper &emitHelper) {
|
||||
// Check that the variadic segment sizes attribute exists and contains the
|
||||
// expected number of elements.
|
||||
|
@ -2209,7 +2185,7 @@ static void genNativeTraitAttrVerifier(OpMethodBody &body,
|
|||
}
|
||||
|
||||
void OpEmitter::genVerifier() {
|
||||
auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
|
||||
auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
|
||||
ERROR_IF_PRUNED(method, "verify", op);
|
||||
auto &body = method->body();
|
||||
|
||||
|
@ -2247,7 +2223,7 @@ void OpEmitter::genVerifier() {
|
|||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
|
||||
void OpEmitter::genOperandResultVerifier(MethodBody &body,
|
||||
Operator::value_range values,
|
||||
StringRef valueKind) {
|
||||
// Check that an optional value is at most 1 element.
|
||||
|
@ -2321,7 +2297,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
|
|||
body << " }\n";
|
||||
}
|
||||
|
||||
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
|
||||
void OpEmitter::genRegionVerifier(MethodBody &body) {
|
||||
/// Code to verify a region.
|
||||
///
|
||||
/// {0}: Getter for the regions.
|
||||
|
@ -2363,7 +2339,7 @@ void OpEmitter::genRegionVerifier(OpMethodBody &body) {
|
|||
body << " }\n";
|
||||
}
|
||||
|
||||
void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
|
||||
void OpEmitter::genSuccessorVerifier(MethodBody &body) {
|
||||
const char *const verifySuccessor = R"(
|
||||
for (auto *successor : {0})
|
||||
if (::mlir::failed({1}(*this, successor, "{2}", index++)))
|
||||
|
@ -2485,9 +2461,8 @@ void OpEmitter::genTraits() {
|
|||
}
|
||||
|
||||
void OpEmitter::genOpNameGetter() {
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
"::llvm::StringLiteral", "getOperationName",
|
||||
OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
|
||||
auto *method = opClass.addStaticMethod<Method::MP_Constexpr>(
|
||||
"::llvm::StringLiteral", "getOperationName");
|
||||
ERROR_IF_PRUNED(method, "getOperationName", op);
|
||||
method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
|
||||
<< "\");";
|
||||
|
@ -2514,8 +2489,9 @@ void OpEmitter::genOpAsmInterface() {
|
|||
opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
|
||||
|
||||
// Generate the right accessor for the number of results.
|
||||
auto *method = opClass.addMethodAndPrune(
|
||||
"void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
|
||||
auto *method = opClass.addMethod(
|
||||
"void", "getAsmResultNames",
|
||||
MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
|
||||
ERROR_IF_PRUNED(method, "getAsmResultNames", op);
|
||||
auto &body = method->body();
|
||||
for (int i = 0; i != numResults; ++i) {
|
||||
|
@ -2567,7 +2543,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
|
|||
const auto *attrSizedOperands =
|
||||
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
|
||||
{
|
||||
SmallVector<OpMethodParameter, 2> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::ValueRange", "values");
|
||||
paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
|
||||
attrSizedOperands ? "" : "nullptr");
|
||||
|
@ -2581,14 +2557,14 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
|
|||
|
||||
{
|
||||
auto *constructor = adaptor.addConstructorAndPrune(
|
||||
llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
|
||||
MethodParameter(op.getCppClassName() + " &", "op"));
|
||||
constructor->addMemberInitializer("odsOperands", "op->getOperands()");
|
||||
constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
|
||||
constructor->addMemberInitializer("odsRegions", "op->getRegions()");
|
||||
}
|
||||
|
||||
{
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
|
||||
auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
|
||||
ERROR_IF_PRUNED(m, "getOperands", op);
|
||||
m->body() << " return odsOperands;";
|
||||
}
|
||||
|
@ -2605,7 +2581,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
|
|||
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
|
||||
|
||||
auto emitAttr = [&](StringRef name, StringRef emitName, Attribute attr) {
|
||||
auto *method = adaptor.addMethodAndPrune(attr.getStorageType(), emitName);
|
||||
auto *method = adaptor.addMethod(attr.getStorageType(), emitName);
|
||||
ERROR_IF_PRUNED(method, "Adaptor::" + emitName, op);
|
||||
auto &body = method->body();
|
||||
body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
|
||||
|
@ -2629,8 +2605,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
|
|||
};
|
||||
|
||||
{
|
||||
auto *m =
|
||||
adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
|
||||
auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
|
||||
ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
|
||||
m->body() << " return odsAttrs;";
|
||||
}
|
||||
|
@ -2645,7 +2620,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
|
|||
|
||||
unsigned numRegions = op.getNumRegions();
|
||||
if (numRegions > 0) {
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
|
||||
auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
|
||||
ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
|
||||
m->body() << " return odsRegions;";
|
||||
}
|
||||
|
@ -2657,13 +2632,13 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
|
|||
// Generate the accessors for a variadic region.
|
||||
for (StringRef name : op.getGetterNames(region.name)) {
|
||||
if (region.isVariadic()) {
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", name);
|
||||
auto *m = adaptor.addMethod("::mlir::RegionRange", name);
|
||||
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
|
||||
m->body() << formatv(" return odsRegions.drop_front({0});", i);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *m = adaptor.addMethodAndPrune("::mlir::Region &", name);
|
||||
auto *m = adaptor.addMethod("::mlir::Region &", name);
|
||||
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
|
||||
m->body() << formatv(" return *odsRegions[{0}];", i);
|
||||
}
|
||||
|
@ -2674,8 +2649,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
|
|||
}
|
||||
|
||||
void OpOperandAdaptorEmitter::addVerification() {
|
||||
auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
|
||||
"::mlir::Location", "loc");
|
||||
auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify",
|
||||
MethodParameter("::mlir::Location", "loc"));
|
||||
ERROR_IF_PRUNED(method, "verify", op);
|
||||
auto &body = method->body();
|
||||
|
||||
|
|
|
@ -9,10 +9,10 @@
|
|||
#include "OpFormatGen.h"
|
||||
#include "FormatGen.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/TableGen/Class.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Interfaces.h"
|
||||
#include "mlir/TableGen/OpClass.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/Trait.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
@ -140,8 +140,7 @@ using SuccessorVariable =
|
|||
|
||||
namespace {
|
||||
/// This class implements single kind directives.
|
||||
template <Element::Kind type>
|
||||
class DirectiveElement : public Element {
|
||||
template <Element::Kind type> class DirectiveElement : public Element {
|
||||
public:
|
||||
DirectiveElement() : Element(type){};
|
||||
static bool classof(const Element *ele) { return ele->getKind() == type; }
|
||||
|
@ -422,23 +421,23 @@ struct OperationFormat {
|
|||
/// Generate the operation parser from this format.
|
||||
void genParser(Operator &op, OpClass &opClass);
|
||||
/// Generate the parser code for a specific format element.
|
||||
void genElementParser(Element *element, OpMethodBody &body,
|
||||
void genElementParser(Element *element, MethodBody &body,
|
||||
FmtContext &attrTypeCtx);
|
||||
/// Generate the c++ to resolve the types of operands and results during
|
||||
/// parsing.
|
||||
void genParserTypeResolution(Operator &op, OpMethodBody &body);
|
||||
void genParserTypeResolution(Operator &op, MethodBody &body);
|
||||
/// Generate the c++ to resolve regions during parsing.
|
||||
void genParserRegionResolution(Operator &op, OpMethodBody &body);
|
||||
void genParserRegionResolution(Operator &op, MethodBody &body);
|
||||
/// Generate the c++ to resolve successors during parsing.
|
||||
void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
|
||||
void genParserSuccessorResolution(Operator &op, MethodBody &body);
|
||||
/// Generate the c++ to handling variadic segment size traits.
|
||||
void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body);
|
||||
void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
|
||||
|
||||
/// Generate the operation printer from this format.
|
||||
void genPrinter(Operator &op, OpClass &opClass);
|
||||
|
||||
/// Generate the printer code for a specific format element.
|
||||
void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
|
||||
void genElementPrinter(Element *element, MethodBody &body, Operator &op,
|
||||
bool &shouldEmitSpace, bool &lastWasPunctuation);
|
||||
|
||||
/// The various elements in this format.
|
||||
|
@ -813,7 +812,7 @@ static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
|
|||
}
|
||||
|
||||
/// Generate the parser for a literal value.
|
||||
static void genLiteralParser(StringRef value, OpMethodBody &body) {
|
||||
static void genLiteralParser(StringRef value, MethodBody &body) {
|
||||
// Handle the case of a keyword/identifier.
|
||||
if (value.front() == '_' || isalpha(value.front())) {
|
||||
body << "Keyword(\"" << value << "\")";
|
||||
|
@ -839,7 +838,7 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
|
|||
|
||||
/// Generate the storage code required for parsing the given element.
|
||||
static void genElementParserStorage(Element *element, const Operator &op,
|
||||
OpMethodBody &body) {
|
||||
MethodBody &body) {
|
||||
if (auto *optional = dyn_cast<OptionalElement>(element)) {
|
||||
auto elements = optional->getThenElements();
|
||||
|
||||
|
@ -937,7 +936,7 @@ static void genElementParserStorage(Element *element, const Operator &op,
|
|||
}
|
||||
|
||||
/// Generate the parser for a parameter to a custom directive.
|
||||
static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
|
||||
static void genCustomParameterParser(Element ¶m, MethodBody &body) {
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
|
||||
body << attr->getVar()->name << "Attr";
|
||||
} else if (isa<AttrDictDirective>(¶m)) {
|
||||
|
@ -988,7 +987,7 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
|
|||
}
|
||||
|
||||
/// Generate the parser for a custom directive.
|
||||
static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
|
||||
static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
|
||||
body << " {\n";
|
||||
|
||||
// Preprocess the directive variables.
|
||||
|
@ -1098,7 +1097,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
|
|||
}
|
||||
|
||||
/// Generate the parser for a enum attribute.
|
||||
static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
|
||||
static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
|
||||
FmtContext &attrTypeCtx) {
|
||||
Attribute baseAttr = var->attr.getBaseAttr();
|
||||
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
|
||||
|
@ -1141,13 +1140,12 @@ static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
|
|||
}
|
||||
|
||||
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
|
||||
llvm::SmallVector<OpMethodParameter, 4> paramList;
|
||||
SmallVector<MethodParameter> paramList;
|
||||
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
|
||||
paramList.emplace_back("::mlir::OperationState &", "result");
|
||||
|
||||
auto *method =
|
||||
opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
|
||||
OpMethod::MP_Static, std::move(paramList));
|
||||
auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
|
||||
std::move(paramList));
|
||||
auto &body = method->body();
|
||||
|
||||
// Generate variables to store the operands and type within the format. This
|
||||
|
@ -1174,7 +1172,7 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
|
|||
body << " return ::mlir::success();\n";
|
||||
}
|
||||
|
||||
void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
|
||||
void OperationFormat::genElementParser(Element *element, MethodBody &body,
|
||||
FmtContext &attrTypeCtx) {
|
||||
/// Optional Group.
|
||||
if (auto *optional = dyn_cast<OptionalElement>(element)) {
|
||||
|
@ -1353,8 +1351,7 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
|
|||
}
|
||||
}
|
||||
|
||||
void OperationFormat::genParserTypeResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
|
||||
// If any of type resolutions use transformed variables, make sure that the
|
||||
// types of those variables are resolved.
|
||||
SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
|
||||
|
@ -1528,7 +1525,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
}
|
||||
|
||||
void OperationFormat::genParserRegionResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
MethodBody &body) {
|
||||
// Check for the case where all regions were parsed.
|
||||
bool hasAllRegions = llvm::any_of(
|
||||
elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
|
||||
|
@ -1547,7 +1544,7 @@ void OperationFormat::genParserRegionResolution(Operator &op,
|
|||
}
|
||||
|
||||
void OperationFormat::genParserSuccessorResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
MethodBody &body) {
|
||||
// Check for the case where all successors were parsed.
|
||||
bool hasAllSuccessors = llvm::any_of(
|
||||
elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
|
||||
|
@ -1566,7 +1563,7 @@ void OperationFormat::genParserSuccessorResolution(Operator &op,
|
|||
}
|
||||
|
||||
void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
MethodBody &body) {
|
||||
if (!allOperands) {
|
||||
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
|
||||
body << " result.addAttribute(\"operand_segment_sizes\", "
|
||||
|
@ -1641,7 +1638,7 @@ const char *enumAttrBeginPrinterCode = R"(
|
|||
|
||||
/// Generate the printer for the 'attr-dict' directive.
|
||||
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
|
||||
OpMethodBody &body, bool withKeyword) {
|
||||
MethodBody &body, bool withKeyword) {
|
||||
body << " _odsPrinter.printOptionalAttrDict"
|
||||
<< (withKeyword ? "WithKeyword" : "")
|
||||
<< "((*this)->getAttrs(), /*elidedAttrs=*/{";
|
||||
|
@ -1665,7 +1662,7 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
|
|||
/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
|
||||
/// space should be emitted before this element. `lastWasPunctuation` is true if
|
||||
/// the previous element was a punctuation literal.
|
||||
static void genLiteralPrinter(StringRef value, OpMethodBody &body,
|
||||
static void genLiteralPrinter(StringRef value, MethodBody &body,
|
||||
bool &shouldEmitSpace, bool &lastWasPunctuation) {
|
||||
body << " _odsPrinter";
|
||||
|
||||
|
@ -1682,8 +1679,8 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
|
|||
|
||||
/// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
|
||||
/// are set to false.
|
||||
static void genSpacePrinter(bool value, OpMethodBody &body,
|
||||
bool &shouldEmitSpace, bool &lastWasPunctuation) {
|
||||
static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
|
||||
bool &lastWasPunctuation) {
|
||||
if (value) {
|
||||
body << " _odsPrinter << ' ';\n";
|
||||
lastWasPunctuation = false;
|
||||
|
@ -1696,7 +1693,7 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
|
|||
/// Generate the printer for a custom directive parameter.
|
||||
static void genCustomDirectiveParameterPrinter(Element *element,
|
||||
const Operator &op,
|
||||
OpMethodBody &body) {
|
||||
MethodBody &body) {
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(element)) {
|
||||
body << op.getGetterName(attr->getVar()->name) << "Attr()";
|
||||
|
||||
|
@ -1734,7 +1731,7 @@ static void genCustomDirectiveParameterPrinter(Element *element,
|
|||
|
||||
/// Generate the printer for a custom directive.
|
||||
static void genCustomDirectivePrinter(CustomDirective *customDir,
|
||||
const Operator &op, OpMethodBody &body) {
|
||||
const Operator &op, MethodBody &body) {
|
||||
body << " print" << customDir->getName() << "(_odsPrinter, *this";
|
||||
for (Element ¶m : customDir->getArguments()) {
|
||||
body << ", ";
|
||||
|
@ -1744,7 +1741,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
|
|||
}
|
||||
|
||||
/// Generate the printer for a region with the given variable name.
|
||||
static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
|
||||
static void genRegionPrinter(const Twine ®ionName, MethodBody &body,
|
||||
bool hasImplicitTermTrait) {
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
|
||||
|
@ -1753,7 +1750,7 @@ static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
|
|||
body << " _odsPrinter.printRegion(" << regionName << ");\n";
|
||||
}
|
||||
static void genVariadicRegionPrinter(const Twine ®ionListName,
|
||||
OpMethodBody &body,
|
||||
MethodBody &body,
|
||||
bool hasImplicitTermTrait) {
|
||||
body << " llvm::interleaveComma(" << regionListName
|
||||
<< ", _odsPrinter, [&](::mlir::Region ®ion) {\n ";
|
||||
|
@ -1762,8 +1759,8 @@ static void genVariadicRegionPrinter(const Twine ®ionListName,
|
|||
}
|
||||
|
||||
/// Generate the C++ for an operand to a (*-)type directive.
|
||||
static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
|
||||
OpMethodBody &body) {
|
||||
static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
|
||||
MethodBody &body) {
|
||||
if (isa<OperandsDirective>(arg))
|
||||
return body << "getOperation()->getOperandTypes()";
|
||||
if (isa<ResultsDirective>(arg))
|
||||
|
@ -1786,7 +1783,7 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
|
|||
|
||||
/// Generate the printer for an enum attribute.
|
||||
static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
||||
OpMethodBody &body) {
|
||||
MethodBody &body) {
|
||||
Attribute baseAttr = var->attr.getBaseAttr();
|
||||
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
|
||||
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
|
||||
|
@ -1864,7 +1861,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
|||
|
||||
/// Generate the check for the anchor of an optional group.
|
||||
static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
|
||||
OpMethodBody &body) {
|
||||
MethodBody &body) {
|
||||
TypeSwitch<Element *>(anchor)
|
||||
.Case<OperandVariable, ResultVariable>([&](auto *element) {
|
||||
const NamedTypeConstraint *var = element->getVar();
|
||||
|
@ -1892,7 +1889,7 @@ static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
|
|||
});
|
||||
}
|
||||
|
||||
void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
|
||||
void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
|
||||
Operator &op, bool &shouldEmitSpace,
|
||||
bool &lastWasPunctuation) {
|
||||
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
|
||||
|
@ -2047,8 +2044,9 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
|
|||
}
|
||||
|
||||
void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
|
||||
auto *method = opClass.addMethodAndPrune("void", "print",
|
||||
"::mlir::OpAsmPrinter &_odsPrinter");
|
||||
auto *method = opClass.addMethod(
|
||||
"void", "print",
|
||||
MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
|
||||
auto &body = method->body();
|
||||
|
||||
// Flags for if we should emit a space, and if the last element was
|
||||
|
@ -2065,8 +2063,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
|
|||
|
||||
/// Function to find an element within the given range that has the same name as
|
||||
/// 'name'.
|
||||
template <typename RangeT>
|
||||
static auto findArg(RangeT &&range, StringRef name) {
|
||||
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
|
||||
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
|
||||
return it != range.end() ? &*it : nullptr;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue