[MLIR] Add the sqrt operation to mlir.

Summary: Add and pipe through the sqrt operation for Standard and LLVM dialects.

Reviewers: nicolasvasilache, ftynse

Reviewed By: ftynse

Subscribers: frej, ftynse, merge_guards_bot, flaub, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73571
This commit is contained in:
Lubomir Litchev 2020-01-30 07:44:44 -08:00 committed by Frank Laub
parent 38ab3b876b
commit fcabccd3d9
7 changed files with 63 additions and 3 deletions

View File

@ -587,6 +587,25 @@ operand and returns one result of the same type. This type may be a float
scalar type, a vector whose element type is float, or a tensor of floats. It
has no standard attributes.
### 'sqrt' operation
Syntax:
```
operation ::= ssa-id `=` `sqrt` ssa-use `:` type
```
Examples:
```mlir
// Scalar square root value.
%a = sqrt %b : f64
// SIMD vector element-wise square root value.
%f = sqrt %g : vector<4xf32>
// Tensor element-wise square root value.
%x = sqrt %y : tensor<4x?xf32>
```
### 'tanh' operation
Syntax:

View File

@ -716,6 +716,7 @@ def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
def LLVM_LogOp : LLVM_Op<"intr.log", [NoSideEffect]>,
Arguments<(ins LLVM_Type:$in)>,

View File

@ -1402,6 +1402,16 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
let hasCanonicalizer = 1;
}
def SqrtOp : FloatUnaryOp<"sqrt"> {
let summary = "sqrt of the specified value";
let description = [{
The `sqrt` operation computes the square root. It takes one operand and
returns one result of the same type. This type may be a float scalar type, a
vector whose element type is float, or a tensor of floats. It has no standard
attributes.
}];
}
def TanhOp : FloatUnaryOp<"tanh"> {
let summary = "hyperbolic tangent of the specified value";
let description = [{

View File

@ -807,6 +807,9 @@ struct SignedDivIOpLowering
: public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
using Super::Super;
};
struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
using Super::Super;
};
struct UnsignedDivIOpLowering
: public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
using Super::Super;
@ -2108,6 +2111,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
SignedShiftRightOpLowering,
SplatOpLowering,
SplatNdOpLowering,
SqrtOpLowering,
SubFOpLowering,
SubIOpLowering,
TanhOpLowering,

View File

@ -398,8 +398,8 @@ func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>
}
// CHECK-LABEL: @ops
func @ops(f32, f32, i32, i32) -> (f32, i32) {
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32):
func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
// CHECK-NEXT: %0 = llvm.fsub %arg0, %arg1 : !llvm.float
%0 = subf %arg0, %arg1: f32
// CHECK-NEXT: %1 = llvm.sub %arg2, %arg3 : !llvm.i32
@ -440,7 +440,10 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
%19 = shift_right_signed %arg2, %arg3 : i32
// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
%20 = shift_right_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
%21 = std.sqrt %arg0 : f32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
%22 = std.sqrt %arg4 : f64
return %0, %4 : f32, i32
}

View File

@ -494,6 +494,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32>
%138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32>
// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
%139 = "std.sqrt"(%f) : (f32) -> f32
// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
%140 = sqrt %f : f32
// CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32>
%141 = sqrt %vcf32 : vector<4xf32>
// CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
%142 = sqrt %t : tensor<4x4x?xf32>
return
}

View File

@ -59,6 +59,15 @@ llvm.func @fabs_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
llvm.return
}
// CHECK-LABEL: @sqrt_test
llvm.func @sqrt_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.sqrt.f32
"llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
// CHECK: call <8 x float> @llvm.sqrt.v8f32
"llvm.intr.sqrt"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>">
llvm.return
}
// CHECK-LABEL: @ceil_test
llvm.func @ceil_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
// CHECK: call float @llvm.ceil.f32
@ -100,6 +109,8 @@ llvm.func @copysign_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<
// CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0
// CHECK: declare float @llvm.fabs.f32(float)
// CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0
// CHECK: declare float @llvm.sqrt.f32(float)
// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0
// CHECK: declare float @llvm.ceil.f32(float)
// CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
// CHECK: declare float @llvm.cos.f32(float)