[mlir][Python] Make DenseElementsAttr loading be int size agnostic.

* I had missed the note about "Standard size" in the docs. On Windows, the 'l' types are 32bit.
* This fixes the only failing MLIR-Python test on Windows.

Differential Revision: https://reviews.llvm.org/D91283
This commit is contained in:
Stella Laurenzo 2020-11-11 10:21:53 -08:00
parent e2537353e6
commit 989b194429
1 changed files with 56 additions and 37 deletions

View File

@ -1534,6 +1534,7 @@ public:
MlirContext context = contextWrapper->get();
// Switch on the types that can be bulk loaded between the Python and
// MLIR-C APIs.
// See: https://docs.python.org/3/library/struct.html#format-characters
if (arrayInfo.format == "f") {
// f32
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
@ -1548,42 +1549,44 @@ public:
contextWrapper->getRef(),
bulkLoad(context, mlirDenseElementsAttrDoubleGet,
mlirF64TypeGet(context), arrayInfo));
} else if (arrayInfo.format == "i") {
// i32
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrInt32Get,
elementType, arrayInfo));
} else if (arrayInfo.format == "I") {
// unsigned i32
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrUInt32Get,
elementType, arrayInfo));
} else if (arrayInfo.format == "l") {
// i64
assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrInt64Get,
elementType, arrayInfo));
} else if (arrayInfo.format == "L") {
// unsigned i64
assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrUInt64Get,
elementType, arrayInfo));
} else if (isSignedIntegerFormat(arrayInfo.format)) {
if (arrayInfo.itemsize == 4) {
// i32
MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrInt32Get,
elementType, arrayInfo));
} else if (arrayInfo.itemsize == 8) {
// i64
MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrInt64Get,
elementType, arrayInfo));
}
} else if (isUnsignedIntegerFormat(arrayInfo.format)) {
if (arrayInfo.itemsize == 4) {
// unsigned i32
MlirType elementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrUInt32Get,
elementType, arrayInfo));
} else if (arrayInfo.itemsize == 8) {
// unsigned i64
MlirType elementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
return PyDenseElementsAttribute(contextWrapper->getRef(),
bulkLoad(context,
mlirDenseElementsAttrUInt64Get,
elementType, arrayInfo));
}
}
// TODO: Fall back to string-based get.
@ -1656,7 +1659,23 @@ private:
const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
return ctor(shapedType, numElements, contents);
}
};
static bool isUnsignedIntegerFormat(const std::string &format) {
if (format.empty())
return false;
char code = format[0];
return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
code == 'Q';
}
static bool isSignedIntegerFormat(const std::string &format) {
if (format.empty())
return false;
char code = format[0];
return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
code == 'q';
}
}; // namespace
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
/// (and boolean) values. Supports element access.