[mlir][Python] Add an Operation.result property.

* If ODS redefines this, it is fine, but I have found this accessor to be universally useful in the old npcomp bindings and I'm closing gaps that will let me switch.

Differential Revision: https://reviews.llvm.org/D92287
This commit is contained in:
Stella Laurenzo 2020-11-29 13:52:11 -08:00
parent bd2083c2fa
commit ba0fe76b7e
3 changed files with 80 additions and 23 deletions

View File

@ -23,6 +23,8 @@ using namespace mlir;
using namespace mlir::python;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
//------------------------------------------------------------------------------
// Docstrings (trivial, non-duplicated docstrings are included inline).
@ -631,7 +633,7 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
getContext()->get(), {canonKey->data(), canonKey->size()});
if (mlirDialectIsNull(dialect)) {
throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
llvm::Twine("Dialect '") + key + "' not found");
Twine("Dialect '") + key + "' not found");
}
return dialect;
}
@ -793,7 +795,7 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
return created;
}
void PyOperation::checkValid() {
void PyOperation::checkValid() const {
if (!valid) {
throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
}
@ -817,7 +819,7 @@ void PyOperationBase::print(py::object fileObject, bool binary,
PyFileAccumulator accum(fileObject, binary);
py::gil_scoped_release();
mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(),
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
accum.getUserData());
mlirOpPrintingFlagsDestroy(flags);
}
@ -975,7 +977,7 @@ py::object PyOperation::createOpView() {
MlirIdentifier ident = mlirOperationGetName(get());
MlirStringRef identStr = mlirIdentifierStr(ident);
auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
llvm::StringRef(identStr.data, identStr.length));
StringRef(identStr.data, identStr.length));
if (opViewClass)
return (*opViewClass)(getRef().getObject());
return py::cast(PyOpView(getRef().getObject()));
@ -1044,7 +1046,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) {
(*refOperation)->checkValid();
beforeOp = (*refOperation)->get();
}
mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get());
mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
operation.setAttached();
}
@ -1158,7 +1160,7 @@ public:
static MlirValue castFrom(PyValue &orig) {
if (!DerivedTy::isaFunction(orig.get())) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") +
throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
@ -1416,9 +1418,9 @@ public:
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError,
llvm::Twine("Cannot cast attribute to ") +
DerivedTy::pyClassName + " (from " + origRepr + ")");
throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
return orig;
}
@ -1449,7 +1451,7 @@ public:
// in C API.
if (mlirAttributeIsNull(attr)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("invalid '") +
Twine("invalid '") +
py::repr(py::cast(type)).cast<std::string>() +
"' and expected floating point type.");
}
@ -1943,7 +1945,7 @@ public:
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
DerivedTy::pyClassName +
" (from " + origRepr + ")");
}
@ -2142,7 +2144,7 @@ public:
}
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
},
@ -2247,7 +2249,7 @@ public:
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
@ -2278,7 +2280,7 @@ public:
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@ -2309,7 +2311,7 @@ public:
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@ -2344,7 +2346,7 @@ public:
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@ -2390,7 +2392,7 @@ public:
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
llvm::Twine("invalid '") +
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
@ -2544,7 +2546,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
self.get(), {name.data(), name.size()});
if (mlirDialectIsNull(dialect)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("Dialect '") + name + "' not found");
Twine("Dialect '") + name + "' not found");
}
return PyDialectDescriptor(self.getRef(), dialect);
},
@ -2763,6 +2765,26 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return PyOpResultList(self.getOperation().getRef());
},
"Returns the list of Operation results.")
.def_property_readonly(
"result",
[](PyOperationBase &self) {
auto &operation = self.getOperation();
auto numResults = mlirOperationGetNumResults(operation);
if (numResults != 1) {
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
throw SetPyError(
PyExc_ValueError,
Twine("Cannot call .result on operation ") +
StringRef(name.data, name.length) + " which has " +
Twine(numResults) +
" results (it is only valid for operations with a "
"single result)");
}
return PyOpResult(operation.getRef(),
mlirOperationGetResult(operation, 0));
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
.def("__iter__",
[](PyOperationBase &self) {
return PyRegionIterator(self.getOperation().getRef());
@ -2931,7 +2953,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// in C API.
if (mlirAttributeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("Unable to parse attribute: '") +
Twine("Unable to parse attribute: '") +
attrSpec + "'");
}
return PyAttribute(context->getRef(), type);
@ -3042,8 +3064,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// in C API.
if (mlirTypeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("Unable to parse type: '") +
typeSpec + "'");
Twine("Unable to parse type: '") + typeSpec +
"'");
}
return PyType(context->getRef(), type);
},

View File

@ -425,7 +425,8 @@ public:
pybind11::object parentKeepAlive = pybind11::object());
/// Gets the backing operation.
MlirOperation get() {
operator MlirOperation() const { return get(); }
MlirOperation get() const {
checkValid();
return operation;
}
@ -440,7 +441,7 @@ public:
assert(!attached && "operation already attached");
attached = true;
}
void checkValid();
void checkValid() const;
/// Gets the owning block or raises an exception if the operation has no
/// owning block.

View File

@ -474,6 +474,7 @@ def testOperationPrint():
run(testOperationPrint)
# CHECK-LABEL: TEST: testKnownOpView
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
@ -503,3 +504,36 @@ def testKnownOpView():
print(repr(custom))
run(testKnownOpView)
# CHECK-LABEL: TEST: testSingleResultProperty
def testSingleResultProperty():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(r"""
"custom.no_result"() : () -> ()
%0:2 = "custom.two_result"() : () -> (f32, f32)
%1 = "custom.one_result"() : () -> f32
""")
print(module)
try:
module.body.operations[0].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.no_result which has 0 results
print(e)
else:
assert False, "Expected exception"
try:
module.body.operations[1].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.two_result which has 2 results
print(e)
else:
assert False, "Expected exception"
# CHECK: %1 = "custom.one_result"() : () -> f32
print(module.body.operations[2])
run(testSingleResultProperty)