forked from OSchip/llvm-project
[mlir][python][f16] add ctype python binding support for f16
Similar to complex128/complex64, float16 has no direct support in the ctypes implementation. This fixes the issue by using a custom F16 type to change the view in and out of MLIR code Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D126928
This commit is contained in:
parent
b64f6e5722
commit
f8b692dd31
|
@ -18,15 +18,33 @@ class C64(ctypes.Structure):
|
|||
_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
|
||||
|
||||
|
||||
class F16(ctypes.Structure):
|
||||
"""A ctype representation for MLIR's Float16."""
|
||||
_fields_ = [("f16", ctypes.c_int16)]
|
||||
|
||||
|
||||
def as_ctype(dtp):
|
||||
"""Converts dtype to ctype."""
|
||||
if dtp is np.dtype(np.complex128):
|
||||
return C128
|
||||
if dtp is np.dtype(np.complex64):
|
||||
return C64
|
||||
if dtp is np.dtype(np.float16):
|
||||
return F16
|
||||
return np.ctypeslib.as_ctypes_type(dtp)
|
||||
|
||||
|
||||
def to_numpy(array):
|
||||
"""Converts ctypes array back to numpy dtype array."""
|
||||
if array.dtype == C128:
|
||||
return array.view("complex128")
|
||||
if array.dtype == C64:
|
||||
return array.view("complex64")
|
||||
if array.dtype == F16:
|
||||
return array.view("float16")
|
||||
return array
|
||||
|
||||
|
||||
def make_nd_memref_descriptor(rank, dtype):
|
||||
|
||||
class MemRefDescriptor(ctypes.Structure):
|
||||
|
@ -105,11 +123,7 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
|
|||
np.ctypeslib.as_array(val[0].shape),
|
||||
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
|
||||
)
|
||||
if strided_arr.dtype == C128:
|
||||
return strided_arr.view("complex128")
|
||||
if strided_arr.dtype == C64:
|
||||
return strided_arr.view("complex64")
|
||||
return strided_arr
|
||||
return to_numpy(strided_arr)
|
||||
|
||||
|
||||
def ranked_memref_to_numpy(ranked_memref):
|
||||
|
@ -121,8 +135,4 @@ def ranked_memref_to_numpy(ranked_memref):
|
|||
np.ctypeslib.as_array(ranked_memref[0].shape),
|
||||
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
|
||||
)
|
||||
if strided_arr.dtype == C128:
|
||||
return strided_arr.view("complex128")
|
||||
if strided_arr.dtype == C64:
|
||||
return strided_arr.view("complex64")
|
||||
return strided_arr
|
||||
return to_numpy(strided_arr)
|
||||
|
|
|
@ -266,6 +266,50 @@ def testMemrefAdd():
|
|||
run(testMemrefAdd)
|
||||
|
||||
|
||||
# Test addition of two f16 memrefs
|
||||
# CHECK-LABEL: TEST: testF16MemrefAdd
|
||||
def testF16MemrefAdd():
|
||||
with Context():
|
||||
module = Module.parse("""
|
||||
module {
|
||||
func.func @main(%arg0: memref<1xf16>,
|
||||
%arg1: memref<1xf16>,
|
||||
%arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = memref.load %arg0[%0] : memref<1xf16>
|
||||
%2 = memref.load %arg1[%0] : memref<1xf16>
|
||||
%3 = arith.addf %1, %2 : f16
|
||||
memref.store %3, %arg2[%0] : memref<1xf16>
|
||||
return
|
||||
}
|
||||
} """)
|
||||
|
||||
arg1 = np.array([11.]).astype(np.float16)
|
||||
arg2 = np.array([22.]).astype(np.float16)
|
||||
arg3 = np.array([0.]).astype(np.float16)
|
||||
|
||||
arg1_memref_ptr = ctypes.pointer(
|
||||
ctypes.pointer(get_ranked_memref_descriptor(arg1)))
|
||||
arg2_memref_ptr = ctypes.pointer(
|
||||
ctypes.pointer(get_ranked_memref_descriptor(arg2)))
|
||||
arg3_memref_ptr = ctypes.pointer(
|
||||
ctypes.pointer(get_ranked_memref_descriptor(arg3)))
|
||||
|
||||
execution_engine = ExecutionEngine(lowerToLLVM(module))
|
||||
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr,
|
||||
arg3_memref_ptr)
|
||||
# CHECK: [11.] + [22.] = [33.]
|
||||
log("{0} + {1} = {2}".format(arg1, arg2, arg3))
|
||||
|
||||
# test to-numpy utility
|
||||
# CHECK: [33.]
|
||||
npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
|
||||
log(npout)
|
||||
|
||||
|
||||
run(testF16MemrefAdd)
|
||||
|
||||
|
||||
# Test addition of two complex memrefs
|
||||
# CHECK-LABEL: TEST: testComplexMemrefAdd
|
||||
def testComplexMemrefAdd():
|
||||
|
@ -442,15 +486,15 @@ def testSharedLibLoad():
|
|||
ctypes.pointer(get_ranked_memref_descriptor(arg0)))
|
||||
|
||||
if sys.platform == 'win32':
|
||||
shared_libs = [
|
||||
"../../../../bin/mlir_runner_utils.dll",
|
||||
"../../../../bin/mlir_c_runner_utils.dll"
|
||||
]
|
||||
shared_libs = [
|
||||
"../../../../bin/mlir_runner_utils.dll",
|
||||
"../../../../bin/mlir_c_runner_utils.dll"
|
||||
]
|
||||
else:
|
||||
shared_libs = [
|
||||
"../../../../lib/libmlir_runner_utils.so",
|
||||
"../../../../lib/libmlir_c_runner_utils.so"
|
||||
]
|
||||
shared_libs = [
|
||||
"../../../../lib/libmlir_runner_utils.so",
|
||||
"../../../../lib/libmlir_c_runner_utils.so"
|
||||
]
|
||||
|
||||
execution_engine = ExecutionEngine(
|
||||
lowerToLLVM(module),
|
||||
|
@ -484,15 +528,15 @@ def testNanoTime():
|
|||
}""")
|
||||
|
||||
if sys.platform == 'win32':
|
||||
shared_libs = [
|
||||
"../../../../bin/mlir_runner_utils.dll",
|
||||
"../../../../bin/mlir_c_runner_utils.dll"
|
||||
]
|
||||
shared_libs = [
|
||||
"../../../../bin/mlir_runner_utils.dll",
|
||||
"../../../../bin/mlir_c_runner_utils.dll"
|
||||
]
|
||||
else:
|
||||
shared_libs = [
|
||||
"../../../../lib/libmlir_runner_utils.so",
|
||||
"../../../../lib/libmlir_c_runner_utils.so"
|
||||
]
|
||||
shared_libs = [
|
||||
"../../../../lib/libmlir_runner_utils.so",
|
||||
"../../../../lib/libmlir_c_runner_utils.so"
|
||||
]
|
||||
|
||||
execution_engine = ExecutionEngine(
|
||||
lowerToLLVM(module),
|
||||
|
|
Loading…
Reference in New Issue