Pipe Linalg to a cblas call via mlir-cpu-runner

This CL extends the execution engine to allow the additional resolution of symbols names
    that have been registered explicitly. This allows linking static library symbols that have not been explicitly exported with the -rdynamic linking flag (which is deemed too intrusive).

--

PiperOrigin-RevId: 247969504
This commit is contained in:
Nicolas Vasilache 2019-05-13 10:59:04 -07:00 committed by Mehdi Amini
parent 9cc5747a7b
commit 5c64d2a6c4
8 changed files with 456 additions and 275 deletions

View File

@ -75,6 +75,52 @@ private:
namespace mlir {
namespace impl {
/// Wrapper class around DynamicLibrarySearchGenerator to allow searching
/// in-process symbols that have not been explicitly exported.
/// This first tries to resolve a symbol by using DynamicLibrarySearchGenerator.
/// For symbols that are not found this way, it then uses
/// `llvm::sys::DynamicLibrary::SearchForAddressOfSymbol` to extract symbols
/// that have been explicitly added with `llvm::sys::DynamicLibrary::AddSymbol`,
/// previously.
class SearchGenerator {
public:
SearchGenerator(char GlobalPrefix)
: defaultGenerator(cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
GlobalPrefix))) {}
// This function forwards to DynamicLibrarySearchGenerator::operator() and
// adds an extra resolution for names explicitly registered via
// `llvm::sys::DynamicLibrary::AddSymbol`.
Expected<llvm::orc::SymbolNameSet>
operator()(llvm::orc::JITDylib &JD, const llvm::orc::SymbolNameSet &Names) {
auto res = defaultGenerator(JD, Names);
if (!res)
return res;
llvm::orc::SymbolMap newSymbols;
for (auto &Name : Names) {
if (res.get().count(Name) > 0)
continue;
res.get().insert(Name);
auto addedSymbolAddress =
llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(*Name);
if (!addedSymbolAddress)
continue;
llvm::JITEvaluatedSymbol Sym(
reinterpret_cast<uintptr_t>(addedSymbolAddress),
llvm::JITSymbolFlags::Exported);
newSymbols[Name] = Sym;
}
if (!newSymbols.empty())
cantFail(JD.define(absoluteSymbols(std::move(newSymbols))));
return res;
}
private:
llvm::orc::DynamicLibrarySearchGenerator defaultGenerator;
};
// Simple layered Orc JIT compilation engine.
class OrcJIT {
public:
@ -82,8 +128,8 @@ public:
// Construct a JIT engine for the target host defined by `machineBuilder`,
// using the data layout provided as `dataLayout`.
// Setup the object layer to use our custom memory manager in order to resolve
// calls to library functions present in the process.
// Setup the object layer to use our custom memory manager in order to
// resolve calls to library functions present in the process.
OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
llvm::DataLayout layout, IRTransformer transform)
: irTransformer(transform),
@ -97,8 +143,7 @@ public:
dataLayout(layout), mangler(session, this->dataLayout),
threadSafeCtx(llvm::make_unique<llvm::LLVMContext>()) {
session.getMainJITDylib().setGenerator(
cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
layout.getGlobalPrefix())));
SearchGenerator(layout.getGlobalPrefix()));
}
// Create a JIT engine for the current host.
@ -130,8 +175,8 @@ public:
private:
// Wrap the `irTransformer` into a function that can be called by the
// IRTranformLayer. If `irTransformer` is not set up, return the module as is
// without errors.
// IRTranformLayer. If `irTransformer` is not set up, return the module as
// is without errors.
llvm::orc::IRTransformLayer::TransformFunction makeIRTransformFunction() {
return [this](llvm::orc::ThreadSafeModule module,
const llvm::orc::MaterializationResponsibility &resp)

View File

@ -1,8 +1,9 @@
add_subdirectory(mlir-cpu-runner)
llvm_canonicalize_cmake_booleans(
LLVM_BUILD_EXAMPLES
)
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
@ -19,6 +20,7 @@ configure_lit_site_cfg(
set(MLIR_TEST_DEPENDS
FileCheck count not
MLIRUnitTests
mlir-blas-cpu-runner
mlir-cpu-runner
mlir-opt
mlir-tblgen

View File

@ -0,0 +1,27 @@
set(LIBS
MLIRAffineOps
MLIRAnalysis
MLIREDSC
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRParser
MLIRTargetLLVMIR
MLIRTransforms
MLIRSupport
MLIRCPURunnerLib
LLVMCore
LLVMSupport
)
add_executable(mlir-blas-cpu-runner
mlir-blas-cpu-runner.cpp
)
llvm_update_compile_flags(mlir-blas-cpu-runner)
whole_archive_link(mlir-blas-cpu-runner
MLIRLLVMIR
MLIRStandardOps
MLIRTargetLLVMIR
MLIRTransforms
MLIRTranslation
)
target_link_libraries(mlir-blas-cpu-runner MLIRIR ${LIBS})

View File

@ -0,0 +1,47 @@
//===- mlir-blas-cpu-runner.cpp - MLIR CPU Execution Driver + Blas Support ===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Main entry point.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DynamicLibrary.h"
#ifdef WITH_LAPACK
#include "lapack/cblas.h"
#else
extern "C" float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY) {
float res = 0.0f;
for (int i = 0; i < N; ++i)
res += X[i * incX] * Y[i * incY];
return res;
}
#endif
extern int run(int argc, char **argv);
void addSymbols() {
using llvm::sys::DynamicLibrary;
DynamicLibrary::AddSymbol("cblas_sdot", (void *)(&cblas_sdot));
}
int main(int argc, char **argv) {
addSymbols();
return run(argc, argv);
}

View File

@ -1,8 +1,29 @@
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-blas-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
func @linalg_dot(!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
!llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
func @cblas_sdot(!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> !llvm.float
func @linalg_dot(%arg0 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
%arg1 : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">,
%arg2 : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
%n = llvm.extractvalue %arg0[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
%x0 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
%x1 = llvm.extractvalue %arg0[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
%x = llvm.getelementptr %x0[%x1] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
%inc_x = llvm.extractvalue %arg0[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
%y0 = llvm.extractvalue %arg1[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
%y1 = llvm.extractvalue %arg1[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
%y = llvm.getelementptr %y0[%y1] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
%inc_y = llvm.extractvalue %arg1[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
%res = llvm.call @cblas_sdot(%n, %x, %inc_x, %y, %inc_y) : (!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> (!llvm.float)
%0 = llvm.extractvalue %arg2[0] : !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">
%old = llvm.load %0 : !llvm<"float*">
%new = llvm.fadd %res, %old : !llvm.float
llvm.store %new, %0 : !llvm<"float*">
return
}
@ -41,18 +62,21 @@ func @entry1() -> f32 {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c16 = constant 16 : index
%f0 = constant 0.00000e+00 : f32
%f1 = constant 0.00000e+00 : f32
%f10 = constant 10.00000e+00 : f32
%f1 = constant 1.00000e+00 : f32
%f2 = constant 2.00000e+00 : f32
%A = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<f32>)
%B = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<f32>)
%C = call @alloc_filled_f32(%c1, %f0) : (index, f32) -> (!linalg.buffer<f32>)
%C = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<f32>)
%res = call @dot(%A, %B, %C) : (!linalg.buffer<f32>, !linalg.buffer<f32>, !linalg.buffer<f32>) -> (f32)
linalg.buffer_dealloc %C : !linalg.buffer<f32>
linalg.buffer_dealloc %B : !linalg.buffer<f32>
linalg.buffer_dealloc %A : !linalg.buffer<f32>
return %res : f32
}
// CHECK: 0.{{0+}}e+00
// CHECK: 4.2{{0+}}e+01

View File

@ -1,3 +1,8 @@
set(LLVM_OPTIONAL_SOURCES
mlir-cpu-runner-lib.cpp
mlir-cpu-runner.cpp
)
set(LIBS
MLIRAffineOps
MLIRAnalysis
@ -12,9 +17,15 @@ set(LIBS
LLVMCore
LLVMSupport
)
add_llvm_library(MLIRCPURunnerLib
mlir-cpu-runner-lib.cpp
)
target_link_libraries(MLIRCPURunnerLib ${LIBS})
add_executable(mlir-cpu-runner
mlir-cpu-runner.cpp
)
llvm_update_compile_flags(mlir-cpu-runner)
whole_archive_link(mlir-cpu-runner MLIRLLVMIR MLIRStandardOps MLIRTargetLLVMIR MLIRTransforms MLIRTranslation)
target_link_libraries(mlir-cpu-runner MLIRIR ${LIBS})
target_link_libraries(mlir-cpu-runner MLIRIR ${LIBS} MLIRCPURunnerLib)

View File

@ -0,0 +1,279 @@
//===- mlir-cpu-runner-lib.cpp - MLIR CPU Execution Driver Library --------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is a command line utility that executes an MLIR file on the CPU by
// translating MLIR to LLVM IR before JIT-compiling and executing the latter.
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include <numeric>
using namespace mlir;
using llvm::Error;
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string>
initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
static llvm::cl::opt<std::string>
mainFuncName("e", llvm::cl::desc("The function to be called"),
llvm::cl::value_desc("<function name>"),
llvm::cl::init("main"));
static llvm::cl::opt<std::string> mainFuncType(
"entry-point-result",
llvm::cl::desc("Textual description of the function type to be called"),
llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
static llvm::cl::OptionCategory optFlags("opt-like flags");
// CLI list of pass information
static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
llvm::cl::cat(optFlags));
// CLI variables for -On options.
static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
llvm::cl::cat(optFlags));
static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
llvm::cl::cat(optFlags));
static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
llvm::cl::cat(optFlags));
static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
llvm::cl::cat(optFlags));
static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
MLIRContext *context) {
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
}
// Initialize the relevant subsystems of LLVM.
static void initializeLLVM() {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
}
static inline Error make_string_error(const llvm::Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
static void printOneMemRef(Type t, void *val) {
auto memRefType = t.cast<MemRefType>();
auto shape = memRefType.getShape();
int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
for (int64_t i = 0; i < size; ++i) {
llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
}
llvm::outs() << '\n';
}
static void printMemRefArguments(ArrayRef<Type> argTypes,
ArrayRef<Type> resTypes,
ArrayRef<void *> args) {
auto properArgs = args.take_front(argTypes.size());
for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
auto type = std::get<0>(kvp);
auto val = std::get<1>(kvp);
printOneMemRef(type, val);
}
auto results = args.drop_front(argTypes.size());
for (const auto &kvp : llvm::zip(resTypes, results)) {
auto type = std::get<0>(kvp);
auto val = std::get<1>(kvp);
printOneMemRef(type, val);
}
}
static Error compileAndExecuteFunctionWithMemRefs(
Module *module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
Function *mainFunction = module->getNamedFunction(entryPoint);
if (!mainFunction || mainFunction->getBlocks().empty()) {
return make_string_error("entry point not found");
}
// Store argument and result types of the original function necessary to
// pretty print the results, because the function itself will be rewritten
// to use the LLVM dialect.
SmallVector<Type, 8> argTypes =
llvm::to_vector<8>(mainFunction->getType().getInputs());
SmallVector<Type, 8> resTypes =
llvm::to_vector<8>(mainFunction->getType().getResults());
float init = std::stof(initValue.getValue());
auto expectedArguments = allocateMemRefArguments(mainFunction, init);
if (!expectedArguments)
return expectedArguments.takeError();
auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(expectedArguments->data());
printMemRefArguments(argTypes, resTypes, *expectedArguments);
freeMemRefArguments(*expectedArguments);
return Error::success();
}
static Error compileAndExecuteSingleFloatReturnFunction(
Module *module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
Function *mainFunction = module->getNamedFunction(entryPoint);
if (!mainFunction || mainFunction->isExternal()) {
return make_string_error("entry point not found");
}
if (!mainFunction->getType().getInputs().empty())
return make_string_error("function inputs not supported");
if (mainFunction->getType().getResults().size() != 1)
return make_string_error("only single f32 function result supported");
auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
if (!t)
return make_string_error("only single llvm.f32 function result supported");
auto *llvmTy = t.getUnderlyingType();
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
return make_string_error("only single llvm.f32 function result supported");
auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
float res;
struct {
void *data;
} data;
data.data = &res;
(*fptr)((void **)&data);
// Intentional printing of the output so we can test.
llvm::outs() << res;
return Error::success();
}
int run(int argc, char **argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
llvm::InitLLVM y(argc, argv);
initializeLLVM();
mlir::initializeLLVMPasses();
llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
optO0, optO1, optO2, optO3};
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
llvm::SmallVector<const llvm::PassInfo *, 4> passes;
llvm::Optional<unsigned> optLevel;
unsigned optCLIPosition = 0;
// Determine if there is an optimization flag present, and its CLI position
// (optCLIPosition).
for (unsigned j = 0; j < 4; ++j) {
auto &flag = optFlags[j].get();
if (flag) {
optLevel = j;
optCLIPosition = flag.getPosition();
break;
}
}
// Generate vector of pass information, plus the index at which we should
// insert any optimization passes in that vector (optPosition).
unsigned optPosition = 0;
for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
passes.push_back(llvmPasses[i]);
if (optCLIPosition < llvmPasses.getPosition(i)) {
optPosition = i;
optCLIPosition = UINT_MAX; // To ensure we never insert again
}
}
MLIRContext context;
auto m = parseMLIRInput(inputFilename, &context);
if (!m) {
llvm::errs() << "could not parse the input IR\n";
return 1;
}
auto transformer =
mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
auto error = mainFuncType.getValue() == "f32"
? compileAndExecuteSingleFloatReturnFunction(
m.get(), mainFuncName.getValue(), transformer)
: compileAndExecuteFunctionWithMemRefs(
m.get(), mainFuncName.getValue(), transformer);
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {
llvm::errs() << "Error: ";
info.log(llvm::errs());
llvm::errs() << '\n';
exitCode = EXIT_FAILURE;
});
return exitCode;
}

View File

@ -15,265 +15,11 @@
// limitations under the License.
// =============================================================================
//
// This is a command line utility that executes an MLIR file on the CPU by
// translating MLIR to LLVM IR before JIT-compiling and executing the latter.
// Main entry point to a command line utility that executes an MLIR file on the
// CPU by translating MLIR to LLVM IR before JIT-compiling and executing the
// latter.
//
//===----------------------------------------------------------------------===//
extern int run(int argc, char **argv);
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include <numeric>
using namespace mlir;
using llvm::Error;
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string>
initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
static llvm::cl::opt<std::string>
mainFuncName("e", llvm::cl::desc("The function to be called"),
llvm::cl::value_desc("<function name>"),
llvm::cl::init("main"));
static llvm::cl::opt<std::string> mainFuncType(
"entry-point-result",
llvm::cl::desc("Textual description of the function type to be called"),
llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
static llvm::cl::OptionCategory optFlags("opt-like flags");
// CLI list of pass information
static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
llvm::cl::cat(optFlags));
// CLI variables for -On options.
static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
llvm::cl::cat(optFlags));
static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
llvm::cl::cat(optFlags));
static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
llvm::cl::cat(optFlags));
static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
llvm::cl::cat(optFlags));
static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
MLIRContext *context) {
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
}
// Initialize the relevant subsystems of LLVM.
static void initializeLLVM() {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
}
static inline Error make_string_error(const llvm::Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
static void printOneMemRef(Type t, void *val) {
auto memRefType = t.cast<MemRefType>();
auto shape = memRefType.getShape();
int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
for (int64_t i = 0; i < size; ++i) {
llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
}
llvm::outs() << '\n';
}
static void printMemRefArguments(ArrayRef<Type> argTypes,
ArrayRef<Type> resTypes,
ArrayRef<void *> args) {
auto properArgs = args.take_front(argTypes.size());
for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
auto type = std::get<0>(kvp);
auto val = std::get<1>(kvp);
printOneMemRef(type, val);
}
auto results = args.drop_front(argTypes.size());
for (const auto &kvp : llvm::zip(resTypes, results)) {
auto type = std::get<0>(kvp);
auto val = std::get<1>(kvp);
printOneMemRef(type, val);
}
}
static Error compileAndExecuteFunctionWithMemRefs(
Module *module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
Function *mainFunction = module->getNamedFunction(entryPoint);
if (!mainFunction || mainFunction->getBlocks().empty()) {
return make_string_error("entry point not found");
}
// Store argument and result types of the original function necessary to
// pretty print the results, because the function itself will be rewritten
// to use the LLVM dialect.
SmallVector<Type, 8> argTypes =
llvm::to_vector<8>(mainFunction->getType().getInputs());
SmallVector<Type, 8> resTypes =
llvm::to_vector<8>(mainFunction->getType().getResults());
float init = std::stof(initValue.getValue());
auto expectedArguments = allocateMemRefArguments(mainFunction, init);
if (!expectedArguments)
return expectedArguments.takeError();
auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(expectedArguments->data());
printMemRefArguments(argTypes, resTypes, *expectedArguments);
freeMemRefArguments(*expectedArguments);
return Error::success();
}
static Error compileAndExecuteSingleFloatReturnFunction(
Module *module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
Function *mainFunction = module->getNamedFunction(entryPoint);
if (!mainFunction || mainFunction->isExternal()) {
return make_string_error("entry point not found");
}
if (!mainFunction->getType().getInputs().empty())
return make_string_error("function inputs not supported");
if (mainFunction->getType().getResults().size() != 1)
return make_string_error("only single f32 function result supported");
auto t = mainFunction->getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
if (!t)
return make_string_error("only single llvm.f32 function result supported");
auto *llvmTy = t.getUnderlyingType();
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
return make_string_error("only single llvm.f32 function result supported");
auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
float res;
struct {
void *data;
} data;
data.data = &res;
(*fptr)((void **)&data);
// Intentional printing of the output so we can test.
llvm::outs() << res;
return Error::success();
}
int main(int argc, char **argv) {
llvm::PrettyStackTraceProgram x(argc, argv);
llvm::InitLLVM y(argc, argv);
initializeLLVM();
mlir::initializeLLVMPasses();
llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
optO0, optO1, optO2, optO3};
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
llvm::SmallVector<const llvm::PassInfo *, 4> passes;
llvm::Optional<unsigned> optLevel;
unsigned optCLIPosition = 0;
// Determine if there is an optimization flag present, and its CLI position
// (optCLIPosition).
for (unsigned j = 0; j < 4; ++j) {
auto &flag = optFlags[j].get();
if (flag) {
optLevel = j;
optCLIPosition = flag.getPosition();
break;
}
}
// Generate vector of pass information, plus the index at which we should
// insert any optimization passes in that vector (optPosition).
unsigned optPosition = 0;
for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
passes.push_back(llvmPasses[i]);
if (optCLIPosition < llvmPasses.getPosition(i)) {
optPosition = i;
optCLIPosition = UINT_MAX; // To ensure we never insert again
}
}
MLIRContext context;
auto m = parseMLIRInput(inputFilename, &context);
if (!m) {
llvm::errs() << "could not parse the input IR\n";
return 1;
}
auto transformer =
mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
auto error = mainFuncType.getValue() == "f32"
? compileAndExecuteSingleFloatReturnFunction(
m.get(), mainFuncName.getValue(), transformer)
: compileAndExecuteFunctionWithMemRefs(
m.get(), mainFuncName.getValue(), transformer);
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {
llvm::errs() << "Error: ";
info.log(llvm::errs());
llvm::errs() << '\n';
exitCode = EXIT_FAILURE;
});
return exitCode;
}
int main(int argc, char **argv) { return run(argc, argv); }