[mlir][sparse] Support more complex operations.

Add complex operations abs, neg, sin, log1p, sub and div.

Add test cases.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126027
This commit is contained in:
Bixia Zheng 2022-05-19 15:18:50 -07:00 committed by bixia1
parent ecf5b78053
commit d390035b46
3 changed files with 240 additions and 1 deletions

View File

@ -31,14 +31,18 @@ enum Kind {
kIndex,
// Unary operations.
kAbsF,
kAbsC,
kCeilF,
kFloorF,
kSqrtF,
kExpm1F,
kLog1pF,
kLog1pC,
kSinF,
kSinC,
kTanhF,
kNegF,
kNegC,
kNegI,
kTruncF,
kExtF,
@ -60,12 +64,14 @@ enum Kind {
kMulC,
kMulI,
kDivF,
kDivC, // complex
kDivS, // signed
kDivU, // unsigned
kAddF,
kAddC,
kAddI,
kSubF,
kSubC,
kSubI,
kAndI,
kOrI,

View File

@ -37,14 +37,18 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
index = x;
break;
case kAbsF:
case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
case kExpm1F:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kNegF:
case kNegC:
case kNegI:
case kCIm:
case kCRe:
@ -151,6 +155,8 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
// TODO: move this if-else logic into buildLattices
if (kind == kSubF)
s1 = mapSet(kNegF, s1);
else if (kind == kSubC)
s1 = mapSet(kNegC, s1);
else if (kind == kSubI)
s1 = mapSet(kNegI, s1);
// Followed by all in s1.
@ -274,14 +280,18 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kTensor:
return tensorExps[e].tensor == t;
case kAbsF:
case kAbsC:
case kCeilF:
case kFloorF:
case kSqrtF:
case kExpm1F:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kNegF:
case kNegC:
case kNegI:
case kTruncF:
case kExtF:
@ -298,6 +308,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
case kDivF: // note: x / c only
case kDivC:
case kDivS:
case kDivU:
assert(!maybeZero(tensorExps[e].children.e1));
@ -342,6 +353,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kIndex:
return "index";
case kAbsF:
case kAbsC:
return "abs";
case kCeilF:
return "ceil";
@ -352,13 +364,15 @@ static const char *kindToOpSymbol(Kind kind) {
case kExpm1F:
return "expm1";
case kLog1pF:
case kLog1pC:
return "log1p";
case kSinF:
case kSinC:
return "sin";
case kTanhF:
return "tanh";
case kNegF:
return "-";
case kNegC:
case kNegI:
return "-";
case kTruncF:
@ -386,6 +400,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kMulI:
return "*";
case kDivF:
case kDivC:
case kDivS:
case kDivU:
return "/";
@ -394,6 +409,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kAddI:
return "+";
case kSubF:
case kSubC:
case kSubI:
return "-";
case kAndI:
@ -533,6 +549,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
return s;
}
case kAbsF:
case kAbsC:
case kCeilF:
case kCIm:
case kCRe:
@ -540,9 +557,12 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kSqrtF:
case kExpm1F:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kNegF:
case kNegC:
case kNegI:
case kTruncF:
case kExtF:
@ -607,6 +627,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
buildLattices(tensorExps[e].children.e0, i),
buildLattices(tensorExps[e].children.e1, i));
case kDivF:
case kDivC:
case kDivS:
case kDivU:
// A division is tricky, since 0/0, 0/c, c/0 all have
@ -630,6 +651,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kAddC:
case kAddI:
case kSubF:
case kSubC:
case kSubI:
case kOrI:
case kXorI:
@ -696,6 +718,11 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
/// Only returns false if we are certain this is a nonzero.
bool Merger::maybeZero(unsigned e) const {
if (tensorExps[e].kind == kInvariant) {
if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
arrayAttr[0].cast<FloatAttr>().getValue().isZero();
}
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
return c.value() == 0;
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
@ -750,6 +777,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
unsigned e = x.getValue();
if (isa<math::AbsOp>(def))
return addExp(kAbsF, e);
if (isa<complex::AbsOp>(def))
return addExp(kAbsC, e);
if (isa<math::CeilOp>(def))
return addExp(kCeilF, e);
if (isa<math::FloorOp>(def))
@ -760,12 +789,18 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kExpm1F, e);
if (isa<math::Log1pOp>(def))
return addExp(kLog1pF, e);
if (isa<complex::Log1pOp>(def))
return addExp(kLog1pC, e);
if (isa<math::SinOp>(def))
return addExp(kSinF, e);
if (isa<complex::SinOp>(def))
return addExp(kSinC, e);
if (isa<math::TanhOp>(def))
return addExp(kTanhF, e);
if (isa<arith::NegFOp>(def))
return addExp(kNegF, e); // no negi in std
if (isa<complex::NegOp>(def))
return addExp(kNegC, e);
if (isa<arith::TruncFOp>(def))
return addExp(kTruncF, e, v);
if (isa<arith::ExtFOp>(def))
@ -813,6 +848,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kMulI, e0, e1);
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
return addExp(kDivF, e0, e1);
if (isa<complex::DivOp>(def) && !maybeZero(e1))
return addExp(kDivC, e0, e1);
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
return addExp(kDivS, e0, e1);
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
@ -825,6 +862,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kAddI, e0, e1);
if (isa<arith::SubFOp>(def))
return addExp(kSubF, e0, e1);
if (isa<complex::SubOp>(def))
return addExp(kSubC, e0, e1);
if (isa<arith::SubIOp>(def))
return addExp(kSubI, e0, e1);
if (isa<arith::AndIOp>(def))
@ -902,6 +941,11 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
// Unary ops.
case kAbsF:
return rewriter.create<math::AbsOp>(loc, v0);
case kAbsC: {
auto type = v0.getType().template cast<ComplexType>();
auto eltType = type.getElementType().template cast<FloatType>();
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
}
case kCeilF:
return rewriter.create<math::CeilOp>(loc, v0);
case kFloorF:
@ -912,12 +956,18 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<math::ExpM1Op>(loc, v0);
case kLog1pF:
return rewriter.create<math::Log1pOp>(loc, v0);
case kLog1pC:
return rewriter.create<complex::Log1pOp>(loc, v0);
case kSinF:
return rewriter.create<math::SinOp>(loc, v0);
case kSinC:
return rewriter.create<complex::SinOp>(loc, v0);
case kTanhF:
return rewriter.create<math::TanhOp>(loc, v0);
case kNegF:
return rewriter.create<arith::NegFOp>(loc, v0);
case kNegC:
return rewriter.create<complex::NegOp>(loc, v0);
case kNegI: // no negi in std
return rewriter.create<arith::SubIOp>(
loc,
@ -964,6 +1014,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::MulIOp>(loc, v0, v1);
case kDivF:
return rewriter.create<arith::DivFOp>(loc, v0, v1);
case kDivC:
return rewriter.create<complex::DivOp>(loc, v0, v1);
case kDivS:
return rewriter.create<arith::DivSIOp>(loc, v0, v1);
case kDivU:
@ -976,6 +1028,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::AddIOp>(loc, v0, v1);
case kSubF:
return rewriter.create<arith::SubFOp>(loc, v0, v1);
case kSubC:
return rewriter.create<complex::SubOp>(loc, v0, v1);
case kSubI:
return rewriter.create<arith::SubIOp>(loc, v0, v1);
case kAndI:

View File

@ -0,0 +1,179 @@
// RUN: mlir-opt %s --sparse-compiler | \
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
#trait_op1 = {
indexing_maps = [
affine_map<(i) -> (i)>, // a (in)
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"],
doc = "x(i) = OP a(i)"
}
#trait_op2 = {
indexing_maps = [
affine_map<(i) -> (i)>, // a (in)
affine_map<(i) -> (i)>, // b (in)
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"],
doc = "x(i) = a(i) OP b(i)"
}
module {
func.func @cops(%arga: tensor<?xcomplex<f64>, #SparseVector>,
%argb: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
%0 = linalg.generic #trait_op2
ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
tensor<?xcomplex<f64>, #SparseVector>)
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
^bb(%a: complex<f64>, %b: complex<f64>, %x: complex<f64>):
%1 = complex.neg %b : complex<f64>
%2 = complex.sub %a, %1 : complex<f64>
linalg.yield %2 : complex<f64>
} -> tensor<?xcomplex<f64>, #SparseVector>
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
func.func @csin(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
%0 = linalg.generic #trait_op1
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
^bb(%a: complex<f64>, %x: complex<f64>):
%1 = complex.sin %a : complex<f64>
linalg.yield %1 : complex<f64>
} -> tensor<?xcomplex<f64>, #SparseVector>
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
%c = complex.constant [2.0 : f64, 0.0 : f64] : complex<f64>
%0 = linalg.generic #trait_op1
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
^bb(%a: complex<f64>, %x: complex<f64>):
%1 = complex.div %a, %c : complex<f64>
linalg.yield %1 : complex<f64>
} -> tensor<?xcomplex<f64>, #SparseVector>
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
func.func @cabs(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xf64, #SparseVector> {
%c0 = arith.constant 0 : index
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
%0 = linalg.generic #trait_op1
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
outs(%xv: tensor<?xf64, #SparseVector>) {
^bb(%a: complex<f64>, %x: f64):
%1 = complex.abs %a : complex<f64>
linalg.yield %1 : f64
} -> tensor<?xf64, #SparseVector>
return %0 : tensor<?xf64, #SparseVector>
}
func.func @dumpc(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
scf.for %i = %c0 to %d step %c1 {
%v = memref.load %mem[%i] : memref<?xcomplex<f64>>
%real = complex.re %v : complex<f64>
%imag = complex.im %v : complex<f64>
vector.print %real : f64
vector.print %imag : f64
}
return
}
func.func @dumpf(%arg0: tensor<?xf64, #SparseVector>) {
%c0 = arith.constant 0 : index
%d0 = arith.constant 0.0 : f64
%values = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
%0 = vector.transfer_read %values[%c0], %d0: memref<?xf64>, vector<3xf64>
vector.print %0 : vector<3xf64>
return
}
// Driver method to call and verify complex kernels.
func.func @entry() {
// Setup sparse vectors.
%v1 = arith.constant sparse<
[ [0], [28], [31] ],
[ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f64>>
%v2 = arith.constant sparse<
[ [1], [28], [31] ],
[ (1.0, 0.0), (-2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex<f64>>
%sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
%sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
// Call sparse vector kernels.
%0 = call @cops(%sv1, %sv2)
: (tensor<?xcomplex<f64>, #SparseVector>,
tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%1 = call @csin(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%2 = call @cdiv(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%3 = call @cabs(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
//
// Verify the results.
//
%d3 = arith.constant 3 : index
%d4 = arith.constant 4 : index
// CHECK: -5.13
// CHECK-NEXT: 2
// CHECK-NEXT: 1
// CHECK-NEXT: 0
// CHECK-NEXT: 1
// CHECK-NEXT: 4
// CHECK-NEXT: 8
// CHECK-NEXT: 6
call @dumpc(%0, %d4) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: 3.43887
// CHECK-NEXT: 1.47097
// CHECK-NEXT: 3.85374
// CHECK-NEXT: -27.0168
// CHECK-NEXT: -193.43
// CHECK-NEXT: 57.2184
call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: -2.565
// CHECK-NEXT: 1
// CHECK-NEXT: 1.5
// CHECK-NEXT: 2
// CHECK-NEXT: 2.5
// CHECK-NEXT: 3
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
// Release the resources.
sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %sv2 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
return
}
}