forked from OSchip/llvm-project
[mlir][sparse] Moved a conditional from the RT library to the generated MLIR.
When generating code to add an element to SparseTensorCOO (e.g., when doing dense=>sparse conversion), we used to check for nonzero values on the runtime side, whereas now we generate MLIR code to do that check. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D110121
This commit is contained in:
parent
1286bbc85f
commit
221856f5cd
|
@ -182,11 +182,27 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
return call.getResult(0);
|
return call.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
|
||||||
|
/// For floating types, we use the "unordered" comparator (i.e., returns
|
||||||
|
/// true if `v` is NaN).
|
||||||
|
static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Type t, Value v) {
|
||||||
|
Value zero = rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(t));
|
||||||
|
if (t.isa<FloatType>())
|
||||||
|
return rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, v, zero);
|
||||||
|
if (t.isIntOrIndex())
|
||||||
|
return rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, v, zero);
|
||||||
|
llvm_unreachable("Unknown element type");
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates a call that adds one element to a coordinate scheme.
|
/// Generates a call that adds one element to a coordinate scheme.
|
||||||
|
/// In particular, this generates code like the following:
|
||||||
|
/// val = a[i1,..,ik];
|
||||||
|
/// if val != 0
|
||||||
|
/// t->add(val, [i1,..,ik], [p1,..,pk]);
|
||||||
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
Value ptr, Value tensor, Value ind, Value perm,
|
Value ptr, Value tensor, Value ind, Value perm,
|
||||||
ValueRange ivs) {
|
ValueRange ivs) {
|
||||||
Location loc = op->getLoc();
|
|
||||||
StringRef name;
|
StringRef name;
|
||||||
Type eltType = tensor.getType().cast<ShapedType>().getElementType();
|
Type eltType = tensor.getType().cast<ShapedType>().getElementType();
|
||||||
if (eltType.isF64())
|
if (eltType.isF64())
|
||||||
|
@ -203,8 +219,11 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
name = "addEltI8";
|
name = "addEltI8";
|
||||||
else
|
else
|
||||||
llvm_unreachable("Unknown element type");
|
llvm_unreachable("Unknown element type");
|
||||||
|
Location loc = op->getLoc();
|
||||||
Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
|
Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
|
||||||
// TODO: add if here?
|
Value cond = genIsNonzero(rewriter, loc, eltType, val);
|
||||||
|
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
|
||||||
|
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
for (auto iv : ivs) {
|
for (auto iv : ivs) {
|
||||||
Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
|
Value idx = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(i++));
|
||||||
|
@ -321,6 +340,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
||||||
// Note that the dense tensor traversal code is actually implemented
|
// Note that the dense tensor traversal code is actually implemented
|
||||||
// using MLIR IR to avoid having to expose too much low-level
|
// using MLIR IR to avoid having to expose too much low-level
|
||||||
// memref traversal details to the runtime support library.
|
// memref traversal details to the runtime support library.
|
||||||
|
// Also note that the code below only generates the "new" ops and
|
||||||
|
// the loop-nest per se; whereas the entire body of the innermost
|
||||||
|
// loop is generated by genAddElt().
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
ShapedType shape = resType.cast<ShapedType>();
|
ShapedType shape = resType.cast<ShapedType>();
|
||||||
auto memTp =
|
auto memTp =
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===//
|
//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -114,7 +114,8 @@ struct SparseTensorConversionPass
|
||||||
});
|
});
|
||||||
// The following operations and dialects may be introduced by the
|
// The following operations and dialects may be introduced by the
|
||||||
// rewriting rules, and are therefore marked as legal.
|
// rewriting rules, and are therefore marked as legal.
|
||||||
target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp>();
|
target.addLegalOp<ConstantOp, tensor::CastOp, tensor::ExtractOp, CmpFOp,
|
||||||
|
CmpIOp>();
|
||||||
target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
|
target.addLegalDialect<scf::SCFDialect, LLVM::LLVMDialect,
|
||||||
memref::MemRefDialect>();
|
memref::MemRefDialect>();
|
||||||
// Populate with rules and apply rewriting rules.
|
// Populate with rules and apply rewriting rules.
|
||||||
|
|
|
@ -548,8 +548,6 @@ char *getTensorFilename(uint64_t id) {
|
||||||
void *_mlir_ciface_##NAME(void *tensor, TYPE value, \
|
void *_mlir_ciface_##NAME(void *tensor, TYPE value, \
|
||||||
StridedMemRefType<uint64_t, 1> *iref, \
|
StridedMemRefType<uint64_t, 1> *iref, \
|
||||||
StridedMemRefType<uint64_t, 1> *pref) { \
|
StridedMemRefType<uint64_t, 1> *pref) { \
|
||||||
if (!value) \
|
|
||||||
return tensor; \
|
|
||||||
assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
|
assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
|
||||||
assert(iref->sizes[0] == pref->sizes[0]); \
|
assert(iref->sizes[0] == pref->sizes[0]); \
|
||||||
const uint64_t *indx = iref->data + iref->offset; \
|
const uint64_t *indx = iref->data + iref->offset; \
|
||||||
|
|
Loading…
Reference in New Issue