forked from OSchip/llvm-project
[mlir][sparse][taco] Support f16.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D128105
This commit is contained in:
parent
3561ee586e
commit
bdeae1f57b
|
@ -12,7 +12,9 @@ compressed = pt.compressed
|
|||
dense = pt.dense
|
||||
|
||||
passed = 0
|
||||
all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float32, pt.float64]
|
||||
all_types = [
|
||||
pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
|
||||
]
|
||||
for t in all_types:
|
||||
i, j = pt.get_index_vars(2)
|
||||
A = pt.tensor([2, 3], dtype=t)
|
||||
|
@ -29,5 +31,5 @@ for t in all_types:
|
|||
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [20.0, 10.0, 70.0])
|
||||
|
||||
# CHECK: Number of passed: 18
|
||||
# CHECK: Number of passed: 21
|
||||
print("Number of passed:", passed)
|
||||
|
|
|
@ -72,7 +72,7 @@ class Type(enum.Enum):
|
|||
INT16 = np.int16
|
||||
INT32 = np.int32
|
||||
INT64 = np.int64
|
||||
# numpy _ctype_from_dtype_scalar can't handle np.float16 yet.
|
||||
FLOAT16 = np.float16
|
||||
FLOAT32 = np.float32
|
||||
FLOAT64 = np.float64
|
||||
COMPLEX64 = np.complex64
|
||||
|
@ -80,15 +80,15 @@ class Type(enum.Enum):
|
|||
|
||||
|
||||
# All floating point type enums.
|
||||
_FLOAT_TYPES = (Type.FLOAT32, Type.FLOAT64)
|
||||
_FLOAT_TYPES = (Type.FLOAT16, Type.FLOAT32, Type.FLOAT64)
|
||||
# All integral type enums.
|
||||
_INT_TYPES = (Type.INT8, Type.INT16, Type.INT32, Type.INT64)
|
||||
# All complex type enums.
|
||||
_COMPLEX_TYPES = (Type.COMPLEX64, Type.COMPLEX128)
|
||||
# Type alias for any numpy type used to implement the runtime support for the
|
||||
# enum data types.
|
||||
_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float32,
|
||||
np.float64, np.complex64, np.complex128]
|
||||
_AnyRuntimeType = Union[np.int8, np.int16, np.int32, np.int64, np.float16,
|
||||
np.float32, np.float64, np.complex64, np.complex128]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -132,6 +132,7 @@ def _dtype_to_mlir_str(dtype: DType) -> str:
|
|||
Type.INT16: "i16",
|
||||
Type.INT32: "i32",
|
||||
Type.INT64: "i64",
|
||||
Type.FLOAT16: "f16",
|
||||
Type.FLOAT32: "f32",
|
||||
Type.FLOAT64: "f64",
|
||||
Type.COMPLEX64: "complex<f32>",
|
||||
|
@ -147,6 +148,7 @@ def _nptype_to_taco_type(ty: np.dtype) -> DType:
|
|||
np.int16: Type.INT16,
|
||||
np.int32: Type.INT32,
|
||||
np.int64: Type.INT64,
|
||||
np.float16: Type.FLOAT16,
|
||||
np.float32: Type.FLOAT32,
|
||||
np.float64: Type.FLOAT64,
|
||||
np.complex64: Type.COMPLEX64,
|
||||
|
@ -162,6 +164,7 @@ def _mlir_type_from_taco_type(dtype: DType) -> ir.Type:
|
|||
Type.INT16: ir.IntegerType.get_signless(16),
|
||||
Type.INT32: ir.IntegerType.get_signless(32),
|
||||
Type.INT64: ir.IntegerType.get_signless(64),
|
||||
Type.FLOAT16: ir.F16Type.get(),
|
||||
Type.FLOAT32: ir.F32Type.get(),
|
||||
Type.FLOAT64: ir.F64Type.get(),
|
||||
Type.COMPLEX64: ir.ComplexType.get(ir.F32Type.get()),
|
||||
|
|
|
@ -39,6 +39,7 @@ int8 = mlir_pytaco.DType(mlir_pytaco.Type.INT8)
|
|||
int16 = mlir_pytaco.DType(mlir_pytaco.Type.INT16)
|
||||
int32 = mlir_pytaco.DType(mlir_pytaco.Type.INT32)
|
||||
int64 = mlir_pytaco.DType(mlir_pytaco.Type.INT64)
|
||||
float16 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT16)
|
||||
float32 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32)
|
||||
float64 = mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64)
|
||||
complex64 = mlir_pytaco.DType(mlir_pytaco.Type.COMPLEX64)
|
||||
|
|
|
@ -89,6 +89,8 @@ def _get_support_func_locator() -> _SupportFuncLocator:
|
|||
c_lib.convertFromMLIRSparseTensorI32),
|
||||
(np.int64, c_lib.convertToMLIRSparseTensorI64,
|
||||
c_lib.convertFromMLIRSparseTensorI64),
|
||||
(np.float16, c_lib.convertToMLIRSparseTensorF16,
|
||||
c_lib.convertFromMLIRSparseTensorF16),
|
||||
(np.float32, c_lib.convertToMLIRSparseTensorF32,
|
||||
c_lib.convertFromMLIRSparseTensorF32),
|
||||
(np.float64, c_lib.convertToMLIRSparseTensorF64,
|
||||
|
|
Loading…
Reference in New Issue