forked from OSchip/llvm-project
[mlir][sparse] add restrictive versions of division support
Right now, we only accept x/c with nonzero c, since this conceptually can be treated as a x*(1/c) conjunction for both FP and INT as far as lattice computations go. The codegen keeps the division though to preserve precise semantics. See discussion: https://llvm.discourse.group/t/sparse-tensors-in-mlir/3389/28 Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D105731
This commit is contained in:
parent
ec1cdee6aa
commit
622eb169f6
|
@ -32,6 +32,9 @@ enum class Kind {
|
|||
// Operation.
|
||||
kMulF,
|
||||
kMulI,
|
||||
kDivF,
|
||||
kDivS, // signed
|
||||
kDivU, // unsigned
|
||||
kAddF,
|
||||
kAddI,
|
||||
kSubF,
|
||||
|
@ -197,6 +200,8 @@ public:
|
|||
Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
|
||||
|
||||
private:
|
||||
bool maybeZero(unsigned e);
|
||||
|
||||
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
|
||||
Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value val);
|
||||
|
||||
|
|
|
@ -646,6 +646,12 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
|
|||
return rewriter.create<MulFOp>(loc, v0, v1);
|
||||
case Kind::kMulI:
|
||||
return rewriter.create<MulIOp>(loc, v0, v1);
|
||||
case Kind::kDivF:
|
||||
return rewriter.create<DivFOp>(loc, v0, v1);
|
||||
case Kind::kDivS:
|
||||
return rewriter.create<SignedDivIOp>(loc, v0, v1);
|
||||
case Kind::kDivU:
|
||||
return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
|
||||
case Kind::kAddF:
|
||||
return rewriter.create<AddFOp>(loc, v0, v1);
|
||||
case Kind::kAddI:
|
||||
|
|
|
@ -201,6 +201,10 @@ static char kindToOpSymbol(Kind kind) {
|
|||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
return '*';
|
||||
case Kind::kDivF:
|
||||
case Kind::kDivS:
|
||||
case Kind::kDivU:
|
||||
return '/';
|
||||
case Kind::kAddF:
|
||||
case Kind::kAddI:
|
||||
return '+';
|
||||
|
@ -302,17 +306,51 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
|
|||
}
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
// A multiplicative operation only needs to be performed
|
||||
// for the conjunction of sparse iteration spaces.
|
||||
//
|
||||
// x*y|!y | y |
|
||||
// ---+---+---+
|
||||
// !x | 0 | 0 |
|
||||
// x | 0 |x*y|
|
||||
return takeConj(kind, // take binary conjunction
|
||||
buildLattices(tensorExps[e].children.e0, idx),
|
||||
buildLattices(tensorExps[e].children.e1, idx));
|
||||
case Kind::kDivF:
|
||||
case Kind::kDivS:
|
||||
case Kind::kDivU:
|
||||
// A division is tricky, since 0/0, 0/c, c/0 all have
|
||||
// specific outcomes for floating-point and integers.
|
||||
// Thus, we need to traverse the full iteration space.
|
||||
//
|
||||
// x/y|!y | y |
|
||||
// ---+---+---+
|
||||
// !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
|
||||
// x |x/0|x/y| INT: x/0=exception for any x
|
||||
//
|
||||
// TODO: for now we "fixed" this by only accepting x/c cases
|
||||
// during expression building, so that the conjunction
|
||||
// rules applies (viz. x/c = x*(1/c) as far as lattice
|
||||
// construction is concerned).
|
||||
return takeConj(kind, // take binary conjunction
|
||||
buildLattices(tensorExps[e].children.e0, idx),
|
||||
buildLattices(tensorExps[e].children.e1, idx));
|
||||
case Kind::kSubF:
|
||||
case Kind::kSubI:
|
||||
// Special case: 0-y is -y.
|
||||
if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
|
||||
return mapZero(kind, // maps to 0-y with just y's lattices
|
||||
buildLattices(tensorExps[e].children.e1, idx));
|
||||
LLVM_FALLTHROUGH;
|
||||
case Kind::kAddF:
|
||||
case Kind::kAddI:
|
||||
// An additive operation needs to be performed
|
||||
// for the disjunction of sparse iteration spaces.
|
||||
//
|
||||
// x+y|!y | y | x-y|!y | y |
|
||||
// ---+---+---+ ---+---+---+
|
||||
// !x | 0 | y | !x | 0 |-y |
|
||||
// x | x |x+y| x | x |x-y|
|
||||
return takeDisj(kind, // take binary disjunction
|
||||
buildLattices(tensorExps[e].children.e0, idx),
|
||||
buildLattices(tensorExps[e].children.e1, idx));
|
||||
|
@ -325,6 +363,16 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
|
|||
return buildTensorExp(op, yield->getOperand(0));
|
||||
}
|
||||
|
||||
bool Merger::maybeZero(unsigned e) {
|
||||
if (tensorExps[e].kind == Kind::kInvariant) {
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
|
||||
return c.getValue() == 0;
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
|
||||
return c.getValue().isZero();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
|
||||
if (auto arg = val.dyn_cast<BlockArgument>()) {
|
||||
unsigned argN = arg.getArgNumber();
|
||||
|
@ -357,6 +405,7 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
|
|||
}
|
||||
}
|
||||
// Construct binary operations if subexpressions can be built.
|
||||
// TODO: see buildLattices() for an explanation of rejecting certain divisions
|
||||
if (def->getNumOperands() == 2) {
|
||||
auto x = buildTensorExp(op, def->getOperand(0));
|
||||
auto y = buildTensorExp(op, def->getOperand(1));
|
||||
|
@ -367,6 +416,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
|
|||
return addExp(Kind::kMulF, e0, e1);
|
||||
if (isa<MulIOp>(def))
|
||||
return addExp(Kind::kMulI, e0, e1);
|
||||
if (isa<DivFOp>(def) && !maybeZero(e1))
|
||||
return addExp(Kind::kDivF, e0, e1);
|
||||
if (isa<SignedDivIOp>(def) && !maybeZero(e1))
|
||||
return addExp(Kind::kDivS, e0, e1);
|
||||
if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
|
||||
return addExp(Kind::kDivU, e0, e1);
|
||||
if (isa<AddFOp>(def))
|
||||
return addExp(Kind::kAddF, e0, e1);
|
||||
if (isa<AddIOp>(def))
|
||||
|
|
|
@ -22,6 +22,15 @@
|
|||
doc = "x(i) = a(i) OP b(i)"
|
||||
}
|
||||
|
||||
#traitc = {
|
||||
indexing_maps = [
|
||||
affine_map<(i) -> (i)>, // a
|
||||
affine_map<(i) -> (i)> // x (out)
|
||||
],
|
||||
iterator_types = ["parallel"],
|
||||
doc = "x(i) = a(i) OP c"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @neg(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
|
||||
|
@ -213,3 +222,38 @@ func @mul(%arga: tensor<32xf64, #SV>,
|
|||
} -> tensor<32xf64>
|
||||
return %0 : tensor<32xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @divbyc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
|
||||
// CHECK: %[[VAL_2:.*]] = constant 2.000000e+00 : f64
|
||||
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64>
|
||||
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
|
||||
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf64>
|
||||
// CHECK: %[[VAL_14:.*]] = divf %[[VAL_13]], %[[VAL_2]] : f64
|
||||
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf64>
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf64>
|
||||
// CHECK: return %[[VAL_15]] : tensor<32xf64>
|
||||
// CHECK: }
|
||||
func @divbyc(%arga: tensor<32xf64, #SV>,
|
||||
%argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
|
||||
%c = constant 2.0 : f64
|
||||
%0 = linalg.generic #traitc
|
||||
ins(%arga: tensor<32xf64, #SV>)
|
||||
outs(%argx: tensor<32xf64>) {
|
||||
^bb(%a: f64, %x: f64):
|
||||
%0 = divf %a, %c : f64
|
||||
linalg.yield %0 : f64
|
||||
} -> tensor<32xf64>
|
||||
return %0 : tensor<32xf64>
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,15 @@
|
|||
doc = "x(i) = a(i) OP b(i)"
|
||||
}
|
||||
|
||||
#traitc = {
|
||||
indexing_maps = [
|
||||
affine_map<(i) -> (i)>, // a
|
||||
affine_map<(i) -> (i)> // x (out)
|
||||
],
|
||||
iterator_types = ["parallel"],
|
||||
doc = "x(i) = a(i) OP c"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>,
|
||||
|
@ -171,3 +180,71 @@ func @mul(%arga: tensor<32xi64, #SV>,
|
|||
} -> tensor<32xi64>
|
||||
return %0 : tensor<32xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @divsbyc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
|
||||
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
|
||||
// CHECK: %[[VAL_14:.*]] = divi_signed %[[VAL_13]], %[[VAL_2]] : i64
|
||||
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
|
||||
// CHECK: return %[[VAL_15]] : tensor<32xi64>
|
||||
// CHECK: }
|
||||
func @divsbyc(%arga: tensor<32xi64, #SV>,
|
||||
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
%c = constant 2 : i64
|
||||
%0 = linalg.generic #traitc
|
||||
ins(%arga: tensor<32xi64, #SV>)
|
||||
outs(%argx: tensor<32xi64>) {
|
||||
^bb(%a: i64, %x: i64):
|
||||
%0 = divi_signed %a, %c : i64
|
||||
linalg.yield %0 : i64
|
||||
} -> tensor<32xi64>
|
||||
return %0 : tensor<32xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @divubyc(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
// CHECK: %[[VAL_2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[VAL_3:.*]] = constant 0 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{.*}}}>>
|
||||
// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>
|
||||
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64>
|
||||
// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
|
||||
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
|
||||
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
|
||||
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi64>
|
||||
// CHECK: %[[VAL_14:.*]] = divi_unsigned %[[VAL_13]], %[[VAL_2]] : i64
|
||||
// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64>
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64>
|
||||
// CHECK: return %[[VAL_15]] : tensor<32xi64>
|
||||
// CHECK: }
|
||||
func @divubyc(%arga: tensor<32xi64, #SV>,
|
||||
%argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> {
|
||||
%c = constant 2 : i64
|
||||
%0 = linalg.generic #traitc
|
||||
ins(%arga: tensor<32xi64, #SV>)
|
||||
outs(%argx: tensor<32xi64>) {
|
||||
^bb(%a: i64, %x: i64):
|
||||
%0 = divi_unsigned %a, %c : i64
|
||||
linalg.yield %0 : i64
|
||||
} -> tensor<32xi64>
|
||||
return %0 : tensor<32xi64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue