Enable TTI for host TargetMachine in JitRunner

This commit improves JitRunner so that it creates a target machine
for the current CPU host which is used to properly initialize LLVM's
TargetTransformInfo for such a target. This will enable optimizations
such as vectorization in LLVM when using JitRunner. Please, note that,
as part of this work, JITTargetMachineBuilder::detectHost() has been
extended to include the host CPU name and sub-target features as part of
the host CPU detection (https://reviews.llvm.org/D65760).

Closes tensorflow/mlir#71

PiperOrigin-RevId: 262452525
This commit is contained in:
Diego Caballero 2019-08-08 16:02:50 -07:00 committed by A. Unique TensorFlower
parent f525a497ea
commit 96371d25c3
1 changed files with 13 additions and 1 deletions

View File

@ -40,6 +40,7 @@
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h" #include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h" #include "llvm/IR/LegacyPassNameParser.h"
@ -308,8 +309,19 @@ int mlir::JitRunnerMain(
if (failed(mlirTransformer(m.get()))) if (failed(mlirTransformer(m.get())))
return EXIT_FAILURE; return EXIT_FAILURE;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!tmBuilderOrError) {
llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
return EXIT_FAILURE;
}
auto tmOrError = tmBuilderOrError->createTargetMachine();
if (!tmOrError) {
llvm::errs() << "Failed to create a TargetMachine for the host\n";
return EXIT_FAILURE;
}
auto transformer = mlir::makeLLVMPassesTransformer( auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/nullptr, optPosition); passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
auto error = mainFuncType.getValue() == "f32" auto error = mainFuncType.getValue() == "f32"
? compileAndExecuteSingleFloatReturnFunction( ? compileAndExecuteSingleFloatReturnFunction(
m.get(), mainFuncName.getValue(), transformer) m.get(), mainFuncName.getValue(), transformer)