[mlir][sparse] add ability to select pointer/index storage type

This change gives sparse compiler clients more control over selecting
individual types for the pointers and indices in the sparse storage schemes.
Narrower width obviously results in smaller memory footprints, but the
range should always suffice for the maximum number of entries or index value.

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D92126
This commit is contained in:
Aart Bik 2020-11-25 12:29:05 -08:00
parent da0aaedcd0
commit d5f0d0c0c4
4 changed files with 181 additions and 23 deletions

View File

@ -821,18 +821,31 @@ enum class SparseVectorizationStrategy {
kAnyStorageInnerLoop kAnyStorageInnerLoop
}; };
/// Defines a type for "pointer" and "index" storage in the sparse storage
/// scheme, with a choice between the native platform-dependent index width,
/// 64-bit integers, or 32-bit integers. A narrow width obviously reduces
/// the memory footprint of the sparse storage scheme, but the width should
/// suffice to define the total required range (viz. the maximum number of
/// stored entries per indirection level for the "pointers" and the maximum
/// value of each tensor index over all dimensions for the "indices").
enum class SparseIntType { kNative, kI64, kI32 };
/// Sparsification options. /// Sparsification options.
struct SparsificationOptions { struct SparsificationOptions {
SparsificationOptions(SparseParallelizationStrategy p, SparsificationOptions(SparseParallelizationStrategy p,
SparseVectorizationStrategy v, unsigned vl) SparseVectorizationStrategy v, unsigned vl,
: parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) { SparseIntType pt, SparseIntType it)
} : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
ptrType(pt), indType(it) {}
SparsificationOptions() SparsificationOptions()
: SparsificationOptions(SparseParallelizationStrategy::kNone, : SparsificationOptions(SparseParallelizationStrategy::kNone,
SparseVectorizationStrategy::kNone, 1u) {} SparseVectorizationStrategy::kNone, 1u,
SparseIntType::kNative, SparseIntType::kNative) {}
SparseParallelizationStrategy parallelizationStrategy; SparseParallelizationStrategy parallelizationStrategy;
SparseVectorizationStrategy vectorizationStrategy; SparseVectorizationStrategy vectorizationStrategy;
unsigned vectorLength; unsigned vectorLength;
SparseIntType ptrType;
SparseIntType indType;
}; };
/// Set up sparsification rewriting rules with the given options. /// Set up sparsification rewriting rules with the given options.

View File

@ -420,16 +420,27 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
} }
} }
/// Maps sparse integer option to actual integral storage type.
static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
switch (tp) {
case linalg::SparseIntType::kNative:
return rewriter.getIndexType();
case linalg::SparseIntType::kI64:
return rewriter.getIntegerType(64);
case linalg::SparseIntType::kI32:
return rewriter.getIntegerType(32);
}
}
/// Local bufferization of all dense and sparse data structures. /// Local bufferization of all dense and sparse data structures.
/// This code enables testing the first prototype sparse compiler. /// This code enables testing the first prototype sparse compiler.
// TODO: replace this with a proliferated bufferization strategy // TODO: replace this with a proliferated bufferization strategy
void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, static void genBuffers(Merger &merger, CodeGen &codegen,
linalg::GenericOp op) { PatternRewriter &rewriter, linalg::GenericOp op) {
Location loc = op.getLoc(); Location loc = op.getLoc();
unsigned numTensors = op.getNumInputsAndOutputs(); unsigned numTensors = op.getNumInputsAndOutputs();
unsigned numInputs = op.getNumInputs(); unsigned numInputs = op.getNumInputs();
assert(numTensors == numInputs + 1); assert(numTensors == numInputs + 1);
Type indexType = rewriter.getIndexType();
// For now, set all unknown dimensions to 999. // For now, set all unknown dimensions to 999.
// TODO: compute these values (using sparsity or by reading tensor) // TODO: compute these values (using sparsity or by reading tensor)
@ -450,9 +461,13 @@ void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
// Handle sparse storage schemes. // Handle sparse storage schemes.
if (merger.isSparseAccess(t, i)) { if (merger.isSparseAccess(t, i)) {
allDense = false; allDense = false;
auto dynTp = MemRefType::get({ShapedType::kDynamicSize}, indexType); auto dynShape = {ShapedType::kDynamicSize};
codegen.pointers[t][i] = rewriter.create<AllocaOp>(loc, dynTp, unknown); auto ptrTp = MemRefType::get(
codegen.indices[t][i] = rewriter.create<AllocaOp>(loc, dynTp, unknown); dynShape, genIntType(rewriter, codegen.options.ptrType));
auto indTp = MemRefType::get(
dynShape, genIntType(rewriter, codegen.options.indType));
codegen.pointers[t][i] = rewriter.create<AllocaOp>(loc, ptrTp, unknown);
codegen.indices[t][i] = rewriter.create<AllocaOp>(loc, indTp, unknown);
} }
// Find lower and upper bound in current dimension. // Find lower and upper bound in current dimension.
Value up; Value up;
@ -516,6 +531,15 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
rewriter.create<StoreOp>(op.getLoc(), rhs, codegen.buffers[tensor], args); rewriter.create<StoreOp>(op.getLoc(), rhs, codegen.buffers[tensor], args);
} }
/// Generates a pointer/index load from the sparse storage scheme.
static Value genIntLoad(PatternRewriter &rewriter, Location loc, Value ptr,
Value s) {
Value load = rewriter.create<LoadOp>(loc, ptr, s);
return load.getType().isa<IndexType>()
? load
: rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
}
/// Recursively generates tensor expression. /// Recursively generates tensor expression.
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, unsigned exp) { linalg::GenericOp op, unsigned exp) {
@ -551,7 +575,6 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
unsigned idx = topSort[at]; unsigned idx = topSort[at];
// Initialize sparse positions. // Initialize sparse positions.
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
for (unsigned b = 0, be = inits.size(); b < be; b++) { for (unsigned b = 0, be = inits.size(); b < be; b++) {
if (inits[b]) { if (inits[b]) {
unsigned tensor = merger.tensor(b); unsigned tensor = merger.tensor(b);
@ -564,11 +587,12 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
break; break;
} }
Value ptr = codegen.pointers[tensor][idx]; Value ptr = codegen.pointers[tensor][idx];
Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) Value one = rewriter.create<ConstantIndexOp>(loc, 1);
: codegen.pidxs[tensor][topSort[pat - 1]]; Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
codegen.pidxs[tensor][idx] = rewriter.create<LoadOp>(loc, ptr, p); : codegen.pidxs[tensor][topSort[pat - 1]];
p = rewriter.create<AddIOp>(loc, p, one); codegen.pidxs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p0);
codegen.highs[tensor][idx] = rewriter.create<LoadOp>(loc, ptr, p); Value p1 = rewriter.create<AddIOp>(loc, p0, one);
codegen.highs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p1);
} else { } else {
// Dense index still in play. // Dense index still in play.
needsUniv = true; needsUniv = true;
@ -723,15 +747,17 @@ static void genLocals(Merger &merger, CodeGen &codegen,
if (locals[b] && merger.isSparseBit(b)) { if (locals[b] && merger.isSparseBit(b)) {
unsigned tensor = merger.tensor(b); unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b)); assert(idx == merger.index(b));
Value ld = rewriter.create<LoadOp>(loc, codegen.indices[tensor][idx], Value ptr = codegen.indices[tensor][idx];
codegen.pidxs[tensor][idx]); Value s = codegen.pidxs[tensor][idx];
codegen.idxs[tensor][idx] = ld; Value load = genIntLoad(rewriter, loc, ptr, s);
codegen.idxs[tensor][idx] = load;
if (!needsUniv) { if (!needsUniv) {
if (min) { if (min) {
Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, ld, min); Value cmp =
min = rewriter.create<SelectOp>(loc, cmp, ld, min); rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
min = rewriter.create<SelectOp>(loc, cmp, load, min);
} else { } else {
min = ld; min = load;
} }
} }
} }

View File

@ -0,0 +1,98 @@
// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=1" | \
// RUN: FileCheck %s --check-prefix=CHECK-TYPE0
// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=2" | \
// RUN: FileCheck %s --check-prefix=CHECK-TYPE1
// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=1" | \
// RUN: FileCheck %s --check-prefix=CHECK-TYPE2
// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=2" | \
// RUN: FileCheck %s --check-prefix=CHECK-TYPE3
#trait_mul_1d = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
affine_map<(i) -> (i)>, // b
affine_map<(i) -> (i)> // x (out)
],
sparse = [
[ "S" ], // a
[ "D" ], // b
[ "D" ] // x
],
iterator_types = ["parallel"],
doc = "x(i) = a(i) * b(i)"
}
// CHECK-TYPE0-LABEL: func @mul_dd(
// CHECK-TYPE0: %[[C0:.*]] = constant 0 : index
// CHECK-TYPE0: %[[C1:.*]] = constant 1 : index
// CHECK-TYPE0: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi64>
// CHECK-TYPE0: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
// CHECK-TYPE0: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi64>
// CHECK-TYPE0: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
// CHECK-TYPE0: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
// CHECK-TYPE0: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi64>
// CHECK-TYPE0: %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
// CHECK-TYPE0: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
// CHECK-TYPE0: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE0: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
// CHECK-TYPE0: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE0: }
// CHECK-TYPE1-LABEL: func @mul_dd(
// CHECK-TYPE1: %[[C0:.*]] = constant 0 : index
// CHECK-TYPE1: %[[C1:.*]] = constant 1 : index
// CHECK-TYPE1: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi64>
// CHECK-TYPE1: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
// CHECK-TYPE1: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi64>
// CHECK-TYPE1: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
// CHECK-TYPE1: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
// CHECK-TYPE1: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi32>
// CHECK-TYPE1: %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index
// CHECK-TYPE1: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
// CHECK-TYPE1: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE1: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
// CHECK-TYPE1: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE1: }
// CHECK-TYPE2-LABEL: func @mul_dd(
// CHECK-TYPE2: %[[C0:.*]] = constant 0 : index
// CHECK-TYPE2: %[[C1:.*]] = constant 1 : index
// CHECK-TYPE2: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi32>
// CHECK-TYPE2: %[[B0:.*]] = index_cast %[[P0]] : i32 to index
// CHECK-TYPE2: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi32>
// CHECK-TYPE2: %[[B1:.*]] = index_cast %[[P1]] : i32 to index
// CHECK-TYPE2: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
// CHECK-TYPE2: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi64>
// CHECK-TYPE2: %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
// CHECK-TYPE2: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
// CHECK-TYPE2: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE2: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
// CHECK-TYPE2: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE2: }
// CHECK-TYPE3-LABEL: func @mul_dd(
// CHECK-TYPE3: %[[C0:.*]] = constant 0 : index
// CHECK-TYPE3: %[[C1:.*]] = constant 1 : index
// CHECK-TYPE3: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi32>
// CHECK-TYPE3: %[[B0:.*]] = index_cast %[[P0]] : i32 to index
// CHECK-TYPE3: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi32>
// CHECK-TYPE3: %[[B1:.*]] = index_cast %[[P1]] : i32 to index
// CHECK-TYPE3: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
// CHECK-TYPE3: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi32>
// CHECK-TYPE3: %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index
// CHECK-TYPE3: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
// CHECK-TYPE3: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE3: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
// CHECK-TYPE3: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE3: }
func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> {
%0 = linalg.generic #trait_mul_1d
ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>) {
^bb(%a: f64, %b: f64):
%0 = mulf %a, %b : f64
linalg.yield %0 : f64
} -> tensor<32xf64>
return %0 : tensor<32xf64>
}

View File

@ -31,6 +31,14 @@ struct TestSparsification
Option<int32_t> vectorLength{ Option<int32_t> vectorLength{
*this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
Option<int32_t> ptrType{*this, "ptr-type",
llvm::cl::desc("Set the pointer type"),
llvm::cl::init(0)};
Option<int32_t> indType{*this, "ind-type",
llvm::cl::desc("Set the index type"),
llvm::cl::init(0)};
/// Registers all dialects required by testing. /// Registers all dialects required by testing.
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, vector::VectorDialect>(); registry.insert<scf::SCFDialect, vector::VectorDialect>();
@ -64,13 +72,26 @@ struct TestSparsification
} }
} }
/// Returns the requested integer type.
linalg::SparseIntType typeOption(int32_t option) {
switch (option) {
default:
return linalg::SparseIntType::kNative;
case 1:
return linalg::SparseIntType::kI64;
case 2:
return linalg::SparseIntType::kI32;
}
}
/// Runs the test on a function. /// Runs the test on a function.
void runOnFunction() override { void runOnFunction() override {
auto *ctx = &getContext(); auto *ctx = &getContext();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
// Translate strategy flags to strategy options. // Translate strategy flags to strategy options.
linalg::SparsificationOptions options(parallelOption(), vectorOption(), linalg::SparsificationOptions options(parallelOption(), vectorOption(),
vectorLength); vectorLength, typeOption(ptrType),
typeOption(indType));
// Apply rewriting. // Apply rewriting.
linalg::populateSparsificationPatterns(ctx, patterns, options); linalg::populateSparsificationPatterns(ctx, patterns, options);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));