Fix include guards and add tests for OpToFuncCallLowering.

PiperOrigin-RevId: 276859463
This commit is contained in:
Alexander Belyaev 2019-10-26 08:20:59 -07:00 committed by A. Unique TensorFlower
parent cde337cfde
commit 780a108d31
3 changed files with 49 additions and 7 deletions

View File

@ -14,8 +14,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#define THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@ -26,6 +26,15 @@
namespace mlir {
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
/// depending on the element type that Op operates upon. The function
/// declaration is added in case it was not added before.
///
/// Example with NVVM:
/// %exp_f32 = std.exp %arg_f32 : f32
///
/// will be transformed into
/// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float
template <typename SourceOp>
struct OpToFuncCallLowering : public LLVMOpLowering {
public:
@ -48,10 +57,9 @@ public:
LLVMType resultType = lowering.convertType(op->getResult(0)->getType())
.template cast<LLVM::LLVMType>();
LLVMType funcType = getFunctionType(resultType, operands);
const std::string funcName = getFunctionName(resultType);
if (funcName.empty()) {
StringRef funcName = getFunctionName(resultType);
if (funcName.empty())
return matchFailure();
}
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
@ -100,4 +108,4 @@ private:
} // namespace mlir
#endif // THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_

View File

@ -83,3 +83,20 @@ module attributes {gpu.kernel_module} {
std.return
}
}
// -----
module attributes {gpu.kernel_module} {
// CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float
// CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double
// CHECK-LABEL: func @gpu_exp
func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
%exp_f32 = std.exp %arg_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
%result_f32 = std.exp %exp_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
%result64 = std.exp %arg_f64 : f64
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double
std.return
}
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -lower-gpu-ops-to-rocdl-ops | FileCheck %s
// RUN: mlir-opt %s -lower-gpu-ops-to-rocdl-ops -split-input-file | FileCheck %s
module attributes {gpu.kernel_module} {
// CHECK-LABEL: func @gpu_index_ops()
@ -35,3 +35,20 @@ module attributes {gpu.kernel_module} {
std.return
}
}
// -----
module attributes {gpu.kernel_module} {
// CHECK: llvm.func @_ocml_exp_f32(!llvm.float) -> !llvm.float
// CHECK: llvm.func @_ocml_exp_f64(!llvm.double) -> !llvm.double
// CHECK-LABEL: func @gpu_exp
func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
%exp_f32 = std.exp %arg_f32 : f32
// CHECK: llvm.call @_ocml_exp_f32(%{{.*}}) : (!llvm.float) -> !llvm.float
%result_f32 = std.exp %exp_f32 : f32
// CHECK: llvm.call @_ocml_exp_f32(%{{.*}}) : (!llvm.float) -> !llvm.float
%result64 = std.exp %arg_f64 : f64
// CHECK: llvm.call @_ocml_exp_f64(%{{.*}}) : (!llvm.double) -> !llvm.double
std.return
}
}