forked from OSchip/llvm-project
Add Location, Region and Block to MLIR Python bindings.
* This is just enough to create regions/blocks and iterate over them. * Does not yet implement the preferred iteration strategy (python pseudo containers). * Refinements need to come after doing basic mappings of operations and values so that the whole hierarchy can be used. Differential Revision: https://reviews.llvm.org/D86683
This commit is contained in:
parent
1b201914b5
commit
2d1362e09a
|
@ -85,6 +85,9 @@ typedef void (*MlirStringCallback)(const char *, intptr_t, void *);
|
|||
/** Creates an MLIR context and transfers its ownership to the caller. */
|
||||
MlirContext mlirContextCreate();
|
||||
|
||||
/** Checks if two contexts are equal. */
|
||||
int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
|
||||
|
||||
/** Takes an MLIR context owned by the caller and destroys it. */
|
||||
void mlirContextDestroy(MlirContext context);
|
||||
|
||||
|
@ -315,6 +318,9 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback,
|
|||
/** Parses a type. The type is owned by the context. */
|
||||
MlirType mlirTypeParseGet(MlirContext context, const char *type);
|
||||
|
||||
/** Gets the context that a type was created with. */
|
||||
MlirContext mlirTypeGetContext(MlirType type);
|
||||
|
||||
/** Checks whether a type is null. */
|
||||
inline int mlirTypeIsNull(MlirType type) { return !type.ptr; }
|
||||
|
||||
|
|
|
@ -28,13 +28,59 @@ Returns a new MlirModule or raises a ValueError if the parsing fails.
|
|||
See also: https://mlir.llvm.org/docs/LangRef/
|
||||
)";
|
||||
|
||||
static const char kContextParseType[] = R"(Parses the assembly form of a type.
|
||||
static const char kContextParseTypeDocstring[] =
|
||||
R"(Parses the assembly form of a type.
|
||||
|
||||
Returns a Type object or raises a ValueError if the type cannot be parsed.
|
||||
|
||||
See also: https://mlir.llvm.org/docs/LangRef/#type-system
|
||||
)";
|
||||
|
||||
static const char kContextGetUnknownLocationDocstring[] =
|
||||
R"(Gets a Location representing an unknown location)";
|
||||
|
||||
static const char kContextGetFileLocationDocstring[] =
|
||||
R"(Gets a Location representing a file, line and column)";
|
||||
|
||||
static const char kContextCreateBlockDocstring[] =
|
||||
R"(Creates a detached block)";
|
||||
|
||||
static const char kContextCreateRegionDocstring[] =
|
||||
R"(Creates a detached region)";
|
||||
|
||||
static const char kRegionAppendBlockDocstring[] =
|
||||
R"(Appends a block to a region.
|
||||
|
||||
Raises:
|
||||
ValueError: If the block is already attached to another region.
|
||||
)";
|
||||
|
||||
static const char kRegionInsertBlockDocstring[] =
|
||||
R"(Inserts a block at a postiion in a region.
|
||||
|
||||
Raises:
|
||||
ValueError: If the block is already attached to another region.
|
||||
)";
|
||||
|
||||
static const char kRegionFirstBlockDocstring[] =
|
||||
R"(Gets the first block in a region.
|
||||
|
||||
Blocks can also be accessed via the `blocks` container.
|
||||
|
||||
Raises:
|
||||
IndexError: If the region has no blocks.
|
||||
)";
|
||||
|
||||
static const char kBlockNextInRegionDocstring[] =
|
||||
R"(Gets the next block in the enclosing region.
|
||||
|
||||
Blocks can also be accessed via the `blocks` container of the owning region.
|
||||
This method exists to mirror the lower level API and should not be preferred.
|
||||
|
||||
Raises:
|
||||
IndexError: If there are no further blocks.
|
||||
)";
|
||||
|
||||
static const char kOperationStrDunderDocstring[] =
|
||||
R"(Prints the assembly form of the operation with default options.
|
||||
|
||||
|
@ -106,6 +152,24 @@ private:
|
|||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyBlock, PyRegion, and PyOperation.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
void PyRegion::attachToParent() {
|
||||
if (!detached) {
|
||||
throw SetPyError(PyExc_ValueError, "Region is already attached to an op");
|
||||
}
|
||||
detached = false;
|
||||
}
|
||||
|
||||
void PyBlock::attachToParent() {
|
||||
if (!detached) {
|
||||
throw SetPyError(PyExc_ValueError, "Block is already attached to an op");
|
||||
}
|
||||
detached = false;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyAttribute.
|
||||
//------------------------------------------------------------------------------
|
||||
|
@ -454,7 +518,59 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
}
|
||||
return PyType(type);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextParseType);
|
||||
py::keep_alive<0, 1>(), kContextParseTypeDocstring)
|
||||
.def(
|
||||
"get_unknown_location",
|
||||
[](PyMlirContext &self) {
|
||||
return PyLocation(mlirLocationUnknownGet(self.context));
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring)
|
||||
.def(
|
||||
"get_file_location",
|
||||
[](PyMlirContext &self, std::string filename, int line, int col) {
|
||||
return PyLocation(mlirLocationFileLineColGet(
|
||||
self.context, filename.c_str(), line, col));
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextGetFileLocationDocstring,
|
||||
py::arg("filename"), py::arg("line"), py::arg("col"))
|
||||
.def(
|
||||
"create_region",
|
||||
[](PyMlirContext &self) {
|
||||
// The creating context is explicitly captured on regions to
|
||||
// facilitate illegal assemblies of objects from multiple contexts
|
||||
// that would invalidate the memory model.
|
||||
return PyRegion(self.context, mlirRegionCreate(),
|
||||
/*detached=*/true);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
|
||||
.def(
|
||||
"create_block",
|
||||
[](PyMlirContext &self, std::vector<PyType> pyTypes) {
|
||||
// In order for the keep_alive extend the proper lifetime, all
|
||||
// types must be from the same context.
|
||||
for (auto pyType : pyTypes) {
|
||||
if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
|
||||
self.context)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"All types used to construct a block must be from "
|
||||
"the same context as the block");
|
||||
}
|
||||
}
|
||||
llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
|
||||
pyTypes.end());
|
||||
return PyBlock(self.context,
|
||||
mlirBlockCreate(types.size(), &types[0]),
|
||||
/*detached=*/true);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
|
||||
|
||||
py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirLocationPrint(self.loc, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
return printAccum.join();
|
||||
});
|
||||
|
||||
// Mapping of Module
|
||||
py::class_<PyModule>(m, "Module")
|
||||
|
@ -475,6 +591,70 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
},
|
||||
kOperationStrDunderDocstring);
|
||||
|
||||
// Mapping of PyRegion.
|
||||
py::class_<PyRegion>(m, "Region")
|
||||
.def(
|
||||
"append_block",
|
||||
[](PyRegion &self, PyBlock &block) {
|
||||
if (!mlirContextEqual(self.context, block.context)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"Block must have been created from the same context as "
|
||||
"this region");
|
||||
}
|
||||
|
||||
block.attachToParent();
|
||||
mlirRegionAppendOwnedBlock(self.region, block.block);
|
||||
},
|
||||
kRegionAppendBlockDocstring)
|
||||
.def(
|
||||
"insert_block",
|
||||
[](PyRegion &self, int pos, PyBlock &block) {
|
||||
if (!mlirContextEqual(self.context, block.context)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"Block must have been created from the same context as "
|
||||
"this region");
|
||||
}
|
||||
block.attachToParent();
|
||||
// TODO: Make this return a failure and raise if out of bounds.
|
||||
mlirRegionInsertOwnedBlock(self.region, pos, block.block);
|
||||
},
|
||||
kRegionInsertBlockDocstring)
|
||||
.def_property_readonly(
|
||||
"first_block",
|
||||
[](PyRegion &self) {
|
||||
MlirBlock block = mlirRegionGetFirstBlock(self.region);
|
||||
if (mlirBlockIsNull(block)) {
|
||||
throw SetPyError(PyExc_IndexError, "Region has no blocks");
|
||||
}
|
||||
return PyBlock(self.context, block, /*detached=*/false);
|
||||
},
|
||||
kRegionFirstBlockDocstring);
|
||||
|
||||
// Mapping of PyBlock.
|
||||
py::class_<PyBlock>(m, "Block")
|
||||
.def_property_readonly(
|
||||
"next_in_region",
|
||||
[](PyBlock &self) {
|
||||
MlirBlock block = mlirBlockGetNextInRegion(self.block);
|
||||
if (mlirBlockIsNull(block)) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"Attempt to read past last block");
|
||||
}
|
||||
return PyBlock(self.context, block, /*detached=*/false);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kBlockNextInRegionDocstring)
|
||||
.def(
|
||||
"__str__",
|
||||
[](PyBlock &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirBlockPrint(self.block, printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
return printAccum.join();
|
||||
},
|
||||
kTypeStrDunderDocstring);
|
||||
|
||||
// Mapping of Type.
|
||||
py::class_<PyAttribute>(m, "Attribute")
|
||||
.def(
|
||||
|
|
|
@ -28,6 +28,13 @@ public:
|
|||
MlirContext context;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirLocation.
|
||||
class PyLocation {
|
||||
public:
|
||||
PyLocation(MlirLocation loc) : loc(loc) {}
|
||||
MlirLocation loc;
|
||||
};
|
||||
|
||||
/// Wrapper around MlirModule.
|
||||
class PyModule {
|
||||
public:
|
||||
|
@ -45,6 +52,72 @@ public:
|
|||
MlirModule module;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirRegion.
|
||||
/// Note that region can exist in a detached state (where this instance is
|
||||
/// responsible for clearing) or an attached state (where its owner is
|
||||
/// responsible).
|
||||
///
|
||||
/// This python wrapper retains a redundant reference to its creating context
|
||||
/// in order to facilitate checking that parts of the operation hierarchy
|
||||
/// are only assembled from the same context.
|
||||
class PyRegion {
|
||||
public:
|
||||
PyRegion(MlirContext context, MlirRegion region, bool detached)
|
||||
: context(context), region(region), detached(detached) {}
|
||||
PyRegion(PyRegion &&other)
|
||||
: context(other.context), region(other.region), detached(other.detached) {
|
||||
other.detached = false;
|
||||
}
|
||||
~PyRegion() {
|
||||
if (detached)
|
||||
mlirRegionDestroy(region);
|
||||
}
|
||||
|
||||
// Call prior to attaching the region to a parent.
|
||||
// This will transition to the attached state and will throw an exception
|
||||
// if already attached.
|
||||
void attachToParent();
|
||||
|
||||
MlirContext context;
|
||||
MlirRegion region;
|
||||
|
||||
private:
|
||||
bool detached;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirBlock.
|
||||
/// Note that blocks can exist in a detached state (where this instance is
|
||||
/// responsible for clearing) or an attached state (where its owner is
|
||||
/// responsible).
|
||||
///
|
||||
/// This python wrapper retains a redundant reference to its creating context
|
||||
/// in order to facilitate checking that parts of the operation hierarchy
|
||||
/// are only assembled from the same context.
|
||||
class PyBlock {
|
||||
public:
|
||||
PyBlock(MlirContext context, MlirBlock block, bool detached)
|
||||
: context(context), block(block), detached(detached) {}
|
||||
PyBlock(PyBlock &&other)
|
||||
: context(other.context), block(other.block), detached(other.detached) {
|
||||
other.detached = false;
|
||||
}
|
||||
~PyBlock() {
|
||||
if (detached)
|
||||
mlirBlockDestroy(block);
|
||||
}
|
||||
|
||||
// Call prior to attaching the block to a parent.
|
||||
// This will transition to the attached state and will throw an exception
|
||||
// if already attached.
|
||||
void attachToParent();
|
||||
|
||||
MlirContext context;
|
||||
MlirBlock block;
|
||||
|
||||
private:
|
||||
bool detached;
|
||||
};
|
||||
|
||||
/// Wrapper around the generic MlirAttribute.
|
||||
/// The lifetime of a type is bound by the PyContext that created it.
|
||||
class PyAttribute {
|
||||
|
@ -84,6 +157,7 @@ class PyType {
|
|||
public:
|
||||
PyType(MlirType type) : type(type) {}
|
||||
bool operator==(const PyType &other);
|
||||
operator MlirType() const { return type; }
|
||||
|
||||
MlirType type;
|
||||
};
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
|
|
|
@ -55,6 +55,10 @@ MlirContext mlirContextCreate() {
|
|||
return wrap(context);
|
||||
}
|
||||
|
||||
int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
|
||||
return unwrap(ctx1) == unwrap(ctx2);
|
||||
}
|
||||
|
||||
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
|
||||
|
||||
/* ========================================================================== */
|
||||
|
@ -350,6 +354,10 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
|
|||
return wrap(mlir::parseType(type, unwrap(context)));
|
||||
}
|
||||
|
||||
MlirContext mlirTypeGetContext(MlirType type) {
|
||||
return wrap(unwrap(type).getContext());
|
||||
}
|
||||
|
||||
int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
|
||||
|
||||
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import mlir
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
# CHECK-LABEL: TEST: testUnknown
|
||||
def testUnknown():
|
||||
ctx = mlir.ir.Context()
|
||||
loc = ctx.get_unknown_location()
|
||||
# CHECK: unknown str: loc(unknown)
|
||||
print("unknown str:", str(loc))
|
||||
# CHECK: unknown repr: loc(unknown)
|
||||
print("unknown repr:", repr(loc))
|
||||
|
||||
run(testUnknown)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFileLineCol
|
||||
def testFileLineCol():
|
||||
ctx = mlir.ir.Context()
|
||||
loc = ctx.get_file_location("foo.txt", 123, 56)
|
||||
# CHECK: file str: loc("foo.txt":123:56)
|
||||
print("file str:", str(loc))
|
||||
# CHECK: file repr: loc("foo.txt":123:56)
|
||||
print("file repr:", repr(loc))
|
||||
|
||||
run(testFileLineCol)
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import mlir
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDetachedRegionBlock
|
||||
def testDetachedRegionBlock():
|
||||
ctx = mlir.ir.Context()
|
||||
t = mlir.ir.F32Type(ctx)
|
||||
region = ctx.create_region()
|
||||
block = ctx.create_block([t, t])
|
||||
# CHECK: <<UNLINKED BLOCK>>
|
||||
print(block)
|
||||
|
||||
run(testDetachedRegionBlock)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBlockTypeContextMismatch
|
||||
def testBlockTypeContextMismatch():
|
||||
c1 = mlir.ir.Context()
|
||||
c2 = mlir.ir.Context()
|
||||
t1 = mlir.ir.F32Type(c1)
|
||||
t2 = mlir.ir.F32Type(c2)
|
||||
try:
|
||||
block = c1.create_block([t1, t2])
|
||||
except ValueError as e:
|
||||
# CHECK: ERROR: All types used to construct a block must be from the same context as the block
|
||||
print("ERROR:", e)
|
||||
|
||||
run(testBlockTypeContextMismatch)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBlockAppend
|
||||
def testBlockAppend():
|
||||
ctx = mlir.ir.Context()
|
||||
t = mlir.ir.F32Type(ctx)
|
||||
region = ctx.create_region()
|
||||
try:
|
||||
region.first_block
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Expected exception not raised")
|
||||
|
||||
block = ctx.create_block([t, t])
|
||||
region.append_block(block)
|
||||
try:
|
||||
region.append_block(block)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Expected exception not raised")
|
||||
|
||||
block2 = ctx.create_block([t])
|
||||
region.insert_block(1, block2)
|
||||
# CHECK: <<UNLINKED BLOCK>>
|
||||
block_first = region.first_block
|
||||
print(block_first)
|
||||
block_next = block_first.next_in_region
|
||||
try:
|
||||
block_next = block_next.next_in_region
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Expected exception not raised")
|
||||
|
||||
run(testBlockAppend)
|
Loading…
Reference in New Issue