[mlir][sparse] refactor handling of merger leafs and ops

Using "default:" in the switch statemements that handle all our
merger ops has become a bit cumbersome since it is easy to overlook
parts of the code that need to handle ops specifically. By enforcing
full switch statements without "default:", we get a compiler warning
when cases are overlooked.

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D127263
This commit is contained in:
Aart Bik 2022-06-07 15:51:17 -07:00
parent 4badd4d40d
commit 06aa6ec87d
2 changed files with 154 additions and 47 deletions

View File

@ -25,6 +25,7 @@ namespace sparse_tensor {
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
: kind(k), val(v), op(o) {
switch (kind) {
// Leaf.
case kTensor:
assert(x != -1u && y == -1u && !v && !o);
tensor = x;
@ -36,6 +37,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
assert(x != -1u && y == -1u && !v && !o);
index = x;
break;
// Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
@ -86,13 +88,32 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
children.e0 = x;
children.e1 = y;
break;
case kBinary:
assert(x != -1u && y != -1u && !v && o);
// Binary operations.
case kMulF:
case kMulC:
case kMulI:
case kDivF:
case kDivC:
case kDivS:
case kDivU:
case kAddF:
case kAddC:
case kAddI:
case kSubF:
case kSubC:
case kSubI:
case kAndI:
case kOrI:
case kXorI:
case kShrS:
case kShrU:
case kShlI:
assert(x != -1u && y != -1u && !v && !o);
children.e0 = x;
children.e1 = y;
break;
default:
assert(x != -1u && y != -1u && !v && !o);
case kBinary:
assert(x != -1u && y != -1u && !v && o);
children.e0 = x;
children.e1 = y;
break;
@ -280,8 +301,13 @@ bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
switch (tensorExps[e].kind) {
// Leaf.
case kTensor:
return tensorExps[e].tensor == t;
case kInvariant:
case kIndex:
return false;
// Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
@ -313,6 +339,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCRe:
case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
case kBinaryBranch:
case kUnary:
return false;
// Binary operations.
case kDivF: // note: x / c only
case kDivC:
case kDivS:
@ -339,7 +369,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kAddI:
return isSingleCondition(t, tensorExps[e].children.e0) &&
isSingleCondition(t, tensorExps[e].children.e1);
default:
case kSubF:
case kSubC:
case kSubI:
case kOrI:
case kXorI:
case kBinary:
return false;
}
}
@ -352,12 +387,14 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
static const char *kindToOpSymbol(Kind kind) {
switch (kind) {
// Leaf.
case kTensor:
return "tensor";
case kInvariant:
return "invariant";
case kIndex:
return "index";
// Unary operations.
case kAbsF:
case kAbsC:
return "abs";
@ -404,6 +441,7 @@ static const char *kindToOpSymbol(Kind kind) {
return "binary_branch";
case kUnary:
return "unary";
// Binary operations.
case kMulF:
case kMulC:
case kMulI:
@ -441,6 +479,7 @@ static const char *kindToOpSymbol(Kind kind) {
void Merger::dumpExp(unsigned e) const {
switch (tensorExps[e].kind) {
// Leaf.
case kTensor:
if (tensorExps[e].tensor == syntheticTensor)
llvm::dbgs() << "synthetic_";
@ -454,7 +493,9 @@ void Merger::dumpExp(unsigned e) const {
case kIndex:
llvm::dbgs() << "index_" << tensorExps[e].index;
break;
// Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
@ -462,10 +503,13 @@ void Merger::dumpExp(unsigned e) const {
case kExpm1F:
case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kTanhC:
case kNegF:
case kNegC:
case kNegI:
case kTruncF:
case kExtF:
@ -477,11 +521,35 @@ void Merger::dumpExp(unsigned e) const {
case kCastU:
case kCastIdx:
case kTruncI:
case kCIm:
case kCRe:
case kBitCast:
case kBinaryBranch:
case kUnary:
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
dumpExp(tensorExps[e].children.e0);
break;
default:
// Binary operations.
case kMulF:
case kMulC:
case kMulI:
case kDivF:
case kDivC:
case kDivS:
case kDivU:
case kAddF:
case kAddC:
case kAddI:
case kSubF:
case kSubC:
case kSubI:
case kAndI:
case kOrI:
case kXorI:
case kShrS:
case kShrU:
case kShlI:
case kBinary:
llvm::dbgs() << "(";
dumpExp(tensorExps[e].children.e0);
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@ -542,6 +610,7 @@ void Merger::dumpBits(const BitVector &bits) const {
unsigned Merger::buildLattices(unsigned e, unsigned i) {
Kind kind = tensorExps[e].kind;
switch (kind) {
// Leaf.
case kTensor:
case kInvariant:
case kIndex: {
@ -560,11 +629,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
latSets[s].push_back(addLat(t, i, e));
return s;
}
// Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
case kCIm:
case kCRe:
case kFloorF:
case kSqrtF:
case kSqrtC:
@ -589,6 +657,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kCastU:
case kCastIdx:
case kTruncI:
case kCIm:
case kCRe:
case kBitCast:
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
// lattice set of the operand through the operator into a new set.
@ -625,6 +695,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
unsigned rhs = addExp(kInvariant, absentVal);
return takeDisj(kind, child0, buildLattices(rhs, i), unop);
}
// Binary operations.
case kMulF:
case kMulC:
case kMulI:
@ -955,16 +1026,17 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
Value v0, Value v1) {
switch (tensorExps[e].kind) {
// Leaf.
case kTensor:
case kInvariant:
case kIndex:
llvm_unreachable("unexpected non-op");
// Unary ops.
// Unary operations.
case kAbsF:
return rewriter.create<math::AbsOp>(loc, v0);
case kAbsC: {
auto type = v0.getType().template cast<ComplexType>();
auto eltType = type.getElementType().template cast<FloatType>();
auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
}
case kCeilF:
@ -1021,18 +1093,19 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
case kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case kCIm:
case kCIm: {
auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::ImOp>(loc, eltType, v0);
}
case kCRe: {
auto type = v0.getType().template cast<ComplexType>();
auto eltType = type.getElementType().template cast<FloatType>();
if (tensorExps[e].kind == kCIm)
return rewriter.create<complex::ImOp>(loc, eltType, v0);
auto type = v0.getType().cast<ComplexType>();
auto eltType = type.getElementType().cast<FloatType>();
return rewriter.create<complex::ReOp>(loc, eltType, v0);
}
case kBitCast:
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary ops.
// Binary operations.
case kMulF:
return rewriter.create<arith::MulFOp>(loc, v0, v1);
case kMulC:
@ -1071,8 +1144,7 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
case kShlI:
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
// Semiring ops with custom logic.
case kBinaryBranch:
case kBinaryBranch: // semi-ring ops with custom logic.
return insertYieldOp(rewriter, loc,
*tensorExps[e].op->getBlock()->getParent(), {v0});
case kUnary:

View File

@ -136,43 +136,78 @@ protected:
}
/// Compares expressions for equality. Equality is defined recursively as:
/// - Two expressions can only be equal if they have the same Kind.
/// - Two binary expressions are equal if they have the same Kind and their
/// children are equal.
/// - Expressions with Kind invariant or tensor are equal if they have the
/// same expression id.
/// - Operations are equal if they have the same kind and children.
/// - Leaf tensors are equal if they refer to the same tensor.
bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
auto tensorExp = merger.exp(e);
if (tensorExp.kind != pattern->kind)
return false;
assert(tensorExp.kind != Kind::kInvariant &&
"Invariant comparison not yet supported");
switch (tensorExp.kind) {
case Kind::kTensor:
// Leaf.
case kTensor:
return tensorExp.tensor == pattern->tensorNum;
case Kind::kAbsF:
case Kind::kCeilF:
case Kind::kFloorF:
case Kind::kNegF:
case Kind::kNegI:
case kInvariant:
case kIndex:
llvm_unreachable("invariant not handled yet");
// Unary operations.
case kAbsF:
case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
case kSqrtC:
case kExpm1F:
case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kTanhC:
case kNegF:
case kNegC:
case kNegI:
case kTruncF:
case kExtF:
case kCastFS:
case kCastFU:
case kCastSF:
case kCastUF:
case kCastS:
case kCastU:
case kCastIdx:
case kTruncI:
case kCIm:
case kCRe:
case kBitCast:
case kBinaryBranch:
case kUnary:
case kShlI:
case kBinary:
return compareExpression(tensorExp.children.e0, pattern->e0);
case Kind::kMulF:
case Kind::kMulI:
case Kind::kDivF:
case Kind::kDivS:
case Kind::kDivU:
case Kind::kAddF:
case Kind::kAddI:
case Kind::kSubF:
case Kind::kSubI:
case Kind::kAndI:
case Kind::kOrI:
case Kind::kXorI:
// Binary operations.
case kMulF:
case kMulC:
case kMulI:
case kDivF:
case kDivC:
case kDivS:
case kDivU:
case kAddF:
case kAddC:
case kAddI:
case kSubF:
case kSubC:
case kSubI:
case kAndI:
case kOrI:
case kXorI:
case kShrS:
case kShrU:
return compareExpression(tensorExp.children.e0, pattern->e0) &&
compareExpression(tensorExp.children.e1, pattern->e1);
default:
llvm_unreachable("Unhandled Kind");
}
llvm_unreachable("unexpected kind");
}
unsigned numTensors;