forked from OSchip/llvm-project
[mlir][sparse] Add more complex operations.
Support complex operations sqrt, expm1, and tanh. Add tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126393
This commit is contained in:
parent
338e76f8ee
commit
a14057d4bd
|
@ -35,12 +35,15 @@ enum Kind {
|
|||
kCeilF,
|
||||
kFloorF,
|
||||
kSqrtF,
|
||||
kSqrtC,
|
||||
kExpm1F,
|
||||
kExpm1C,
|
||||
kLog1pF,
|
||||
kLog1pC,
|
||||
kSinF,
|
||||
kSinC,
|
||||
kTanhF,
|
||||
kTanhC,
|
||||
kNegF,
|
||||
kNegC,
|
||||
kNegI,
|
||||
|
|
|
@ -41,12 +41,15 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
|
|||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kSqrtC:
|
||||
case kExpm1F:
|
||||
case kExpm1C:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kTanhC:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
|
@ -284,12 +287,15 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
|
|||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kSqrtC:
|
||||
case kExpm1F:
|
||||
case kExpm1C:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kTanhC:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
|
@ -360,8 +366,10 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
case kFloorF:
|
||||
return "floor";
|
||||
case kSqrtF:
|
||||
case kSqrtC:
|
||||
return "sqrt";
|
||||
case kExpm1F:
|
||||
case kExpm1C:
|
||||
return "expm1";
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
|
@ -370,6 +378,7 @@ static const char *kindToOpSymbol(Kind kind) {
|
|||
case kSinC:
|
||||
return "sin";
|
||||
case kTanhF:
|
||||
case kTanhC:
|
||||
return "tanh";
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
|
@ -449,10 +458,13 @@ void Merger::dumpExp(unsigned e) const {
|
|||
case kCeilF:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kSqrtC:
|
||||
case kExpm1F:
|
||||
case kExpm1C:
|
||||
case kLog1pF:
|
||||
case kSinF:
|
||||
case kTanhF:
|
||||
case kTanhC:
|
||||
case kNegF:
|
||||
case kNegI:
|
||||
case kTruncF:
|
||||
|
@ -555,12 +567,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
|
|||
case kCRe:
|
||||
case kFloorF:
|
||||
case kSqrtF:
|
||||
case kSqrtC:
|
||||
case kExpm1F:
|
||||
case kExpm1C:
|
||||
case kLog1pF:
|
||||
case kLog1pC:
|
||||
case kSinF:
|
||||
case kSinC:
|
||||
case kTanhF:
|
||||
case kTanhC:
|
||||
case kNegF:
|
||||
case kNegC:
|
||||
case kNegI:
|
||||
|
@ -785,8 +800,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
return addExp(kFloorF, e);
|
||||
if (isa<math::SqrtOp>(def))
|
||||
return addExp(kSqrtF, e);
|
||||
if (isa<complex::SqrtOp>(def))
|
||||
return addExp(kSqrtC, e);
|
||||
if (isa<math::ExpM1Op>(def))
|
||||
return addExp(kExpm1F, e);
|
||||
if (isa<complex::Expm1Op>(def))
|
||||
return addExp(kExpm1C, e);
|
||||
if (isa<math::Log1pOp>(def))
|
||||
return addExp(kLog1pF, e);
|
||||
if (isa<complex::Log1pOp>(def))
|
||||
|
@ -797,6 +816,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
|||
return addExp(kSinC, e);
|
||||
if (isa<math::TanhOp>(def))
|
||||
return addExp(kTanhF, e);
|
||||
if (isa<complex::TanhOp>(def))
|
||||
return addExp(kTanhC, e);
|
||||
if (isa<arith::NegFOp>(def))
|
||||
return addExp(kNegF, e); // no negi in std
|
||||
if (isa<complex::NegOp>(def))
|
||||
|
@ -952,8 +973,12 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<math::FloorOp>(loc, v0);
|
||||
case kSqrtF:
|
||||
return rewriter.create<math::SqrtOp>(loc, v0);
|
||||
case kSqrtC:
|
||||
return rewriter.create<complex::SqrtOp>(loc, v0);
|
||||
case kExpm1F:
|
||||
return rewriter.create<math::ExpM1Op>(loc, v0);
|
||||
case kExpm1C:
|
||||
return rewriter.create<complex::Expm1Op>(loc, v0);
|
||||
case kLog1pF:
|
||||
return rewriter.create<math::Log1pOp>(loc, v0);
|
||||
case kLog1pC:
|
||||
|
@ -964,6 +989,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
|
|||
return rewriter.create<complex::SinOp>(loc, v0);
|
||||
case kTanhF:
|
||||
return rewriter.create<math::TanhOp>(loc, v0);
|
||||
case kTanhC:
|
||||
return rewriter.create<complex::TanhOp>(loc, v0);
|
||||
case kNegF:
|
||||
return rewriter.create<arith::NegFOp>(loc, v0);
|
||||
case kNegC:
|
||||
|
|
|
@ -59,6 +59,54 @@ module {
|
|||
return %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @complex_sqrt(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xcomplex<f64>, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%0 = linalg.generic #trait_op1
|
||||
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
|
||||
^bb(%a: complex<f64>, %x: complex<f64>):
|
||||
%1 = complex.sqrt %a : complex<f64>
|
||||
linalg.yield %1 : complex<f64>
|
||||
} -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
return %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @complex_tanh(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xcomplex<f64>, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%0 = linalg.generic #trait_op1
|
||||
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
|
||||
^bb(%a: complex<f64>, %x: complex<f64>):
|
||||
%1 = complex.tanh %a : complex<f64>
|
||||
linalg.yield %1 : complex<f64>
|
||||
} -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
return %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @clog1p_expm1(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xcomplex<f64>, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
|
||||
%0 = linalg.generic #trait_op1
|
||||
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
|
||||
^bb(%a: complex<f64>, %x: complex<f64>):
|
||||
%1 = complex.log1p %a : complex<f64>
|
||||
// TODO(bixia): Enable this line after adding complex.expm1 to
|
||||
// complex to standard lowering.
|
||||
// %2 = complex.expm1 %1 : complex<f64>
|
||||
linalg.yield %1 : complex<f64>
|
||||
} -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
return %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
}
|
||||
|
||||
func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
|
||||
-> tensor<?xcomplex<f64>, #SparseVector> {
|
||||
%c0 = arith.constant 0 : index
|
||||
|
@ -131,9 +179,15 @@ module {
|
|||
tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%1 = call @csin(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%2 = call @cdiv(%sv1)
|
||||
%2 = call @complex_sqrt(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%3 = call @cabs(%sv1)
|
||||
%3 = call @complex_tanh(%sv2)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%4 = call @clog1p_expm1(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%5 = call @cdiv(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
|
||||
%6 = call @cabs(%sv1)
|
||||
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
|
||||
|
||||
//
|
||||
|
@ -157,15 +211,36 @@ module {
|
|||
// CHECK-NEXT: -193.43
|
||||
// CHECK-NEXT: 57.2184
|
||||
call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: 0.433635
|
||||
// CHECK-NEXT: 2.30609
|
||||
// CHECK-NEXT: 2
|
||||
// CHECK-NEXT: 1
|
||||
// CHECK-NEXT: 2.53083
|
||||
// CHECK-NEXT: 1.18538
|
||||
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: 0.761594
|
||||
// CHECK-NEXT: 0
|
||||
// CHECK-NEXT: -0.964028
|
||||
// CHECK-NEXT: 0
|
||||
// CHECK-NEXT: 0.995055
|
||||
// CHECK-NEXT: 0
|
||||
call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: 1.52361
|
||||
// CHECK-NEXT: 2.69061
|
||||
// CHECK-NEXT: 1.73287
|
||||
// CHECK-NEXT: 0.785398
|
||||
// CHECK-NEXT: 2.13833
|
||||
// CHECK-NEXT: 0.785398
|
||||
call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: -2.565
|
||||
// CHECK-NEXT: 1
|
||||
// CHECK-NEXT: 1.5
|
||||
// CHECK-NEXT: 2
|
||||
// CHECK-NEXT: 2.5
|
||||
// CHECK-NEXT: 3
|
||||
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
|
||||
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
|
||||
call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
|
||||
call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
|
||||
|
||||
// Release the resources.
|
||||
sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
|
@ -173,7 +248,10 @@ module {
|
|||
sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
|
||||
sparse_tensor.release %3 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %4 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %5 : tensor<?xcomplex<f64>, #SparseVector>
|
||||
sparse_tensor.release %6 : tensor<?xf64, #SparseVector>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue