[mlir][sparse] Factoring out predicates on DimLevelTypes

This way the predicates can be reused elsewhere, and can more easily be kept in sync with changes to the enum.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134926
This commit is contained in:
wren romano 2022-09-29 17:34:09 -07:00
parent 865406d21e
commit 1d4d1c99c5
2 changed files with 67 additions and 47 deletions

View File

@ -157,6 +157,63 @@ enum class MLIR_SPARSETENSOR_EXPORT DimLevelType : uint8_t {
kSingletonNuNo = 8,
};
/// Check if the `DimLevelType` is dense.
constexpr MLIR_SPARSETENSOR_EXPORT bool isDenseDLT(DimLevelType dlt) {
return dlt == DimLevelType::kDense;
}
/// Check if the `DimLevelType` is compressed (regardless of properties).
constexpr MLIR_SPARSETENSOR_EXPORT bool isCompressedDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kCompressed:
case DimLevelType::kCompressedNu:
case DimLevelType::kCompressedNo:
case DimLevelType::kCompressedNuNo:
return true;
default:
return false;
}
}
/// Check if the `DimLevelType` is singleton (regardless of properties).
constexpr MLIR_SPARSETENSOR_EXPORT bool isSingletonDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kSingleton:
case DimLevelType::kSingletonNu:
case DimLevelType::kSingletonNo:
case DimLevelType::kSingletonNuNo:
return true;
default:
return false;
}
}
/// Check if the `DimLevelType` is ordered (regardless of storage format).
constexpr MLIR_SPARSETENSOR_EXPORT bool isOrderedDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kCompressedNo:
case DimLevelType::kCompressedNuNo:
case DimLevelType::kSingletonNo:
case DimLevelType::kSingletonNuNo:
return false;
default:
return true;
}
}
/// Check if the `DimLevelType` is unique (regardless of storage format).
constexpr MLIR_SPARSETENSOR_EXPORT bool isUniqueDLT(DimLevelType dlt) {
switch (dlt) {
case DimLevelType::kCompressedNu:
case DimLevelType::kCompressedNuNo:
case DimLevelType::kSingletonNu:
case DimLevelType::kSingletonNuNo:
return false;
default:
return true;
}
}
} // namespace sparse_tensor
} // namespace mlir

View File

@ -102,67 +102,30 @@ public:
/// Get the dimension-types array, in storage-order.
const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
/// Safely check if the (storage-order) dimension uses dense storage.
bool isDenseDim(uint64_t d) const {
/// Safely lookup the level-type of the given (storage-order) dimension.
DimLevelType getDimType(uint64_t d) const {
ASSERT_VALID_DIM(d);
return dimTypes[d] == DimLevelType::kDense;
return dimTypes[d];
}
/// Safely check if the (storage-order) dimension uses dense storage.
bool isDenseDim(uint64_t d) const { return isDenseDLT(getDimType(d)); }
/// Safely check if the (storage-order) dimension uses compressed storage.
bool isCompressedDim(uint64_t d) const {
ASSERT_VALID_DIM(d);
switch (dimTypes[d]) {
case DimLevelType::kCompressed:
case DimLevelType::kCompressedNu:
case DimLevelType::kCompressedNo:
case DimLevelType::kCompressedNuNo:
return true;
default:
return false;
}
return isCompressedDLT(getDimType(d));
}
/// Safely check if the (storage-order) dimension uses singleton storage.
bool isSingletonDim(uint64_t d) const {
ASSERT_VALID_DIM(d);
switch (dimTypes[d]) {
case DimLevelType::kSingleton:
case DimLevelType::kSingletonNu:
case DimLevelType::kSingletonNo:
case DimLevelType::kSingletonNuNo:
return true;
default:
return false;
}
return isSingletonDLT(getDimType(d));
}
/// Safely check if the (storage-order) dimension is ordered.
bool isOrderedDim(uint64_t d) const {
ASSERT_VALID_DIM(d);
switch (dimTypes[d]) {
case DimLevelType::kCompressedNo:
case DimLevelType::kCompressedNuNo:
case DimLevelType::kSingletonNo:
case DimLevelType::kSingletonNuNo:
return false;
default:
return true;
}
}
bool isOrderedDim(uint64_t d) const { return isOrderedDLT(getDimType(d)); }
/// Safely check if the (storage-order) dimension is unique.
bool isUniqueDim(uint64_t d) const {
ASSERT_VALID_DIM(d);
switch (dimTypes[d]) {
case DimLevelType::kCompressedNu:
case DimLevelType::kCompressedNuNo:
case DimLevelType::kSingletonNu:
case DimLevelType::kSingletonNuNo:
return false;
default:
return true;
}
}
bool isUniqueDim(uint64_t d) const { return isUniqueDLT(getDimType(d)); }
/// Allocate a new enumerator.
#define DECL_NEWENUMERATOR(VNAME, V) \