[mlir][sparse] Add more complex operations.

Support complex operations sqrt, expm1, and tanh.

Add tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126393
This commit is contained in:
bixia1 2022-05-24 16:07:31 -07:00
parent 338e76f8ee
commit a14057d4bd
3 changed files with 113 additions and 5 deletions

View File

@ -35,12 +35,15 @@ enum Kind {
kCeilF,
kFloorF,
kSqrtF,
kSqrtC,
kExpm1F,
kExpm1C,
kLog1pF,
kLog1pC,
kSinF,
kSinC,
kTanhF,
kTanhC,
kNegF,
kNegC,
kNegI,

View File

@ -41,12 +41,15 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
case kCeilF:
case kFloorF:
case kSqrtF:
case kSqrtC:
case kExpm1F:
case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kTanhC:
case kNegF:
case kNegC:
case kNegI:
@ -284,12 +287,15 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCeilF:
case kFloorF:
case kSqrtF:
case kSqrtC:
case kExpm1F:
case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kTanhC:
case kNegF:
case kNegC:
case kNegI:
@ -360,8 +366,10 @@ static const char *kindToOpSymbol(Kind kind) {
case kFloorF:
return "floor";
case kSqrtF:
case kSqrtC:
return "sqrt";
case kExpm1F:
case kExpm1C:
return "expm1";
case kLog1pF:
case kLog1pC:
@ -370,6 +378,7 @@ static const char *kindToOpSymbol(Kind kind) {
case kSinC:
return "sin";
case kTanhF:
case kTanhC:
return "tanh";
case kNegF:
case kNegC:
@ -449,10 +458,13 @@ void Merger::dumpExp(unsigned e) const {
case kCeilF:
case kFloorF:
case kSqrtF:
case kSqrtC:
case kExpm1F:
case kExpm1C:
case kLog1pF:
case kSinF:
case kTanhF:
case kTanhC:
case kNegF:
case kNegI:
case kTruncF:
@ -555,12 +567,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
case kCRe:
case kFloorF:
case kSqrtF:
case kSqrtC:
case kExpm1F:
case kExpm1C:
case kLog1pF:
case kLog1pC:
case kSinF:
case kSinC:
case kTanhF:
case kTanhC:
case kNegF:
case kNegC:
case kNegI:
@ -785,8 +800,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kFloorF, e);
if (isa<math::SqrtOp>(def))
return addExp(kSqrtF, e);
if (isa<complex::SqrtOp>(def))
return addExp(kSqrtC, e);
if (isa<math::ExpM1Op>(def))
return addExp(kExpm1F, e);
if (isa<complex::Expm1Op>(def))
return addExp(kExpm1C, e);
if (isa<math::Log1pOp>(def))
return addExp(kLog1pF, e);
if (isa<complex::Log1pOp>(def))
@ -797,6 +816,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kSinC, e);
if (isa<math::TanhOp>(def))
return addExp(kTanhF, e);
if (isa<complex::TanhOp>(def))
return addExp(kTanhC, e);
if (isa<arith::NegFOp>(def))
return addExp(kNegF, e); // no negi in std
if (isa<complex::NegOp>(def))
@ -952,8 +973,12 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<math::FloorOp>(loc, v0);
case kSqrtF:
return rewriter.create<math::SqrtOp>(loc, v0);
case kSqrtC:
return rewriter.create<complex::SqrtOp>(loc, v0);
case kExpm1F:
return rewriter.create<math::ExpM1Op>(loc, v0);
case kExpm1C:
return rewriter.create<complex::Expm1Op>(loc, v0);
case kLog1pF:
return rewriter.create<math::Log1pOp>(loc, v0);
case kLog1pC:
@ -964,6 +989,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<complex::SinOp>(loc, v0);
case kTanhF:
return rewriter.create<math::TanhOp>(loc, v0);
case kTanhC:
return rewriter.create<complex::TanhOp>(loc, v0);
case kNegF:
return rewriter.create<arith::NegFOp>(loc, v0);
case kNegC:

View File

@ -59,6 +59,54 @@ module {
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
func.func @complex_sqrt(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
%0 = linalg.generic #trait_op1
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
^bb(%a: complex<f64>, %x: complex<f64>):
%1 = complex.sqrt %a : complex<f64>
linalg.yield %1 : complex<f64>
} -> tensor<?xcomplex<f64>, #SparseVector>
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
func.func @complex_tanh(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
%0 = linalg.generic #trait_op1
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
^bb(%a: complex<f64>, %x: complex<f64>):
%1 = complex.tanh %a : complex<f64>
linalg.yield %1 : complex<f64>
} -> tensor<?xcomplex<f64>, #SparseVector>
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
func.func @clog1p_expm1(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
%0 = linalg.generic #trait_op1
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
^bb(%a: complex<f64>, %x: complex<f64>):
%1 = complex.log1p %a : complex<f64>
// TODO(bixia): Enable this line after adding complex.expm1 to
// complex to standard lowering.
// %2 = complex.expm1 %1 : complex<f64>
linalg.yield %1 : complex<f64>
} -> tensor<?xcomplex<f64>, #SparseVector>
return %0 : tensor<?xcomplex<f64>, #SparseVector>
}
func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
-> tensor<?xcomplex<f64>, #SparseVector> {
%c0 = arith.constant 0 : index
@ -131,9 +179,15 @@ module {
tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%1 = call @csin(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%2 = call @cdiv(%sv1)
%2 = call @complex_sqrt(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%3 = call @cabs(%sv1)
%3 = call @complex_tanh(%sv2)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%4 = call @clog1p_expm1(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%5 = call @cdiv(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
%6 = call @cabs(%sv1)
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
//
@ -157,15 +211,36 @@ module {
// CHECK-NEXT: -193.43
// CHECK-NEXT: 57.2184
call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: 0.433635
// CHECK-NEXT: 2.30609
// CHECK-NEXT: 2
// CHECK-NEXT: 1
// CHECK-NEXT: 2.53083
// CHECK-NEXT: 1.18538
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: 0.761594
// CHECK-NEXT: 0
// CHECK-NEXT: -0.964028
// CHECK-NEXT: 0
// CHECK-NEXT: 0.995055
// CHECK-NEXT: 0
call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: 1.52361
// CHECK-NEXT: 2.69061
// CHECK-NEXT: 1.73287
// CHECK-NEXT: 0.785398
// CHECK-NEXT: 2.13833
// CHECK-NEXT: 0.785398
call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: -2.565
// CHECK-NEXT: 1
// CHECK-NEXT: 1.5
// CHECK-NEXT: 2
// CHECK-NEXT: 2.5
// CHECK-NEXT: 3
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
// Release the resources.
sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
@ -173,7 +248,10 @@ module {
sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
sparse_tensor.release %3 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %4 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %5 : tensor<?xcomplex<f64>, #SparseVector>
sparse_tensor.release %6 : tensor<?xf64, #SparseVector>
return
}
}