[mlir][sparse] add shift ops support

Arbitrary shifts have some complications, but shift by invariants
(viz. tensor index exp only at left hand side) can be easily
handled with the conjunctive rule.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D106002
This commit is contained in:
Aart Bik 2021-07-14 11:07:39 -07:00
parent 9f6ff37a36
commit 2b6e433230
3 changed files with 143 additions and 8 deletions

View File

@ -47,6 +47,9 @@ enum Kind {
kAndI,
kOrI,
kXorI,
kShrS, // signed
kShrU, // unsigned
kShlI,
};
/// Children subexpressions of tensor operations.
@ -215,7 +218,8 @@ public:
Value v1);
private:
bool maybeZero(unsigned e);
bool maybeZero(unsigned e) const;
bool isInvariant(unsigned e) const;
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);

View File

@ -208,13 +208,16 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
case kFloorF:
case kNegF:
case kNegI:
case Kind::kDivF: // note: x / c only
case Kind::kDivS:
case Kind::kDivU:
case Kind::kShrS: // note: x >> inv only
case Kind::kShrU:
case Kind::kShlI:
return isConjunction(t, tensorExps[e].children.e0);
case Kind::kMulF:
case Kind::kMulI:
case Kind::kAndI:
case Kind::kDivF: // note: x / c only
case Kind::kDivS:
case Kind::kDivU:
return isConjunction(t, tensorExps[e].children.e0) ||
isConjunction(t, tensorExps[e].children.e1);
default:
@ -228,9 +231,9 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
// Print methods (for debugging).
//
static const char *kOpSymbols[] = {"", "", "abs", "ceil", "floor", "-",
"-", "*", "*", "/", "/", "+",
"+", "-", "-", "&", "|", "^"};
static const char *kOpSymbols[] = {
"", "", "abs", "ceil", "floor", "-", "-", "*", "*", "/", "/",
"+", "+", "-", "-", "&", "|", "^", "a>>", ">>", "<<"};
void Merger::dumpExp(unsigned e) const {
switch (tensorExps[e].kind) {
@ -383,6 +386,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
return takeDisj(kind, // take binary disjunction
buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
case Kind::kShrS:
case Kind::kShrU:
case Kind::kShlI:
// A shift operation by an invariant amount (viz. tensor expressions
// can only occur at the left-hand-side of the operator) can be handled
// with the conjuction rule.
return takeConj(kind, // take binary conjunction
buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
}
llvm_unreachable("unexpected expression kind");
}
@ -392,7 +404,7 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
return buildTensorExp(op, yield->getOperand(0));
}
bool Merger::maybeZero(unsigned e) {
bool Merger::maybeZero(unsigned e) const {
if (tensorExps[e].kind == Kind::kInvariant) {
if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
return c.getValue() == 0;
@ -402,6 +414,10 @@ bool Merger::maybeZero(unsigned e) {
return true;
}
bool Merger::isInvariant(unsigned e) const {
return tensorExps[e].kind == Kind::kInvariant;
}
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (auto arg = v.dyn_cast<BlockArgument>()) {
unsigned argN = arg.getArgNumber();
@ -470,6 +486,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(Kind::kOrI, e0, e1);
if (isa<XOrOp>(def))
return addExp(Kind::kXorI, e0, e1);
if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
return addExp(Kind::kShrS, e0, e1);
if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
return addExp(Kind::kShrU, e0, e1);
if (isa<ShiftLeftOp>(def) && isInvariant(e1))
return addExp(Kind::kShlI, e0, e1);
}
}
// Cannot build.
@ -517,6 +539,12 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
return rewriter.create<OrOp>(loc, v0, v1);
case Kind::kXorI:
return rewriter.create<XOrOp>(loc, v0, v1);
case Kind::kShrS:
return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
case Kind::kShrU:
return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
case Kind::kShlI:
return rewriter.create<ShiftLeftOp>(loc, v0, v1);
}
llvm_unreachable("unexpected expression kind in build");
}

View File

@ -404,3 +404,106 @@ func @xor(%arga: tensor<32xi64, #SV>,
} -> tensor<32xi64>
return %0 : tensor<32xi64>
}
// CHECK-LABEL: func @ashrbyc(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 1 : index
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
// CHECK: %[[VAL_14:.*]] = shift_right_signed %[[VAL_13]], %[[VAL_2]] : i64
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
// CHECK: }
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
// CHECK: return %[[VAL_15]] : tensor<32xi64>
// CHECK: }
func @ashrbyc(%arga: tensor<32xi64, #SV>,
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
%c = constant 2 : i64
%0 = linalg.generic #traitc
ins(%arga: tensor<32xi64, #SV>)
outs(%argx: tensor<32xi64>) {
^bb(%a: i64, %x: i64):
%0 = shift_right_signed %a, %c : i64
linalg.yield %0 : i64
} -> tensor<32xi64>
return %0 : tensor<32xi64>
}
// CHECK-LABEL: func @lsrbyc(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 1 : index
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
// CHECK: %[[VAL_14:.*]] = shift_right_unsigned %[[VAL_13]], %[[VAL_2]] : i64
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
// CHECK: }
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
// CHECK: return %[[VAL_15]] : tensor<32xi64>
// CHECK: }
func @lsrbyc(%arga: tensor<32xi64, #SV>,
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
%c = constant 2 : i64
%0 = linalg.generic #traitc
ins(%arga: tensor<32xi64, #SV>)
outs(%argx: tensor<32xi64>) {
^bb(%a: i64, %x: i64):
%0 = shift_right_unsigned %a, %c : i64
linalg.yield %0 : i64
} -> tensor<32xi64>
return %0 : tensor<32xi64>
}
// CHECK-LABEL: func @lslbyc(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
// CHECK: %[[VAL_3:.*]] = constant 0 : index
// CHECK: %[[VAL_4:.*]] = constant 1 : index
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
// CHECK: %[[VAL_14:.*]] = shift_left %[[VAL_13]], %[[VAL_2]] : i64
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
// CHECK: }
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
// CHECK: return %[[VAL_15]] : tensor<32xi64>
// CHECK: }
func @lslbyc(%arga: tensor<32xi64, #SV>,
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
%c = constant 2 : i64
%0 = linalg.generic #traitc
ins(%arga: tensor<32xi64, #SV>)
outs(%argx: tensor<32xi64>) {
^bb(%a: i64, %x: i64):
%0 = shift_left %a, %c : i64
linalg.yield %0 : i64
} -> tensor<32xi64>
return %0 : tensor<32xi64>
}