forked from OSchip/llvm-project
[mlir][Python] Add an Operation.result property.
* If ODS redefines this, it is fine, but I have found this accessor to be universally useful in the old npcomp bindings and I'm closing gaps that will let me switch. Differential Revision: https://reviews.llvm.org/D92287
This commit is contained in:
parent
bd2083c2fa
commit
ba0fe76b7e
|
@ -23,6 +23,8 @@ using namespace mlir;
|
|||
using namespace mlir::python;
|
||||
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Docstrings (trivial, non-duplicated docstrings are included inline).
|
||||
|
@ -631,7 +633,7 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key,
|
|||
getContext()->get(), {canonKey->data(), canonKey->size()});
|
||||
if (mlirDialectIsNull(dialect)) {
|
||||
throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
|
||||
llvm::Twine("Dialect '") + key + "' not found");
|
||||
Twine("Dialect '") + key + "' not found");
|
||||
}
|
||||
return dialect;
|
||||
}
|
||||
|
@ -793,7 +795,7 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
|
|||
return created;
|
||||
}
|
||||
|
||||
void PyOperation::checkValid() {
|
||||
void PyOperation::checkValid() const {
|
||||
if (!valid) {
|
||||
throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
|
||||
}
|
||||
|
@ -817,7 +819,7 @@ void PyOperationBase::print(py::object fileObject, bool binary,
|
|||
|
||||
PyFileAccumulator accum(fileObject, binary);
|
||||
py::gil_scoped_release();
|
||||
mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(),
|
||||
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
|
||||
accum.getUserData());
|
||||
mlirOpPrintingFlagsDestroy(flags);
|
||||
}
|
||||
|
@ -975,7 +977,7 @@ py::object PyOperation::createOpView() {
|
|||
MlirIdentifier ident = mlirOperationGetName(get());
|
||||
MlirStringRef identStr = mlirIdentifierStr(ident);
|
||||
auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
|
||||
llvm::StringRef(identStr.data, identStr.length));
|
||||
StringRef(identStr.data, identStr.length));
|
||||
if (opViewClass)
|
||||
return (*opViewClass)(getRef().getObject());
|
||||
return py::cast(PyOpView(getRef().getObject()));
|
||||
|
@ -1044,7 +1046,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) {
|
|||
(*refOperation)->checkValid();
|
||||
beforeOp = (*refOperation)->get();
|
||||
}
|
||||
mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get());
|
||||
mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
|
||||
operation.setAttached();
|
||||
}
|
||||
|
||||
|
@ -1158,7 +1160,7 @@ public:
|
|||
static MlirValue castFrom(PyValue &orig) {
|
||||
if (!DerivedTy::isaFunction(orig.get())) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") +
|
||||
throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
|
||||
DerivedTy::pyClassName +
|
||||
" (from " + origRepr + ")");
|
||||
}
|
||||
|
@ -1416,9 +1418,9 @@ public:
|
|||
static MlirAttribute castFrom(PyAttribute &orig) {
|
||||
if (!DerivedTy::isaFunction(orig)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Cannot cast attribute to ") +
|
||||
DerivedTy::pyClassName + " (from " + origRepr + ")");
|
||||
throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
|
||||
DerivedTy::pyClassName +
|
||||
" (from " + origRepr + ")");
|
||||
}
|
||||
return orig;
|
||||
}
|
||||
|
@ -1449,7 +1451,7 @@ public:
|
|||
// in C API.
|
||||
if (mlirAttributeIsNull(attr)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
Twine("invalid '") +
|
||||
py::repr(py::cast(type)).cast<std::string>() +
|
||||
"' and expected floating point type.");
|
||||
}
|
||||
|
@ -1943,7 +1945,7 @@ public:
|
|||
static MlirType castFrom(PyType &orig) {
|
||||
if (!DerivedTy::isaFunction(orig)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
|
||||
throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
|
||||
DerivedTy::pyClassName +
|
||||
" (from " + origRepr + ")");
|
||||
}
|
||||
|
@ -2142,7 +2144,7 @@ public:
|
|||
}
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point or integer type.");
|
||||
},
|
||||
|
@ -2247,7 +2249,7 @@ public:
|
|||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point or integer type.");
|
||||
}
|
||||
|
@ -2278,7 +2280,7 @@ public:
|
|||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
|
@ -2309,7 +2311,7 @@ public:
|
|||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
|
@ -2344,7 +2346,7 @@ public:
|
|||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
|
@ -2390,7 +2392,7 @@ public:
|
|||
if (mlirTypeIsNull(t)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
llvm::Twine("invalid '") +
|
||||
Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
"' and expected floating point, integer, vector or "
|
||||
"complex "
|
||||
|
@ -2544,7 +2546,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
self.get(), {name.data(), name.size()});
|
||||
if (mlirDialectIsNull(dialect)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Dialect '") + name + "' not found");
|
||||
Twine("Dialect '") + name + "' not found");
|
||||
}
|
||||
return PyDialectDescriptor(self.getRef(), dialect);
|
||||
},
|
||||
|
@ -2763,6 +2765,26 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
return PyOpResultList(self.getOperation().getRef());
|
||||
},
|
||||
"Returns the list of Operation results.")
|
||||
.def_property_readonly(
|
||||
"result",
|
||||
[](PyOperationBase &self) {
|
||||
auto &operation = self.getOperation();
|
||||
auto numResults = mlirOperationGetNumResults(operation);
|
||||
if (numResults != 1) {
|
||||
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
Twine("Cannot call .result on operation ") +
|
||||
StringRef(name.data, name.length) + " which has " +
|
||||
Twine(numResults) +
|
||||
" results (it is only valid for operations with a "
|
||||
"single result)");
|
||||
}
|
||||
return PyOpResult(operation.getRef(),
|
||||
mlirOperationGetResult(operation, 0));
|
||||
},
|
||||
"Shortcut to get an op result if it has only one (throws an error "
|
||||
"otherwise).")
|
||||
.def("__iter__",
|
||||
[](PyOperationBase &self) {
|
||||
return PyRegionIterator(self.getOperation().getRef());
|
||||
|
@ -2931,7 +2953,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
// in C API.
|
||||
if (mlirAttributeIsNull(type)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Unable to parse attribute: '") +
|
||||
Twine("Unable to parse attribute: '") +
|
||||
attrSpec + "'");
|
||||
}
|
||||
return PyAttribute(context->getRef(), type);
|
||||
|
@ -3042,8 +3064,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
|||
// in C API.
|
||||
if (mlirTypeIsNull(type)) {
|
||||
throw SetPyError(PyExc_ValueError,
|
||||
llvm::Twine("Unable to parse type: '") +
|
||||
typeSpec + "'");
|
||||
Twine("Unable to parse type: '") + typeSpec +
|
||||
"'");
|
||||
}
|
||||
return PyType(context->getRef(), type);
|
||||
},
|
||||
|
|
|
@ -425,7 +425,8 @@ public:
|
|||
pybind11::object parentKeepAlive = pybind11::object());
|
||||
|
||||
/// Gets the backing operation.
|
||||
MlirOperation get() {
|
||||
operator MlirOperation() const { return get(); }
|
||||
MlirOperation get() const {
|
||||
checkValid();
|
||||
return operation;
|
||||
}
|
||||
|
@ -440,7 +441,7 @@ public:
|
|||
assert(!attached && "operation already attached");
|
||||
attached = true;
|
||||
}
|
||||
void checkValid();
|
||||
void checkValid() const;
|
||||
|
||||
/// Gets the owning block or raises an exception if the operation has no
|
||||
/// owning block.
|
||||
|
|
|
@ -474,6 +474,7 @@ def testOperationPrint():
|
|||
run(testOperationPrint)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testKnownOpView
|
||||
def testKnownOpView():
|
||||
with Context(), Location.unknown():
|
||||
Context.current.allow_unregistered_dialects = True
|
||||
|
@ -503,3 +504,36 @@ def testKnownOpView():
|
|||
print(repr(custom))
|
||||
|
||||
run(testKnownOpView)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSingleResultProperty
|
||||
def testSingleResultProperty():
|
||||
with Context(), Location.unknown():
|
||||
Context.current.allow_unregistered_dialects = True
|
||||
module = Module.parse(r"""
|
||||
"custom.no_result"() : () -> ()
|
||||
%0:2 = "custom.two_result"() : () -> (f32, f32)
|
||||
%1 = "custom.one_result"() : () -> f32
|
||||
""")
|
||||
print(module)
|
||||
|
||||
try:
|
||||
module.body.operations[0].result
|
||||
except ValueError as e:
|
||||
# CHECK: Cannot call .result on operation custom.no_result which has 0 results
|
||||
print(e)
|
||||
else:
|
||||
assert False, "Expected exception"
|
||||
|
||||
try:
|
||||
module.body.operations[1].result
|
||||
except ValueError as e:
|
||||
# CHECK: Cannot call .result on operation custom.two_result which has 2 results
|
||||
print(e)
|
||||
else:
|
||||
assert False, "Expected exception"
|
||||
|
||||
# CHECK: %1 = "custom.one_result"() : () -> f32
|
||||
print(module.body.operations[2])
|
||||
|
||||
run(testSingleResultProperty)
|
||||
|
|
Loading…
Reference in New Issue