forked from OSchip/llvm-project
[mlir][sparse] add shift ops support
Arbitrary shifts have some complications, but shift by invariants (viz. tensor index exp only at left hand side) can be easily handled with the conjunctive rule. Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D106002
This commit is contained in:
parent
9f6ff37a36
commit
2b6e433230
|
@ -47,6 +47,9 @@ enum Kind {
|
|||
kAndI,
|
||||
kOrI,
|
||||
kXorI,
|
||||
kShrS, // signed
|
||||
kShrU, // unsigned
|
||||
kShlI,
|
||||
};
|
||||
|
||||
/// Children subexpressions of tensor operations.
|
||||
|
@ -215,7 +218,8 @@ public:
|
|||
Value v1);
|
||||
|
||||
private:
|
||||
bool maybeZero(unsigned e);
|
||||
bool maybeZero(unsigned e) const;
|
||||
bool isInvariant(unsigned e) const;
|
||||
|
||||
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
|
||||
Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
|
||||
|
|
|
@ -208,13 +208,16 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
|
|||
case kFloorF:
|
||||
case kNegF:
|
||||
case kNegI:
|
||||
case Kind::kDivF: // note: x / c only
|
||||
case Kind::kDivS:
|
||||
case Kind::kDivU:
|
||||
case Kind::kShrS: // note: x >> inv only
|
||||
case Kind::kShrU:
|
||||
case Kind::kShlI:
|
||||
return isConjunction(t, tensorExps[e].children.e0);
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
case Kind::kAndI:
|
||||
case Kind::kDivF: // note: x / c only
|
||||
case Kind::kDivS:
|
||||
case Kind::kDivU:
|
||||
return isConjunction(t, tensorExps[e].children.e0) ||
|
||||
isConjunction(t, tensorExps[e].children.e1);
|
||||
default:
|
||||
|
@ -228,9 +231,9 @@ bool Merger::isConjunction(unsigned t, unsigned e) const {
|
|||
// Print methods (for debugging).
|
||||
//
|
||||
|
||||
static const char *kOpSymbols[] = {"", "", "abs", "ceil", "floor", "-",
|
||||
"-", "*", "*", "/", "/", "+",
|
||||
"+", "-", "-", "&", "|", "^"};
|
||||
static const char *kOpSymbols[] = {
|
||||
"", "", "abs", "ceil", "floor", "-", "-", "*", "*", "/", "/",
|
||||
"+", "+", "-", "-", "&", "|", "^", "a>>", ">>", "<<"};
|
||||
|
||||
void Merger::dumpExp(unsigned e) const {
|
||||
switch (tensorExps[e].kind) {
|
||||
|
@ -383,6 +386,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
return takeDisj(kind, // take binary disjunction
|
||||
buildLattices(tensorExps[e].children.e0, i),
|
||||
buildLattices(tensorExps[e].children.e1, i));
|
||||
case Kind::kShrS:
|
||||
case Kind::kShrU:
|
||||
case Kind::kShlI:
|
||||
// A shift operation by an invariant amount (viz. tensor expressions
|
||||
// can only occur at the left-hand-side of the operator) can be handled
|
||||
// with the conjuction rule.
|
||||
return takeConj(kind, // take binary conjunction
|
||||
buildLattices(tensorExps[e].children.e0, i),
|
||||
buildLattices(tensorExps[e].children.e1, i));
|
||||
}
|
||||
llvm_unreachable("unexpected expression kind");
|
||||
}
|
||||
|
@ -392,7 +404,7 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
|
|||
return buildTensorExp(op, yield->getOperand(0));
|
||||
}
|
||||
|
||||
bool Merger::maybeZero(unsigned e) {
|
||||
bool Merger::maybeZero(unsigned e) const {
|
||||
if (tensorExps[e].kind == Kind::kInvariant) {
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
|
||||
return c.getValue() == 0;
|
||||
|
@ -402,6 +414,10 @@ bool Merger::maybeZero(unsigned e) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Merger::isInvariant(unsigned e) const {
|
||||
return tensorExps[e].kind == Kind::kInvariant;
|
||||
}
|
||||
|
||||
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
if (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
unsigned argN = arg.getArgNumber();
|
||||
|
@ -470,6 +486,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
return addExp(Kind::kOrI, e0, e1);
|
||||
if (isa<XOrOp>(def))
|
||||
return addExp(Kind::kXorI, e0, e1);
|
||||
if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
|
||||
return addExp(Kind::kShrS, e0, e1);
|
||||
if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
|
||||
return addExp(Kind::kShrU, e0, e1);
|
||||
if (isa<ShiftLeftOp>(def) && isInvariant(e1))
|
||||
return addExp(Kind::kShlI, e0, e1);
|
||||
}
|
||||
}
|
||||
// Cannot build.
|
||||
|
@ -517,6 +539,12 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<OrOp>(loc, v0, v1);
|
||||
case Kind::kXorI:
|
||||
return rewriter.create<XOrOp>(loc, v0, v1);
|
||||
case Kind::kShrS:
|
||||
return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
|
||||
case Kind::kShrU:
|
||||
return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
|
||||
case Kind::kShlI:
|
||||
return rewriter.create<ShiftLeftOp>(loc, v0, v1);
|
||||
}
|
||||
llvm_unreachable("unexpected expression kind in build");
|
||||
}
|
||||
|
|
|
@ -404,3 +404,106 @@ func @xor(%arga: tensor<32xi64, #SV>,
|
|||
} -> tensor<32xi64>
|
||||
return %0 : tensor<32xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ashrbyc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
|
||||
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
|
||||
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
|
||||
// CHECK: %[[VAL_14:.*]] = shift_right_signed %[[VAL_13]], %[[VAL_2]] : i64
|
||||
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
|
||||
// CHECK: return %[[VAL_15]] : tensor<32xi64>
|
||||
// CHECK: }
|
||||
func @ashrbyc(%arga: tensor<32xi64, #SV>,
|
||||
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
%c = constant 2 : i64
|
||||
%0 = linalg.generic #traitc
|
||||
ins(%arga: tensor<32xi64, #SV>)
|
||||
outs(%argx: tensor<32xi64>) {
|
||||
^bb(%a: i64, %x: i64):
|
||||
%0 = shift_right_signed %a, %c : i64
|
||||
linalg.yield %0 : i64
|
||||
} -> tensor<32xi64>
|
||||
return %0 : tensor<32xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @lsrbyc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
|
||||
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
|
||||
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
|
||||
// CHECK: %[[VAL_14:.*]] = shift_right_unsigned %[[VAL_13]], %[[VAL_2]] : i64
|
||||
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
|
||||
// CHECK: return %[[VAL_15]] : tensor<32xi64>
|
||||
// CHECK: }
|
||||
func @lsrbyc(%arga: tensor<32xi64, #SV>,
|
||||
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
%c = constant 2 : i64
|
||||
%0 = linalg.generic #traitc
|
||||
ins(%arga: tensor<32xi64, #SV>)
|
||||
outs(%argx: tensor<32xi64>) {
|
||||
^bb(%a: i64, %x: i64):
|
||||
%0 = shift_right_unsigned %a, %c : i64
|
||||
linalg.yield %0 : i64
|
||||
} -> tensor<32xi64>
|
||||
return %0 : tensor<32xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @lslbyc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
|
||||
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi64>
|
||||
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
|
||||
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
|
||||
// CHECK: %[[VAL_14:.*]] = shift_left %[[VAL_13]], %[[VAL_2]] : i64
|
||||
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
|
||||
// CHECK: return %[[VAL_15]] : tensor<32xi64>
|
||||
// CHECK: }
|
||||
func @lslbyc(%arga: tensor<32xi64, #SV>,
|
||||
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
%c = constant 2 : i64
|
||||
%0 = linalg.generic #traitc
|
||||
ins(%arga: tensor<32xi64, #SV>)
|
||||
outs(%argx: tensor<32xi64>) {
|
||||
^bb(%a: i64, %x: i64):
|
||||
%0 = shift_left %a, %c : i64
|
||||
linalg.yield %0 : i64
|
||||
} -> tensor<32xi64>
|
||||
return %0 : tensor<32xi64>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue