[mlir][sparse][taco] Support f16.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D128105
This commit is contained in:
bixia1 2022-06-17 16:14:36 -07:00
parent 3561ee586e
commit bdeae1f57b
4 changed files with 14 additions and 6 deletions

View File

@ -12,7 +12,9 @@ compressed = pt.compressed
dense = pt.dense
passed = 0
all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float32, pt.float64]
all_types = [
pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
]
for t in all_types:
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3], dtype=t)
@ -29,5 +31,5 @@ for t in all_types:
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.allclose(values, [20.0, 10.0, 70.0])
# CHECK: Number of passed: 18
# CHECK: Number of passed: 21
print("Number of passed:", passed)

View File

@ -72,7 +72,7 @@ class Type(enum.Enum):
INT16 = np.int16
INT32 = np.int32
INT64 = np.int64
# numpy _ctype_from_dtype_scalar can't handle np.float16 yet.
FLOAT16 = np.float16
FLOAT32 = np.float32
FLOAT64 = np.float64
COMPLEX64 = np.complex64
@ -80,15 +80,15 @@ class Type(enum.Enum):
# All floating point type enums.
_FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64)
_FLOAT_TYPES = (Type.FLOAT16, Type.FLOAT32, Type.FLOAT64)
# All integral type enums.
_INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
# All complex type enums.
_COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
# Type alias for any numpy type used to implement the runtime support for the
# enum data types.
_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32,
np.float64, np.complex64, np.complex128]
_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16,
np.float32, np.float64, np.complex64, np.complex128]
@dataclasses.dataclass(frozen=True)
@ -132,6 +132,7 @@ def _dtype_to_mlir_str(dtype: DType) -> str:
Type.INT16: "i16",
Type.INT32: "i32",
Type.INT64: "i64",
Type.FLOAT16: "f16",
Type.FLOAT32: "f32",
Type.FLOAT64: "f64",
Type.COMPLEX64: "complex<f32>",
@ -147,6 +148,7 @@ def _nptype_to_taco_type(ty: np.dtype) -> DType:
np.int16: Type.INT16,
np.int32: Type.INT32,
np.int64: Type.INT64,
np.float16: Type.FLOAT16,
np.float32: Type.FLOAT32,
np.float64: Type.FLOAT64,
np.complex64: Type.COMPLEX64,
@ -162,6 +164,7 @@ def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
Type.INT16: ir.IntegerType.get_signless(16),
Type.INT32: ir.IntegerType.get_signless(32),
Type.INT64: ir.IntegerType.get_signless(64),
Type.FLOAT16: ir.F16Type.get(),
Type.FLOAT32: ir.F32Type.get(),
Type.FLOAT64: ir.F64Type.get(),
Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),

View File

@ -39,6 +39,7 @@ int8 = mlir_pytaco.DType(mlir_pytaco.Type.INT8)
int16 = mlir_pytaco.DType(mlir_pytaco.Type.INT16)
int32 = mlir_pytaco.DType(mlir_pytaco.Type.INT32)
int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64)
float16 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT16)
float32 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32)
float64 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64)
complex64 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX64)

View File

@ -89,6 +89,8 @@ def _get_support_func_locator() -> _SupportFuncLocator:
c_lib.convertFromMLIRSparseTensorI32),
(np.int64, c_lib.convertToMLIRSparseTensorI64,
c_lib.convertFromMLIRSparseTensorI64),
(np.float16, c_lib.convertToMLIRSparseTensorF16,
c_lib.convertFromMLIRSparseTensorF16),
(np.float32, c_lib.convertToMLIRSparseTensorF32,
c_lib.convertFromMLIRSparseTensorF32),
(np.float64, c_lib.convertToMLIRSparseTensorF64,