forked from OSchip/llvm-project
[mlir] Add basic support for attributes in ODS-generated Python bindings
In ODS, attributes of an operation can be provided as a part of the "arguments" field, together with operands. Such attributes are accepted by the op builder and have accessors generated. Implement similar functionality for ODS-generated op-specific Python bindings: the `__init__` method now accepts arguments together with operands, in the same order as in the ODS `arguments` field; the instance properties are introduced to OpView classes to access the attributes. This initial implementation accepts and returns instances of the corresponding attribute class, and not the underlying values since the mapping scheme of the value types between C++, C and Python is not yet clear. Default-valued attributes are not supported as that would require Python to be able to parse C++ literals. Since attributes in ODS are tightely related to the actual C++ type system, provide a separate Tablegen file with the mapping between ODS storage type for attributes (typically, the underlying C++ attribute class), and the corresponding class name. So far, this might look unnecessary since all names match exactly, but this is not necessarily the cases for non-standard, out-of-tree attributes, which may also be placed in non-default namespaces or Python modules. This also allows out-of-tree users to generate Python bindings without having to modify the bindings generator itself. Storage type was preferred over the Tablegen "def" of the attribute class because ODS essentially encodes attribute _constraints_ rather than classes, e.g. there may be many Tablegen "def"s in the ODS that correspond to the same attribute type with additional constraints The presence of the explicit mapping requires the change in the .td file structure: instead of just calling the bindings generator directly on the main ODS file of the dialect, it becomes necessary to create a new file that includes the main ODS file of the dialect and provides the mapping for attribute types. Arguably, this approach offers better separability of the Python bindings in the build system as the main dialect no longer needs to know that it is being processed by the bindings generator. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D91542
This commit is contained in:
parent
af0d607e72
commit
c5a6712f8c
|
@ -100,12 +100,6 @@ include_directories( ${MLIR_INCLUDE_DIR})
|
|||
# from another directory like tools
|
||||
add_subdirectory(tools/mlir-tblgen)
|
||||
|
||||
# Create an anchor target that will depend on dialect-specific op bindings.
|
||||
if (MLIR_BINDINGS_PYTHON_ENABLED)
|
||||
add_custom_target(MLIRBindingsPythonIncGen)
|
||||
include(AddMLIRPythonExtension)
|
||||
endif()
|
||||
|
||||
add_subdirectory(include/mlir)
|
||||
add_subdirectory(lib)
|
||||
# C API needs all dialects for registration, but should be built before tests.
|
||||
|
|
|
@ -132,16 +132,10 @@ function(add_mlir_python_extension libname extname)
|
|||
|
||||
endfunction()
|
||||
|
||||
function(add_mlir_dialect_python_bindings filename dialectname)
|
||||
function(add_mlir_dialect_python_bindings tblgen_target filename dialectname)
|
||||
set(LLVM_TARGET_DEFINITIONS ${filename})
|
||||
mlir_tablegen("${dialectname}.py" -gen-python-op-bindings
|
||||
-bind-dialect=${dialectname})
|
||||
if (${ARGC} GREATER 2)
|
||||
set(suffix ${ARGV2})
|
||||
else()
|
||||
get_filename_component(suffix ${filename} NAME_WE)
|
||||
endif()
|
||||
set(tblgen_target "MLIRBindingsPython${suffix}")
|
||||
add_public_tablegen_target(${tblgen_target})
|
||||
|
||||
add_custom_command(
|
||||
|
@ -150,6 +144,5 @@ function(add_mlir_dialect_python_bindings filename dialectname)
|
|||
COMMAND "${CMAKE_COMMAND}" -E copy_if_different
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/${dialectname}.py"
|
||||
"${PROJECT_BINARY_DIR}/python/mlir/dialects/${dialectname}.py")
|
||||
add_dependencies(MLIRBindingsPythonIncGen ${tblgen_target})
|
||||
endfunction()
|
||||
|
||||
|
|
|
@ -7,7 +7,3 @@ mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
|||
add_public_tablegen_target(MLIRStandardOpsIncGen)
|
||||
|
||||
add_mlir_doc(Ops -gen-op-doc StandardOps Dialects/)
|
||||
|
||||
if (MLIR_BINDINGS_PYTHON_ENABLED)
|
||||
add_mlir_dialect_python_bindings(Ops.td std StandardOps)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
//===-- Attributes.td - Attribute mapping for Python -------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This defines the mapping between MLIR ODS attributes and the corresponding
|
||||
// Python binding classes.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_ATTRIBUTES
|
||||
#define PYTHON_BINDINGS_ATTRIBUTES
|
||||
|
||||
// A mapping between the attribute storage type and the corresponding Python
|
||||
// type. There is not necessarily a 1-1 match for non-standard attributes.
|
||||
class PythonAttr<string c, string p> {
|
||||
string cppStorageType = c;
|
||||
string pythonType = p;
|
||||
}
|
||||
|
||||
// Mappings between supported standard attribtues and Python types.
|
||||
def : PythonAttr<"::mlir::Attribute", "_ir.Attribute">;
|
||||
def : PythonAttr<"::mlir::BoolAttr", "_ir.BoolAttr">;
|
||||
def : PythonAttr<"::mlir::IntegerAttr", "_ir.IntegerAttr">;
|
||||
def : PythonAttr<"::mlir::FloatAttr", "_ir.FloatAttr">;
|
||||
def : PythonAttr<"::mlir::StringAttr", "_ir.StringAttr">;
|
||||
def : PythonAttr<"::mlir::DenseElementsAttr", "_ir.DenseElementsAttr">;
|
||||
def : PythonAttr<"::mlir::DenseIntElementsAttr", "_ir.DenseIntElementsAttr">;
|
||||
def : PythonAttr<"::mlir::DenseFPElementsAttr", "_ir.DenseFPElementsAttr">;
|
||||
|
||||
#endif
|
|
@ -1,5 +1,15 @@
|
|||
include(AddMLIRPythonExtension)
|
||||
add_custom_target(MLIRBindingsPythonExtension)
|
||||
|
||||
################################################################################
|
||||
# Generate dialect-specific bindings.
|
||||
################################################################################
|
||||
|
||||
add_mlir_dialect_python_bindings(MLIRBindingsPythonStandardOps
|
||||
StandardOps.td
|
||||
std)
|
||||
add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonStandardOps)
|
||||
|
||||
################################################################################
|
||||
# Copy python source tree.
|
||||
################################################################################
|
||||
|
@ -19,8 +29,6 @@ add_custom_target(MLIRBindingsPythonSources ALL
|
|||
)
|
||||
add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonSources)
|
||||
|
||||
add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonIncGen)
|
||||
|
||||
foreach(PY_SRC_FILE ${PY_SRC_FILES})
|
||||
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
|
||||
add_custom_command(
|
||||
|
|
|
@ -1310,8 +1310,14 @@ public:
|
|||
return mlirOperationGetNumAttributes(operation->get());
|
||||
}
|
||||
|
||||
bool dunderContains(const std::string &name) {
|
||||
return !mlirAttributeIsNull(
|
||||
mlirOperationGetAttributeByName(operation->get(), name.c_str()));
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
|
||||
.def("__contains__", &PyOpAttributeMap::dunderContains)
|
||||
.def("__len__", &PyOpAttributeMap::dunderLen)
|
||||
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
|
||||
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed);
|
||||
|
@ -1747,6 +1753,24 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Unit Attribute subclass. Unit attributes don't have values.
|
||||
class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
|
||||
static constexpr const char *pyClassName = "UnitAttr";
|
||||
using PyConcreteAttribute::PyConcreteAttribute;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](DefaultingPyMlirContext context) {
|
||||
return PyUnitAttribute(context->getRef(),
|
||||
mlirUnitAttrGet(context->get()));
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a Unit attribute.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -2852,6 +2876,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
PyDenseElementsAttribute::bind(m);
|
||||
PyDenseIntElementsAttribute::bind(m);
|
||||
PyDenseFPElementsAttribute::bind(m);
|
||||
PyUnitAttribute::bind(m);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyType.
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
//===-- StandardOps.td - Entry point for StandardOps bind --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the main file from which the Python bindings for the Standard
|
||||
// dialect are generated.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_STANDARD_OPS
|
||||
#define PYTHON_BINDINGS_STANDARD_OPS
|
||||
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "Attributes.td"
|
||||
|
||||
#endif
|
|
@ -147,7 +147,7 @@ Operator::arg_range Operator::getArgs() const {
|
|||
|
||||
StringRef Operator::getArgName(int index) const {
|
||||
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||
return argumentValues->getArgName(index)->getValue();
|
||||
return argumentValues->getArgNameStr(index);
|
||||
}
|
||||
|
||||
auto Operator::getArgDecorators(int index) const -> var_decorator_range {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s
|
||||
// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include -I %S/../../lib/Bindings/Python %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "Attributes.td"
|
||||
|
||||
// CHECK: @_cext.register_dialect
|
||||
// CHECK: class _Dialect(_ir.Dialect):
|
||||
|
@ -105,6 +106,75 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
|
|||
Optional<AnyType>:$variadic2);
|
||||
}
|
||||
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttributedOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op"
|
||||
def AttributedOp : TestOp<"attributed_op"> {
|
||||
// CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: attributes["i32attr"] = i32attr
|
||||
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
|
||||
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ir.UnitAttr.get(
|
||||
// CHECK: _ir.Location.current.context if loc is None else loc.context)
|
||||
// CHECK: attributes["in"] = in_
|
||||
// CHECK: super().__init__(_ir.Operation.create(
|
||||
// CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def i32attr(self):
|
||||
// CHECK: return _ir.IntegerAttr(self.operation.attributes["i32attr"])
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def optionalF32Attr(self):
|
||||
// CHECK: if "optionalF32Attr" not in self.operation.attributes:
|
||||
// CHECK: return None
|
||||
// CHECK: return _ir.FloatAttr(self.operation.attributes["optionalF32Attr"])
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def unitAttr(self):
|
||||
// CHECK: return "unitAttr" in self.operation.attributes
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def in_(self):
|
||||
// CHECK: return _ir.IntegerAttr(self.operation.attributes["in"])
|
||||
let arguments = (ins I32Attr:$i32attr, OptionalAttr<F32Attr>:$optionalF32Attr,
|
||||
UnitAttr:$unitAttr, I32Attr:$in);
|
||||
}
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttributedOpWithOperands(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands"
|
||||
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
|
||||
// CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: operands.append(_gen_arg_0)
|
||||
// CHECK: operands.append(_gen_arg_2)
|
||||
// CHECK: if bool(in_): attributes["in"] = _ir.UnitAttr.get(
|
||||
// CHECK: _ir.Location.current.context if loc is None else loc.context)
|
||||
// CHECK: if is_ is not None: attributes["is"] = is_
|
||||
// CHECK: super().__init__(_ir.Operation.create(
|
||||
// CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results,
|
||||
// CHECK: loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def in_(self):
|
||||
// CHECK: return "in" in self.operation.attributes
|
||||
|
||||
// CHECK: @property
|
||||
// CHECK: def is_(self):
|
||||
// CHECK: if "is" not in self.operation.attributes:
|
||||
// CHECK: return None
|
||||
// CHECK: return _ir.FloatAttr(self.operation.attributes["is"])
|
||||
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
|
||||
}
|
||||
|
||||
|
||||
// CHECK: @_cext.register_operation(_Dialect)
|
||||
// CHECK: class EmptyOp(_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.empty"
|
||||
|
|
|
@ -145,6 +145,39 @@ constexpr const char *opVariadicSegmentTemplate = R"Py(
|
|||
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
|
||||
R"Py([0] if len({0}_range) > 0 else None)Py";
|
||||
|
||||
/// Template for an operation attribute getter:
|
||||
/// {0} is the name of the attribute sanitized for Python;
|
||||
/// {1} is the Python type of the attribute;
|
||||
/// {2} os the original name of the attribute.
|
||||
constexpr const char *attributeGetterTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
return {1}(self.operation.attributes["{2}"])
|
||||
)Py";
|
||||
|
||||
/// Template for an optional operation attribute getter:
|
||||
/// {0} is the name of the attribute sanitized for Python;
|
||||
/// {1} is the Python type of the attribute;
|
||||
/// {2} is the original name of the attribute.
|
||||
constexpr const char *optionalAttributeGetterTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
if "{2}" not in self.operation.attributes:
|
||||
return None
|
||||
return {1}(self.operation.attributes["{2}"])
|
||||
)Py";
|
||||
|
||||
/// Template for a accessing a unit operation attribute, returns True of the
|
||||
/// unit attribute is present, False otherwise (unit attributes have meaning
|
||||
/// by mere presence):
|
||||
/// {0} is the name of the attribute sanitized for Python,
|
||||
/// {1} is the original name of the attribute.
|
||||
constexpr const char *unitAttributeGetterTemplate = R"Py(
|
||||
@property
|
||||
def {0}(self):
|
||||
return "{1}" in self.operation.attributes
|
||||
)Py";
|
||||
|
||||
static llvm::cl::OptionCategory
|
||||
clOpPythonBindingCat("Options for -gen-python-op-bindings");
|
||||
|
||||
|
@ -153,6 +186,8 @@ static llvm::cl::opt<std::string>
|
|||
llvm::cl::desc("The dialect to run the generator for"),
|
||||
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
||||
|
||||
using AttributeClasses = DenseMap<StringRef, StringRef>;
|
||||
|
||||
/// Checks whether `str` is a Python keyword.
|
||||
static bool isPythonKeyword(StringRef str) {
|
||||
static llvm::StringSet<> keywords(
|
||||
|
@ -285,7 +320,7 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
|
|||
return op.getResult(i);
|
||||
}
|
||||
|
||||
/// Emits accessor to Op operands.
|
||||
/// Emits accessors to Op operands.
|
||||
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
|
||||
auto getNumVariadic = [](const Operator &oper) {
|
||||
return oper.getNumVariableLengthOperands();
|
||||
|
@ -294,7 +329,7 @@ static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
|
|||
getOperand);
|
||||
}
|
||||
|
||||
/// Emits access or Op results.
|
||||
/// Emits accessors Op results.
|
||||
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
|
||||
auto getNumVariadic = [](const Operator &oper) {
|
||||
return oper.getNumVariableLengthResults();
|
||||
|
@ -303,6 +338,39 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) {
|
|||
getResult);
|
||||
}
|
||||
|
||||
/// Emits accessors to Op attributes.
|
||||
static void emitAttributeAccessors(const Operator &op,
|
||||
const AttributeClasses &attributeClasses,
|
||||
raw_ostream &os) {
|
||||
for (const auto &namedAttr : op.getAttributes()) {
|
||||
// Skip "derived" attributes because they are just C++ functions that we
|
||||
// don't currently expose.
|
||||
if (namedAttr.attr.isDerivedAttr())
|
||||
continue;
|
||||
|
||||
if (namedAttr.name.empty())
|
||||
continue;
|
||||
|
||||
// Unit attributes are handled specially.
|
||||
if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
|
||||
os << llvm::formatv(unitAttributeGetterTemplate,
|
||||
sanitizeName(namedAttr.name), namedAttr.name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Other kinds of attributes need a mapping to a Python type.
|
||||
if (!attributeClasses.count(namedAttr.attr.getStorageType().trim()))
|
||||
continue;
|
||||
|
||||
os << llvm::formatv(
|
||||
namedAttr.attr.isOptional() ? optionalAttributeGetterTemplate
|
||||
: attributeGetterTemplate,
|
||||
sanitizeName(namedAttr.name),
|
||||
attributeClasses.lookup(namedAttr.attr.getStorageType()),
|
||||
namedAttr.name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Template for the default auto-generated builder.
|
||||
/// {0} is the operation name;
|
||||
/// {1} is a comma-separated list of builder arguments, including the trailing
|
||||
|
@ -362,14 +430,82 @@ constexpr const char *optionalSegmentTemplate =
|
|||
constexpr const char *variadicSegmentTemplate =
|
||||
"{0}_segment_sizes.append(len({1}))";
|
||||
|
||||
/// Populates `builderArgs` with the list of `__init__` arguments that
|
||||
/// correspond to either operands or results of `op`, and `builderLines` with
|
||||
/// additional lines that are required in the builder. `kind` must be either
|
||||
/// "operand" or "result". `unnamedTemplate` is used to generate names for
|
||||
/// operands or results that don't have the name in ODS.
|
||||
/// Template for setting an attribute in the operation builder.
|
||||
/// {0} is the attribute name;
|
||||
/// {1} is the builder argument name.
|
||||
constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
|
||||
|
||||
/// Template for setting an optional attribute in the operation builder.
|
||||
/// {0} is the attribute name;
|
||||
/// {1} is the builder argument name.
|
||||
constexpr const char *initOptionalAttributeTemplate =
|
||||
R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
|
||||
|
||||
constexpr const char *initUnitAttributeTemplate =
|
||||
R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get(
|
||||
_ir.Location.current.context if loc is None else loc.context))Py";
|
||||
|
||||
/// Populates `builderArgs` with the Python-compatible names of builder function
|
||||
/// arguments, first the results, then the intermixed attributes and operands in
|
||||
/// the same order as they appear in the `arguments` field of the op definition.
|
||||
/// Additionally, `operandNames` is populated with names of operands in their
|
||||
/// order of appearance.
|
||||
static void
|
||||
populateBuilderArgs(const Operator &op,
|
||||
llvm::SmallVectorImpl<std::string> &builderArgs,
|
||||
llvm::SmallVectorImpl<std::string> &operandNames) {
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
std::string name = op.getResultName(i).str();
|
||||
if (name.empty())
|
||||
name = llvm::formatv("_gen_res_{0}", i);
|
||||
name = sanitizeName(name);
|
||||
builderArgs.push_back(name);
|
||||
}
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
std::string name = op.getArgName(i).str();
|
||||
if (name.empty())
|
||||
name = llvm::formatv("_gen_arg_{0}", i);
|
||||
name = sanitizeName(name);
|
||||
builderArgs.push_back(name);
|
||||
if (!op.getArg(i).is<NamedAttribute *>())
|
||||
operandNames.push_back(name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates `builderLines` with additional lines that are required in the
|
||||
/// builder to set up operation attributes. `argNames` is expected to contain
|
||||
/// the names of builder arguments that correspond to op arguments, i.e. to the
|
||||
/// operands and attributes in the same order as they appear in the `arguments`
|
||||
/// field.
|
||||
static void
|
||||
populateBuilderLinesAttr(const Operator &op,
|
||||
llvm::ArrayRef<std::string> argNames,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
Argument arg = op.getArg(i);
|
||||
auto *attribute = arg.dyn_cast<NamedAttribute *>();
|
||||
if (!attribute)
|
||||
continue;
|
||||
|
||||
// Unit attributes are handled specially.
|
||||
if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
|
||||
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
|
||||
attribute->name, argNames[i]));
|
||||
continue;
|
||||
}
|
||||
|
||||
builderLines.push_back(llvm::formatv(attribute->attr.isOptional()
|
||||
? initOptionalAttributeTemplate
|
||||
: initAttributeTemplate,
|
||||
attribute->name, argNames[i]));
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates `builderLines` with additional lines that are required in the
|
||||
/// builder. `kind` must be either "operand" or "result". `names` contains the
|
||||
/// names of init arguments that correspond to the elements.
|
||||
static void populateBuilderLines(
|
||||
const Operator &op, const char *kind, const char *unnamedTemplate,
|
||||
llvm::SmallVectorImpl<std::string> &builderArgs,
|
||||
const Operator &op, const char *kind, llvm::ArrayRef<std::string> names,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines,
|
||||
llvm::function_ref<int(const Operator &)> getNumElements,
|
||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||
|
@ -383,11 +519,7 @@ static void populateBuilderLines(
|
|||
// For each element, find or generate a name.
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
std::string name = element.name.str();
|
||||
if (name.empty())
|
||||
name = llvm::formatv(unnamedTemplate, i).str();
|
||||
name = sanitizeName(name);
|
||||
builderArgs.push_back(name);
|
||||
std::string name = names[i];
|
||||
|
||||
// Choose the formatting string based on the element kind.
|
||||
llvm::StringRef formatString, segmentFormatString;
|
||||
|
@ -417,21 +549,25 @@ static void populateBuilderLines(
|
|||
/// Emits a default builder constructing an operation from the list of its
|
||||
/// result types, followed by a list of its operands.
|
||||
static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
|
||||
// TODO: support attribute types.
|
||||
if (op.getNumNativeAttributes() != 0)
|
||||
return;
|
||||
|
||||
// If we are asked to skip default builders, comply.
|
||||
if (op.skipDefaultBuilders())
|
||||
return;
|
||||
|
||||
llvm::SmallVector<std::string, 8> builderArgs;
|
||||
llvm::SmallVector<std::string, 8> builderLines;
|
||||
builderArgs.reserve(op.getNumOperands() + op.getNumResults());
|
||||
populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines,
|
||||
getNumResults, getResult);
|
||||
populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines,
|
||||
llvm::SmallVector<std::string, 4> operandArgNames;
|
||||
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
|
||||
op.getNumNativeAttributes());
|
||||
populateBuilderArgs(op, builderArgs, operandArgNames);
|
||||
populateBuilderLines(
|
||||
op, "result",
|
||||
llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()),
|
||||
builderLines, getNumResults, getResult);
|
||||
populateBuilderLines(op, "operand", operandArgNames, builderLines,
|
||||
getNumOperands, getOperand);
|
||||
populateBuilderLinesAttr(
|
||||
op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()),
|
||||
builderLines);
|
||||
|
||||
builderArgs.push_back("loc=None");
|
||||
builderArgs.push_back("ip=None");
|
||||
|
@ -440,12 +576,24 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
|
|||
llvm::join(builderLines, "\n "));
|
||||
}
|
||||
|
||||
static void constructAttributeMapping(const llvm::RecordKeeper &records,
|
||||
AttributeClasses &attributeClasses) {
|
||||
for (const llvm::Record *rec :
|
||||
records.getAllDerivedDefinitions("PythonAttr")) {
|
||||
attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(),
|
||||
rec->getValueAsString("pythonType").trim());
|
||||
}
|
||||
}
|
||||
|
||||
/// Emits bindings for a specific Op to the given output stream.
|
||||
static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
||||
static void emitOpBindings(const Operator &op,
|
||||
const AttributeClasses &attributeClasses,
|
||||
raw_ostream &os) {
|
||||
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
|
||||
op.getOperationName());
|
||||
emitDefaultOpBuilder(op, os);
|
||||
emitOperandAccessors(op, os);
|
||||
emitAttributeAccessors(op, attributeClasses, os);
|
||||
emitResultAccessors(op, os);
|
||||
}
|
||||
|
||||
|
@ -456,12 +604,15 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
|||
if (clDialectName.empty())
|
||||
llvm::PrintFatalError("dialect name not provided");
|
||||
|
||||
AttributeClasses attributeClasses;
|
||||
constructAttributeMapping(records, attributeClasses);
|
||||
|
||||
os << fileHeader;
|
||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||
Operator op(rec);
|
||||
if (op.getDialectName() == clDialectName.getValue())
|
||||
emitOpBindings(op, os);
|
||||
emitOpBindings(op, attributeClasses, os);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue