forked from OSchip/llvm-project
[mlir][sparse] Factoring out getZero() and avoiding unnecessary Type params
This is preliminary work towards D110790 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D110882
This commit is contained in:
parent
b084b98abe
commit
ca01034714
|
@ -182,12 +182,19 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
return call.getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a constant zero of the given type.
|
||||
static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type t) {
|
||||
return rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(t));
|
||||
}
|
||||
|
||||
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
|
||||
/// For floating types, we use the "unordered" comparator (i.e., returns
|
||||
/// true if `v` is NaN).
|
||||
static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type t, Value v) {
|
||||
Value zero = rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(t));
|
||||
Value v) {
|
||||
Type t = v.getType();
|
||||
Value zero = getZero(rewriter, loc, t);
|
||||
if (t.isa<FloatType>())
|
||||
return rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, v, zero);
|
||||
if (t.isIntOrIndex())
|
||||
|
@ -203,11 +210,11 @@ static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
|
|||
/// if (tensor[ivs]!=0) {
|
||||
/// ind = ivs
|
||||
static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Type eltType, Value tensor,
|
||||
Value ind, ValueRange ivs) {
|
||||
Operation *op, Value tensor, Value ind,
|
||||
ValueRange ivs) {
|
||||
Location loc = op->getLoc();
|
||||
Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
|
||||
Value cond = genIsNonzero(rewriter, loc, eltType, val);
|
||||
Value cond = genIsNonzero(rewriter, loc, val);
|
||||
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
|
||||
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
|
||||
unsigned i = 0;
|
||||
|
@ -446,8 +453,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
val = genIndexAndValueForSparse(
|
||||
rewriter, op, indices, values, ind, ivs, rank);
|
||||
else
|
||||
val = genIndexAndValueForDense(rewriter, op, eltType,
|
||||
tensor, ind, ivs);
|
||||
val = genIndexAndValueForDense(rewriter, op, tensor,
|
||||
ind, ivs);
|
||||
genAddEltCall(rewriter, op, eltType, ptr, val, ind,
|
||||
perm);
|
||||
return {};
|
||||
|
|
Loading…
Reference in New Issue