[mlir][sparse] add "sort" to the compress op codegen

This revision also adds convenience methods to test the
dim level type/property (with the codegen being first client)

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D134776
This commit is contained in:
Aart Bik 2022-09-27 17:06:20 -07:00
parent c5983963de
commit 4d06861950
5 changed files with 203 additions and 56 deletions

View File

@ -27,9 +27,43 @@
namespace mlir {
namespace sparse_tensor {
/// Convenience method to get a sparse encoding attribute from a type.
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
//
// Dimension level types.
//
bool isDenseDim(SparseTensorEncodingAttr::DimLevelType dltp);
bool isCompressedDim(SparseTensorEncodingAttr::DimLevelType dltp);
bool isSingletonDim(SparseTensorEncodingAttr::DimLevelType dltp);
/// Convenience method to test for dense dimension (0 <= d < rank).
bool isDenseDim(RankedTensorType type, uint64_t d);
/// Convenience method to test for compressed dimension (0 <= d < rank).
bool isCompressedDim(RankedTensorType type, uint64_t d);
/// Convenience method to test for singleton dimension (0 <= d < rank).
bool isSingletonDim(RankedTensorType type, uint64_t d);
//
// Dimension level properties.
//
bool isOrderedDim(SparseTensorEncodingAttr::DimLevelType dltp);
bool isUniqueDim(SparseTensorEncodingAttr::DimLevelType dltp);
/// Convenience method to test for ordered property in the
/// given dimension (0 <= d < rank).
bool isOrderedDim(RankedTensorType type, uint64_t d);
/// Convenience method to test for unique property in the
/// given dimension (0 <= d < rank).
bool isUniqueDim(RankedTensorType type, uint64_t d);
} // namespace sparse_tensor
} // namespace mlir

View File

@ -216,6 +216,10 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return success();
}
//===----------------------------------------------------------------------===//
// Convenience Methods.
//===----------------------------------------------------------------------===//
SparseTensorEncodingAttr
mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
if (auto ttp = type.dyn_cast<RankedTensorType>())
@ -223,6 +227,98 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
return nullptr;
}
bool mlir::sparse_tensor::isDenseDim(
SparseTensorEncodingAttr::DimLevelType dltp) {
return dltp == SparseTensorEncodingAttr::DimLevelType::Dense;
}
bool mlir::sparse_tensor::isCompressedDim(
SparseTensorEncodingAttr::DimLevelType dltp) {
switch (dltp) {
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
return true;
default:
return false;
}
}
bool mlir::sparse_tensor::isSingletonDim(
SparseTensorEncodingAttr::DimLevelType dltp) {
switch (dltp) {
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
return true;
default:
return false;
}
}
bool mlir::sparse_tensor::isDenseDim(RankedTensorType type, uint64_t d) {
assert(d < static_cast<uint64_t>(type.getRank()));
if (auto enc = getSparseTensorEncoding(type))
return isDenseDim(enc.getDimLevelType()[d]);
return true; // unannotated tensor is dense
}
bool mlir::sparse_tensor::isCompressedDim(RankedTensorType type, uint64_t d) {
assert(d < static_cast<uint64_t>(type.getRank()));
if (auto enc = getSparseTensorEncoding(type))
return isCompressedDim(enc.getDimLevelType()[d]);
return false; // unannotated tensor is dense
}
bool mlir::sparse_tensor::isSingletonDim(RankedTensorType type, uint64_t d) {
assert(d < static_cast<uint64_t>(type.getRank()));
if (auto enc = getSparseTensorEncoding(type))
return isSingletonDim(enc.getDimLevelType()[d]);
return false; // unannotated tensor is dense
}
bool mlir::sparse_tensor::isOrderedDim(
SparseTensorEncodingAttr::DimLevelType dltp) {
switch (dltp) {
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
return false;
default:
return true;
}
}
bool mlir::sparse_tensor::isUniqueDim(
SparseTensorEncodingAttr::DimLevelType dltp) {
switch (dltp) {
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
return false;
default:
return true;
}
}
bool mlir::sparse_tensor::isOrderedDim(RankedTensorType type, uint64_t d) {
assert(d < static_cast<uint64_t>(type.getRank()));
if (auto enc = getSparseTensorEncoding(type))
return isOrderedDim(enc.getDimLevelType()[d]);
return true; // unannotated tensor is dense (and thus ordered)
}
bool mlir::sparse_tensor::isUniqueDim(RankedTensorType type, uint64_t d) {
assert(d < static_cast<uint64_t>(type.getRank()));
if (auto enc = getSparseTensorEncoding(type))
return isUniqueDim(enc.getDimLevelType()[d]);
return true; // unannotated tensor is dense (and thus unique)
}
//===----------------------------------------------------------------------===//
// TensorDialect Operations.
//===----------------------------------------------------------------------===//

View File

@ -103,37 +103,28 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
/// Returns field index of sparse tensor type for pointers/indices, when set.
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
auto enc = getSparseTensorEncoding(type);
assert(enc);
assert(getSparseTensorEncoding(type));
RankedTensorType rType = type.cast<RankedTensorType>();
unsigned field = 2; // start past sizes
unsigned ptr = 0;
unsigned idx = 0;
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
break; // no fields
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
if (isCompressedDim(rType, r)) {
if (ptr++ == ptrDim)
return field;
field++;
if (idx++ == idxDim)
return field;
field++;
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
} else if (isSingletonDim(rType, r)) {
if (idx++ == idxDim)
return field;
field++;
break;
} else {
assert(isDenseDim(rType, r)); // no fields
}
}
assert(ptrDim == -1u && idxDim == -1u);
return field + 1; // return values field index
}
@ -176,7 +167,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
// The dimSizes array.
fields.push_back(MemRefType::get({rank}, indexType));
// The memSizes array.
unsigned lastField = getFieldIndex(type, -1, -1);
unsigned lastField = getFieldIndex(type, -1u, -1u);
fields.push_back(MemRefType::get({lastField - 2}, indexType));
// Per-dimension storage.
for (unsigned r = 0; r < rank; r++) {
@ -184,22 +175,13 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
// As a result, the compound type can be constructed directly in the given
// order. Clients of this type know what field is what from the sparse
// tensor type.
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
break; // no fields
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
if (isCompressedDim(rType, r)) {
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
} else if (isSingletonDim(rType, r)) {
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
break;
} else {
assert(isDenseDim(rType, r)); // no fields
}
}
// The values array.
@ -254,7 +236,7 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
fields.push_back(dimSizes);
// The sizes array.
unsigned lastField = getFieldIndex(type, -1, -1);
unsigned lastField = getFieldIndex(type, -1u, -1u);
Value memSizes = builder.create<memref::AllocOp>(
loc, MemRefType::get({lastField - 2}, indexType));
fields.push_back(memSizes);
@ -265,25 +247,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
builder.create<memref::StoreOp>(loc, sizes[ro], dimSizes,
constantIndex(builder, loc, r));
linear = builder.create<arith::MulIOp>(loc, linear, sizes[ro]);
// Allocate fiels.
switch (enc.getDimLevelType()[r]) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
break; // no fields
case SparseTensorEncodingAttr::DimLevelType::Compressed:
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
// Allocate fields.
if (isCompressedDim(rType, r)) {
fields.push_back(createAllocation(builder, loc, ptrType, heuristic));
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
allDense = false;
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
} else if (isSingletonDim(rType, r)) {
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
allDense = false;
break;
} else {
assert(isDenseDim(rType, r)); // no fields
}
}
// The values array. For all-dense, the full length is required.
@ -507,7 +480,8 @@ public:
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
RankedTensorType srcType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
Type idxType = rewriter.getIndexType();
@ -561,17 +535,18 @@ public:
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
Type eltType = srcType.getElementType();
RankedTensorType dstType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = dstType.getElementType();
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
Value count = adaptor.getCount();
//
// TODO: need to implement "std::sort(added, added + count);" for ordered
//
// If the innermost dimension is ordered, we need to sort the indices
// in the "added" array prior to applying the compression.
unsigned rank = dstType.getShape().size();
if (isOrderedDim(dstType, rank - 1))
rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{});
// While performing the insertions, we also need to reset the elements
// of the values/filled-switch by only iterating over the set elements,
// to ensure that the runtime complexity remains proportional to the
@ -699,7 +674,7 @@ public:
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToPointersOp op) {
uint64_t dim = op.getDimension().getZExtValue();
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1);
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u);
}
};
@ -712,7 +687,7 @@ public:
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToIndicesOp op) {
uint64_t dim = op.getDimension().getZExtValue();
return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/dim);
return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim);
}
};

View File

@ -156,8 +156,9 @@ struct SparseTensorCodegenPass
RewritePatternSet patterns(ctx);
SparseTensorTypeToBufferConverter converter;
ConversionTarget target(*ctx);
// Everything in the sparse dialect must go!
// Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<SortOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.

View File

@ -24,6 +24,10 @@
pointerBitWidth = 32
}>
#UCSR = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed-no" ]
}>
#CSC = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
dimOrdering = affine_map<(i, j) -> (j, i)>
@ -363,7 +367,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// TODO: sort
// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
// TODO: insert
@ -385,6 +389,43 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
return
}
// CHECK-LABEL: func @sparse_compression_unordered(
// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
// CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
// CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
// CHECK-SAME: %[[A8:.*8]]: index,
// CHECK-SAME: %[[A9:.*9]]: index)
// CHECK-DAG: %[[B0:.*]] = arith.constant false
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-NOT: sparse_tensor.sort
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
// TODO: insert
// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
// CHECK-NEXT: }
// CHECK-DAG: memref.dealloc %[[A5]] : memref<?xf64>
// CHECK-DAG: memref.dealloc %[[A6]] : memref<?xi1>
// CHECK-DAG: memref.dealloc %[[A7]] : memref<?xindex>
// CHECK: return
func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
%i: index) {
sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
return
}
// CHECK-LABEL: func @sparse_push_back(
// CHECK-SAME: %[[A:.*]]: memref<?xindex>,
// CHECK-SAME: %[[B:.*]]: memref<?xf64>,