forked from OSchip/llvm-project
[mlir][sparse] minor cleanup of Merger
Removed inconsistent name prefixes, added consistency checks on debug strings, added more assertions to verify assumptions that may be lifted in the future. Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D106108
This commit is contained in:
parent
0bf4b81d57
commit
8fe65972cb
|
@ -21,11 +21,11 @@ namespace sparse_tensor {
|
|||
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
|
||||
: kind(k), val(v) {
|
||||
switch (kind) {
|
||||
case Kind::kTensor:
|
||||
case kTensor:
|
||||
assert(x != -1u && y == -1u && !v);
|
||||
tensor = x;
|
||||
break;
|
||||
case Kind::kInvariant:
|
||||
case kInvariant:
|
||||
assert(x == -1u && y == -1u && v);
|
||||
break;
|
||||
case kAbsF:
|
||||
|
@ -99,10 +99,10 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
|
|||
for (unsigned p : latSets[s0])
|
||||
latSets[s].push_back(p);
|
||||
// Map binary 0-y to unary -y.
|
||||
if (kind == Kind::kSubF)
|
||||
s1 = mapSet(Kind::kNegF, s1);
|
||||
else if (kind == Kind::kSubI)
|
||||
s1 = mapSet(Kind::kNegI, s1);
|
||||
if (kind == kSubF)
|
||||
s1 = mapSet(kNegF, s1);
|
||||
else if (kind == kSubI)
|
||||
s1 = mapSet(kNegI, s1);
|
||||
// Followed by all in s1.
|
||||
for (unsigned p : latSets[s1])
|
||||
latSets[s].push_back(p);
|
||||
|
@ -110,7 +110,7 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
|
|||
}
|
||||
|
||||
unsigned Merger::mapSet(Kind kind, unsigned s0) {
|
||||
assert(Kind::kAbsF <= kind && kind <= Kind::kNegI);
|
||||
assert(kAbsF <= kind && kind <= kNegI);
|
||||
unsigned s = addSet();
|
||||
for (unsigned p : latSets[s0]) {
|
||||
unsigned e = addExp(kind, latPoints[p].exp);
|
||||
|
@ -129,8 +129,7 @@ unsigned Merger::optimizeSet(unsigned s0) {
|
|||
if (p0 != p1) {
|
||||
// Is this a straightforward copy?
|
||||
unsigned e = latPoints[p1].exp;
|
||||
if (tensorExps[e].kind == Kind::kTensor &&
|
||||
tensorExps[e].tensor == outTensor)
|
||||
if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
|
||||
continue;
|
||||
// Conjunction already covered?
|
||||
for (unsigned p2 : latSets[s]) {
|
||||
|
@ -162,9 +161,9 @@ llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
|
|||
}
|
||||
// Now apply the two basic rules.
|
||||
llvm::BitVector simple = latPoints[p0].bits;
|
||||
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
|
||||
bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
|
||||
for (unsigned b = 0, be = simple.size(); b < be; b++) {
|
||||
if (simple[b] && !isDim(b, Dim::kSparse)) {
|
||||
if (simple[b] && !isDim(b, kSparse)) {
|
||||
if (reset)
|
||||
simple.reset(b);
|
||||
reset = true;
|
||||
|
@ -189,7 +188,7 @@ bool Merger::latGT(unsigned i, unsigned j) const {
|
|||
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
|
||||
llvm::BitVector tmp = latPoints[j].bits;
|
||||
tmp ^= latPoints[i].bits;
|
||||
return !hasAnyDimOf(tmp, Dim::kSparse);
|
||||
return !hasAnyDimOf(tmp, kSparse);
|
||||
}
|
||||
|
||||
bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
||||
|
@ -201,23 +200,27 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
|||
|
||||
bool Merger::isConjunction(unsigned t, unsigned e) const {
|
||||
switch (tensorExps[e].kind) {
|
||||
case Kind::kTensor:
|
||||
case kTensor:
|
||||
return tensorExps[e].tensor == t;
|
||||
case kAbsF:
|
||||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kNegF:
|
||||
case kNegI:
|
||||
case Kind::kDivF: // note: x / c only
|
||||
case Kind::kDivS:
|
||||
case Kind::kDivU:
|
||||
case Kind::kShrS: // note: x >> inv only
|
||||
case Kind::kShrU:
|
||||
case Kind::kShlI:
|
||||
return isConjunction(t, tensorExps[e].children.e0);
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
case Kind::kAndI:
|
||||
case kDivF: // note: x / c only
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
assert(!maybeZero(tensorExps[e].children.e1));
|
||||
return isConjunction(t, tensorExps[e].children.e0);
|
||||
case kShrS: // note: x >> inv only
|
||||
case kShrU:
|
||||
case kShlI:
|
||||
assert(isInvariant(tensorExps[e].children.e1));
|
||||
return isConjunction(t, tensorExps[e].children.e0);
|
||||
case kMulF:
|
||||
case kMulI:
|
||||
case kAndI:
|
||||
return isConjunction(t, tensorExps[e].children.e0) ||
|
||||
isConjunction(t, tensorExps[e].children.e1);
|
||||
default:
|
||||
|
@ -231,20 +234,66 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
|
|||
// Print methods (for debugging).
|
||||
//
|
||||
|
||||
static const char *kOpSymbols[] = {
|
||||
"", "", "abs", "ceil", "floor", "-", "-", "*", "*", "/", "/",
|
||||
"+", "+", "-", "-", "&", "|", "^", "a>>", ">>", "<<"};
|
||||
static const char *kindToOpSymbol(Kind kind) {
|
||||
switch (kind) {
|
||||
case kTensor:
|
||||
return "tensor";
|
||||
case kInvariant:
|
||||
return "invariant";
|
||||
case kAbsF:
|
||||
return "abs";
|
||||
case kCeilF:
|
||||
return "ceil";
|
||||
case kFloorF:
|
||||
return "floor";
|
||||
case kNegF:
|
||||
return "-";
|
||||
case kNegI:
|
||||
return "-";
|
||||
case kMulF:
|
||||
return "*";
|
||||
case kMulI:
|
||||
return "*";
|
||||
case kDivF:
|
||||
return "/";
|
||||
case kDivS:
|
||||
return "/";
|
||||
case kDivU:
|
||||
return "/";
|
||||
case kAddF:
|
||||
return "+";
|
||||
case kAddI:
|
||||
return "+";
|
||||
case kSubF:
|
||||
return "-";
|
||||
case kSubI:
|
||||
return "-";
|
||||
case kAndI:
|
||||
return "&";
|
||||
case kOrI:
|
||||
return "|";
|
||||
case kXorI:
|
||||
return "^";
|
||||
case kShrS:
|
||||
return "a>>";
|
||||
case kShrU:
|
||||
return ">>";
|
||||
case kShlI:
|
||||
return "<<";
|
||||
}
|
||||
llvm_unreachable("unexpected kind for symbol");
|
||||
}
|
||||
|
||||
void Merger::dumpExp(unsigned e) const {
|
||||
switch (tensorExps[e].kind) {
|
||||
case Kind::kTensor:
|
||||
case kTensor:
|
||||
if (tensorExps[e].tensor == syntheticTensor)
|
||||
llvm::dbgs() << "synthetic_";
|
||||
else if (tensorExps[e].tensor == outTensor)
|
||||
llvm::dbgs() << "output_";
|
||||
llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
|
||||
break;
|
||||
case Kind::kInvariant:
|
||||
case kInvariant:
|
||||
llvm::dbgs() << "invariant";
|
||||
break;
|
||||
case kAbsF:
|
||||
|
@ -252,13 +301,13 @@ void Merger::dumpExp(unsigned e) const {
|
|||
case kFloorF:
|
||||
case kNegF:
|
||||
case kNegI:
|
||||
llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " ";
|
||||
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
|
||||
dumpExp(tensorExps[e].children.e0);
|
||||
break;
|
||||
default:
|
||||
llvm::dbgs() << "(";
|
||||
dumpExp(tensorExps[e].children.e0);
|
||||
llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " ";
|
||||
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
|
||||
dumpExp(tensorExps[e].children.e1);
|
||||
llvm::dbgs() << ")";
|
||||
}
|
||||
|
@ -290,16 +339,16 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
|
|||
unsigned i = index(b);
|
||||
llvm::dbgs() << " i_" << t << "_" << i << "_";
|
||||
switch (dims[t][i]) {
|
||||
case Dim::kSparse:
|
||||
case kSparse:
|
||||
llvm::dbgs() << "S";
|
||||
break;
|
||||
case Dim::kDense:
|
||||
case kDense:
|
||||
llvm::dbgs() << "D";
|
||||
break;
|
||||
case Dim::kSingle:
|
||||
case kSingle:
|
||||
llvm::dbgs() << "T";
|
||||
break;
|
||||
case Dim::kUndef:
|
||||
case kUndef:
|
||||
llvm::dbgs() << "U";
|
||||
break;
|
||||
}
|
||||
|
@ -316,13 +365,13 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
|
|||
unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
||||
Kind kind = tensorExps[e].kind;
|
||||
switch (kind) {
|
||||
case Kind::kTensor:
|
||||
case Kind::kInvariant: {
|
||||
case kTensor:
|
||||
case kInvariant: {
|
||||
// Either the index is really used in the tensor expression, or it is
|
||||
// set to the undefined index in that dimension. An invariant expression
|
||||
// is set to a synthetic tensor with undefined indices only.
|
||||
unsigned s = addSet();
|
||||
unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
|
||||
unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
|
||||
latSets[s].push_back(addLat(t, i, e));
|
||||
return s;
|
||||
}
|
||||
|
@ -338,9 +387,9 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
// --+---+---+
|
||||
// | 0 |-y |
|
||||
return mapSet(kind, buildLattices(tensorExps[e].children.e0, i));
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
case Kind::kAndI:
|
||||
case kMulF:
|
||||
case kMulI:
|
||||
case kAndI:
|
||||
// A multiplicative operation only needs to be performed
|
||||
// for the conjunction of sparse iteration spaces.
|
||||
//
|
||||
|
@ -351,9 +400,9 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
return takeConj(kind, // take binary conjunction
|
||||
buildLattices(tensorExps[e].children.e0, i),
|
||||
buildLattices(tensorExps[e].children.e1, i));
|
||||
case Kind::kDivF:
|
||||
case Kind::kDivS:
|
||||
case Kind::kDivU:
|
||||
case kDivF:
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
// A division is tricky, since 0/0, 0/c, c/0 all have
|
||||
// specific outcomes for floating-point and integers.
|
||||
// Thus, we need to traverse the full iteration space.
|
||||
|
@ -367,15 +416,16 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
// during expression building, so that the conjunction
|
||||
// rules applies (viz. x/c = x*(1/c) as far as lattice
|
||||
// construction is concerned).
|
||||
assert(!maybeZero(tensorExps[e].children.e1));
|
||||
return takeConj(kind, // take binary conjunction
|
||||
buildLattices(tensorExps[e].children.e0, i),
|
||||
buildLattices(tensorExps[e].children.e1, i));
|
||||
case Kind::kAddF:
|
||||
case Kind::kAddI:
|
||||
case Kind::kSubF:
|
||||
case Kind::kSubI:
|
||||
case Kind::kOrI:
|
||||
case Kind::kXorI:
|
||||
case kAddF:
|
||||
case kAddI:
|
||||
case kSubF:
|
||||
case kSubI:
|
||||
case kOrI:
|
||||
case kXorI:
|
||||
// An additive operation needs to be performed
|
||||
// for the disjunction of sparse iteration spaces.
|
||||
//
|
||||
|
@ -386,12 +436,13 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
return takeDisj(kind, // take binary disjunction
|
||||
buildLattices(tensorExps[e].children.e0, i),
|
||||
buildLattices(tensorExps[e].children.e1, i));
|
||||
case Kind::kShrS:
|
||||
case Kind::kShrU:
|
||||
case Kind::kShlI:
|
||||
case kShrS:
|
||||
case kShrU:
|
||||
case kShlI:
|
||||
// A shift operation by an invariant amount (viz. tensor expressions
|
||||
// can only occur at the left-hand-side of the operator) can be handled
|
||||
// with the conjuction rule.
|
||||
assert(isInvariant(tensorExps[e].children.e1));
|
||||
return takeConj(kind, // take binary conjunction
|
||||
buildLattices(tensorExps[e].children.e0, i),
|
||||
buildLattices(tensorExps[e].children.e1, i));
|
||||
|
@ -405,7 +456,7 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
|
|||
}
|
||||
|
||||
bool Merger::maybeZero(unsigned e) const {
|
||||
if (tensorExps[e].kind == Kind::kInvariant) {
|
||||
if (tensorExps[e].kind == kInvariant) {
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
|
||||
return c.getValue() == 0;
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
|
||||
|
@ -415,7 +466,7 @@ bool Merger::maybeZero(unsigned e) const {
|
|||
}
|
||||
|
||||
bool Merger::isInvariant(unsigned e) const {
|
||||
return tensorExps[e].kind == Kind::kInvariant;
|
||||
return tensorExps[e].kind == kInvariant;
|
||||
}
|
||||
|
||||
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
|
@ -427,30 +478,30 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
if (arg.getOwner()->getParentOp() == op) {
|
||||
OpOperand *t = op.getInputAndOutputOperands()[argN];
|
||||
if (!op.isScalar(t))
|
||||
return addExp(Kind::kTensor, argN);
|
||||
return addExp(kTensor, argN);
|
||||
v = t->get(); // get scalar value
|
||||
}
|
||||
// Any other argument (marked as scalar argument for the generic op
|
||||
// or belonging to an enveloping op) is considered invariant.
|
||||
return addExp(Kind::kInvariant, v);
|
||||
return addExp(kInvariant, v);
|
||||
}
|
||||
// Something defined outside is invariant.
|
||||
Operation *def = v.getDefiningOp();
|
||||
if (def->getBlock() != &op.region().front())
|
||||
return addExp(Kind::kInvariant, v);
|
||||
return addExp(kInvariant, v);
|
||||
// Construct unary operations if subexpression can be built.
|
||||
if (def->getNumOperands() == 1) {
|
||||
auto x = buildTensorExp(op, def->getOperand(0));
|
||||
if (x.hasValue()) {
|
||||
unsigned e = x.getValue();
|
||||
if (isa<AbsFOp>(def))
|
||||
return addExp(Kind::kAbsF, e);
|
||||
return addExp(kAbsF, e);
|
||||
if (isa<CeilFOp>(def))
|
||||
return addExp(Kind::kCeilF, e);
|
||||
return addExp(kCeilF, e);
|
||||
if (isa<FloorFOp>(def))
|
||||
return addExp(Kind::kFloorF, e);
|
||||
return addExp(kFloorF, e);
|
||||
if (isa<NegFOp>(def))
|
||||
return addExp(Kind::kNegF, e);
|
||||
return addExp(kNegF, e);
|
||||
// TODO: no negi in std?
|
||||
}
|
||||
}
|
||||
|
@ -463,35 +514,35 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
unsigned e0 = x.getValue();
|
||||
unsigned e1 = y.getValue();
|
||||
if (isa<MulFOp>(def))
|
||||
return addExp(Kind::kMulF, e0, e1);
|
||||
return addExp(kMulF, e0, e1);
|
||||
if (isa<MulIOp>(def))
|
||||
return addExp(Kind::kMulI, e0, e1);
|
||||
return addExp(kMulI, e0, e1);
|
||||
if (isa<DivFOp>(def) && !maybeZero(e1))
|
||||
return addExp(Kind::kDivF, e0, e1);
|
||||
return addExp(kDivF, e0, e1);
|
||||
if (isa<SignedDivIOp>(def) && !maybeZero(e1))
|
||||
return addExp(Kind::kDivS, e0, e1);
|
||||
return addExp(kDivS, e0, e1);
|
||||
if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
|
||||
return addExp(Kind::kDivU, e0, e1);
|
||||
return addExp(kDivU, e0, e1);
|
||||
if (isa<AddFOp>(def))
|
||||
return addExp(Kind::kAddF, e0, e1);
|
||||
return addExp(kAddF, e0, e1);
|
||||
if (isa<AddIOp>(def))
|
||||
return addExp(Kind::kAddI, e0, e1);
|
||||
return addExp(kAddI, e0, e1);
|
||||
if (isa<SubFOp>(def))
|
||||
return addExp(Kind::kSubF, e0, e1);
|
||||
return addExp(kSubF, e0, e1);
|
||||
if (isa<SubIOp>(def))
|
||||
return addExp(Kind::kSubI, e0, e1);
|
||||
return addExp(kSubI, e0, e1);
|
||||
if (isa<AndOp>(def))
|
||||
return addExp(Kind::kAndI, e0, e1);
|
||||
return addExp(kAndI, e0, e1);
|
||||
if (isa<OrOp>(def))
|
||||
return addExp(Kind::kOrI, e0, e1);
|
||||
return addExp(kOrI, e0, e1);
|
||||
if (isa<XOrOp>(def))
|
||||
return addExp(Kind::kXorI, e0, e1);
|
||||
return addExp(kXorI, e0, e1);
|
||||
if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
|
||||
return addExp(Kind::kShrS, e0, e1);
|
||||
return addExp(kShrS, e0, e1);
|
||||
if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
|
||||
return addExp(Kind::kShrU, e0, e1);
|
||||
return addExp(kShrU, e0, e1);
|
||||
if (isa<ShiftLeftOp>(def) && isInvariant(e1))
|
||||
return addExp(Kind::kShlI, e0, e1);
|
||||
return addExp(kShlI, e0, e1);
|
||||
}
|
||||
}
|
||||
// Cannot build.
|
||||
|
@ -501,8 +552,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
|
||||
Value v0, Value v1) {
|
||||
switch (tensorExps[e].kind) {
|
||||
case Kind::kTensor:
|
||||
case Kind::kInvariant:
|
||||
case kTensor:
|
||||
case kInvariant:
|
||||
llvm_unreachable("unexpected non-op");
|
||||
case kAbsF:
|
||||
return rewriter.create<AbsFOp>(loc, v0);
|
||||
|
@ -515,35 +566,35 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
|
|||
case kNegI:
|
||||
assert(v1); // no negi in std
|
||||
return rewriter.create<SubIOp>(loc, v0, v1);
|
||||
case Kind::kMulF:
|
||||
case kMulF:
|
||||
return rewriter.create<MulFOp>(loc, v0, v1);
|
||||
case Kind::kMulI:
|
||||
case kMulI:
|
||||
return rewriter.create<MulIOp>(loc, v0, v1);
|
||||
case Kind::kDivF:
|
||||
case kDivF:
|
||||
return rewriter.create<DivFOp>(loc, v0, v1);
|
||||
case Kind::kDivS:
|
||||
case kDivS:
|
||||
return rewriter.create<SignedDivIOp>(loc, v0, v1);
|
||||
case Kind::kDivU:
|
||||
case kDivU:
|
||||
return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
|
||||
case Kind::kAddF:
|
||||
case kAddF:
|
||||
return rewriter.create<AddFOp>(loc, v0, v1);
|
||||
case Kind::kAddI:
|
||||
case kAddI:
|
||||
return rewriter.create<AddIOp>(loc, v0, v1);
|
||||
case Kind::kSubF:
|
||||
case kSubF:
|
||||
return rewriter.create<SubFOp>(loc, v0, v1);
|
||||
case Kind::kSubI:
|
||||
case kSubI:
|
||||
return rewriter.create<SubIOp>(loc, v0, v1);
|
||||
case Kind::kAndI:
|
||||
case kAndI:
|
||||
return rewriter.create<AndOp>(loc, v0, v1);
|
||||
case Kind::kOrI:
|
||||
case kOrI:
|
||||
return rewriter.create<OrOp>(loc, v0, v1);
|
||||
case Kind::kXorI:
|
||||
case kXorI:
|
||||
return rewriter.create<XOrOp>(loc, v0, v1);
|
||||
case Kind::kShrS:
|
||||
case kShrS:
|
||||
return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
|
||||
case Kind::kShrU:
|
||||
case kShrU:
|
||||
return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
|
||||
case Kind::kShlI:
|
||||
case kShlI:
|
||||
return rewriter.create<ShiftLeftOp>(loc, v0, v1);
|
||||
}
|
||||
llvm_unreachable("unexpected expression kind in build");
|
||||
|
|
Loading…
Reference in New Issue