forked from OSchip/llvm-project
846 lines
33 KiB
C++
846 lines
33 KiB
C++
//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "IRModule.h"
|
|
|
|
#include "PybindUtils.h"
|
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
|
|
namespace py = pybind11;
|
|
using namespace mlir;
|
|
using namespace mlir::python;
|
|
|
|
using llvm::None;
|
|
using llvm::Optional;
|
|
using llvm::SmallVector;
|
|
using llvm::Twine;
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Docstrings (trivial, non-duplicated docstrings are included inline).
|
|
//------------------------------------------------------------------------------
|
|
|
|
static const char kDenseElementsAttrGetDocstring[] =
|
|
R"(Gets a DenseElementsAttr from a Python buffer or array.
|
|
|
|
When `type` is not provided, then some limited type inferencing is done based
|
|
on the buffer format. Support presently exists for 8/16/32/64 signed and
|
|
unsigned integers and float16/float32/float64. DenseElementsAttrs of these
|
|
types can also be converted back to a corresponding buffer.
|
|
|
|
For conversions outside of these types, a `type=` must be explicitly provided
|
|
and the buffer contents must be bit-castable to the MLIR internal
|
|
representation:
|
|
|
|
* Integer types (except for i1): the buffer must be byte aligned to the
|
|
next byte boundary.
|
|
* Floating point types: Must be bit-castable to the given floating point
|
|
size.
|
|
* i1 (bool): Bit packed into 8bit words where the bit pattern matches a
|
|
row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
|
|
this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
|
|
|
|
If a single element buffer is passed (or for i1, a single byte with value 0
|
|
or 255), then a splat will be created.
|
|
|
|
Args:
|
|
array: The array or buffer to convert.
|
|
signless: If inferring an appropriate MLIR type, use signless types for
|
|
integers (defaults True).
|
|
type: Skips inference of the MLIR element type and uses this instead. The
|
|
storage size must be consistent with the actual contents of the buffer.
|
|
shape: Overrides the shape of the buffer when constructing the MLIR
|
|
shaped type. This is needed when the physical and logical shape differ (as
|
|
for i1).
|
|
context: Explicit context, if not from context manager.
|
|
|
|
Returns:
|
|
DenseElementsAttr on success.
|
|
|
|
Raises:
|
|
ValueError: If the type of the buffer or array cannot be matched to an MLIR
|
|
type or if the buffer does not meet expectations.
|
|
)";
|
|
|
|
namespace {
|
|
|
|
static MlirStringRef toMlirStringRef(const std::string &s) {
|
|
return mlirStringRefCreate(s.data(), s.size());
|
|
}
|
|
|
|
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
|
|
static constexpr const char *pyClassName = "AffineMapAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyAffineMap &affineMap) {
|
|
MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
|
|
return PyAffineMapAttribute(affineMap.getContext(), attr);
|
|
},
|
|
py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
static T pyTryCast(py::handle object) {
|
|
try {
|
|
return object.cast<T>();
|
|
} catch (py::cast_error &err) {
|
|
std::string msg =
|
|
std::string(
|
|
"Invalid attribute when attempting to create an ArrayAttribute (") +
|
|
err.what() + ")";
|
|
throw py::cast_error(msg);
|
|
} catch (py::reference_cast_error &err) {
|
|
std::string msg = std::string("Invalid attribute (None?) when attempting "
|
|
"to create an ArrayAttribute (") +
|
|
err.what() + ")";
|
|
throw py::cast_error(msg);
|
|
}
|
|
}
|
|
|
|
class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
|
|
static constexpr const char *pyClassName = "ArrayAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
class PyArrayAttributeIterator {
|
|
public:
|
|
PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
|
|
|
|
PyArrayAttributeIterator &dunderIter() { return *this; }
|
|
|
|
PyAttribute dunderNext() {
|
|
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
|
|
throw py::stop_iteration();
|
|
}
|
|
return PyAttribute(attr.getContext(),
|
|
mlirArrayAttrGetElement(attr.get(), nextIndex++));
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
|
|
py::module_local())
|
|
.def("__iter__", &PyArrayAttributeIterator::dunderIter)
|
|
.def("__next__", &PyArrayAttributeIterator::dunderNext);
|
|
}
|
|
|
|
private:
|
|
PyAttribute attr;
|
|
int nextIndex = 0;
|
|
};
|
|
|
|
PyAttribute getItem(intptr_t i) {
|
|
return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
|
|
}
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](py::list attributes, DefaultingPyMlirContext context) {
|
|
SmallVector<MlirAttribute> mlirAttributes;
|
|
mlirAttributes.reserve(py::len(attributes));
|
|
for (auto attribute : attributes) {
|
|
mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
|
|
}
|
|
MlirAttribute attr = mlirArrayAttrGet(
|
|
context->get(), mlirAttributes.size(), mlirAttributes.data());
|
|
return PyArrayAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("attributes"), py::arg("context") = py::none(),
|
|
"Gets a uniqued Array attribute");
|
|
c.def("__getitem__",
|
|
[](PyArrayAttribute &arr, intptr_t i) {
|
|
if (i >= mlirArrayAttrGetNumElements(arr))
|
|
throw py::index_error("ArrayAttribute index out of range");
|
|
return arr.getItem(i);
|
|
})
|
|
.def("__len__",
|
|
[](const PyArrayAttribute &arr) {
|
|
return mlirArrayAttrGetNumElements(arr);
|
|
})
|
|
.def("__iter__", [](const PyArrayAttribute &arr) {
|
|
return PyArrayAttributeIterator(arr);
|
|
});
|
|
c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
|
|
std::vector<MlirAttribute> attributes;
|
|
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
|
|
attributes.reserve(numOldElements + py::len(extras));
|
|
for (intptr_t i = 0; i < numOldElements; ++i)
|
|
attributes.push_back(arr.getItem(i));
|
|
for (py::handle attr : extras)
|
|
attributes.push_back(pyTryCast<PyAttribute>(attr));
|
|
MlirAttribute arrayAttr = mlirArrayAttrGet(
|
|
arr.getContext()->get(), attributes.size(), attributes.data());
|
|
return PyArrayAttribute(arr.getContext(), arrayAttr);
|
|
});
|
|
}
|
|
};
|
|
|
|
/// Float Point Attribute subclass - FloatAttr.
|
|
class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
|
|
static constexpr const char *pyClassName = "FloatAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyType &type, double value, DefaultingPyLocation loc) {
|
|
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirAttributeIsNull(attr)) {
|
|
throw SetPyError(PyExc_ValueError,
|
|
Twine("invalid '") +
|
|
py::repr(py::cast(type)).cast<std::string>() +
|
|
"' and expected floating point type.");
|
|
}
|
|
return PyFloatAttribute(type.getContext(), attr);
|
|
},
|
|
py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
|
|
"Gets an uniqued float point attribute associated to a type");
|
|
c.def_static(
|
|
"get_f32",
|
|
[](double value, DefaultingPyMlirContext context) {
|
|
MlirAttribute attr = mlirFloatAttrDoubleGet(
|
|
context->get(), mlirF32TypeGet(context->get()), value);
|
|
return PyFloatAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("value"), py::arg("context") = py::none(),
|
|
"Gets an uniqued float point attribute associated to a f32 type");
|
|
c.def_static(
|
|
"get_f64",
|
|
[](double value, DefaultingPyMlirContext context) {
|
|
MlirAttribute attr = mlirFloatAttrDoubleGet(
|
|
context->get(), mlirF64TypeGet(context->get()), value);
|
|
return PyFloatAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("value"), py::arg("context") = py::none(),
|
|
"Gets an uniqued float point attribute associated to a f64 type");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyFloatAttribute &self) {
|
|
return mlirFloatAttrGetValueDouble(self);
|
|
},
|
|
"Returns the value of the float point attribute");
|
|
}
|
|
};
|
|
|
|
/// Integer Attribute subclass - IntegerAttr.
|
|
class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
|
|
static constexpr const char *pyClassName = "IntegerAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyType &type, int64_t value) {
|
|
MlirAttribute attr = mlirIntegerAttrGet(type, value);
|
|
return PyIntegerAttribute(type.getContext(), attr);
|
|
},
|
|
py::arg("type"), py::arg("value"),
|
|
"Gets an uniqued integer attribute associated to a type");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyIntegerAttribute &self) {
|
|
return mlirIntegerAttrGetValueInt(self);
|
|
},
|
|
"Returns the value of the integer attribute");
|
|
}
|
|
};
|
|
|
|
/// Bool Attribute subclass - BoolAttr.
|
|
class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
|
|
static constexpr const char *pyClassName = "BoolAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](bool value, DefaultingPyMlirContext context) {
|
|
MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
|
|
return PyBoolAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("value"), py::arg("context") = py::none(),
|
|
"Gets an uniqued bool attribute");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
|
|
"Returns the value of the bool attribute");
|
|
}
|
|
};
|
|
|
|
class PyFlatSymbolRefAttribute
|
|
: public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
|
|
static constexpr const char *pyClassName = "FlatSymbolRefAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](std::string value, DefaultingPyMlirContext context) {
|
|
MlirAttribute attr =
|
|
mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
|
|
return PyFlatSymbolRefAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("value"), py::arg("context") = py::none(),
|
|
"Gets a uniqued FlatSymbolRef attribute");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyFlatSymbolRefAttribute &self) {
|
|
MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
|
|
return py::str(stringRef.data, stringRef.length);
|
|
},
|
|
"Returns the value of the FlatSymbolRef attribute as a string");
|
|
}
|
|
};
|
|
|
|
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
|
|
static constexpr const char *pyClassName = "StringAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](std::string value, DefaultingPyMlirContext context) {
|
|
MlirAttribute attr =
|
|
mlirStringAttrGet(context->get(), toMlirStringRef(value));
|
|
return PyStringAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("value"), py::arg("context") = py::none(),
|
|
"Gets a uniqued string attribute");
|
|
c.def_static(
|
|
"get_typed",
|
|
[](PyType &type, std::string value) {
|
|
MlirAttribute attr =
|
|
mlirStringAttrTypedGet(type, toMlirStringRef(value));
|
|
return PyStringAttribute(type.getContext(), attr);
|
|
},
|
|
py::arg("type"), py::arg("value"),
|
|
"Gets a uniqued string attribute associated to a type");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyStringAttribute &self) {
|
|
MlirStringRef stringRef = mlirStringAttrGetValue(self);
|
|
return py::str(stringRef.data, stringRef.length);
|
|
},
|
|
"Returns the value of the string attribute");
|
|
}
|
|
};
|
|
|
|
// TODO: Support construction of string elements.
|
|
class PyDenseElementsAttribute
|
|
: public PyConcreteAttribute<PyDenseElementsAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
|
|
static constexpr const char *pyClassName = "DenseElementsAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static PyDenseElementsAttribute
|
|
getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
|
|
Optional<std::vector<int64_t>> explicitShape,
|
|
DefaultingPyMlirContext contextWrapper) {
|
|
// Request a contiguous view. In exotic cases, this will cause a copy.
|
|
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
|
Py_buffer *view = new Py_buffer();
|
|
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
|
|
delete view;
|
|
throw py::error_already_set();
|
|
}
|
|
py::buffer_info arrayInfo(view);
|
|
SmallVector<int64_t> shape;
|
|
if (explicitShape) {
|
|
shape.append(explicitShape->begin(), explicitShape->end());
|
|
} else {
|
|
shape.append(arrayInfo.shape.begin(),
|
|
arrayInfo.shape.begin() + arrayInfo.ndim);
|
|
}
|
|
|
|
MlirAttribute encodingAttr = mlirAttributeGetNull();
|
|
MlirContext context = contextWrapper->get();
|
|
|
|
// Detect format codes that are suitable for bulk loading. This includes
|
|
// all byte aligned integer and floating point types up to 8 bytes.
|
|
// Notably, this excludes, bool (which needs to be bit-packed) and
|
|
// other exotics which do not have a direct representation in the buffer
|
|
// protocol (i.e. complex, etc).
|
|
Optional<MlirType> bulkLoadElementType;
|
|
if (explicitType) {
|
|
bulkLoadElementType = *explicitType;
|
|
} else if (arrayInfo.format == "f") {
|
|
// f32
|
|
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
|
|
bulkLoadElementType = mlirF32TypeGet(context);
|
|
} else if (arrayInfo.format == "d") {
|
|
// f64
|
|
assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
|
|
bulkLoadElementType = mlirF64TypeGet(context);
|
|
} else if (arrayInfo.format == "e") {
|
|
// f16
|
|
assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
|
|
bulkLoadElementType = mlirF16TypeGet(context);
|
|
} else if (isSignedIntegerFormat(arrayInfo.format)) {
|
|
if (arrayInfo.itemsize == 4) {
|
|
// i32
|
|
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
|
|
: mlirIntegerTypeSignedGet(context, 32);
|
|
} else if (arrayInfo.itemsize == 8) {
|
|
// i64
|
|
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
|
|
: mlirIntegerTypeSignedGet(context, 64);
|
|
} else if (arrayInfo.itemsize == 1) {
|
|
// i8
|
|
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
|
|
: mlirIntegerTypeSignedGet(context, 8);
|
|
} else if (arrayInfo.itemsize == 2) {
|
|
// i16
|
|
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
|
|
: mlirIntegerTypeSignedGet(context, 16);
|
|
}
|
|
} else if (isUnsignedIntegerFormat(arrayInfo.format)) {
|
|
if (arrayInfo.itemsize == 4) {
|
|
// unsigned i32
|
|
bulkLoadElementType = signless
|
|
? mlirIntegerTypeGet(context, 32)
|
|
: mlirIntegerTypeUnsignedGet(context, 32);
|
|
} else if (arrayInfo.itemsize == 8) {
|
|
// unsigned i64
|
|
bulkLoadElementType = signless
|
|
? mlirIntegerTypeGet(context, 64)
|
|
: mlirIntegerTypeUnsignedGet(context, 64);
|
|
} else if (arrayInfo.itemsize == 1) {
|
|
// i8
|
|
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
|
|
: mlirIntegerTypeUnsignedGet(context, 8);
|
|
} else if (arrayInfo.itemsize == 2) {
|
|
// i16
|
|
bulkLoadElementType = signless
|
|
? mlirIntegerTypeGet(context, 16)
|
|
: mlirIntegerTypeUnsignedGet(context, 16);
|
|
}
|
|
}
|
|
if (bulkLoadElementType) {
|
|
auto shapedType = mlirRankedTensorTypeGet(
|
|
shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
|
|
size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
|
|
MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
|
|
shapedType, rawBufferSize, arrayInfo.ptr);
|
|
if (mlirAttributeIsNull(attr)) {
|
|
throw std::invalid_argument(
|
|
"DenseElementsAttr could not be constructed from the given buffer. "
|
|
"This may mean that the Python buffer layout does not match that "
|
|
"MLIR expected layout and is a bug.");
|
|
}
|
|
return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
|
|
}
|
|
|
|
throw std::invalid_argument(
|
|
std::string("unimplemented array format conversion from format: ") +
|
|
arrayInfo.format);
|
|
}
|
|
|
|
static PyDenseElementsAttribute getSplat(PyType shapedType,
|
|
PyAttribute &elementAttr) {
|
|
auto contextWrapper =
|
|
PyMlirContext::forContext(mlirTypeGetContext(shapedType));
|
|
if (!mlirAttributeIsAInteger(elementAttr) &&
|
|
!mlirAttributeIsAFloat(elementAttr)) {
|
|
std::string message = "Illegal element type for DenseElementsAttr: ";
|
|
message.append(py::repr(py::cast(elementAttr)));
|
|
throw SetPyError(PyExc_ValueError, message);
|
|
}
|
|
if (!mlirTypeIsAShaped(shapedType) ||
|
|
!mlirShapedTypeHasStaticShape(shapedType)) {
|
|
std::string message =
|
|
"Expected a static ShapedType for the shaped_type parameter: ";
|
|
message.append(py::repr(py::cast(shapedType)));
|
|
throw SetPyError(PyExc_ValueError, message);
|
|
}
|
|
MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
|
|
MlirType attrType = mlirAttributeGetType(elementAttr);
|
|
if (!mlirTypeEqual(shapedElementType, attrType)) {
|
|
std::string message =
|
|
"Shaped element type and attribute type must be equal: shaped=";
|
|
message.append(py::repr(py::cast(shapedType)));
|
|
message.append(", element=");
|
|
message.append(py::repr(py::cast(elementAttr)));
|
|
throw SetPyError(PyExc_ValueError, message);
|
|
}
|
|
|
|
MlirAttribute elements =
|
|
mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
|
|
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
|
|
}
|
|
|
|
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
|
|
|
|
py::buffer_info accessBuffer() {
|
|
if (mlirDenseElementsAttrIsSplat(*this)) {
|
|
// TODO: Currently crashes the program.
|
|
// Reported as https://github.com/pybind/pybind11/issues/3336
|
|
throw std::invalid_argument(
|
|
"unsupported data type for conversion to Python buffer");
|
|
}
|
|
|
|
MlirType shapedType = mlirAttributeGetType(*this);
|
|
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
|
|
std::string format;
|
|
|
|
if (mlirTypeIsAF32(elementType)) {
|
|
// f32
|
|
return bufferInfo<float>(shapedType);
|
|
} else if (mlirTypeIsAF64(elementType)) {
|
|
// f64
|
|
return bufferInfo<double>(shapedType);
|
|
} else if (mlirTypeIsAF16(elementType)) {
|
|
// f16
|
|
return bufferInfo<uint16_t>(shapedType, "e");
|
|
} else if (mlirTypeIsAInteger(elementType) &&
|
|
mlirIntegerTypeGetWidth(elementType) == 32) {
|
|
if (mlirIntegerTypeIsSignless(elementType) ||
|
|
mlirIntegerTypeIsSigned(elementType)) {
|
|
// i32
|
|
return bufferInfo<int32_t>(shapedType);
|
|
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
|
|
// unsigned i32
|
|
return bufferInfo<uint32_t>(shapedType);
|
|
}
|
|
} else if (mlirTypeIsAInteger(elementType) &&
|
|
mlirIntegerTypeGetWidth(elementType) == 64) {
|
|
if (mlirIntegerTypeIsSignless(elementType) ||
|
|
mlirIntegerTypeIsSigned(elementType)) {
|
|
// i64
|
|
return bufferInfo<int64_t>(shapedType);
|
|
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
|
|
// unsigned i64
|
|
return bufferInfo<uint64_t>(shapedType);
|
|
}
|
|
} else if (mlirTypeIsAInteger(elementType) &&
|
|
mlirIntegerTypeGetWidth(elementType) == 8) {
|
|
if (mlirIntegerTypeIsSignless(elementType) ||
|
|
mlirIntegerTypeIsSigned(elementType)) {
|
|
// i8
|
|
return bufferInfo<int8_t>(shapedType);
|
|
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
|
|
// unsigned i8
|
|
return bufferInfo<uint8_t>(shapedType);
|
|
}
|
|
} else if (mlirTypeIsAInteger(elementType) &&
|
|
mlirIntegerTypeGetWidth(elementType) == 16) {
|
|
if (mlirIntegerTypeIsSignless(elementType) ||
|
|
mlirIntegerTypeIsSigned(elementType)) {
|
|
// i16
|
|
return bufferInfo<int16_t>(shapedType);
|
|
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
|
|
// unsigned i16
|
|
return bufferInfo<uint16_t>(shapedType);
|
|
}
|
|
}
|
|
|
|
// TODO: Currently crashes the program.
|
|
// Reported as https://github.com/pybind/pybind11/issues/3336
|
|
throw std::invalid_argument(
|
|
"unsupported data type for conversion to Python buffer");
|
|
}
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
|
|
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
|
|
py::arg("array"), py::arg("signless") = true,
|
|
py::arg("type") = py::none(), py::arg("shape") = py::none(),
|
|
py::arg("context") = py::none(),
|
|
kDenseElementsAttrGetDocstring)
|
|
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
|
|
py::arg("shaped_type"), py::arg("element_attr"),
|
|
"Gets a DenseElementsAttr where all values are the same")
|
|
.def_property_readonly("is_splat",
|
|
[](PyDenseElementsAttribute &self) -> bool {
|
|
return mlirDenseElementsAttrIsSplat(self);
|
|
})
|
|
.def_buffer(&PyDenseElementsAttribute::accessBuffer);
|
|
}
|
|
|
|
private:
|
|
static bool isUnsignedIntegerFormat(const std::string &format) {
|
|
if (format.empty())
|
|
return false;
|
|
char code = format[0];
|
|
return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
|
|
code == 'Q';
|
|
}
|
|
|
|
static bool isSignedIntegerFormat(const std::string &format) {
|
|
if (format.empty())
|
|
return false;
|
|
char code = format[0];
|
|
return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
|
|
code == 'q';
|
|
}
|
|
|
|
template <typename Type>
|
|
py::buffer_info bufferInfo(MlirType shapedType,
|
|
const char *explicitFormat = nullptr) {
|
|
intptr_t rank = mlirShapedTypeGetRank(shapedType);
|
|
// Prepare the data for the buffer_info.
|
|
// Buffer is configured for read-only access below.
|
|
Type *data = static_cast<Type *>(
|
|
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
|
|
// Prepare the shape for the buffer_info.
|
|
SmallVector<intptr_t, 4> shape;
|
|
for (intptr_t i = 0; i < rank; ++i)
|
|
shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
|
|
// Prepare the strides for the buffer_info.
|
|
SmallVector<intptr_t, 4> strides;
|
|
intptr_t strideFactor = 1;
|
|
for (intptr_t i = 1; i < rank; ++i) {
|
|
strideFactor = 1;
|
|
for (intptr_t j = i; j < rank; ++j) {
|
|
strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
|
|
}
|
|
strides.push_back(sizeof(Type) * strideFactor);
|
|
}
|
|
strides.push_back(sizeof(Type));
|
|
std::string format;
|
|
if (explicitFormat) {
|
|
format = explicitFormat;
|
|
} else {
|
|
format = py::format_descriptor<Type>::format();
|
|
}
|
|
return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
|
|
/*readonly=*/true);
|
|
}
|
|
}; // namespace
|
|
|
|
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
|
|
/// (and boolean) values. Supports element access.
|
|
class PyDenseIntElementsAttribute
|
|
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
|
|
PyDenseElementsAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
|
|
static constexpr const char *pyClassName = "DenseIntElementsAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
/// Returns the element at the given linear position. Asserts if the index is
|
|
/// out of range.
|
|
py::int_ dunderGetItem(intptr_t pos) {
|
|
if (pos < 0 || pos >= dunderLen()) {
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to access out of bounds element");
|
|
}
|
|
|
|
MlirType type = mlirAttributeGetType(*this);
|
|
type = mlirShapedTypeGetElementType(type);
|
|
assert(mlirTypeIsAInteger(type) &&
|
|
"expected integer element type in dense int elements attribute");
|
|
// Dispatch element extraction to an appropriate C function based on the
|
|
// elemental type of the attribute. py::int_ is implicitly constructible
|
|
// from any C++ integral type and handles bitwidth correctly.
|
|
// TODO: consider caching the type properties in the constructor to avoid
|
|
// querying them on each element access.
|
|
unsigned width = mlirIntegerTypeGetWidth(type);
|
|
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
|
|
if (isUnsigned) {
|
|
if (width == 1) {
|
|
return mlirDenseElementsAttrGetBoolValue(*this, pos);
|
|
}
|
|
if (width == 32) {
|
|
return mlirDenseElementsAttrGetUInt32Value(*this, pos);
|
|
}
|
|
if (width == 64) {
|
|
return mlirDenseElementsAttrGetUInt64Value(*this, pos);
|
|
}
|
|
} else {
|
|
if (width == 1) {
|
|
return mlirDenseElementsAttrGetBoolValue(*this, pos);
|
|
}
|
|
if (width == 32) {
|
|
return mlirDenseElementsAttrGetInt32Value(*this, pos);
|
|
}
|
|
if (width == 64) {
|
|
return mlirDenseElementsAttrGetInt64Value(*this, pos);
|
|
}
|
|
}
|
|
throw SetPyError(PyExc_TypeError, "Unsupported integer type");
|
|
}
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
|
|
}
|
|
};
|
|
|
|
class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
|
|
static constexpr const char *pyClassName = "DictAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
|
|
|
|
bool dunderContains(const std::string &name) {
|
|
return !mlirAttributeIsNull(
|
|
mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
|
|
}
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def("__contains__", &PyDictAttribute::dunderContains);
|
|
c.def("__len__", &PyDictAttribute::dunderLen);
|
|
c.def_static(
|
|
"get",
|
|
[](py::dict attributes, DefaultingPyMlirContext context) {
|
|
SmallVector<MlirNamedAttribute> mlirNamedAttributes;
|
|
mlirNamedAttributes.reserve(attributes.size());
|
|
for (auto &it : attributes) {
|
|
auto &mlir_attr = it.second.cast<PyAttribute &>();
|
|
auto name = it.first.cast<std::string>();
|
|
mlirNamedAttributes.push_back(mlirNamedAttributeGet(
|
|
mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
|
|
toMlirStringRef(name)),
|
|
mlir_attr));
|
|
}
|
|
MlirAttribute attr =
|
|
mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
|
|
mlirNamedAttributes.data());
|
|
return PyDictAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("value") = py::dict(), py::arg("context") = py::none(),
|
|
"Gets an uniqued dict attribute");
|
|
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
|
|
MlirAttribute attr =
|
|
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
|
|
if (mlirAttributeIsNull(attr)) {
|
|
throw SetPyError(PyExc_KeyError,
|
|
"attempt to access a non-existent attribute");
|
|
}
|
|
return PyAttribute(self.getContext(), attr);
|
|
});
|
|
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
|
|
if (index < 0 || index >= self.dunderLen()) {
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to access out of bounds attribute");
|
|
}
|
|
MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
|
|
return PyNamedAttribute(
|
|
namedAttr.attribute,
|
|
std::string(mlirIdentifierStr(namedAttr.name).data));
|
|
});
|
|
}
|
|
};
|
|
|
|
/// Refinement of PyDenseElementsAttribute for attributes containing
|
|
/// floating-point values. Supports element access.
|
|
class PyDenseFPElementsAttribute
|
|
: public PyConcreteAttribute<PyDenseFPElementsAttribute,
|
|
PyDenseElementsAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
|
|
static constexpr const char *pyClassName = "DenseFPElementsAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
py::float_ dunderGetItem(intptr_t pos) {
|
|
if (pos < 0 || pos >= dunderLen()) {
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to access out of bounds element");
|
|
}
|
|
|
|
MlirType type = mlirAttributeGetType(*this);
|
|
type = mlirShapedTypeGetElementType(type);
|
|
// Dispatch element extraction to an appropriate C function based on the
|
|
// elemental type of the attribute. py::float_ is implicitly constructible
|
|
// from float and double.
|
|
// TODO: consider caching the type properties in the constructor to avoid
|
|
// querying them on each element access.
|
|
if (mlirTypeIsAF32(type)) {
|
|
return mlirDenseElementsAttrGetFloatValue(*this, pos);
|
|
}
|
|
if (mlirTypeIsAF64(type)) {
|
|
return mlirDenseElementsAttrGetDoubleValue(*this, pos);
|
|
}
|
|
throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
|
|
}
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
|
|
}
|
|
};
|
|
|
|
class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
|
|
static constexpr const char *pyClassName = "TypeAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyType value, DefaultingPyMlirContext context) {
|
|
MlirAttribute attr = mlirTypeAttrGet(value.get());
|
|
return PyTypeAttribute(context->getRef(), attr);
|
|
},
|
|
py::arg("value"), py::arg("context") = py::none(),
|
|
"Gets a uniqued Type attribute");
|
|
c.def_property_readonly("value", [](PyTypeAttribute &self) {
|
|
return PyType(self.getContext()->getRef(),
|
|
mlirTypeAttrGetValue(self.get()));
|
|
});
|
|
}
|
|
};
|
|
|
|
/// Unit Attribute subclass. Unit attributes don't have values.
|
|
class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
|
|
static constexpr const char *pyClassName = "UnitAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](DefaultingPyMlirContext context) {
|
|
return PyUnitAttribute(context->getRef(),
|
|
mlirUnitAttrGet(context->get()));
|
|
},
|
|
py::arg("context") = py::none(), "Create a Unit attribute.");
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::python::populateIRAttributes(py::module &m) {
|
|
PyAffineMapAttribute::bind(m);
|
|
PyArrayAttribute::bind(m);
|
|
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
|
|
PyBoolAttribute::bind(m);
|
|
PyDenseElementsAttribute::bind(m);
|
|
PyDenseFPElementsAttribute::bind(m);
|
|
PyDenseIntElementsAttribute::bind(m);
|
|
PyDictAttribute::bind(m);
|
|
PyFlatSymbolRefAttribute::bind(m);
|
|
PyFloatAttribute::bind(m);
|
|
PyIntegerAttribute::bind(m);
|
|
PyStringAttribute::bind(m);
|
|
PyTypeAttribute::bind(m);
|
|
PyUnitAttribute::bind(m);
|
|
}
|