forked from OSchip/llvm-project
[mlir][Python] Make DenseElementsAttr loading be int size agnostic.
* I had missed the note about "Standard size" in the docs. On Windows, the 'l' types are 32bit. * This fixes the only failing MLIR-Python test on Windows. Differential Revision: https://reviews.llvm.org/D91283
This commit is contained in:
parent
e2537353e6
commit
989b194429
|
@ -1534,6 +1534,7 @@ public:
|
|||
MlirContext context = contextWrapper->get();
|
||||
// Switch on the types that can be bulk loaded between the Python and
|
||||
// MLIR-C APIs.
|
||||
// See: https://docs.python.org/3/library/struct.html#format-characters
|
||||
if (arrayInfo.format == "f") {
|
||||
// f32
|
||||
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
|
||||
|
@ -1548,42 +1549,44 @@ public:
|
|||
contextWrapper->getRef(),
|
||||
bulkLoad(context, mlirDenseElementsAttrDoubleGet,
|
||||
mlirF64TypeGet(context), arrayInfo));
|
||||
} else if (arrayInfo.format == "i") {
|
||||
// i32
|
||||
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
|
||||
MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
|
||||
: mlirIntegerTypeSignedGet(context, 32);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrInt32Get,
|
||||
elementType, arrayInfo));
|
||||
} else if (arrayInfo.format == "I") {
|
||||
// unsigned i32
|
||||
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
|
||||
MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
|
||||
: mlirIntegerTypeUnsignedGet(context, 32);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrUInt32Get,
|
||||
elementType, arrayInfo));
|
||||
} else if (arrayInfo.format == "l") {
|
||||
// i64
|
||||
assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
|
||||
MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
|
||||
: mlirIntegerTypeSignedGet(context, 64);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrInt64Get,
|
||||
elementType, arrayInfo));
|
||||
} else if (arrayInfo.format == "L") {
|
||||
// unsigned i64
|
||||
assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
|
||||
MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
|
||||
: mlirIntegerTypeUnsignedGet(context, 64);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrUInt64Get,
|
||||
elementType, arrayInfo));
|
||||
} else if (isSignedIntegerFormat(arrayInfo.format)) {
|
||||
if (arrayInfo.itemsize == 4) {
|
||||
// i32
|
||||
MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
|
||||
: mlirIntegerTypeSignedGet(context, 32);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrInt32Get,
|
||||
elementType, arrayInfo));
|
||||
} else if (arrayInfo.itemsize == 8) {
|
||||
// i64
|
||||
MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
|
||||
: mlirIntegerTypeSignedGet(context, 64);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrInt64Get,
|
||||
elementType, arrayInfo));
|
||||
}
|
||||
} else if (isUnsignedIntegerFormat(arrayInfo.format)) {
|
||||
if (arrayInfo.itemsize == 4) {
|
||||
// unsigned i32
|
||||
MlirType elementType = signless
|
||||
? mlirIntegerTypeGet(context, 32)
|
||||
: mlirIntegerTypeUnsignedGet(context, 32);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrUInt32Get,
|
||||
elementType, arrayInfo));
|
||||
} else if (arrayInfo.itemsize == 8) {
|
||||
// unsigned i64
|
||||
MlirType elementType = signless
|
||||
? mlirIntegerTypeGet(context, 64)
|
||||
: mlirIntegerTypeUnsignedGet(context, 64);
|
||||
return PyDenseElementsAttribute(contextWrapper->getRef(),
|
||||
bulkLoad(context,
|
||||
mlirDenseElementsAttrUInt64Get,
|
||||
elementType, arrayInfo));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Fall back to string-based get.
|
||||
|
@ -1656,7 +1659,23 @@ private:
|
|||
const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
|
||||
return ctor(shapedType, numElements, contents);
|
||||
}
|
||||
};
|
||||
|
||||
static bool isUnsignedIntegerFormat(const std::string &format) {
|
||||
if (format.empty())
|
||||
return false;
|
||||
char code = format[0];
|
||||
return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
|
||||
code == 'Q';
|
||||
}
|
||||
|
||||
static bool isSignedIntegerFormat(const std::string &format) {
|
||||
if (format.empty())
|
||||
return false;
|
||||
char code = format[0];
|
||||
return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
|
||||
code == 'q';
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
|
||||
/// (and boolean) values. Supports element access.
|
||||
|
|
Loading…
Reference in New Issue