forked from OSchip/llvm-project
[mlir][sparse] Support more complex operations.
Add complex operations abs, neg, sin, log1p, sub and div. Add test cases. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126027
This commit is contained in:
parent
ecf5b78053
commit
d390035b46
|
@ -31,14 +31,18 @@ enum Kind {
|
|||
kIndex,
|
||||
// Unary operations.
|
||||
kAbsF,
|
||||
kAbsC,
|
||||
kCeilF,
|
||||
kFloorF,
|
||||
kSqrtF,
|
||||
kExpm1F,
|
||||
kLog1pF,
|
||||
kLog1pC,
|
||||
kSinF,
|
||||
kSinC,
|
||||
kTanhF,
|
||||
kNegF,
|
||||
kNegC,
|
||||
kNegI,
|
||||
kTruncF,
|
||||
kExtF,
|
||||
|
@ -60,12 +64,14 @@ enum Kind {
|
|||
kMulC,
|
||||
kMulI,
|
||||
kDivF,
|
||||
kDivC, // complex
|
||||
kDivS, // signed
|
||||
kDivU, // unsigned
|
||||
kAddF,
|
||||
kAddC,
|
||||
kAddI,
|
||||
kSubF,
|
||||
kSubC,
|
||||
kSubI,
|
||||
kAndI,
|
||||
kOrI,
|
||||
|
|
|
@ -37,14 +37,18 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
|||
index = x;
|
||||
break;
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kExpm1F:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
case kCIm:
|
||||
case kCRe:
|
||||
|
@ -151,6 +155,8 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
|
|||
// TODO: move this if-else logic into buildLattices
|
||||
if (kind == kSubF)
|
||||
s1 = mapSet(kNegF, s1);
|
||||
else if (kind == kSubC)
|
||||
s1 = mapSet(kNegC, s1);
|
||||
else if (kind == kSubI)
|
||||
s1 = mapSet(kNegI, s1);
|
||||
// Followed by all in s1.
|
||||
|
@ -274,14 +280,18 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
|||
case kTensor:
|
||||
return tensorExps[e].tensor == t;
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kExpm1F:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
case kTruncF:
|
||||
case kExtF:
|
||||
|
@ -298,6 +308,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
|||
case kBitCast:
|
||||
return isSingleCondition(t, tensorExps[e].children.e0);
|
||||
case kDivF: // note: x / c only
|
||||
case kDivC:
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
assert(!maybeZero(tensorExps[e].children.e1));
|
||||
|
@ -342,6 +353,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
case kIndex:
|
||||
return "index";
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
return "abs";
|
||||
case kCeilF:
|
||||
return "ceil";
|
||||
|
@ -352,13 +364,15 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
case kExpm1F:
|
||||
return "expm1";
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
return "log1p";
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
return "sin";
|
||||
case kTanhF:
|
||||
return "tanh";
|
||||
case kNegF:
|
||||
return "-";
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
return "-";
|
||||
case kTruncF:
|
||||
|
@ -386,6 +400,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
case kMulI:
|
||||
return "*";
|
||||
case kDivF:
|
||||
case kDivC:
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
return "/";
|
||||
|
@ -394,6 +409,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
case kAddI:
|
||||
return "+";
|
||||
case kSubF:
|
||||
case kSubC:
|
||||
case kSubI:
|
||||
return "-";
|
||||
case kAndI:
|
||||
|
@ -533,6 +549,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
return s;
|
||||
}
|
||||
case kAbsF:
|
||||
case kAbsC:
|
||||
case kCeilF:
|
||||
case kCIm:
|
||||
case kCRe:
|
||||
|
@ -540,9 +557,12 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
case kSqrtF:
|
||||
case kExpm1F:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
case kTruncF:
|
||||
case kExtF:
|
||||
|
@ -607,6 +627,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
buildLattices(tensorExps[e].children.e0, i),
|
||||
buildLattices(tensorExps[e].children.e1, i));
|
||||
case kDivF:
|
||||
case kDivC:
|
||||
case kDivS:
|
||||
case kDivU:
|
||||
// A division is tricky, since 0/0, 0/c, c/0 all have
|
||||
|
@ -630,6 +651,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
case kAddC:
|
||||
case kAddI:
|
||||
case kSubF:
|
||||
case kSubC:
|
||||
case kSubI:
|
||||
case kOrI:
|
||||
case kXorI:
|
||||
|
@ -696,6 +718,11 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
|
|||
/// Only returns false if we are certain this is a nonzero.
|
||||
bool Merger::maybeZero(unsigned e) const {
|
||||
if (tensorExps[e].kind == kInvariant) {
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
|
||||
ArrayAttr arrayAttr = c.getValue();
|
||||
return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
|
||||
arrayAttr[0].cast<FloatAttr>().getValue().isZero();
|
||||
}
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
|
||||
return c.value() == 0;
|
||||
if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
|
||||
|
@ -750,6 +777,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
unsigned e = x.getValue();
|
||||
if (isa<math::AbsOp>(def))
|
||||
return addExp(kAbsF, e);
|
||||
if (isa<complex::AbsOp>(def))
|
||||
return addExp(kAbsC, e);
|
||||
if (isa<math::CeilOp>(def))
|
||||
return addExp(kCeilF, e);
|
||||
if (isa<math::FloorOp>(def))
|
||||
|
@ -760,12 +789,18 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
return addExp(kExpm1F, e);
|
||||
if (isa<math::Log1pOp>(def))
|
||||
return addExp(kLog1pF, e);
|
||||
if (isa<complex::Log1pOp>(def))
|
||||
return addExp(kLog1pC, e);
|
||||
if (isa<math::SinOp>(def))
|
||||
return addExp(kSinF, e);
|
||||
if (isa<complex::SinOp>(def))
|
||||
return addExp(kSinC, e);
|
||||
if (isa<math::TanhOp>(def))
|
||||
return addExp(kTanhF, e);
|
||||
if (isa<arith::NegFOp>(def))
|
||||
return addExp(kNegF, e); // no negi in std
|
||||
if (isa<complex::NegOp>(def))
|
||||
return addExp(kNegC, e);
|
||||
if (isa<arith::TruncFOp>(def))
|
||||
return addExp(kTruncF, e, v);
|
||||
if (isa<arith::ExtFOp>(def))
|
||||
|
@ -813,6 +848,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
return addExp(kMulI, e0, e1);
|
||||
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
|
||||
return addExp(kDivF, e0, e1);
|
||||
if (isa<complex::DivOp>(def) && !maybeZero(e1))
|
||||
return addExp(kDivC, e0, e1);
|
||||
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
|
||||
return addExp(kDivS, e0, e1);
|
||||
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
|
||||
|
@ -825,6 +862,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
return addExp(kAddI, e0, e1);
|
||||
if (isa<arith::SubFOp>(def))
|
||||
return addExp(kSubF, e0, e1);
|
||||
if (isa<complex::SubOp>(def))
|
||||
return addExp(kSubC, e0, e1);
|
||||
if (isa<arith::SubIOp>(def))
|
||||
return addExp(kSubI, e0, e1);
|
||||
if (isa<arith::AndIOp>(def))
|
||||
|
@ -902,6 +941,11 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
// Unary ops.
|
||||
case kAbsF:
|
||||
return rewriter.create<math::AbsOp>(loc, v0);
|
||||
case kAbsC: {
|
||||
auto type = v0.getType().template cast<ComplexType>();
|
||||
auto eltType = type.getElementType().template cast<FloatType>();
|
||||
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
|
||||
}
|
||||
case kCeilF:
|
||||
return rewriter.create<math::CeilOp>(loc, v0);
|
||||
case kFloorF:
|
||||
|
@ -912,12 +956,18 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<math::ExpM1Op>(loc, v0);
|
||||
case kLog1pF:
|
||||
return rewriter.create<math::Log1pOp>(loc, v0);
|
||||
case kLog1pC:
|
||||
return rewriter.create<complex::Log1pOp>(loc, v0);
|
||||
case kSinF:
|
||||
return rewriter.create<math::SinOp>(loc, v0);
|
||||
case kSinC:
|
||||
return rewriter.create<complex::SinOp>(loc, v0);
|
||||
case kTanhF:
|
||||
return rewriter.create<math::TanhOp>(loc, v0);
|
||||
case kNegF:
|
||||
return rewriter.create<arith::NegFOp>(loc, v0);
|
||||
case kNegC:
|
||||
return rewriter.create<complex::NegOp>(loc, v0);
|
||||
case kNegI: // no negi in std
|
||||
return rewriter.create<arith::SubIOp>(
|
||||
loc,
|
||||
|
@ -964,6 +1014,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<arith::MulIOp>(loc, v0, v1);
|
||||
case kDivF:
|
||||
return rewriter.create<arith::DivFOp>(loc, v0, v1);
|
||||
case kDivC:
|
||||
return rewriter.create<complex::DivOp>(loc, v0, v1);
|
||||
case kDivS:
|
||||
return rewriter.create<arith::DivSIOp>(loc, v0, v1);
|
||||
case kDivU:
|
||||
|
@ -976,6 +1028,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<arith::AddIOp>(loc, v0, v1);
|
||||
case kSubF:
|
||||
return rewriter.create<arith::SubFOp>(loc, v0, v1);
|
||||
case kSubC:
|
||||
return rewriter.create<complex::SubOp>(loc, v0, v1);
|
||||
case kSubI:
|
||||
return rewriter.create<arith::SubIOp>(loc, v0, v1);
|
||||
case kAndI:
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
// RUN: mlir-opt %s --sparse-compiler | \
|
||||
// RUN: mlir-cpu-runner \
|
||||
// RUN: -e entry -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
|
||||
|
||||
#trait_op1 = {
|
||||
indexing_maps = [
|
||||
affine_map<(i) -> (i)>, // a (in)
|
||||
affine_map<(i) -> (i)> // x (out)
|
||||
],
|
||||
iterator_types = ["parallel"],
|
||||
doc = "x(i) = OP a(i)"
|
||||
}
|
||||
|
||||
#trait_op2 = {
|
||||
indexing_maps = [
|
||||
affine_map<(i) -> (i)>, // a (in)
|
||||
affine_map<(i) -> (i)>, // b (in)
|
||||
affine_map<(i) -> (i)> // x (out)
|
||||
],
|
||||
iterator_types = ["parallel"],
|
||||
doc = "x(i) = a(i) OP b(i)"
|
||||
}
|
||||
|
||||
module {
|
||||
func.func @cops(%arga: tensor<?xcomplex<f64>, #SparseVector>,
|
||||
%argb: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xcomplex<f64>, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%0 = linalg.generic #trait_op2
|
||||
ins(%arga, %argb: tensor<?xcomplex<f64>, #SparseVector>,
|
||||
tensor<?xcomplex<f64>, #SparseVector>)
|
||||
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
|
||||
^bb(%a: complex<f64>, %b: complex<f64>, %x: complex<f64>):
|
||||
%1 = complex.neg %b : complex<f64>
|
||||
%2 = complex.sub %a, %1 : complex<f64>
|
||||
linalg.yield %2 : complex<f64>
|
||||
} -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
return %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @csin(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xcomplex<f64>, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%0 = linalg.generic #trait_op1
|
||||
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
|
||||
^bb(%a: complex<f64>, %x: complex<f64>):
|
||||
%1 = complex.sin %a : complex<f64>
|
||||
linalg.yield %1 : complex<f64>
|
||||
} -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
return %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xcomplex<f64>, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%c = complex.constant [2.0 : f64, 0.0 : f64] : complex<f64>
|
||||
%0 = linalg.generic #trait_op1
|
||||
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
|
||||
^bb(%a: complex<f64>, %x: complex<f64>):
|
||||
%1 = complex.div %a, %c : complex<f64>
|
||||
linalg.yield %1 : complex<f64>
|
||||
} -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
return %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @cabs(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xf64, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%xv = sparse_tensor.init [%d] : tensor<?xf64, #SparseVector>
|
||||
%0 = linalg.generic #trait_op1
|
||||
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
outs(%xv: tensor<?xf64, #SparseVector>) {
|
||||
^bb(%a: complex<f64>, %x: f64):
|
||||
%1 = complex.abs %a : complex<f64>
|
||||
linalg.yield %1 : f64
|
||||
} -> tensor<?xf64, #SparseVector>
|
||||
return %0 : tensor<?xf64, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @dumpc(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
|
||||
scf.for %i = %c0 to %d step %c1 {
|
||||
%v = memref.load %mem[%i] : memref<?xcomplex<f64>>
|
||||
%real = complex.re %v : complex<f64>
|
||||
%imag = complex.im %v : complex<f64>
|
||||
vector.print %real : f64
|
||||
vector.print %imag : f64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func.func @dumpf(%arg0: tensor<?xf64, #SparseVector>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d0 = arith.constant 0.0 : f64
|
||||
%values = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
|
||||
%0 = vector.transfer_read %values[%c0], %d0: memref<?xf64>, vector<3xf64>
|
||||
vector.print %0 : vector<3xf64>
|
||||
return
|
||||
}
|
||||
|
||||
// Driver method to call and verify complex kernels.
|
||||
func.func @entry() {
|
||||
// Setup sparse vectors.
|
||||
%v1 = arith.constant sparse<
|
||||
[ [0], [28], [31] ],
|
||||
[ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f64>>
|
||||
%v2 = arith.constant sparse<
|
||||
[ [1], [28], [31] ],
|
||||
[ (1.0, 0.0), (-2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex<f64>>
|
||||
%sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
|
||||
%sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex<f64>> to tensor<?xcomplex<f64>, #SparseVector>
|
||||
|
||||
// Call sparse vector kernels.
|
||||
%0 = call @cops(%sv1, %sv2)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>,
|
||||
tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%1 = call @csin(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%2 = call @cdiv(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%3 = call @cabs(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
|
||||
|
||||
//
|
||||
// Verify the results.
|
||||
//
|
||||
%d3 = arith.constant 3 : index
|
||||
%d4 = arith.constant 4 : index
|
||||
// CHECK: -5.13
|
||||
// CHECK-NEXT: 2
|
||||
// CHECK-NEXT: 1
|
||||
// CHECK-NEXT: 0
|
||||
// CHECK-NEXT: 1
|
||||
// CHECK-NEXT: 4
|
||||
// CHECK-NEXT: 8
|
||||
// CHECK-NEXT: 6
|
||||
call @dumpc(%0, %d4) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: 3.43887
|
||||
// CHECK-NEXT: 1.47097
|
||||
// CHECK-NEXT: 3.85374
|
||||
// CHECK-NEXT: -27.0168
|
||||
// CHECK-NEXT: -193.43
|
||||
// CHECK-NEXT: 57.2184
|
||||
call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: -2.565
|
||||
// CHECK-NEXT: 1
|
||||
// CHECK-NEXT: 1.5
|
||||
// CHECK-NEXT: 2
|
||||
// CHECK-NEXT: 2.5
|
||||
// CHECK-NEXT: 3
|
||||
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
|
||||
call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
|
||||
|
||||
// Release the resources.
|
||||
sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %sv2 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
|
||||
return
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue