forked from OSchip/llvm-project
[MLIR] Add the sqrt operation to mlir.
Summary: Add and pipe through the sqrt operation for Standard and LLVM dialects. Reviewers: nicolasvasilache, ftynse Reviewed By: ftynse Subscribers: frej, ftynse, merge_guards_bot, flaub, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73571
This commit is contained in:
parent
38ab3b876b
commit
fcabccd3d9
|
@ -587,6 +587,25 @@ operand and returns one result of the same type. This type may be a float
|
|||
scalar type, a vector whose element type is float, or a tensor of floats. It
|
||||
has no standard attributes.
|
||||
|
||||
### 'sqrt' operation
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `sqrt` ssa-use `:` type
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir
|
||||
// Scalar square root value.
|
||||
%a = sqrt %b : f64
|
||||
// SIMD vector element-wise square root value.
|
||||
%f = sqrt %g : vector<4xf32>
|
||||
// Tensor element-wise square root value.
|
||||
%x = sqrt %y : tensor<4x?xf32>
|
||||
```
|
||||
|
||||
### 'tanh' operation
|
||||
|
||||
Syntax:
|
||||
|
|
|
@ -716,6 +716,7 @@ def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
|
|||
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
|
||||
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
|
||||
def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
|
||||
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
|
||||
|
||||
def LLVM_LogOp : LLVM_Op<"intr.log", [NoSideEffect]>,
|
||||
Arguments<(ins LLVM_Type:$in)>,
|
||||
|
|
|
@ -1402,6 +1402,16 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def SqrtOp : FloatUnaryOp<"sqrt"> {
|
||||
let summary = "sqrt of the specified value";
|
||||
let description = [{
|
||||
The `sqrt` operation computes the square root. It takes one operand and
|
||||
returns one result of the same type. This type may be a float scalar type, a
|
||||
vector whose element type is float, or a tensor of floats. It has no standard
|
||||
attributes.
|
||||
}];
|
||||
}
|
||||
|
||||
def TanhOp : FloatUnaryOp<"tanh"> {
|
||||
let summary = "hyperbolic tangent of the specified value";
|
||||
let description = [{
|
||||
|
|
|
@ -807,6 +807,9 @@ struct SignedDivIOpLowering
|
|||
: public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct UnsignedDivIOpLowering
|
||||
: public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
|
||||
using Super::Super;
|
||||
|
@ -2108,6 +2111,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
|||
SignedShiftRightOpLowering,
|
||||
SplatOpLowering,
|
||||
SplatNdOpLowering,
|
||||
SqrtOpLowering,
|
||||
SubFOpLowering,
|
||||
SubIOpLowering,
|
||||
TanhOpLowering,
|
||||
|
|
|
@ -398,8 +398,8 @@ func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @ops
|
||||
func @ops(f32, f32, i32, i32) -> (f32, i32) {
|
||||
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32):
|
||||
func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
|
||||
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
|
||||
// CHECK-NEXT: %0 = llvm.fsub %arg0, %arg1 : !llvm.float
|
||||
%0 = subf %arg0, %arg1: f32
|
||||
// CHECK-NEXT: %1 = llvm.sub %arg2, %arg3 : !llvm.i32
|
||||
|
@ -440,7 +440,10 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
|
|||
%19 = shift_right_signed %arg2, %arg3 : i32
|
||||
// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
|
||||
%20 = shift_right_unsigned %arg2, %arg3 : i32
|
||||
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
|
||||
%21 = std.sqrt %arg0 : f32
|
||||
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
|
||||
%22 = std.sqrt %arg4 : f64
|
||||
return %0, %4 : f32, i32
|
||||
}
|
||||
|
||||
|
|
|
@ -494,6 +494,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
|
|||
// CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32>
|
||||
%138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
|
||||
%139 = "std.sqrt"(%f) : (f32) -> f32
|
||||
|
||||
// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
|
||||
%140 = sqrt %f : f32
|
||||
|
||||
// CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32>
|
||||
%141 = sqrt %vcf32 : vector<4xf32>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
|
||||
%142 = sqrt %t : tensor<4x4x?xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -59,6 +59,15 @@ llvm.func @fabs_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
|
|||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sqrt_test
|
||||
llvm.func @sqrt_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
|
||||
// CHECK: call float @llvm.sqrt.f32
|
||||
"llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
|
||||
// CHECK: call <8 x float> @llvm.sqrt.v8f32
|
||||
"llvm.intr.sqrt"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>">
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ceil_test
|
||||
llvm.func @ceil_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
|
||||
// CHECK: call float @llvm.ceil.f32
|
||||
|
@ -100,6 +109,8 @@ llvm.func @copysign_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<
|
|||
// CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0
|
||||
// CHECK: declare float @llvm.fabs.f32(float)
|
||||
// CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0
|
||||
// CHECK: declare float @llvm.sqrt.f32(float)
|
||||
// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0
|
||||
// CHECK: declare float @llvm.ceil.f32(float)
|
||||
// CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
|
||||
// CHECK: declare float @llvm.cos.f32(float)
|
||||
|
|
Loading…
Reference in New Issue