forked from OSchip/llvm-project
[mlir][sparse] refactor handling of merger leafs and ops
Using "default:" in the switch statemements that handle all our merger ops has become a bit cumbersome since it is easy to overlook parts of the code that need to handle ops specifically. By enforcing full switch statements without "default:", we get a compiler warning when cases are overlooked. Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D127263
This commit is contained in:
parent
4badd4d40d
commit
06aa6ec87d
|
@ -25,6 +25,7 @@ namespace sparse_tensor {
|
|||
TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
||||
: kind(k), val(v), op(o) {
|
||||
switch (kind) {
|
||||
// Leaf.
|
||||
case kTensor:
|
||||
assert(x != -1u && y == -1u && !v && !o);
|
||||
tensor = x;
|
||||
|
@ -36,6 +37,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
|||
assert(x != -1u && y == -1u && !v && !o);
|
||||
index = x;
|
||||
break;
|
||||
// Unary operations.
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
|
@ -86,13 +88,32 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
|||
children.e0 = x;
|
||||
children.e1 = y;
|
||||
break;
|
||||
case kBinary:
|
||||
assert(x != -1u && y != -1u && !v && o);
|
||||
// Binary operations.
|
||||
case kMulF:
|
||||
case kMulC:
|
||||
case kMulI:
|
||||
case kDivF:
|
||||
case kDivC:
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
case kAddF:
|
||||
case kAddC:
|
||||
case kAddI:
|
||||
case kSubF:
|
||||
case kSubC:
|
||||
case kSubI:
|
||||
case kAndI:
|
||||
case kOrI:
|
||||
case kXorI:
|
||||
case kShrS:
|
||||
case kShrU:
|
||||
case kShlI:
|
||||
assert(x != -1u && y != -1u && !v && !o);
|
||||
children.e0 = x;
|
||||
children.e1 = y;
|
||||
break;
|
||||
default:
|
||||
assert(x != -1u && y != -1u && !v && !o);
|
||||
case kBinary:
|
||||
assert(x != -1u && y != -1u && !v && o);
|
||||
children.e0 = x;
|
||||
children.e1 = y;
|
||||
break;
|
||||
|
@ -280,8 +301,13 @@ bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
|
|||
|
||||
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
||||
switch (tensorExps[e].kind) {
|
||||
// Leaf.
|
||||
case kTensor:
|
||||
return tensorExps[e].tensor == t;
|
||||
case kInvariant:
|
||||
case kIndex:
|
||||
return false;
|
||||
// Unary operations.
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
|
@ -313,6 +339,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
|||
case kCRe:
|
||||
case kBitCast:
|
||||
return isSingleCondition(t, tensorExps[e].children.e0);
|
||||
case kBinaryBranch:
|
||||
case kUnary:
|
||||
return false;
|
||||
// Binary operations.
|
||||
case kDivF: // note: x / c only
|
||||
case kDivC:
|
||||
case kDivS:
|
||||
|
@ -339,7 +369,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
|||
case kAddI:
|
||||
return isSingleCondition(t, tensorExps[e].children.e0) &&
|
||||
isSingleCondition(t, tensorExps[e].children.e1);
|
||||
default:
|
||||
case kSubF:
|
||||
case kSubC:
|
||||
case kSubI:
|
||||
case kOrI:
|
||||
case kXorI:
|
||||
case kBinary:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -352,12 +387,14 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
|||
|
||||
static const char *kindToOpSymbol(Kind kind) {
|
||||
switch (kind) {
|
||||
// Leaf.
|
||||
case kTensor:
|
||||
return "tensor";
|
||||
case kInvariant:
|
||||
return "invariant";
|
||||
case kIndex:
|
||||
return "index";
|
||||
// Unary operations.
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
return "abs";
|
||||
|
@ -404,6 +441,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
return "binary_branch";
|
||||
case kUnary:
|
||||
return "unary";
|
||||
// Binary operations.
|
||||
case kMulF:
|
||||
case kMulC:
|
||||
case kMulI:
|
||||
|
@ -441,6 +479,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
|
||||
void Merger::dumpExp(unsigned e) const {
|
||||
switch (tensorExps[e].kind) {
|
||||
// Leaf.
|
||||
case kTensor:
|
||||
if (tensorExps[e].tensor == syntheticTensor)
|
||||
llvm::dbgs() << "synthetic_";
|
||||
|
@ -454,7 +493,9 @@ void Merger::dumpExp(unsigned e) const {
|
|||
case kIndex:
|
||||
llvm::dbgs() << "index_" << tensorExps[e].index;
|
||||
break;
|
||||
// Unary operations.
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
|
@ -462,10 +503,13 @@ void Merger::dumpExp(unsigned e) const {
|
|||
case kExpm1F:
|
||||
case kExpm1C:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kTanhC:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
case kTruncF:
|
||||
case kExtF:
|
||||
|
@ -477,11 +521,35 @@ void Merger::dumpExp(unsigned e) const {
|
|||
case kCastU:
|
||||
case kCastIdx:
|
||||
case kTruncI:
|
||||
case kCIm:
|
||||
case kCRe:
|
||||
case kBitCast:
|
||||
case kBinaryBranch:
|
||||
case kUnary:
|
||||
llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
|
||||
dumpExp(tensorExps[e].children.e0);
|
||||
break;
|
||||
default:
|
||||
// Binary operations.
|
||||
case kMulF:
|
||||
case kMulC:
|
||||
case kMulI:
|
||||
case kDivF:
|
||||
case kDivC:
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
case kAddF:
|
||||
case kAddC:
|
||||
case kAddI:
|
||||
case kSubF:
|
||||
case kSubC:
|
||||
case kSubI:
|
||||
case kAndI:
|
||||
case kOrI:
|
||||
case kXorI:
|
||||
case kShrS:
|
||||
case kShrU:
|
||||
case kShlI:
|
||||
case kBinary:
|
||||
llvm::dbgs() << "(";
|
||||
dumpExp(tensorExps[e].children.e0);
|
||||
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
|
||||
|
@ -542,6 +610,7 @@ void Merger::dumpBits(const BitVector &bits) const {
|
|||
unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
||||
Kind kind = tensorExps[e].kind;
|
||||
switch (kind) {
|
||||
// Leaf.
|
||||
case kTensor:
|
||||
case kInvariant:
|
||||
case kIndex: {
|
||||
|
@ -560,11 +629,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
latSets[s].push_back(addLat(t, i, e));
|
||||
return s;
|
||||
}
|
||||
// Unary operations.
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
case kCIm:
|
||||
case kCRe:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kSqrtC:
|
||||
|
@ -589,6 +657,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
case kCastU:
|
||||
case kCastIdx:
|
||||
case kTruncI:
|
||||
case kCIm:
|
||||
case kCRe:
|
||||
case kBitCast:
|
||||
// A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
|
||||
// lattice set of the operand through the operator into a new set.
|
||||
|
@ -625,6 +695,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
unsigned rhs = addExp(kInvariant, absentVal);
|
||||
return takeDisj(kind, child0, buildLattices(rhs, i), unop);
|
||||
}
|
||||
// Binary operations.
|
||||
case kMulF:
|
||||
case kMulC:
|
||||
case kMulI:
|
||||
|
@ -955,16 +1026,17 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
|
|||
Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
||||
Value v0, Value v1) {
|
||||
switch (tensorExps[e].kind) {
|
||||
// Leaf.
|
||||
case kTensor:
|
||||
case kInvariant:
|
||||
case kIndex:
|
||||
llvm_unreachable("unexpected non-op");
|
||||
// Unary ops.
|
||||
// Unary operations.
|
||||
case kAbsF:
|
||||
return rewriter.create<math::AbsOp>(loc, v0);
|
||||
case kAbsC: {
|
||||
auto type = v0.getType().template cast<ComplexType>();
|
||||
auto eltType = type.getElementType().template cast<FloatType>();
|
||||
auto type = v0.getType().cast<ComplexType>();
|
||||
auto eltType = type.getElementType().cast<FloatType>();
|
||||
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
|
||||
}
|
||||
case kCeilF:
|
||||
|
@ -1021,18 +1093,19 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
|
||||
case kTruncI:
|
||||
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
|
||||
case kCIm:
|
||||
case kCIm: {
|
||||
auto type = v0.getType().cast<ComplexType>();
|
||||
auto eltType = type.getElementType().cast<FloatType>();
|
||||
return rewriter.create<complex::ImOp>(loc, eltType, v0);
|
||||
}
|
||||
case kCRe: {
|
||||
auto type = v0.getType().template cast<ComplexType>();
|
||||
auto eltType = type.getElementType().template cast<FloatType>();
|
||||
if (tensorExps[e].kind == kCIm)
|
||||
return rewriter.create<complex::ImOp>(loc, eltType, v0);
|
||||
|
||||
auto type = v0.getType().cast<ComplexType>();
|
||||
auto eltType = type.getElementType().cast<FloatType>();
|
||||
return rewriter.create<complex::ReOp>(loc, eltType, v0);
|
||||
}
|
||||
case kBitCast:
|
||||
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
|
||||
// Binary ops.
|
||||
// Binary operations.
|
||||
case kMulF:
|
||||
return rewriter.create<arith::MulFOp>(loc, v0, v1);
|
||||
case kMulC:
|
||||
|
@ -1071,8 +1144,7 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
|
||||
case kShlI:
|
||||
return rewriter.create<arith::ShLIOp>(loc, v0, v1);
|
||||
// Semiring ops with custom logic.
|
||||
case kBinaryBranch:
|
||||
case kBinaryBranch: // semi-ring ops with custom logic.
|
||||
return insertYieldOp(rewriter, loc,
|
||||
*tensorExps[e].op->getBlock()->getParent(), {v0});
|
||||
case kUnary:
|
||||
|
|
|
@ -136,43 +136,78 @@ protected:
|
|||
}
|
||||
|
||||
/// Compares expressions for equality. Equality is defined recursively as:
|
||||
/// - Two expressions can only be equal if they have the same Kind.
|
||||
/// - Two binary expressions are equal if they have the same Kind and their
|
||||
/// children are equal.
|
||||
/// - Expressions with Kind invariant or tensor are equal if they have the
|
||||
/// same expression id.
|
||||
/// - Operations are equal if they have the same kind and children.
|
||||
/// - Leaf tensors are equal if they refer to the same tensor.
|
||||
bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
|
||||
auto tensorExp = merger.exp(e);
|
||||
if (tensorExp.kind != pattern->kind)
|
||||
return false;
|
||||
assert(tensorExp.kind != Kind::kInvariant &&
|
||||
"Invariant comparison not yet supported");
|
||||
switch (tensorExp.kind) {
|
||||
case Kind::kTensor:
|
||||
// Leaf.
|
||||
case kTensor:
|
||||
return tensorExp.tensor == pattern->tensorNum;
|
||||
case Kind::kAbsF:
|
||||
case Kind::kCeilF:
|
||||
case Kind::kFloorF:
|
||||
case Kind::kNegF:
|
||||
case Kind::kNegI:
|
||||
case kInvariant:
|
||||
case kIndex:
|
||||
llvm_unreachable("invariant not handled yet");
|
||||
// Unary operations.
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kSqrtC:
|
||||
case kExpm1F:
|
||||
case kExpm1C:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kTanhC:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
case kTruncF:
|
||||
case kExtF:
|
||||
case kCastFS:
|
||||
case kCastFU:
|
||||
case kCastSF:
|
||||
case kCastUF:
|
||||
case kCastS:
|
||||
case kCastU:
|
||||
case kCastIdx:
|
||||
case kTruncI:
|
||||
case kCIm:
|
||||
case kCRe:
|
||||
case kBitCast:
|
||||
case kBinaryBranch:
|
||||
case kUnary:
|
||||
case kShlI:
|
||||
case kBinary:
|
||||
return compareExpression(tensorExp.children.e0, pattern->e0);
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
case Kind::kDivF:
|
||||
case Kind::kDivS:
|
||||
case Kind::kDivU:
|
||||
case Kind::kAddF:
|
||||
case Kind::kAddI:
|
||||
case Kind::kSubF:
|
||||
case Kind::kSubI:
|
||||
case Kind::kAndI:
|
||||
case Kind::kOrI:
|
||||
case Kind::kXorI:
|
||||
// Binary operations.
|
||||
case kMulF:
|
||||
case kMulC:
|
||||
case kMulI:
|
||||
case kDivF:
|
||||
case kDivC:
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
case kAddF:
|
||||
case kAddC:
|
||||
case kAddI:
|
||||
case kSubF:
|
||||
case kSubC:
|
||||
case kSubI:
|
||||
case kAndI:
|
||||
case kOrI:
|
||||
case kXorI:
|
||||
case kShrS:
|
||||
case kShrU:
|
||||
return compareExpression(tensorExp.children.e0, pattern->e0) &&
|
||||
compareExpression(tensorExp.children.e1, pattern->e1);
|
||||
default:
|
||||
llvm_unreachable("Unhandled Kind");
|
||||
}
|
||||
llvm_unreachable("unexpected kind");
|
||||
}
|
||||
|
||||
unsigned numTensors;
|
||||
|
|
Loading…
Reference in New Issue