diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h new file mode 100644 index 000000000000..aefb049b1e6d --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -0,0 +1,92 @@ +//===- ExecutionEngine.h - MLIR Execution engine and utils -----*- C++ -*--===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file provides a JIT-backed execution engine for MLIR modules. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_ +#define MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Error.h" + +#include + +namespace llvm { +template class Expected; +} + +namespace mlir { + +class Module; + +namespace impl { +class OrcJIT; +} // end namespace impl + +/// JIT-backed execution engine for MLIR modules. Assumes the module can be +/// converted to LLVM IR. For each function, creates a wrapper function with +/// the fixed interface +/// +/// void _mlir_funcName(void **) +/// +/// where the only argument is interpreted as a list of pointers to the actual +/// arguments of the function, followed by a pointer to the result. This allows +/// the engine to provide the caller with a generic function pointer that can +/// be used to invoke the JIT-compiled function. +class ExecutionEngine { +public: + ~ExecutionEngine(); + + /// Creates an execution engine for the given module. + static llvm::Expected> create(Module *m); + + /// Looks up a packed-argument function with the given name and returns a + /// pointer to it. Propagates errors in case of failure. + llvm::Expected lookup(StringRef name) const; + + /// Invokes the function with the given name passing it the list of arguments. + /// The arguments are accepted by lvalue-reference since the packed function + /// interface expects a list of non-null pointers. + template + llvm::Error invoke(StringRef name, Args &... args); + +private: + // FIXME: we may want a `unique_ptr` here if impl::OrcJIT decides to provide + // a default constructor. + impl::OrcJIT *jit; + llvm::LLVMContext llvmContext; +}; + +template +llvm::Error ExecutionEngine::invoke(StringRef name, Args &... args) { + auto expectedFPtr = lookup(name); + if (!expectedFPtr) + return expectedFPtr.takeError(); + auto fptr = *expectedFPtr; + + llvm::SmallVector packedArgs{static_cast(&args)...}; + (*fptr)(packedArgs.data()); + + return llvm::Error::success(); +} + +} // end namespace mlir + +#endif // MLIR_EXECUTIONENGINE_EXECUTIONENGINE_H_ diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h new file mode 100644 index 000000000000..95e8686bcec9 --- /dev/null +++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h @@ -0,0 +1,55 @@ +//===- MemRefUtils.h - MLIR runtime utilities for memrefs -------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is a set of utilities to working with objects of memref type in an JIT +// context using the MLIR execution engine. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ +#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ + +#include "mlir/Support/LLVM.h" + +namespace llvm { +template class Expected; +} + +namespace mlir { + +class Function; + +/// Simple memref descriptor class compatible with the ABI of functions emitted +/// by MLIR to LLVM IR conversion for statically-shaped memrefs of float type. +struct StaticFloatMemRef { + float *data; +}; + +/// Given an MLIR function that takes only statically-shaped memrefs with +/// element type f32, allocate the memref descriptor and the data storage for +/// each of the arguments, initialize the storage with `initialValue`, and +/// return a list of type-erased descriptor pointers. +llvm::Expected> +allocateMemRefArguments(const Function *func, float initialValue = 0.0); + +/// Free a list of type-erased descriptors to statically-shaped memrefs with +/// element type f32. +void freeMemRefArguments(ArrayRef args); + +} // namespace mlir + +#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp new file mode 100644 index 000000000000..b35e13ab349b --- /dev/null +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -0,0 +1,299 @@ +//===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the execution engine for MLIR modules based on LLVM Orc +// JIT engine. +// +//===----------------------------------------------------------------------===// +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/TargetRegistry.h" + +using namespace mlir; +using llvm::Error; +using llvm::Expected; + +namespace { +// Memory manager for the JIT's objectLayer. Its main goal is to fallback to +// resolving functions in the current process if they cannot be resolved in the +// JIT-compiled modules. +class MemoryManager : public llvm::SectionMemoryManager { +public: + MemoryManager(llvm::orc::ExecutionSession &execSession) + : session(execSession) {} + + // Resolve the named symbol. First, try looking it up in the main library of + // the execution session. If there is no such symbol, try looking it up in + // the current process (for example, if it is a standard library function). + // Return `nullptr` if lookup fails. + llvm::JITSymbol findSymbol(const std::string &name) override { + auto mainLibSymbol = session.lookup({&session.getMainJITDylib()}, name); + if (mainLibSymbol) + return mainLibSymbol.get(); + auto address = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name); + if (!address) { + llvm::errs() << "Could not look up: " << name << '\n'; + return nullptr; + } + return llvm::JITSymbol(address, llvm::JITSymbolFlags::Exported); + } + +private: + llvm::orc::ExecutionSession &session; +}; +} // end anonymous namespace + +namespace mlir { +namespace impl { +// Simple layered Orc JIT compilation engine. +class OrcJIT { +public: + // Construct a JIT engine for the target host defined by `machineBuilder`, + // using the data layout provided as `dataLayout`. + // Setup the object layer to use our custom memory manager in order to resolve + // calls to library functions present in the process. + OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder, + llvm::DataLayout layout) + : objectLayer( + session, + [this]() { return llvm::make_unique(session); }), + compileLayer( + session, objectLayer, + llvm::orc::ConcurrentIRCompiler(std::move(machineBuilder))), + dataLayout(layout), mangler(session, this->dataLayout), + threadSafeCtx(llvm::make_unique()) { + session.getMainJITDylib().setGenerator( + cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + layout))); + } + + // Create a JIT engine for the current host. + static Expected> createDefault() { + auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!machineBuilder) + return machineBuilder.takeError(); + + auto dataLayout = machineBuilder->getDefaultDataLayoutForTarget(); + if (!dataLayout) + return dataLayout.takeError(); + + return llvm::make_unique(std::move(*machineBuilder), + std::move(*dataLayout)); + } + + // Add an LLVM module to the main library managed by the JIT engine. + Error addModule(std::unique_ptr M) { + return compileLayer.add( + session.getMainJITDylib(), + llvm::orc::ThreadSafeModule(std::move(M), threadSafeCtx)); + } + + // Lookup a symbol in the main library managed by the JIT engine. + Expected lookup(StringRef Name) { + return session.lookup({&session.getMainJITDylib()}, mangler(Name.str())); + } + +private: + llvm::orc::ExecutionSession session; + llvm::orc::RTDyldObjectLinkingLayer objectLayer; + llvm::orc::IRCompileLayer compileLayer; + llvm::DataLayout dataLayout; + llvm::orc::MangleAndInterner mangler; + llvm::orc::ThreadSafeContext threadSafeCtx; +}; +} // end namespace impl +} // namespace mlir + +// Wrap a string into an llvm::StringError. +static inline Error make_string_error(const llvm::Twine &message) { + return llvm::make_error(message.str(), + llvm::inconvertibleErrorCode()); +} + +// Given a list of PassInfo coming from a higher level, creates the passes to +// run as an owning vector and appends the extra required passes to lower to +// LLVMIR. Currently, these extra passes are: +// - constant folding +// - CSE +// - canonicalization +// - affine lowering +static std::vector> +getDefaultPasses(const std::vector &mlirPassInfoList) { + std::vector> passList; + passList.reserve(mlirPassInfoList.size() + 4); + // Run each of the passes that were selected. + for (const auto *passInfo : mlirPassInfoList) { + passList.emplace_back(passInfo->createPass()); + } + // Append the extra passes for lowering to MLIR. + passList.emplace_back(mlir::createConstantFoldPass()); + passList.emplace_back(mlir::createCSEPass()); + passList.emplace_back(mlir::createCanonicalizerPass()); + passList.emplace_back(mlir::createLowerAffinePass()); + return passList; +} + +// Run the passes sequentially on the given module. +// Return `nullptr` immediately if any of the passes fails. +static bool runPasses(const std::vector> &passes, + Module *module) { + for (const auto &pass : passes) { + mlir::PassResult result = pass->runOnModule(module); + if (result == mlir::PassResult::Failure || module->verify()) { + llvm::errs() << "Pass failed\n"; + return true; + } + } + return false; +} + +// Setup LLVM target triple from the current machine. +static bool setupTargetTriple(llvm::Module *llvmModule) { + // Setup the machine properties from the current architecture. + auto targetTriple = llvm::sys::getDefaultTargetTriple(); + std::string errorMessage; + auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage); + if (!target) { + llvm::errs() << "NO target: " << errorMessage << "\n"; + return true; + } + auto machine = + target->createTargetMachine(targetTriple, "generic", "", {}, {}); + llvmModule->setDataLayout(machine->createDataLayout()); + llvmModule->setTargetTriple(targetTriple); + return false; +} + +static std::string makePackedFunctionName(StringRef name) { + return "_mlir_" + name.str(); +} + +// For each function in the LLVM module, define an interface function that wraps +// all the arguments of the original function and all its results into an i8** +// pointer to provide a unified invocation interface. +void packFunctionArguments(llvm::Module *module) { + auto &ctx = module->getContext(); + llvm::IRBuilder<> builder(ctx); + llvm::DenseSet interfaceFunctions; + for (auto &func : module->getFunctionList()) { + if (func.isDeclaration()) { + continue; + } + if (interfaceFunctions.count(&func)) { + continue; + } + + // Given a function `foo(<...>)`, define the interface function + // `mlir_foo(i8**)`. + auto newType = llvm::FunctionType::get( + builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(), + /*isVarArg=*/false); + auto newName = makePackedFunctionName(func.getName()); + llvm::Constant *funcCst = module->getOrInsertFunction(newName, newType); + llvm::Function *interfaceFunc = llvm::cast(funcCst); + interfaceFunctions.insert(interfaceFunc); + + // Extract the arguments from the type-erased argument list and cast them to + // the proper types. + auto bb = llvm::BasicBlock::Create(ctx); + bb->insertInto(interfaceFunc); + builder.SetInsertPoint(bb); + llvm::Value *argList = interfaceFunc->arg_begin(); + llvm::SmallVector args; + args.reserve(llvm::size(func.args())); + for (auto &indexedArg : llvm::enumerate(func.args())) { + llvm::Value *argIndex = llvm::Constant::getIntegerValue( + builder.getInt64Ty(), llvm::APInt(64, indexedArg.index())); + llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex); + llvm::Value *argPtr = builder.CreateLoad(argPtrPtr); + argPtr = builder.CreateBitCast( + argPtr, indexedArg.value().getType()->getPointerTo()); + llvm::Value *arg = builder.CreateLoad(argPtr); + args.push_back(arg); + } + + // Call the implementation function with the extracted arguments. + llvm::Value *result = builder.CreateCall(&func, args); + + // Assuming the result is one value, potentially of type `void`. + if (!result->getType()->isVoidTy()) { + llvm::Value *retIndex = llvm::Constant::getIntegerValue( + builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args()))); + llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex); + llvm::Value *retPtr = builder.CreateLoad(retPtrPtr); + retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo()); + builder.CreateStore(result, retPtr); + } + + // The interface function returns void. + builder.CreateRetVoid(); + } +} + +ExecutionEngine::~ExecutionEngine() { + if (jit) + delete jit; +} + +Expected> ExecutionEngine::create(Module *m) { + auto engine = llvm::make_unique(); + auto expectedJIT = impl::OrcJIT::createDefault(); + if (!expectedJIT) + return expectedJIT.takeError(); + + if (runPasses(getDefaultPasses({}), m)) + return make_string_error("passes failed"); + + auto llvmModule = convertModuleToLLVMIR(*m, engine->llvmContext); + if (!llvmModule) + return make_string_error("could not convert to LLVM IR"); + // FIXME: the triple should be passed to the translation or dialect conversion + // instead of this. Currently, the LLVM module created above has no triple + // associated with it. + setupTargetTriple(llvmModule.get()); + packFunctionArguments(llvmModule.get()); + + engine->jit = std::move(*expectedJIT).release(); + if (auto err = engine->jit->addModule(std::move(llvmModule))) + return std::move(err); + + return engine; +} + +Expected ExecutionEngine::lookup(StringRef name) const { + auto expectedSymbol = jit->lookup(makePackedFunctionName(name)); + if (!expectedSymbol) + return expectedSymbol.takeError(); + auto rawFPtr = expectedSymbol->getAddress(); + auto fptr = reinterpret_cast(rawFPtr); + if (!fptr) + return make_string_error("looked up function is null"); + return fptr; +} diff --git a/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/mlir/lib/ExecutionEngine/MemRefUtils.cpp new file mode 100644 index 000000000000..b2b301d9d20d --- /dev/null +++ b/mlir/lib/ExecutionEngine/MemRefUtils.cpp @@ -0,0 +1,106 @@ +//===- MemRefUtils.cpp - MLIR runtime utilities for memrefs ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is a set of utilities to working with objects of memref type in an JIT +// context using the MLIR execution engine. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/MemRefUtils.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/Support/Error.h" + +using namespace mlir; + +static inline llvm::Error make_string_error(const llvm::Twine &message) { + return llvm::make_error(message.str(), + llvm::inconvertibleErrorCode()); +} + +static llvm::Expected +allocMemRefDescriptor(Type type, bool allocateData = true, + float initialValue = 0.0) { + auto memRefType = type.dyn_cast(); + if (!memRefType) + return make_string_error("non-memref argument not supported"); + if (memRefType.getNumDynamicDims() != 0) + return make_string_error("memref with dynamic shapes not supported"); + + auto elementType = memRefType.getElementType(); + if (!elementType.isF32()) + return make_string_error( + "memref with element other than f32 not supported"); + + auto *descriptor = + reinterpret_cast(malloc(sizeof(StaticFloatMemRef))); + if (!allocateData) { + descriptor->data = nullptr; + return descriptor; + } + + auto shape = memRefType.getShape(); + int64_t size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + descriptor->data = reinterpret_cast(malloc(sizeof(float) * size)); + for (int64_t i = 0; i < size; ++i) { + descriptor->data[i] = initialValue; + } + return descriptor; +} + +llvm::Expected> +mlir::allocateMemRefArguments(const Function *func, float initialValue) { + SmallVector args; + args.reserve(func->getNumArguments()); + for (const auto &arg : func->getArguments()) { + auto descriptor = + allocMemRefDescriptor(arg->getType(), + /*allocateData=*/true, initialValue); + if (!descriptor) + return descriptor.takeError(); + args.push_back(*descriptor); + } + + if (func->getType().getNumResults() > 1) + return make_string_error("functions with more than 1 result not supported"); + + for (Type resType : func->getType().getResults()) { + auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false); + if (!descriptor) + return descriptor.takeError(); + args.push_back(*descriptor); + } + + return args; +} + +// Because the function can return the same descriptor as passed in arguments, +// we check that we don't attempt to free the underlying data twice. +void mlir::freeMemRefArguments(ArrayRef args) { + llvm::DenseSet dataPointers; + for (void *arg : args) { + float *dataPtr = reinterpret_cast(arg)->data; + if (dataPointers.count(dataPtr) == 0) { + free(dataPtr); + dataPointers.insert(dataPtr); + } + free(arg); + } +} diff --git a/mlir/test/mlir-cpu-runner/simple.mlir b/mlir/test/mlir-cpu-runner/simple.mlir new file mode 100644 index 000000000000..8ad76ccac36d --- /dev/null +++ b/mlir/test/mlir-cpu-runner/simple.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-cpu-runner %s | FileCheck %s +// RUN: mlir-cpu-runner -e foo -init-value 1000 %s | FileCheck -check-prefix=NOMAIN %s + +func @fabsf(f32) -> f32 + +func @main(%a : memref<2xf32>, %b : memref<1xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = constant -420.0 : f32 + %1 = load %a[%c0] : memref<2xf32> + %2 = load %a[%c1] : memref<2xf32> + %3 = addf %0, %1 : f32 + %4 = addf %3, %2 : f32 + %5 = call @fabsf(%4) : (f32) -> f32 + store %5, %b[%c0] : memref<1xf32> + return +} +// CHECK: 0.000000e+00 0.000000e+00 +// CHECK-NEXT: 4.200000e+02 + +func @foo(%a : memref<1x1xf32>) -> memref<1x1xf32> { + %c0 = constant 0 : index + %0 = constant 1234.0 : f32 + %1 = load %a[%c0, %c0] : memref<1x1xf32> + %2 = addf %1, %0 : f32 + store %2, %a[%c0, %c0] : memref<1x1xf32> + return %a : memref<1x1xf32> +} +// NOMAIN: 2.234000e+03 +// NOMAIN-NEXT: 2.234000e+03 \ No newline at end of file diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp new file mode 100644 index 000000000000..f338c1386f7f --- /dev/null +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -0,0 +1,162 @@ +//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is a command line utility that executes an MLIR file on the CPU by +// translating MLIR to LLVM IR before JIT-compiling and executing the latter. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/MemRefUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Parser.h" +#include "mlir/Support/FileUtilities.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; +using llvm::Error; + +static llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); +static llvm::cl::opt + initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"), + llvm::cl::value_desc(""), llvm::cl::init("0.0")); +static llvm::cl::opt + mainFuncName("e", llvm::cl::desc("The function to be called"), + llvm::cl::value_desc(""), + llvm::cl::init("main")); + +static std::unique_ptr parseMLIRInput(StringRef inputFilename, + MLIRContext *context) { + // Set up the input file. + std::string errorMessage; + auto file = openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return nullptr; + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + return std::unique_ptr(parseSourceFile(sourceMgr, context)); +} + +// Initialize the relevant subsystems of LLVM. +static void initializeLLVM() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); +} + +static inline Error make_string_error(const llvm::Twine &message) { + return llvm::make_error(message.str(), + llvm::inconvertibleErrorCode()); +} + +static void printOneMemRef(Type t, void *val) { + auto memRefType = t.cast(); + auto shape = memRefType.getShape(); + int64_t size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + for (int64_t i = 0; i < size; ++i) { + llvm::outs() << reinterpret_cast(val)->data[i] << ' '; + } + llvm::outs() << '\n'; +} + +static void printMemRefArguments(const Function *func, ArrayRef args) { + auto properArgs = args.take_front(func->getNumArguments()); + for (const auto &kvp : llvm::zip(func->getArguments(), properArgs)) { + auto arg = std::get<0>(kvp); + auto val = std::get<1>(kvp); + printOneMemRef(arg->getType(), val); + } + + auto results = args.drop_front(func->getNumArguments()); + for (const auto &kvp : llvm::zip(func->getType().getResults(), results)) { + auto type = std::get<0>(kvp); + auto val = std::get<1>(kvp); + printOneMemRef(type, val); + } +} + +static Error compileAndExecute(Module *module, StringRef entryPoint) { + Function *mainFunction = module->getNamedFunction(entryPoint); + if (!mainFunction || mainFunction->getBlocks().empty()) { + return make_string_error("entry point not found"); + } + + float init = std::stof(initValue.getValue()); + + auto expectedArguments = allocateMemRefArguments(mainFunction, init); + if (!expectedArguments) + return expectedArguments.takeError(); + + auto expectedEngine = mlir::ExecutionEngine::create(module); + if (!expectedEngine) + return expectedEngine.takeError(); + + auto engine = std::move(*expectedEngine); + auto expectedFPtr = engine->lookup(entryPoint); + if (!expectedFPtr) + return expectedFPtr.takeError(); + void (*fptr)(void **) = *expectedFPtr; + (*fptr)(expectedArguments->data()); + printMemRefArguments(mainFunction, *expectedArguments); + freeMemRefArguments(*expectedArguments); + + return Error::success(); +} + +int main(int argc, char **argv) { + llvm::PrettyStackTraceProgram x(argc, argv); + llvm::InitLLVM y(argc, argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); + + initializeLLVM(); + + MLIRContext context; + auto m = parseMLIRInput(inputFilename, &context); + if (!m) { + llvm::errs() << "could not parse the input IR\n"; + return 1; + } + auto error = compileAndExecute(m.get(), mainFuncName.getValue()); + int exitCode = EXIT_SUCCESS; + llvm::handleAllErrors(std::move(error), + [&exitCode](const llvm::ErrorInfoBase &info) { + llvm::errs() << "Error: "; + info.log(llvm::errs()); + llvm::errs() << '\n'; + exitCode = EXIT_FAILURE; + }); + + return exitCode; +}