forked from OSchip/llvm-project
Fix include guards and add tests for OpToFuncCallLowering.
PiperOrigin-RevId: 276859463
This commit is contained in:
parent
cde337cfde
commit
780a108d31
|
@ -14,8 +14,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#ifndef THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
|
||||
#define THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
|
||||
#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
|
||||
#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
|
||||
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
|
@ -26,6 +26,15 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
|
||||
/// depending on the element type that Op operates upon. The function
|
||||
/// declaration is added in case it was not added before.
|
||||
///
|
||||
/// Example with NVVM:
|
||||
/// %exp_f32 = std.exp %arg_f32 : f32
|
||||
///
|
||||
/// will be transformed into
|
||||
/// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float
|
||||
template <typename SourceOp>
|
||||
struct OpToFuncCallLowering : public LLVMOpLowering {
|
||||
public:
|
||||
|
@ -48,10 +57,9 @@ public:
|
|||
LLVMType resultType = lowering.convertType(op->getResult(0)->getType())
|
||||
.template cast<LLVM::LLVMType>();
|
||||
LLVMType funcType = getFunctionType(resultType, operands);
|
||||
const std::string funcName = getFunctionName(resultType);
|
||||
if (funcName.empty()) {
|
||||
StringRef funcName = getFunctionName(resultType);
|
||||
if (funcName.empty())
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
|
||||
auto callOp = rewriter.create<LLVM::CallOp>(
|
||||
|
@ -100,4 +108,4 @@ private:
|
|||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // THIRD_PARTY_LLVM_LLVM_PROJECTS_GOOGLE_MLIR_LIB_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
|
||||
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
|
||||
|
|
|
@ -83,3 +83,20 @@ module attributes {gpu.kernel_module} {
|
|||
std.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.kernel_module} {
|
||||
// CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float
|
||||
// CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double
|
||||
// CHECK-LABEL: func @gpu_exp
|
||||
func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
|
||||
%exp_f32 = std.exp %arg_f32 : f32
|
||||
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
|
||||
%result_f32 = std.exp %exp_f32 : f32
|
||||
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float
|
||||
%result64 = std.exp %arg_f64 : f64
|
||||
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double
|
||||
std.return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -lower-gpu-ops-to-rocdl-ops | FileCheck %s
|
||||
// RUN: mlir-opt %s -lower-gpu-ops-to-rocdl-ops -split-input-file | FileCheck %s
|
||||
|
||||
module attributes {gpu.kernel_module} {
|
||||
// CHECK-LABEL: func @gpu_index_ops()
|
||||
|
@ -35,3 +35,20 @@ module attributes {gpu.kernel_module} {
|
|||
std.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {gpu.kernel_module} {
|
||||
// CHECK: llvm.func @_ocml_exp_f32(!llvm.float) -> !llvm.float
|
||||
// CHECK: llvm.func @_ocml_exp_f64(!llvm.double) -> !llvm.double
|
||||
// CHECK-LABEL: func @gpu_exp
|
||||
func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) {
|
||||
%exp_f32 = std.exp %arg_f32 : f32
|
||||
// CHECK: llvm.call @_ocml_exp_f32(%{{.*}}) : (!llvm.float) -> !llvm.float
|
||||
%result_f32 = std.exp %exp_f32 : f32
|
||||
// CHECK: llvm.call @_ocml_exp_f32(%{{.*}}) : (!llvm.float) -> !llvm.float
|
||||
%result64 = std.exp %arg_f64 : f64
|
||||
// CHECK: llvm.call @_ocml_exp_f64(%{{.*}}) : (!llvm.double) -> !llvm.double
|
||||
std.return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue