forked from OSchip/llvm-project
NFC: Refactor Function to be value typed.
Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed). PiperOrigin-RevId: 255983022
This commit is contained in:
parent
84bd67fc4f
commit
54cd6a7e97
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -110,13 +111,14 @@ struct PythonValueHandle {
|
|||
struct PythonFunction {
|
||||
PythonFunction() : function{nullptr} {}
|
||||
PythonFunction(mlir_func_t f) : function{f} {}
|
||||
PythonFunction(mlir::Function *f) : function{f} {}
|
||||
PythonFunction(mlir::Function f)
|
||||
: function(const_cast<void *>(f.getAsOpaquePointer())) {}
|
||||
operator mlir_func_t() { return function; }
|
||||
std::string str() {
|
||||
mlir::Function *f = reinterpret_cast<mlir::Function *>(function);
|
||||
mlir::Function f = mlir::Function::getFromOpaquePointer(function);
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
f->print(os);
|
||||
f.print(os);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -124,18 +126,18 @@ struct PythonFunction {
|
|||
// declaration, add the entry block, transforming the declaration into a
|
||||
// definition. Return true if the block was added, false otherwise.
|
||||
bool define() {
|
||||
auto *f = reinterpret_cast<mlir::Function *>(function);
|
||||
if (!f->getBlocks().empty())
|
||||
auto f = mlir::Function::getFromOpaquePointer(function);
|
||||
if (!f.getBlocks().empty())
|
||||
return false;
|
||||
|
||||
f->addEntryBlock();
|
||||
f.addEntryBlock();
|
||||
return true;
|
||||
}
|
||||
|
||||
PythonValueHandle arg(unsigned index) {
|
||||
Function *f = static_cast<Function *>(function);
|
||||
assert(index < f->getNumArguments() && "argument index out of bounds");
|
||||
return PythonValueHandle(ValueHandle(f->getArgument(index)));
|
||||
auto f = mlir::Function::getFromOpaquePointer(function);
|
||||
assert(index < f.getNumArguments() && "argument index out of bounds");
|
||||
return PythonValueHandle(ValueHandle(f.getArgument(index)));
|
||||
}
|
||||
|
||||
mlir_func_t function;
|
||||
|
@ -250,10 +252,9 @@ struct PythonFunctionContext {
|
|||
|
||||
PythonFunction enter() {
|
||||
assert(function.function && "function is not set up");
|
||||
auto *mlirFunc = static_cast<mlir::Function *>(function.function);
|
||||
contextBuilder.emplace(mlirFunc->getBody());
|
||||
context =
|
||||
new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc());
|
||||
auto mlirFunc = mlir::Function::getFromOpaquePointer(function.function);
|
||||
contextBuilder.emplace(mlirFunc.getBody());
|
||||
context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc());
|
||||
return function;
|
||||
}
|
||||
|
||||
|
@ -594,7 +595,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
|
|||
}
|
||||
|
||||
// Create the function itself.
|
||||
auto *func = new mlir::Function(
|
||||
auto func = mlir::Function::create(
|
||||
UnknownLoc::get(&mlirContext), name,
|
||||
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
|
||||
inputAttrs);
|
||||
|
@ -652,9 +653,9 @@ PYBIND11_MODULE(pybind, m) {
|
|||
return ValueHandle::create<ConstantFloatOp>(value, floatType);
|
||||
});
|
||||
m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
|
||||
auto *function = reinterpret_cast<Function *>(func.function);
|
||||
auto attr = FunctionAttr::get(function);
|
||||
return ValueHandle::create<ConstantOp>(function->getType(), attr);
|
||||
auto function = Function::getFromOpaquePointer(func.function);
|
||||
auto attr = FunctionAttr::get(function.getName(), function.getContext());
|
||||
return ValueHandle::create<ConstantOp>(function.getType(), attr);
|
||||
});
|
||||
m.def("appendTo", [](const PythonBlockHandle &handle) {
|
||||
return PythonBlockAppender(handle);
|
||||
|
|
|
@ -57,15 +57,15 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
|
|||
}
|
||||
|
||||
/// A basic function builder
|
||||
inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name,
|
||||
llvm::ArrayRef<mlir::Type> types,
|
||||
llvm::ArrayRef<mlir::Type> resultTypes) {
|
||||
inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name,
|
||||
llvm::ArrayRef<mlir::Type> types,
|
||||
llvm::ArrayRef<mlir::Type> resultTypes) {
|
||||
auto *context = module.getContext();
|
||||
auto *function = new mlir::Function(
|
||||
auto function = mlir::Function::create(
|
||||
mlir::UnknownLoc::get(context), name,
|
||||
mlir::FunctionType::get({types}, resultTypes, context));
|
||||
function->addEntryBlock();
|
||||
module.getFunctions().push_back(function);
|
||||
function.addEntryBlock();
|
||||
module.push_back(function);
|
||||
return function;
|
||||
}
|
||||
|
||||
|
@ -83,19 +83,19 @@ inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
|
|||
/// llvm::outs() for FileCheck'ing.
|
||||
/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
|
||||
/// which will make the associated FileCheck test fail.
|
||||
inline void cleanupAndPrintFunction(mlir::Function *f) {
|
||||
inline void cleanupAndPrintFunction(mlir::Function f) {
|
||||
bool printToOuts = true;
|
||||
auto check = [f, &printToOuts](mlir::LogicalResult result) {
|
||||
auto check = [&f, &printToOuts](mlir::LogicalResult result) {
|
||||
if (failed(result)) {
|
||||
f->emitError("Verification and cleanup passes failed");
|
||||
f.emitError("Verification and cleanup passes failed");
|
||||
printToOuts = false;
|
||||
}
|
||||
};
|
||||
auto pm = cleanupPassManager();
|
||||
check(f->getModule()->verify());
|
||||
check(pm->run(f->getModule()));
|
||||
check(f.getModule()->verify());
|
||||
check(pm->run(f.getModule()));
|
||||
if (printToOuts)
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
}
|
||||
|
||||
/// Helper class to sugar building loop nests from indexings that appear in
|
||||
|
|
|
@ -36,14 +36,14 @@ TEST_FUNC(linalg_ops) {
|
|||
MLIRContext context;
|
||||
Module module(&context);
|
||||
auto indexType = mlir::IndexType::get(&context);
|
||||
mlir::Function *f =
|
||||
mlir::Function f =
|
||||
makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
// clang-format off
|
||||
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
|
||||
ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
|
@ -75,14 +75,14 @@ TEST_FUNC(linalg_ops_folded_slices) {
|
|||
MLIRContext context;
|
||||
Module module(&context);
|
||||
auto indexType = mlir::IndexType::get(&context);
|
||||
mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices",
|
||||
{indexType, indexType, indexType}, {});
|
||||
mlir::Function f = makeFunction(module, "linalg_ops_folded_slices",
|
||||
{indexType, indexType, indexType}, {});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
// clang-format off
|
||||
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
|
||||
ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
|
@ -104,7 +104,7 @@ TEST_FUNC(linalg_ops_folded_slices) {
|
|||
// CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32>
|
||||
// clang-format on
|
||||
|
||||
f->walk<SliceOp>([](SliceOp slice) {
|
||||
f.walk<SliceOp>([](SliceOp slice) {
|
||||
auto *sliceResult = slice.getResult();
|
||||
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
|
||||
sliceResult->replaceAllUsesWith(viewOp.getResult());
|
||||
|
|
|
@ -37,26 +37,26 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function *f = linalg::common::makeFunction(
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f->getArgument(0), 0),
|
||||
N = dim(f->getArgument(2), 1),
|
||||
K = dim(f->getArgument(0), 1),
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f->getArgument(0), {rM, rK}),
|
||||
vB = view(f->getArgument(1), {rK, rN}),
|
||||
vC = view(f->getArgument(2), {rM, rN});
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
@ -67,7 +67,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
|||
TEST_FUNC(foo) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
|
||||
convertLinalg3ToLLVM(module);
|
||||
|
|
|
@ -34,26 +34,26 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function *f = linalg::common::makeFunction(
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
mlir::OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
mlir::OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f->getArgument(0), 0),
|
||||
N = dim(f->getArgument(2), 1),
|
||||
K = dim(f->getArgument(0), 1),
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f->getArgument(0), {rM, rK}),
|
||||
vB = view(f->getArgument(1), {rK, rN}),
|
||||
vC = view(f->getArgument(2), {rM, rN});
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
@ -64,7 +64,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
|||
TEST_FUNC(matmul_as_matvec) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
|
@ -82,7 +82,7 @@ TEST_FUNC(matmul_as_matvec) {
|
|||
TEST_FUNC(matmul_as_dot) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
|
@ -103,7 +103,7 @@ TEST_FUNC(matmul_as_dot) {
|
|||
TEST_FUNC(matmul_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
composeSliceOps(f);
|
||||
// clang-format off
|
||||
|
@ -135,7 +135,7 @@ TEST_FUNC(matmul_as_loops) {
|
|||
TEST_FUNC(matmul_as_matvec_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f =
|
||||
mlir::Function f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
lowerToLoops(f);
|
||||
|
@ -166,14 +166,14 @@ TEST_FUNC(matmul_as_matvec_as_loops) {
|
|||
TEST_FUNC(matmul_as_matvec_as_affine) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f =
|
||||
mlir::Function f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine");
|
||||
lowerToFinerGrainedTensorContraction(f);
|
||||
composeSliceOps(f);
|
||||
lowerToLoops(f);
|
||||
PassManager pm;
|
||||
pm.addPass(createLowerLinalgLoadStorePass());
|
||||
if (succeeded(pm.run(f->getModule())))
|
||||
if (succeeded(pm.run(f.getModule())))
|
||||
cleanupAndPrintFunction(f);
|
||||
|
||||
// clang-format off
|
||||
|
|
|
@ -37,26 +37,26 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function *f = linalg::common::makeFunction(
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
mlir::OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
mlir::OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f->getArgument(0), 0),
|
||||
N = dim(f->getArgument(2), 1),
|
||||
K = dim(f->getArgument(0), 1),
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f->getArgument(0), {rM, rK}),
|
||||
vB = view(f->getArgument(1), {rK, rN}),
|
||||
vC = view(f->getArgument(2), {rM, rN});
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
@ -110,7 +110,7 @@ TEST_FUNC(execution) {
|
|||
// dialect through partial conversions.
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
|
||||
lowerToLoops(f);
|
||||
convertLinalg3ToLLVM(module);
|
||||
|
||||
|
|
|
@ -55,11 +55,11 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps,
|
|||
|
||||
/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
|
||||
/// to only use linalg.view operations.
|
||||
void composeSliceOps(mlir::Function *f);
|
||||
void composeSliceOps(mlir::Function f);
|
||||
|
||||
/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec)
|
||||
/// as linalg.matvec(resp. linalg.dot).
|
||||
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
|
||||
void lowerToFinerGrainedTensorContraction(mlir::Function f);
|
||||
|
||||
/// Operation-wise writing of linalg operations to loop form.
|
||||
/// It is the caller's responsibility to erase the `op` if necessary.
|
||||
|
@ -69,7 +69,7 @@ llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>>
|
|||
writeAsLoops(mlir::Operation *op);
|
||||
|
||||
/// Traverses `f` and rewrites linalg operations in loop form.
|
||||
void lowerToLoops(mlir::Function *f);
|
||||
void lowerToLoops(mlir::Function f);
|
||||
|
||||
/// Creates a pass that rewrites linalg.load and linalg.store to affine.load and
|
||||
/// affine.store operations.
|
||||
|
|
|
@ -148,7 +148,7 @@ static void populateLinalg3ToLLVMConversionPatterns(
|
|||
|
||||
void linalg::convertLinalg3ToLLVM(Module &module) {
|
||||
// Remove affine constructs.
|
||||
for (auto &func : module) {
|
||||
for (auto func : module) {
|
||||
auto rr = lowerAffineConstructs(func);
|
||||
(void)rr;
|
||||
assert(succeeded(rr) && "affine loop lowering failed");
|
||||
|
|
|
@ -35,8 +35,8 @@ using namespace mlir::edsc::intrinsics;
|
|||
using namespace linalg;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
void linalg::composeSliceOps(mlir::Function *f) {
|
||||
f->walk<SliceOp>([](SliceOp sliceOp) {
|
||||
void linalg::composeSliceOps(mlir::Function f) {
|
||||
f.walk<SliceOp>([](SliceOp sliceOp) {
|
||||
auto *sliceResult = sliceOp.getResult();
|
||||
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
|
||||
sliceResult->replaceAllUsesWith(viewOp.getResult());
|
||||
|
@ -44,8 +44,8 @@ void linalg::composeSliceOps(mlir::Function *f) {
|
|||
});
|
||||
}
|
||||
|
||||
void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
|
||||
f->walk([](Operation *op) {
|
||||
void linalg::lowerToFinerGrainedTensorContraction(mlir::Function f) {
|
||||
f.walk([](Operation *op) {
|
||||
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
||||
matmulOp.writeAsFinerGrainTensorContraction();
|
||||
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
||||
|
@ -211,8 +211,8 @@ linalg::writeAsLoops(Operation *op) {
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
void linalg::lowerToLoops(mlir::Function *f) {
|
||||
f->walk([](Operation *op) {
|
||||
void linalg::lowerToLoops(mlir::Function f) {
|
||||
f.walk([](Operation *op) {
|
||||
if (writeAsLoops(op))
|
||||
op->erase();
|
||||
});
|
||||
|
|
|
@ -34,27 +34,27 @@ using namespace linalg;
|
|||
using namespace linalg::common;
|
||||
using namespace linalg::intrinsics;
|
||||
|
||||
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
||||
MLIRContext *context = module.getContext();
|
||||
auto dynamic2DMemRefType = floatMemRefType<2>(context);
|
||||
mlir::Function *f = linalg::common::makeFunction(
|
||||
mlir::Function f = linalg::common::makeFunction(
|
||||
module, name,
|
||||
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
||||
// clang-format off
|
||||
ValueHandle
|
||||
M = dim(f->getArgument(0), 0),
|
||||
N = dim(f->getArgument(2), 1),
|
||||
K = dim(f->getArgument(0), 1),
|
||||
M = dim(f.getArgument(0), 0),
|
||||
N = dim(f.getArgument(2), 1),
|
||||
K = dim(f.getArgument(0), 1),
|
||||
rM = range(constant_index(0), M, constant_index(1)),
|
||||
rN = range(constant_index(0), N, constant_index(1)),
|
||||
rK = range(constant_index(0), K, constant_index(1)),
|
||||
vA = view(f->getArgument(0), {rM, rK}),
|
||||
vB = view(f->getArgument(1), {rK, rN}),
|
||||
vC = view(f->getArgument(2), {rM, rN});
|
||||
vA = view(f.getArgument(0), {rM, rK}),
|
||||
vB = view(f.getArgument(1), {rK, rN}),
|
||||
vC = view(f.getArgument(2), {rM, rN});
|
||||
matmul(vA, vB, vC);
|
||||
ret();
|
||||
// clang-format on
|
||||
|
@ -65,11 +65,11 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
|
|||
TEST_FUNC(matmul_tiled_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
|
||||
lowerToTiledLoops(f, {8, 9});
|
||||
PassManager pm;
|
||||
pm.addPass(createLowerLinalgLoadStorePass());
|
||||
if (succeeded(pm.run(f->getModule())))
|
||||
if (succeeded(pm.run(f.getModule())))
|
||||
cleanupAndPrintFunction(f);
|
||||
|
||||
// clang-format off
|
||||
|
@ -96,10 +96,10 @@ TEST_FUNC(matmul_tiled_loops) {
|
|||
TEST_FUNC(matmul_tiled_views) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
|
||||
OpBuilder b(f->getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f->getLoc(), 9)});
|
||||
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
|
||||
OpBuilder b(f.getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f.getLoc(), 9)});
|
||||
composeSliceOps(f);
|
||||
|
||||
// clang-format off
|
||||
|
@ -125,11 +125,11 @@ TEST_FUNC(matmul_tiled_views) {
|
|||
TEST_FUNC(matmul_tiled_views_as_loops) {
|
||||
MLIRContext context;
|
||||
Module module(&context);
|
||||
mlir::Function *f =
|
||||
mlir::Function f =
|
||||
makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops");
|
||||
OpBuilder b(f->getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f->getLoc(), 9)});
|
||||
OpBuilder b(f.getBody());
|
||||
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
|
||||
b.create<ConstantIndexOp>(f.getLoc(), 9)});
|
||||
composeSliceOps(f);
|
||||
lowerToLoops(f);
|
||||
// This cannot lower below linalg.load and linalg.store due to lost
|
||||
|
|
|
@ -34,12 +34,12 @@ writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef<mlir::Value *> tileSizes);
|
|||
/// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function
|
||||
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
|
||||
/// in a function can generally not be specified.
|
||||
void lowerToTiledLoops(mlir::Function *f, llvm::ArrayRef<uint64_t> tileSizes);
|
||||
void lowerToTiledLoops(mlir::Function f, llvm::ArrayRef<uint64_t> tileSizes);
|
||||
|
||||
/// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function
|
||||
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
|
||||
/// in a function can generally not be specified.
|
||||
void lowerToTiledViews(mlir::Function *f,
|
||||
void lowerToTiledViews(mlir::Function f,
|
||||
llvm::ArrayRef<mlir::Value *> tileSizes);
|
||||
|
||||
} // namespace linalg
|
||||
|
|
|
@ -43,9 +43,8 @@ linalg::writeAsTiledLoops(Operation *op, ArrayRef<uint64_t> tileSizes) {
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
void linalg::lowerToTiledLoops(mlir::Function *f,
|
||||
ArrayRef<uint64_t> tileSizes) {
|
||||
f->walk([tileSizes](Operation *op) {
|
||||
void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef<uint64_t> tileSizes) {
|
||||
f.walk([tileSizes](Operation *op) {
|
||||
if (writeAsTiledLoops(op, tileSizes).hasValue())
|
||||
op->erase();
|
||||
});
|
||||
|
@ -185,8 +184,8 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) {
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef<Value *> tileSizes) {
|
||||
f->walk([tileSizes](Operation *op) {
|
||||
void linalg::lowerToTiledViews(mlir::Function f, ArrayRef<Value *> tileSizes) {
|
||||
f.walk([tileSizes](Operation *op) {
|
||||
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
|
||||
writeAsTiledViews(matmulOp, tileSizes);
|
||||
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
|
||||
|
|
|
@ -75,7 +75,7 @@ public:
|
|||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->getFunctions().push_back(func.release());
|
||||
theModule->push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
|
@ -129,40 +129,40 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::Function *mlirGen(PrototypeAST &proto) {
|
||||
mlir::Function mlirGen(PrototypeAST &proto) {
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
// Arguments type is uniformly a generic array.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
|
||||
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
|
||||
// Mark the function as generic: it'll require type specialization for every
|
||||
// call site.
|
||||
if (function->getNumArguments())
|
||||
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
if (function.getNumArguments())
|
||||
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
|
||||
mlir::Function mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
|
||||
mlir::Function function(mlirGen(*funcAST.getProto()));
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
function->addEntryBlock();
|
||||
function.addEntryBlock();
|
||||
|
||||
auto &entryBlock = function->front();
|
||||
auto &entryBlock = function.front();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
@ -172,16 +172,18 @@ private:
|
|||
|
||||
// Create a builder for the function, it will be used throughout the codegen
|
||||
// to create operations in this function.
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
|
||||
|
||||
// Emit the body of the function.
|
||||
if (!mlirGen(*funcAST.getBody()))
|
||||
if (!mlirGen(*funcAST.getBody())) {
|
||||
function.erase();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Implicitly return void if no return statement was emitted.
|
||||
// FIXME: we may fix the parser instead to always return the last expression
|
||||
// (this would possibly help the REPL case later)
|
||||
if (function->getBlocks().back().back().getName().getStringRef() !=
|
||||
if (function.getBlocks().back().back().getName().getStringRef() !=
|
||||
"toy.return") {
|
||||
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
|
||||
mlirGen(fakeRet);
|
||||
|
|
|
@ -76,7 +76,7 @@ public:
|
|||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->getFunctions().push_back(func.release());
|
||||
theModule->push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
|
@ -130,40 +130,40 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::Function *mlirGen(PrototypeAST &proto) {
|
||||
mlir::Function mlirGen(PrototypeAST &proto) {
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
// Arguments type is uniformly a generic array.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
|
||||
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
|
||||
// Mark the function as generic: it'll require type specialization for every
|
||||
// call site.
|
||||
if (function->getNumArguments())
|
||||
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
if (function.getNumArguments())
|
||||
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
|
||||
mlir::Function mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
|
||||
mlir::Function function(mlirGen(*funcAST.getProto()));
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
function->addEntryBlock();
|
||||
function.addEntryBlock();
|
||||
|
||||
auto &entryBlock = function->front();
|
||||
auto &entryBlock = function.front();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
@ -173,16 +173,18 @@ private:
|
|||
|
||||
// Create a builder for the function, it will be used throughout the codegen
|
||||
// to create operations in this function.
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
|
||||
|
||||
// Emit the body of the function.
|
||||
if (!mlirGen(*funcAST.getBody()))
|
||||
if (!mlirGen(*funcAST.getBody())) {
|
||||
function.erase();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Implicitly return void if no return statement was emitted.
|
||||
// FIXME: we may fix the parser instead to always return the last expression
|
||||
// (this would possibly help the REPL case later)
|
||||
if (function->getBlocks().back().back().getName().getStringRef() !=
|
||||
if (function.getBlocks().back().back().getName().getStringRef() !=
|
||||
"toy.return") {
|
||||
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
|
||||
mlirGen(fakeRet);
|
||||
|
|
|
@ -76,7 +76,7 @@ public:
|
|||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->getFunctions().push_back(func.release());
|
||||
theModule->push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
|
@ -130,40 +130,40 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::Function *mlirGen(PrototypeAST &proto) {
|
||||
mlir::Function mlirGen(PrototypeAST &proto) {
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
// Arguments type is uniformly a generic array.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
|
||||
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
|
||||
// Mark the function as generic: it'll require type specialization for every
|
||||
// call site.
|
||||
if (function->getNumArguments())
|
||||
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
if (function.getNumArguments())
|
||||
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
|
||||
mlir::Function mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
|
||||
mlir::Function function(mlirGen(*funcAST.getProto()));
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
function->addEntryBlock();
|
||||
function.addEntryBlock();
|
||||
|
||||
auto &entryBlock = function->front();
|
||||
auto &entryBlock = function.front();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
@ -173,16 +173,18 @@ private:
|
|||
|
||||
// Create a builder for the function, it will be used throughout the codegen
|
||||
// to create operations in this function.
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
|
||||
|
||||
// Emit the body of the function.
|
||||
if (!mlirGen(*funcAST.getBody()))
|
||||
if (!mlirGen(*funcAST.getBody())) {
|
||||
function.erase();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Implicitly return void if no return statement was emited.
|
||||
// FIXME: we may fix the parser instead to always return the last expression
|
||||
// (this would possibly help the REPL case later)
|
||||
if (function->getBlocks().back().back().getName().getStringRef() !=
|
||||
if (function.getBlocks().back().back().getName().getStringRef() !=
|
||||
"toy.return") {
|
||||
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
|
||||
mlirGen(fakeRet);
|
||||
|
|
|
@ -113,14 +113,14 @@ public:
|
|||
// function to process, the mangled name for this specialization, and the
|
||||
// types of the arguments on which to specialize.
|
||||
struct FunctionToSpecialize {
|
||||
mlir::Function *function;
|
||||
mlir::Function function;
|
||||
std::string mangledName;
|
||||
SmallVector<mlir::Type, 4> argumentsType;
|
||||
};
|
||||
|
||||
void runOnModule() override {
|
||||
auto &module = getModule();
|
||||
auto *main = module.getNamedFunction("main");
|
||||
auto main = module.getNamedFunction("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
"Shape inference failed: can't find a main function\n");
|
||||
|
@ -139,7 +139,7 @@ public:
|
|||
|
||||
// Delete any generic function left
|
||||
// FIXME: we may want this as a separate pass.
|
||||
for (mlir::Function &function : llvm::make_early_inc_range(module)) {
|
||||
for (mlir::Function function : llvm::make_early_inc_range(module)) {
|
||||
if (auto genericAttr =
|
||||
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
|
||||
if (genericAttr.getValue())
|
||||
|
@ -153,7 +153,7 @@ public:
|
|||
mlir::LogicalResult
|
||||
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
|
||||
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
|
||||
mlir::Function *f = functionToSpecialize.function;
|
||||
mlir::Function f = functionToSpecialize.function;
|
||||
|
||||
// Check if cloning for specialization is needed (usually anything but main)
|
||||
// We will create a new function with the concrete types for the parameters
|
||||
|
@ -169,36 +169,36 @@ public:
|
|||
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
|
||||
{ToyArrayType::get(&getContext())},
|
||||
&getContext());
|
||||
auto *newFunction = new mlir::Function(
|
||||
f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
|
||||
getModule().getFunctions().push_back(newFunction);
|
||||
auto newFunction = mlir::Function::create(
|
||||
f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs());
|
||||
getModule().push_back(newFunction);
|
||||
|
||||
// Clone the function body
|
||||
mlir::BlockAndValueMapping mapper;
|
||||
f->cloneInto(newFunction, mapper);
|
||||
f.cloneInto(newFunction, mapper);
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "====== Cloned : \n";
|
||||
f->dump();
|
||||
f.dump();
|
||||
llvm::dbgs() << "====== Into : \n";
|
||||
newFunction->dump();
|
||||
newFunction.dump();
|
||||
});
|
||||
f = newFunction;
|
||||
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
// Remap the entry-block arguments
|
||||
// FIXME: this seems like a bug in `cloneInto()` above?
|
||||
auto &entryBlock = f->getBlocks().front();
|
||||
auto &entryBlock = f.getBlocks().front();
|
||||
int blockArgSize = entryBlock.getArguments().size();
|
||||
assert(blockArgSize == static_cast<int>(f->getType().getInputs().size()));
|
||||
entryBlock.addArguments(f->getType().getInputs());
|
||||
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
|
||||
entryBlock.addArguments(f.getType().getInputs());
|
||||
auto argList = entryBlock.getArguments();
|
||||
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
|
||||
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
|
||||
entryBlock.eraseArgument(0);
|
||||
}
|
||||
assert(succeeded(f->verify()));
|
||||
assert(succeeded(f.verify()));
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Run shape inference on : '" << f->getName() << "'\n");
|
||||
<< "Run shape inference on : '" << f.getName() << "'\n");
|
||||
|
||||
auto *toyDialect = getContext().getRegisteredDialect("toy");
|
||||
if (!toyDialect) {
|
||||
|
@ -211,7 +211,7 @@ public:
|
|||
// Populate the worklist with the operations that need shape inference:
|
||||
// these are the Toy operations that return a generic array.
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
|
||||
f->walk([&](mlir::Operation *op) {
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (op->getDialect() == toyDialect) {
|
||||
if (op->getNumResults() == 1 &&
|
||||
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
|
||||
|
@ -292,9 +292,9 @@ public:
|
|||
// restart after the callee is processed.
|
||||
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
|
||||
auto calleeName = callOp.getCalleeName();
|
||||
auto *callee = getModule().getNamedFunction(calleeName);
|
||||
auto callee = getModule().getNamedFunction(calleeName);
|
||||
if (!callee) {
|
||||
f->emitError("Shape inference failed, call to unknown '")
|
||||
f.emitError("Shape inference failed, call to unknown '")
|
||||
<< calleeName << "'";
|
||||
signalPassFailure();
|
||||
return mlir::failure();
|
||||
|
@ -302,7 +302,7 @@ public:
|
|||
auto mangledName = mangle(calleeName, op->getOpOperands());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
|
||||
<< "', mangled: '" << mangledName << "'\n");
|
||||
auto *mangledCallee = getModule().getNamedFunction(mangledName);
|
||||
auto mangledCallee = getModule().getNamedFunction(mangledName);
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
|
@ -327,7 +327,7 @@ public:
|
|||
// Done with inference on this function, removing it from the worklist.
|
||||
funcWorklist.pop_back();
|
||||
// Mark the function as non-generic now that inference has succeeded
|
||||
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
if (!opWorklist.empty()) {
|
||||
|
@ -337,31 +337,31 @@ public:
|
|||
<< " operations couldn't be inferred\n";
|
||||
for (auto *ope : opWorklist)
|
||||
errorMsg << " - " << *ope << "\n";
|
||||
f->emitError(errorMsg.str());
|
||||
f.emitError(errorMsg.str());
|
||||
signalPassFailure();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Finally, update the return type of the function based on the argument to
|
||||
// the return operation.
|
||||
for (auto &block : f->getBlocks()) {
|
||||
for (auto &block : f.getBlocks()) {
|
||||
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
|
||||
if (!ret)
|
||||
continue;
|
||||
if (ret.getNumOperands() &&
|
||||
f->getType().getResult(0) == ret.getOperand()->getType())
|
||||
f.getType().getResult(0) == ret.getOperand()->getType())
|
||||
// type match, we're done
|
||||
break;
|
||||
SmallVector<mlir::Type, 1> retTy;
|
||||
if (ret.getNumOperands())
|
||||
retTy.push_back(ret.getOperand()->getType());
|
||||
std::vector<mlir::Type> argumentsType;
|
||||
for (auto arg : f->getArguments())
|
||||
for (auto arg : f.getArguments())
|
||||
argumentsType.push_back(arg->getType());
|
||||
auto newType =
|
||||
mlir::FunctionType::get(argumentsType, retTy, &getContext());
|
||||
f->setType(newType);
|
||||
assert(succeeded(f->verify()));
|
||||
f.setType(newType);
|
||||
assert(succeeded(f.verify()));
|
||||
break;
|
||||
}
|
||||
return mlir::success();
|
||||
|
|
|
@ -136,14 +136,14 @@ public:
|
|||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Get or create the declaration of the printf function in the module.
|
||||
Function *printfFunc = getPrintf(*op->getFunction()->getModule());
|
||||
Function printfFunc = getPrintf(*op->getFunction().getModule());
|
||||
|
||||
auto print = cast<toy::PrintOp>(op);
|
||||
auto loc = print.getLoc();
|
||||
// We will operate on a MemRef abstraction, we use a type.cast to get one
|
||||
// if our operand is still a Toy array.
|
||||
Value *operand = memRefTypeCast(rewriter, operands[0]);
|
||||
Type retTy = printfFunc->getType().getResult(0);
|
||||
Type retTy = printfFunc.getType().getResult(0);
|
||||
|
||||
// Create our loop nest now
|
||||
using namespace edsc;
|
||||
|
@ -205,8 +205,8 @@ private:
|
|||
|
||||
/// Return the prototype declaration for printf in the module, create it if
|
||||
/// necessary.
|
||||
Function *getPrintf(Module &module) const {
|
||||
auto *printfFunc = module.getNamedFunction("printf");
|
||||
Function getPrintf(Module &module) const {
|
||||
auto printfFunc = module.getNamedFunction("printf");
|
||||
if (printfFunc)
|
||||
return printfFunc;
|
||||
|
||||
|
@ -218,10 +218,10 @@ private:
|
|||
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
|
||||
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
|
||||
printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
|
||||
printfFunc = Function::create(builder.getUnknownLoc(), "printf", printfTy);
|
||||
// It should be variadic, but we don't support it fully just yet.
|
||||
printfFunc->setAttr("std.varargs", builder.getBoolAttr(true));
|
||||
module.getFunctions().push_back(printfFunc);
|
||||
printfFunc.setAttr("std.varargs", builder.getBoolAttr(true));
|
||||
module.push_back(printfFunc);
|
||||
return printfFunc;
|
||||
}
|
||||
};
|
||||
|
@ -369,7 +369,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
|||
// affine dialect: they already include conversion to the LLVM dialect.
|
||||
|
||||
// First patch calls type to return memref instead of ToyArray
|
||||
for (auto &function : getModule()) {
|
||||
for (auto function : getModule()) {
|
||||
function.walk([&](Operation *op) {
|
||||
auto callOp = dyn_cast<CallOp>(op);
|
||||
if (!callOp)
|
||||
|
@ -384,7 +384,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
|||
});
|
||||
}
|
||||
|
||||
for (auto &function : getModule()) {
|
||||
for (auto function : getModule()) {
|
||||
function.walk([&](Operation *op) {
|
||||
// Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
|
||||
if (auto allocOp = dyn_cast<toy::AllocOp>(op)) {
|
||||
|
@ -403,8 +403,8 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
|
|||
}
|
||||
|
||||
// Lower Linalg to affine
|
||||
for (auto &function : getModule())
|
||||
linalg::lowerToLoops(&function);
|
||||
for (auto function : getModule())
|
||||
linalg::lowerToLoops(function);
|
||||
|
||||
getModule().dump();
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ public:
|
|||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->getFunctions().push_back(func.release());
|
||||
theModule->push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
|
@ -130,40 +130,40 @@ private:
|
|||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::Function *mlirGen(PrototypeAST &proto) {
|
||||
mlir::Function mlirGen(PrototypeAST &proto) {
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
// Arguments type is uniformly a generic array.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
|
||||
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
|
||||
// Mark the function as generic: it'll require type specialization for every
|
||||
// call site.
|
||||
if (function->getNumArguments())
|
||||
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
if (function.getNumArguments())
|
||||
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
/// Emit a new function and add it to the MLIR module.
|
||||
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
|
||||
mlir::Function mlirGen(FunctionAST &funcAST) {
|
||||
// Create a scope in the symbol table to hold variable declarations.
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
|
||||
// Create an MLIR function for the given prototype.
|
||||
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
|
||||
mlir::Function function(mlirGen(*funcAST.getProto()));
|
||||
if (!function)
|
||||
return nullptr;
|
||||
|
||||
// Let's start the body of the function now!
|
||||
// In MLIR the entry block of the function is special: it must have the same
|
||||
// argument list as the function itself.
|
||||
function->addEntryBlock();
|
||||
function.addEntryBlock();
|
||||
|
||||
auto &entryBlock = function->front();
|
||||
auto &entryBlock = function.front();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
|
@ -173,16 +173,18 @@ private:
|
|||
|
||||
// Create a builder for the function, it will be used throughout the codegen
|
||||
// to create operations in this function.
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
|
||||
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
|
||||
|
||||
// Emit the body of the function.
|
||||
if (!mlirGen(*funcAST.getBody()))
|
||||
if (!mlirGen(*funcAST.getBody())) {
|
||||
function.erase();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Implicitly return void if no return statement was emited.
|
||||
// FIXME: we may fix the parser instead to always return the last expression
|
||||
// (this would possibly help the REPL case later)
|
||||
if (function->getBlocks().back().back().getName().getStringRef() !=
|
||||
if (function.getBlocks().back().back().getName().getStringRef() !=
|
||||
"toy.return") {
|
||||
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
|
||||
mlirGen(fakeRet);
|
||||
|
|
|
@ -113,7 +113,7 @@ public:
|
|||
// function to process, the mangled name for this specialization, and the
|
||||
// types of the arguments on which to specialize.
|
||||
struct FunctionToSpecialize {
|
||||
mlir::Function *function;
|
||||
mlir::Function function;
|
||||
std::string mangledName;
|
||||
SmallVector<mlir::Type, 4> argumentsType;
|
||||
};
|
||||
|
@ -121,7 +121,7 @@ public:
|
|||
void runOnModule() override {
|
||||
auto &module = getModule();
|
||||
mlir::ModuleManager moduleManager(&module);
|
||||
auto *main = moduleManager.getNamedFunction("main");
|
||||
auto main = moduleManager.getNamedFunction("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
"Shape inference failed: can't find a main function\n");
|
||||
|
@ -140,7 +140,7 @@ public:
|
|||
|
||||
// Delete any generic function left
|
||||
// FIXME: we may want this as a separate pass.
|
||||
for (mlir::Function &function : llvm::make_early_inc_range(module)) {
|
||||
for (mlir::Function function : llvm::make_early_inc_range(module)) {
|
||||
if (auto genericAttr =
|
||||
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
|
||||
if (genericAttr.getValue())
|
||||
|
@ -155,7 +155,7 @@ public:
|
|||
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
|
||||
mlir::ModuleManager &moduleManager) {
|
||||
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
|
||||
mlir::Function *f = functionToSpecialize.function;
|
||||
mlir::Function f = functionToSpecialize.function;
|
||||
|
||||
// Check if cloning for specialization is needed (usually anything but main)
|
||||
// We will create a new function with the concrete types for the parameters
|
||||
|
@ -171,36 +171,36 @@ public:
|
|||
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
|
||||
{ToyArrayType::get(&getContext())},
|
||||
&getContext());
|
||||
auto *newFunction = new mlir::Function(
|
||||
f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
|
||||
auto newFunction = mlir::Function::create(
|
||||
f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs());
|
||||
moduleManager.insert(newFunction);
|
||||
|
||||
// Clone the function body
|
||||
mlir::BlockAndValueMapping mapper;
|
||||
f->cloneInto(newFunction, mapper);
|
||||
f.cloneInto(newFunction, mapper);
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "====== Cloned : \n";
|
||||
f->dump();
|
||||
f.dump();
|
||||
llvm::dbgs() << "====== Into : \n";
|
||||
newFunction->dump();
|
||||
newFunction.dump();
|
||||
});
|
||||
f = newFunction;
|
||||
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
// Remap the entry-block arguments
|
||||
// FIXME: this seems like a bug in `cloneInto()` above?
|
||||
auto &entryBlock = f->getBlocks().front();
|
||||
auto &entryBlock = f.getBlocks().front();
|
||||
int blockArgSize = entryBlock.getArguments().size();
|
||||
assert(blockArgSize == static_cast<int>(f->getType().getInputs().size()));
|
||||
entryBlock.addArguments(f->getType().getInputs());
|
||||
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
|
||||
entryBlock.addArguments(f.getType().getInputs());
|
||||
auto argList = entryBlock.getArguments();
|
||||
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
|
||||
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
|
||||
entryBlock.eraseArgument(0);
|
||||
}
|
||||
assert(succeeded(f->verify()));
|
||||
assert(succeeded(f.verify()));
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Run shape inference on : '" << f->getName() << "'\n");
|
||||
<< "Run shape inference on : '" << f.getName() << "'\n");
|
||||
|
||||
auto *toyDialect = getContext().getRegisteredDialect("toy");
|
||||
if (!toyDialect) {
|
||||
|
@ -212,7 +212,7 @@ public:
|
|||
// Populate the worklist with the operations that need shape inference:
|
||||
// these are the Toy operations that return a generic array.
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
|
||||
f->walk([&](mlir::Operation *op) {
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (op->getDialect() == toyDialect) {
|
||||
if (op->getNumResults() == 1 &&
|
||||
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
|
||||
|
@ -295,16 +295,16 @@ public:
|
|||
// restart after the callee is processed.
|
||||
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
|
||||
auto calleeName = callOp.getCalleeName();
|
||||
auto *callee = moduleManager.getNamedFunction(calleeName);
|
||||
auto callee = moduleManager.getNamedFunction(calleeName);
|
||||
if (!callee) {
|
||||
signalPassFailure();
|
||||
return f->emitError("Shape inference failed, call to unknown '")
|
||||
return f.emitError("Shape inference failed, call to unknown '")
|
||||
<< calleeName << "'";
|
||||
}
|
||||
auto mangledName = mangle(calleeName, op->getOpOperands());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
|
||||
<< "', mangled: '" << mangledName << "'\n");
|
||||
auto *mangledCallee = moduleManager.getNamedFunction(mangledName);
|
||||
auto mangledCallee = moduleManager.getNamedFunction(mangledName);
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
|
@ -315,7 +315,7 @@ public:
|
|||
// Found a specialized callee! Let's turn this into a normal call
|
||||
// operation.
|
||||
SmallVector<mlir::Value *, 8> operands(op->getOperands());
|
||||
mlir::OpBuilder builder(f->getBody());
|
||||
mlir::OpBuilder builder(f.getBody());
|
||||
builder.setInsertionPoint(op);
|
||||
auto newCall =
|
||||
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
|
||||
|
@ -330,12 +330,12 @@ public:
|
|||
// Done with inference on this function, removing it from the worklist.
|
||||
funcWorklist.pop_back();
|
||||
// Mark the function as non-generic now that inference has succeeded
|
||||
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
if (!opWorklist.empty()) {
|
||||
signalPassFailure();
|
||||
auto diag = f->emitError("Shape inference failed, ")
|
||||
auto diag = f.emitError("Shape inference failed, ")
|
||||
<< opWorklist.size() << " operations couldn't be inferred\n";
|
||||
for (auto *ope : opWorklist)
|
||||
diag << " - " << *ope << "\n";
|
||||
|
@ -344,24 +344,24 @@ public:
|
|||
|
||||
// Finally, update the return type of the function based on the argument to
|
||||
// the return operation.
|
||||
for (auto &block : f->getBlocks()) {
|
||||
for (auto &block : f.getBlocks()) {
|
||||
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
|
||||
if (!ret)
|
||||
continue;
|
||||
if (ret.getNumOperands() &&
|
||||
f->getType().getResult(0) == ret.getOperand()->getType())
|
||||
f.getType().getResult(0) == ret.getOperand()->getType())
|
||||
// type match, we're done
|
||||
break;
|
||||
SmallVector<mlir::Type, 1> retTy;
|
||||
if (ret.getNumOperands())
|
||||
retTy.push_back(ret.getOperand()->getType());
|
||||
std::vector<mlir::Type> argumentsType;
|
||||
for (auto arg : f->getArguments())
|
||||
for (auto arg : f.getArguments())
|
||||
argumentsType.push_back(arg->getType());
|
||||
auto newType =
|
||||
mlir::FunctionType::get(argumentsType, retTy, &getContext());
|
||||
f->setType(newType);
|
||||
assert(succeeded(f->verify()));
|
||||
f.setType(newType);
|
||||
assert(succeeded(f.verify()));
|
||||
break;
|
||||
}
|
||||
return mlir::success();
|
||||
|
|
|
@ -34,7 +34,7 @@ template <bool IsPostDom> class DominanceInfoBase {
|
|||
using base = llvm::DominatorTreeBase<Block, IsPostDom>;
|
||||
|
||||
public:
|
||||
DominanceInfoBase(Function *function) { recalculate(function); }
|
||||
DominanceInfoBase(Function function) { recalculate(function); }
|
||||
DominanceInfoBase(Operation *op) { recalculate(op); }
|
||||
DominanceInfoBase(DominanceInfoBase &&) = default;
|
||||
DominanceInfoBase &operator=(DominanceInfoBase &&) = default;
|
||||
|
@ -43,7 +43,7 @@ public:
|
|||
DominanceInfoBase &operator=(const DominanceInfoBase &) = delete;
|
||||
|
||||
/// Recalculate the dominance info.
|
||||
void recalculate(Function *function);
|
||||
void recalculate(Function function);
|
||||
void recalculate(Operation *op);
|
||||
|
||||
/// Get the root dominance node of the given region.
|
||||
|
|
|
@ -104,8 +104,8 @@ struct NestedPattern {
|
|||
NestedPattern &operator=(const NestedPattern &) = default;
|
||||
|
||||
/// Returns all the top-level matches in `func`.
|
||||
void match(Function *func, SmallVectorImpl<NestedMatch> *matches) {
|
||||
func->walk([&](Operation *op) { matchOne(op, matches); });
|
||||
void match(Function func, SmallVectorImpl<NestedMatch> *matches) {
|
||||
func.walk([&](Operation *op) { matchOne(op, matches); });
|
||||
}
|
||||
|
||||
/// Returns all the top-level matches in `op`.
|
||||
|
|
|
@ -44,7 +44,7 @@ struct StaticFloatMemRef {
|
|||
/// each of the arguments, initialize the storage with `initialValue`, and
|
||||
/// return a list of type-erased descriptor pointers.
|
||||
llvm::Expected<SmallVector<void *, 8>>
|
||||
allocateMemRefArguments(Function *func, float initialValue = 0.0);
|
||||
allocateMemRefArguments(Function func, float initialValue = 0.0);
|
||||
|
||||
/// Free a list of type-erased descriptors to statically-shaped memrefs with
|
||||
/// element type f32.
|
||||
|
|
|
@ -44,7 +44,7 @@ public:
|
|||
|
||||
/// Returns whether the given function is a kernel function, i.e., has the
|
||||
/// 'gpu.kernel' attribute.
|
||||
static bool isKernel(Function *function);
|
||||
static bool isKernel(Function function);
|
||||
};
|
||||
|
||||
/// Utility class for the GPU dialect to represent triples of `Value`s
|
||||
|
@ -122,12 +122,12 @@ public:
|
|||
using Op::Op;
|
||||
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
Function *kernelFunc, Value *gridSizeX, Value *gridSizeY,
|
||||
Function kernelFunc, Value *gridSizeX, Value *gridSizeY,
|
||||
Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
|
||||
Value *blockSizeZ, ArrayRef<Value *> kernelOperands);
|
||||
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
Function *kernelFunc, KernelDim3 gridSize,
|
||||
Function kernelFunc, KernelDim3 gridSize,
|
||||
KernelDim3 blockSize, ArrayRef<Value *> kernelOperands);
|
||||
|
||||
/// The kernel function specified by the operation's `kernel` attribute.
|
||||
|
|
|
@ -313,9 +313,8 @@ class FunctionAttr
|
|||
detail::StringAttributeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
using ValueType = Function *;
|
||||
using ValueType = StringRef;
|
||||
|
||||
static FunctionAttr get(Function *value);
|
||||
static FunctionAttr get(StringRef value, MLIRContext *ctx);
|
||||
|
||||
/// Returns the name of the held function reference.
|
||||
|
|
|
@ -101,7 +101,7 @@ public:
|
|||
|
||||
/// Returns the function that this block is part of, even if the block is
|
||||
/// nested under an operation region.
|
||||
Function *getFunction();
|
||||
Function getFunction();
|
||||
|
||||
/// Insert this block (which must not already be in a function) right before
|
||||
/// the specified block.
|
||||
|
|
|
@ -112,7 +112,7 @@ public:
|
|||
AffineMapAttr getAffineMapAttr(AffineMap map);
|
||||
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
|
||||
TypeAttr getTypeAttr(Type type);
|
||||
FunctionAttr getFunctionAttr(Function *value);
|
||||
FunctionAttr getFunctionAttr(Function value);
|
||||
FunctionAttr getFunctionAttr(StringRef value);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values);
|
||||
|
|
|
@ -145,17 +145,13 @@ public:
|
|||
|
||||
/// Verify an attribute from this dialect on the given function. Returns
|
||||
/// failure if the verification failed, success otherwise.
|
||||
virtual LogicalResult verifyFunctionAttribute(Function *, NamedAttribute) {
|
||||
return success();
|
||||
}
|
||||
virtual LogicalResult verifyFunctionAttribute(Function, NamedAttribute);
|
||||
|
||||
/// Verify an attribute from this dialect on the argument at 'argIndex' for
|
||||
/// the given function. Returns failure if the verification failed, success
|
||||
/// otherwise.
|
||||
virtual LogicalResult
|
||||
verifyFunctionArgAttribute(Function *, unsigned argIndex, NamedAttribute) {
|
||||
return success();
|
||||
}
|
||||
virtual LogicalResult verifyFunctionArgAttribute(Function, unsigned argIndex,
|
||||
NamedAttribute);
|
||||
|
||||
/// Verify an attribute from this dialect on the given operation. Returns
|
||||
/// failure if the verification failed, success otherwise.
|
||||
|
|
|
@ -29,29 +29,79 @@
|
|||
namespace mlir {
|
||||
class BlockAndValueMapping;
|
||||
class FunctionType;
|
||||
class Function;
|
||||
class MLIRContext;
|
||||
class Module;
|
||||
|
||||
/// This is the base class for all of the MLIR function types.
|
||||
class Function : public llvm::ilist_node_with_parent<Function, Module> {
|
||||
public:
|
||||
Function(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
Function(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<NamedAttributeList> argAttrs);
|
||||
namespace detail {
|
||||
/// This class represents all of the internal state of a Function. This allows
|
||||
/// for the Function class to be value typed.
|
||||
class FunctionStorage
|
||||
: public llvm::ilist_node_with_parent<FunctionStorage, Module> {
|
||||
FunctionStorage(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
FunctionStorage(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<NamedAttributeList> argAttrs);
|
||||
/// The name of the function.
|
||||
Identifier name;
|
||||
|
||||
/// The module this function is embedded into.
|
||||
Module *module = nullptr;
|
||||
|
||||
/// The source location the function was defined or derived from.
|
||||
Location getLoc() { return location; }
|
||||
Location location;
|
||||
|
||||
/// The type of the function.
|
||||
FunctionType type;
|
||||
|
||||
/// This holds general named attributes for the function.
|
||||
NamedAttributeList attrs;
|
||||
|
||||
/// The attributes lists for each of the function arguments.
|
||||
std::vector<NamedAttributeList> argAttrs;
|
||||
|
||||
/// The body of the function.
|
||||
Region body;
|
||||
|
||||
friend struct llvm::ilist_traits<FunctionStorage>;
|
||||
friend Function;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// This class represents an MLIR function, or the common unit of computation.
|
||||
/// The region of a function is not allowed to implicitly capture global values,
|
||||
/// and all external references must use Function arguments or attributes.
|
||||
class Function {
|
||||
public:
|
||||
Function(detail::FunctionStorage *impl = nullptr) : impl(impl) {}
|
||||
|
||||
static Function create(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {}) {
|
||||
return new detail::FunctionStorage(location, name, type, attrs);
|
||||
}
|
||||
static Function create(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<NamedAttributeList> argAttrs) {
|
||||
return new detail::FunctionStorage(location, name, type, attrs, argAttrs);
|
||||
}
|
||||
|
||||
/// Allow converting a Function to bool for null checks.
|
||||
operator bool() const { return impl; }
|
||||
bool operator==(Function other) const { return impl == other.impl; }
|
||||
bool operator!=(Function other) const { return !(*this == other); }
|
||||
|
||||
/// The source location the function was defined or derived from.
|
||||
Location getLoc() { return impl->location; }
|
||||
|
||||
/// Set the source location this function was defined or derived from.
|
||||
void setLoc(Location loc) { location = loc; }
|
||||
void setLoc(Location loc) { impl->location = loc; }
|
||||
|
||||
/// Return the name of this function, without the @.
|
||||
Identifier getName() { return name; }
|
||||
Identifier getName() { return impl->name; }
|
||||
|
||||
/// Return the type of this function.
|
||||
FunctionType getType() { return type; }
|
||||
FunctionType getType() { return impl->type; }
|
||||
|
||||
/// Change the type of this function in place. This is an extremely dangerous
|
||||
/// operation and it is up to the caller to ensure that this is legal for this
|
||||
|
@ -61,12 +111,12 @@ public:
|
|||
/// parameters we drop the extra attributes, if there are more parameters
|
||||
/// they won't have any attributes.
|
||||
void setType(FunctionType newType) {
|
||||
type = newType;
|
||||
argAttrs.resize(type.getNumInputs());
|
||||
impl->type = newType;
|
||||
impl->argAttrs.resize(newType.getNumInputs());
|
||||
}
|
||||
|
||||
MLIRContext *getContext();
|
||||
Module *getModule() { return module; }
|
||||
Module *getModule() { return impl->module; }
|
||||
|
||||
/// Add an entry block to an empty function, and set up the block arguments
|
||||
/// to match the signature of the function.
|
||||
|
@ -82,28 +132,28 @@ public:
|
|||
// Body Handling
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
Region &getBody() { return body; }
|
||||
void eraseBody() { body.getBlocks().clear(); }
|
||||
Region &getBody() { return impl->body; }
|
||||
void eraseBody() { getBody().getBlocks().clear(); }
|
||||
|
||||
/// This is the list of blocks in the function.
|
||||
using RegionType = Region::RegionType;
|
||||
RegionType &getBlocks() { return body.getBlocks(); }
|
||||
RegionType &getBlocks() { return getBody().getBlocks(); }
|
||||
|
||||
// Iteration over the block in the function.
|
||||
using iterator = RegionType::iterator;
|
||||
using reverse_iterator = RegionType::reverse_iterator;
|
||||
|
||||
iterator begin() { return body.begin(); }
|
||||
iterator end() { return body.end(); }
|
||||
reverse_iterator rbegin() { return body.rbegin(); }
|
||||
reverse_iterator rend() { return body.rend(); }
|
||||
iterator begin() { return getBody().begin(); }
|
||||
iterator end() { return getBody().end(); }
|
||||
reverse_iterator rbegin() { return getBody().rbegin(); }
|
||||
reverse_iterator rend() { return getBody().rend(); }
|
||||
|
||||
bool empty() { return body.empty(); }
|
||||
void push_back(Block *block) { body.push_back(block); }
|
||||
void push_front(Block *block) { body.push_front(block); }
|
||||
bool empty() { return getBody().empty(); }
|
||||
void push_back(Block *block) { getBody().push_back(block); }
|
||||
void push_front(Block *block) { getBody().push_front(block); }
|
||||
|
||||
Block &back() { return body.back(); }
|
||||
Block &front() { return body.front(); }
|
||||
Block &back() { return getBody().back(); }
|
||||
Block &front() { return getBody().front(); }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operation Walkers
|
||||
|
@ -150,53 +200,55 @@ public:
|
|||
/// the lifetime of an function.
|
||||
|
||||
/// Return all of the attributes on this function.
|
||||
ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
|
||||
ArrayRef<NamedAttribute> getAttrs() { return impl->attrs.getAttrs(); }
|
||||
|
||||
/// Return the internal attribute list on this function.
|
||||
NamedAttributeList &getAttrList() { return attrs; }
|
||||
NamedAttributeList &getAttrList() { return impl->attrs; }
|
||||
|
||||
/// Return all of the attributes for the argument at 'index'.
|
||||
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return argAttrs[index].getAttrs();
|
||||
return impl->argAttrs[index].getAttrs();
|
||||
}
|
||||
|
||||
/// Set the attributes held by this function.
|
||||
void setAttrs(ArrayRef<NamedAttribute> attributes) {
|
||||
attrs.setAttrs(attributes);
|
||||
impl->attrs.setAttrs(attributes);
|
||||
}
|
||||
|
||||
/// Set the attributes held by the argument at 'index'.
|
||||
void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
argAttrs[index].setAttrs(attributes);
|
||||
impl->argAttrs[index].setAttrs(attributes);
|
||||
}
|
||||
void setArgAttrs(unsigned index, NamedAttributeList attributes) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
argAttrs[index] = attributes;
|
||||
impl->argAttrs[index] = attributes;
|
||||
}
|
||||
void setAllArgAttrs(ArrayRef<NamedAttributeList> attributes) {
|
||||
assert(attributes.size() == getNumArguments());
|
||||
for (unsigned i = 0, e = attributes.size(); i != e; ++i)
|
||||
argAttrs[i] = attributes[i];
|
||||
impl->argAttrs[i] = attributes[i];
|
||||
}
|
||||
|
||||
/// Return all argument attributes of this function.
|
||||
MutableArrayRef<NamedAttributeList> getAllArgAttrs() { return argAttrs; }
|
||||
MutableArrayRef<NamedAttributeList> getAllArgAttrs() {
|
||||
return impl->argAttrs;
|
||||
}
|
||||
|
||||
/// Return the specified attribute if present, null otherwise.
|
||||
Attribute getAttr(Identifier name) { return attrs.get(name); }
|
||||
Attribute getAttr(StringRef name) { return attrs.get(name); }
|
||||
Attribute getAttr(Identifier name) { return impl->attrs.get(name); }
|
||||
Attribute getAttr(StringRef name) { return impl->attrs.get(name); }
|
||||
|
||||
/// Return the specified attribute, if present, for the argument at 'index',
|
||||
/// null otherwise.
|
||||
Attribute getArgAttr(unsigned index, Identifier name) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return argAttrs[index].get(name);
|
||||
return impl->argAttrs[index].get(name);
|
||||
}
|
||||
Attribute getArgAttr(unsigned index, StringRef name) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return argAttrs[index].get(name);
|
||||
return impl->argAttrs[index].get(name);
|
||||
}
|
||||
|
||||
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) {
|
||||
|
@ -219,13 +271,15 @@ public:
|
|||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
|
||||
void setAttr(Identifier name, Attribute value) {
|
||||
impl->attrs.set(name, value);
|
||||
}
|
||||
void setAttr(StringRef name, Attribute value) {
|
||||
setAttr(Identifier::get(name, getContext()), value);
|
||||
}
|
||||
void setArgAttr(unsigned index, Identifier name, Attribute value) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
argAttrs[index].set(name, value);
|
||||
impl->argAttrs[index].set(name, value);
|
||||
}
|
||||
void setArgAttr(unsigned index, StringRef name, Attribute value) {
|
||||
setArgAttr(index, Identifier::get(name, getContext()), value);
|
||||
|
@ -234,12 +288,12 @@ public:
|
|||
/// Remove the attribute with the specified name if it exists. The return
|
||||
/// value indicates whether the attribute was present or not.
|
||||
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
|
||||
return attrs.remove(name);
|
||||
return impl->attrs.remove(name);
|
||||
}
|
||||
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
|
||||
Identifier name) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return attrs.remove(name);
|
||||
return impl->attrs.remove(name);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -281,44 +335,37 @@ public:
|
|||
/// contains entries for function arguments, these arguments are not included
|
||||
/// in the new function. Replaces references to cloned sub-values with the
|
||||
/// corresponding value that is copied, and adds those mappings to the mapper.
|
||||
Function *clone(BlockAndValueMapping &mapper);
|
||||
Function *clone();
|
||||
Function clone(BlockAndValueMapping &mapper);
|
||||
Function clone();
|
||||
|
||||
/// Clone the internal blocks and attributes from this function into dest. Any
|
||||
/// cloned blocks are appended to the back of dest. This function asserts that
|
||||
/// the attributes of the current function and dest are compatible.
|
||||
void cloneInto(Function *dest, BlockAndValueMapping &mapper);
|
||||
void cloneInto(Function dest, BlockAndValueMapping &mapper);
|
||||
|
||||
/// Methods for supporting PointerLikeTypeTraits.
|
||||
const void *getAsOpaquePointer() const {
|
||||
return static_cast<const void *>(impl);
|
||||
}
|
||||
static Function getFromOpaquePointer(const void *pointer) {
|
||||
return reinterpret_cast<detail::FunctionStorage *>(
|
||||
const_cast<void *>(pointer));
|
||||
}
|
||||
|
||||
private:
|
||||
/// Set the name of this function.
|
||||
void setName(Identifier newName) { name = newName; }
|
||||
void setName(Identifier newName) { impl->name = newName; }
|
||||
|
||||
/// The name of the function.
|
||||
Identifier name;
|
||||
|
||||
/// The module this function is embedded into.
|
||||
Module *module = nullptr;
|
||||
|
||||
/// The source location the function was defined or derived from.
|
||||
Location location;
|
||||
|
||||
/// The type of the function.
|
||||
FunctionType type;
|
||||
|
||||
/// This holds general named attributes for the function.
|
||||
NamedAttributeList attrs;
|
||||
|
||||
/// The attributes lists for each of the function arguments.
|
||||
std::vector<NamedAttributeList> argAttrs;
|
||||
|
||||
/// The body of the function.
|
||||
Region body;
|
||||
|
||||
void operator=(Function &) = delete;
|
||||
friend struct llvm::ilist_traits<Function>;
|
||||
/// A pointer to the impl storage instance for this function. This allows for
|
||||
/// 'Function' to be treated as a value type.
|
||||
detail::FunctionStorage *impl = nullptr;
|
||||
|
||||
// Allow access to 'setName'.
|
||||
friend class SymbolTable;
|
||||
|
||||
// Allow access to 'impl'.
|
||||
friend class Module;
|
||||
friend class Region;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -487,21 +534,52 @@ private:
|
|||
namespace llvm {
|
||||
|
||||
template <>
|
||||
struct ilist_traits<::mlir::Function>
|
||||
: public ilist_alloc_traits<::mlir::Function> {
|
||||
using Function = ::mlir::Function;
|
||||
using function_iterator = simple_ilist<Function>::iterator;
|
||||
struct ilist_traits<::mlir::detail::FunctionStorage>
|
||||
: public ilist_alloc_traits<::mlir::detail::FunctionStorage> {
|
||||
using FunctionStorage = ::mlir::detail::FunctionStorage;
|
||||
using function_iterator = simple_ilist<FunctionStorage>::iterator;
|
||||
|
||||
static void deleteNode(Function *function) { delete function; }
|
||||
static void deleteNode(FunctionStorage *function) { delete function; }
|
||||
|
||||
void addNodeToList(Function *function);
|
||||
void removeNodeFromList(Function *function);
|
||||
void transferNodesFromList(ilist_traits<Function> &otherList,
|
||||
void addNodeToList(FunctionStorage *function);
|
||||
void removeNodeFromList(FunctionStorage *function);
|
||||
void transferNodesFromList(ilist_traits<FunctionStorage> &otherList,
|
||||
function_iterator first, function_iterator last);
|
||||
|
||||
private:
|
||||
mlir::Module *getContainingModule();
|
||||
};
|
||||
} // end namespace llvm
|
||||
|
||||
// Functions hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::Function> {
|
||||
static mlir::Function getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::Function::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static mlir::Function getTombstoneKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::Function::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static unsigned getHashValue(mlir::Function val) {
|
||||
return hash_value(val.getAsOpaquePointer());
|
||||
}
|
||||
static bool isEqual(mlir::Function LHS, mlir::Function RHS) {
|
||||
return LHS == RHS;
|
||||
}
|
||||
};
|
||||
|
||||
/// Allow stealing the low bits of FunctionStorage.
|
||||
template <> struct PointerLikeTypeTraits<mlir::Function> {
|
||||
public:
|
||||
static inline void *getAsVoidPointer(mlir::Function I) {
|
||||
return const_cast<void *>(I.getAsOpaquePointer());
|
||||
}
|
||||
static inline mlir::Function getFromVoidPointer(void *P) {
|
||||
return mlir::Function::getFromOpaquePointer(P);
|
||||
}
|
||||
enum { NumLowBitsAvailable = 3 };
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_IR_FUNCTION_H
|
||||
|
|
|
@ -34,34 +34,54 @@ public:
|
|||
|
||||
MLIRContext *getContext() { return context; }
|
||||
|
||||
/// An iterator class used to iterate over the held functions.
|
||||
class iterator : public llvm::mapped_iterator<
|
||||
llvm::iplist<detail::FunctionStorage>::iterator,
|
||||
Function (*)(detail::FunctionStorage &)> {
|
||||
static Function unwrap(detail::FunctionStorage &impl) { return &impl; }
|
||||
|
||||
public:
|
||||
using reference = Function;
|
||||
|
||||
/// Initializes the operand type iterator to the specified operand iterator.
|
||||
iterator(llvm::iplist<detail::FunctionStorage>::iterator it)
|
||||
: llvm::mapped_iterator<llvm::iplist<detail::FunctionStorage>::iterator,
|
||||
Function (*)(detail::FunctionStorage &)>(
|
||||
it, &unwrap) {}
|
||||
iterator(Function it)
|
||||
: iterator(llvm::iplist<detail::FunctionStorage>::iterator(it.impl)) {}
|
||||
};
|
||||
|
||||
/// This is the list of functions in the module.
|
||||
using FunctionListType = llvm::iplist<Function>;
|
||||
FunctionListType &getFunctions() { return functions; }
|
||||
llvm::iterator_range<iterator> getFunctions() { return {begin(), end()}; }
|
||||
|
||||
// Iteration over the functions in the module.
|
||||
using iterator = FunctionListType::iterator;
|
||||
using reverse_iterator = FunctionListType::reverse_iterator;
|
||||
|
||||
iterator begin() { return functions.begin(); }
|
||||
iterator end() { return functions.end(); }
|
||||
reverse_iterator rbegin() { return functions.rbegin(); }
|
||||
reverse_iterator rend() { return functions.rend(); }
|
||||
Function front() { return &functions.front(); }
|
||||
Function back() { return &functions.back(); }
|
||||
|
||||
void push_back(Function fn) { functions.push_back(fn.impl); }
|
||||
void insert(iterator insertPt, Function fn) {
|
||||
functions.insert(insertPt.getCurrent(), fn.impl);
|
||||
}
|
||||
|
||||
// Interfaces for working with the symbol table.
|
||||
|
||||
/// Look up a function with the specified name, returning null if no such
|
||||
/// name exists. Function names never include the @ on them. Note: This
|
||||
/// performs a linear scan of held symbols.
|
||||
Function *getNamedFunction(StringRef name) {
|
||||
Function getNamedFunction(StringRef name) {
|
||||
return getNamedFunction(Identifier::get(name, getContext()));
|
||||
}
|
||||
|
||||
/// Look up a function with the specified name, returning null if no such
|
||||
/// name exists. Function names never include the @ on them. Note: This
|
||||
/// performs a linear scan of held symbols.
|
||||
Function *getNamedFunction(Identifier name) {
|
||||
auto it = llvm::find_if(
|
||||
functions, [name](Function &fn) { return fn.getName() == name; });
|
||||
Function getNamedFunction(Identifier name) {
|
||||
auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) {
|
||||
return Function(&fn).getName() == name;
|
||||
});
|
||||
return it == functions.end() ? nullptr : &*it;
|
||||
}
|
||||
|
||||
|
@ -74,11 +94,13 @@ public:
|
|||
void dump();
|
||||
|
||||
private:
|
||||
friend struct llvm::ilist_traits<Function>;
|
||||
friend class Function;
|
||||
friend struct llvm::ilist_traits<detail::FunctionStorage>;
|
||||
friend detail::FunctionStorage;
|
||||
friend Function;
|
||||
|
||||
/// getSublistAccess() - Returns pointer to member of function list
|
||||
static FunctionListType Module::*getSublistAccess(Function *) {
|
||||
static llvm::iplist<detail::FunctionStorage> Module::*
|
||||
getSublistAccess(detail::FunctionStorage *) {
|
||||
return &Module::functions;
|
||||
}
|
||||
|
||||
|
@ -86,7 +108,7 @@ private:
|
|||
MLIRContext *context;
|
||||
|
||||
/// This is the actual list of functions the module contains.
|
||||
FunctionListType functions;
|
||||
llvm::iplist<detail::FunctionStorage> functions;
|
||||
};
|
||||
|
||||
/// A class used to manage the symbols held by a module. This class handles
|
||||
|
@ -98,24 +120,24 @@ public:
|
|||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names must never include the @ on them.
|
||||
template <typename NameTy> Function *getNamedFunction(NameTy &&name) const {
|
||||
template <typename NameTy> Function getNamedFunction(NameTy &&name) const {
|
||||
return symbolTable.lookup(name);
|
||||
}
|
||||
|
||||
/// Insert a new symbol into the module, auto-renaming it as necessary.
|
||||
void insert(Function *function) {
|
||||
void insert(Function function) {
|
||||
symbolTable.insert(function);
|
||||
module->getFunctions().push_back(function);
|
||||
module->push_back(function);
|
||||
}
|
||||
void insert(Module::iterator insertPt, Function *function) {
|
||||
void insert(Module::iterator insertPt, Function function) {
|
||||
symbolTable.insert(function);
|
||||
module->getFunctions().insert(insertPt, function);
|
||||
module->insert(insertPt, function);
|
||||
}
|
||||
|
||||
/// Remove the given symbol from the module symbol table and then erase it.
|
||||
void erase(Function *function) {
|
||||
void erase(Function function) {
|
||||
symbolTable.erase(function);
|
||||
function->erase();
|
||||
function.erase();
|
||||
}
|
||||
|
||||
/// Return the internally held module.
|
||||
|
|
|
@ -128,7 +128,7 @@ public:
|
|||
/// Returns the function that this operation is part of.
|
||||
/// The function is determined by traversing the chain of parent operations.
|
||||
/// Returns nullptr if the operation is unlinked.
|
||||
Function *getFunction();
|
||||
Function getFunction();
|
||||
|
||||
/// Replace any uses of 'from' with 'to' within this operation.
|
||||
void replaceUsesOfWith(Value *from, Value *to);
|
||||
|
|
|
@ -420,7 +420,7 @@ private:
|
|||
/// patterns in a greedy work-list driven manner. Return true if no more
|
||||
/// patterns can be matched in the result function.
|
||||
///
|
||||
bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns);
|
||||
bool applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns);
|
||||
|
||||
/// Helper class to create a list of rewrite patterns given a list of their
|
||||
/// types and a list of attributes perfect-forwarded to each of the conversion
|
||||
|
|
|
@ -27,11 +27,16 @@
|
|||
namespace mlir {
|
||||
class BlockAndValueMapping;
|
||||
|
||||
namespace detail {
|
||||
class FunctionStorage;
|
||||
}
|
||||
|
||||
/// This class contains a list of basic blocks and has a notion of the object it
|
||||
/// is part of - a Function or an Operation.
|
||||
class Region {
|
||||
public:
|
||||
explicit Region(Function *container = nullptr);
|
||||
Region() = default;
|
||||
explicit Region(Function container);
|
||||
explicit Region(Operation *container);
|
||||
~Region();
|
||||
|
||||
|
@ -77,7 +82,7 @@ public:
|
|||
|
||||
/// A Region is either a function body or a part of an operation. If it is
|
||||
/// a Function body, then return this function, otherwise return null.
|
||||
Function *getContainingFunction();
|
||||
Function getContainingFunction();
|
||||
|
||||
/// Return true if this region is a proper ancestor of the `other` region.
|
||||
bool isProperAncestor(Region *other);
|
||||
|
@ -118,7 +123,7 @@ private:
|
|||
RegionType blocks;
|
||||
|
||||
/// This is the object we are part of.
|
||||
llvm::PointerUnion<Function *, Operation *> container;
|
||||
llvm::PointerUnion<detail::FunctionStorage *, Operation *> container;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#ifndef MLIR_IR_SYMBOLTABLE_H
|
||||
#define MLIR_IR_SYMBOLTABLE_H
|
||||
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -35,18 +35,18 @@ public:
|
|||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names never include the @ on them.
|
||||
Function *lookup(StringRef name) const;
|
||||
Function lookup(StringRef name) const;
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names never include the @ on them.
|
||||
Function *lookup(Identifier name) const;
|
||||
Function lookup(Identifier name) const;
|
||||
|
||||
/// Erase the given symbol from the table.
|
||||
void erase(Function *symbol);
|
||||
void erase(Function symbol);
|
||||
|
||||
/// Insert a new symbol into the table, and rename it as necessary to avoid
|
||||
/// collisions.
|
||||
void insert(Function *symbol);
|
||||
void insert(Function symbol);
|
||||
|
||||
/// Returns the context held by this symbol table.
|
||||
MLIRContext *getContext() const { return context; }
|
||||
|
@ -55,7 +55,7 @@ private:
|
|||
MLIRContext *context;
|
||||
|
||||
/// This is a mapping from a name to the function with that name.
|
||||
llvm::DenseMap<Identifier, Function *> symbolTable;
|
||||
llvm::DenseMap<Identifier, Function> symbolTable;
|
||||
|
||||
/// This is used when name conflicts are detected.
|
||||
unsigned uniquingCounter = 0;
|
||||
|
|
|
@ -72,7 +72,7 @@ public:
|
|||
}
|
||||
|
||||
/// Return the function that this Value is defined in.
|
||||
Function *getFunction();
|
||||
Function getFunction();
|
||||
|
||||
/// If this value is the result of an operation, return the operation that
|
||||
/// defines it.
|
||||
|
@ -128,7 +128,7 @@ public:
|
|||
}
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
Function *getFunction();
|
||||
Function getFunction();
|
||||
|
||||
Block *getOwner() { return owner; }
|
||||
|
||||
|
|
|
@ -153,7 +153,7 @@ public:
|
|||
|
||||
/// Verify a function argument attribute registered to this dialect.
|
||||
/// Returns failure if the verification failed, success otherwise.
|
||||
LogicalResult verifyFunctionArgAttribute(Function *func, unsigned argIdx,
|
||||
LogicalResult verifyFunctionArgAttribute(Function func, unsigned argIdx,
|
||||
NamedAttribute argAttr) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -106,7 +106,7 @@ template <typename IRUnitT> class AnalysisMap {
|
|||
}
|
||||
|
||||
public:
|
||||
explicit AnalysisMap(IRUnitT *ir) : ir(ir) {}
|
||||
explicit AnalysisMap(IRUnitT ir) : ir(ir) {}
|
||||
|
||||
/// Get an analysis for the current IR unit, computing it if necessary.
|
||||
template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
|
||||
|
@ -140,8 +140,8 @@ public:
|
|||
}
|
||||
|
||||
/// Returns the IR unit that this analysis map represents.
|
||||
IRUnitT *getIRUnit() { return ir; }
|
||||
const IRUnitT *getIRUnit() const { return ir; }
|
||||
IRUnitT getIRUnit() { return ir; }
|
||||
const IRUnitT getIRUnit() const { return ir; }
|
||||
|
||||
/// Clear any held analyses.
|
||||
void clear() { analyses.clear(); }
|
||||
|
@ -158,7 +158,7 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
IRUnitT *ir;
|
||||
IRUnitT ir;
|
||||
ConceptMap analyses;
|
||||
};
|
||||
|
||||
|
@ -231,14 +231,14 @@ public:
|
|||
/// Query for the analysis of a function. The analysis is computed if it does
|
||||
/// not exist.
|
||||
template <typename AnalysisT>
|
||||
AnalysisT &getFunctionAnalysis(Function *function) {
|
||||
AnalysisT &getFunctionAnalysis(Function function) {
|
||||
return slice(function).getAnalysis<AnalysisT>();
|
||||
}
|
||||
|
||||
/// Query for a cached analysis of a child function, or return null.
|
||||
template <typename AnalysisT>
|
||||
llvm::Optional<std::reference_wrapper<AnalysisT>>
|
||||
getCachedFunctionAnalysis(Function *function) const {
|
||||
getCachedFunctionAnalysis(Function function) const {
|
||||
auto it = functionAnalyses.find(function);
|
||||
if (it == functionAnalyses.end())
|
||||
return llvm::None;
|
||||
|
@ -258,7 +258,7 @@ public:
|
|||
}
|
||||
|
||||
/// Create an analysis slice for the given child function.
|
||||
FunctionAnalysisManager slice(Function *function);
|
||||
FunctionAnalysisManager slice(Function function);
|
||||
|
||||
/// Invalidate any non preserved analyses.
|
||||
void invalidate(const detail::PreservedAnalyses &pa);
|
||||
|
@ -269,11 +269,11 @@ public:
|
|||
|
||||
private:
|
||||
/// The cached analyses for functions within the current module.
|
||||
llvm::DenseMap<Function *, std::unique_ptr<detail::AnalysisMap<Function>>>
|
||||
llvm::DenseMap<Function, std::unique_ptr<detail::AnalysisMap<Function>>>
|
||||
functionAnalyses;
|
||||
|
||||
/// The analyses for the owning module.
|
||||
detail::AnalysisMap<Module> moduleAnalyses;
|
||||
detail::AnalysisMap<Module *> moduleAnalyses;
|
||||
|
||||
/// An optional instrumentation object.
|
||||
PassInstrumentor *passInstrumentor;
|
||||
|
|
|
@ -70,12 +70,12 @@ class ModulePassExecutor;
|
|||
/// interface for accessing and initializing necessary state for pass execution.
|
||||
template <typename IRUnitT, typename AnalysisManagerT>
|
||||
struct PassExecutionState {
|
||||
PassExecutionState(IRUnitT *ir, AnalysisManagerT &analysisManager)
|
||||
PassExecutionState(IRUnitT ir, AnalysisManagerT &analysisManager)
|
||||
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
|
||||
|
||||
/// The current IR unit being transformed and a bool for if the pass signaled
|
||||
/// a failure.
|
||||
llvm::PointerIntPair<IRUnitT *, 1, bool> irAndPassFailed;
|
||||
llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
|
||||
|
||||
/// The analysis manager for the IR unit.
|
||||
AnalysisManagerT &analysisManager;
|
||||
|
@ -107,9 +107,7 @@ protected:
|
|||
virtual FunctionPassBase *clone() const = 0;
|
||||
|
||||
/// Return the current function being transformed.
|
||||
Function &getFunction() {
|
||||
return *getPassState().irAndPassFailed.getPointer();
|
||||
}
|
||||
Function getFunction() { return getPassState().irAndPassFailed.getPointer(); }
|
||||
|
||||
/// Return the MLIR context for the current function being transformed.
|
||||
MLIRContext &getContext() { return *getFunction().getContext(); }
|
||||
|
@ -128,7 +126,7 @@ protected:
|
|||
private:
|
||||
/// Forwarding function to execute this pass.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult run(Function *fn, FunctionAnalysisManager &fam);
|
||||
LogicalResult run(Function fn, FunctionAnalysisManager &fam);
|
||||
|
||||
/// The current execution state for the pass.
|
||||
llvm::Optional<PassStateT> passState;
|
||||
|
@ -140,7 +138,8 @@ private:
|
|||
/// Pass to transform a module. Derived passes should not inherit from this
|
||||
/// class directly, and instead should use the CRTP ModulePass class.
|
||||
class ModulePassBase : public Pass {
|
||||
using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
|
||||
using PassStateT =
|
||||
detail::PassExecutionState<Module *, ModuleAnalysisManager>;
|
||||
|
||||
public:
|
||||
static bool classof(const Pass *pass) {
|
||||
|
@ -272,7 +271,7 @@ struct FunctionPass : public detail::PassModel<Function, T, FunctionPassBase> {
|
|||
template <typename T>
|
||||
struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
|
||||
/// Returns the analysis for a child function.
|
||||
template <typename AnalysisT> AnalysisT &getFunctionAnalysis(Function *f) {
|
||||
template <typename AnalysisT> AnalysisT &getFunctionAnalysis(Function f) {
|
||||
return this->getAnalysisManager().template getFunctionAnalysis<AnalysisT>(
|
||||
f);
|
||||
}
|
||||
|
@ -280,7 +279,7 @@ struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
|
|||
/// Returns an existing analysis for a child function if it exists.
|
||||
template <typename AnalysisT>
|
||||
llvm::Optional<std::reference_wrapper<AnalysisT>>
|
||||
getCachedFunctionAnalysis(Function *f) {
|
||||
getCachedFunctionAnalysis(Function f) {
|
||||
return this->getAnalysisManager()
|
||||
.template getCachedFunctionAnalysis<AnalysisT>(f);
|
||||
}
|
||||
|
|
|
@ -77,29 +77,29 @@ public:
|
|||
~PassInstrumentor();
|
||||
|
||||
/// See PassInstrumentation::runBeforePass for details.
|
||||
template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT *ir) {
|
||||
template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT ir) {
|
||||
runBeforePass(pass, llvm::Any(ir));
|
||||
}
|
||||
|
||||
/// See PassInstrumentation::runAfterPass for details.
|
||||
template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT *ir) {
|
||||
template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT ir) {
|
||||
runAfterPass(pass, llvm::Any(ir));
|
||||
}
|
||||
|
||||
/// See PassInstrumentation::runAfterPassFailed for details.
|
||||
template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT *ir) {
|
||||
template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT ir) {
|
||||
runAfterPassFailed(pass, llvm::Any(ir));
|
||||
}
|
||||
|
||||
/// See PassInstrumentation::runBeforeAnalysis for details.
|
||||
template <typename IRUnitT>
|
||||
void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
|
||||
void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
|
||||
runBeforeAnalysis(name, id, llvm::Any(ir));
|
||||
}
|
||||
|
||||
/// See PassInstrumentation::runAfterAnalysis for details.
|
||||
template <typename IRUnitT>
|
||||
void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
|
||||
void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
|
||||
runAfterAnalysis(name, id, llvm::Any(ir));
|
||||
}
|
||||
|
||||
|
|
|
@ -214,11 +214,11 @@ def CallOp : Std_Op<"call"> {
|
|||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, Function *callee,"
|
||||
"Builder *builder, OperationState *result, Function callee,"
|
||||
"ArrayRef<Value *> operands = {}", [{
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(callee->getType().getResults());
|
||||
result->addTypes(callee.getType().getResults());
|
||||
}]>, OpBuilder<
|
||||
"Builder *builder, OperationState *result, StringRef callee,"
|
||||
"ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
|
||||
|
|
|
@ -345,7 +345,7 @@ LLVM_NODISCARD LogicalResult applyConversionPatterns(
|
|||
/// Convert the given functions with the provided conversion patterns. This
|
||||
/// function returns failure if a type conversion failed.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
|
||||
LogicalResult applyConversionPatterns(MutableArrayRef<Function> fns,
|
||||
ConversionTarget &target,
|
||||
TypeConverter &converter,
|
||||
OwningRewritePatternList &&patterns);
|
||||
|
@ -354,7 +354,7 @@ LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
|
|||
/// convert as many of the operations within 'fn' as possible given the set of
|
||||
/// patterns.
|
||||
LLVM_NODISCARD
|
||||
LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target,
|
||||
LogicalResult applyConversionPatterns(Function fn, ConversionTarget &target,
|
||||
OwningRewritePatternList &&patterns);
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -37,7 +37,7 @@ Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
|
|||
|
||||
/// Convert from the Affine dialect to the Standard dialect, in particular
|
||||
/// convert structured affine control flow into CFG branch-based control flow.
|
||||
LogicalResult lowerAffineConstructs(Function &function);
|
||||
LogicalResult lowerAffineConstructs(Function function);
|
||||
|
||||
/// Emit code that computes the lower bound of the given affine loop using
|
||||
/// standard arithmetic operations.
|
||||
|
|
|
@ -33,11 +33,11 @@ class FunctionPassBase;
|
|||
|
||||
/// Displays the CFG in a window. This is for use from the debugger and
|
||||
/// depends on Graphviz to generate the graph.
|
||||
void viewGraph(Function &function, const Twine &name, bool shortNames = false,
|
||||
void viewGraph(Function function, const Twine &name, bool shortNames = false,
|
||||
const Twine &title = "",
|
||||
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
|
||||
|
||||
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function,
|
||||
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function function,
|
||||
bool shortNames = false, const Twine &title = "");
|
||||
|
||||
/// Creates a pass to print CFG graphs.
|
||||
|
|
|
@ -303,7 +303,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
|
|||
if (inserted) {
|
||||
reorderedDims.push_back(v);
|
||||
}
|
||||
return getAffineDimExpr(iterPos->second, v->getFunction()->getContext())
|
||||
return getAffineDimExpr(iterPos->second, v->getFunction().getContext())
|
||||
.cast<AffineDimExpr>();
|
||||
}
|
||||
|
||||
|
|
|
@ -37,17 +37,16 @@ template class llvm::DomTreeNodeBase<Block>;
|
|||
|
||||
/// Recalculate the dominance info.
|
||||
template <bool IsPostDom>
|
||||
void DominanceInfoBase<IsPostDom>::recalculate(Function *function) {
|
||||
void DominanceInfoBase<IsPostDom>::recalculate(Function function) {
|
||||
dominanceInfos.clear();
|
||||
|
||||
// Build the top level function dominance.
|
||||
auto functionDominance = llvm::make_unique<base>();
|
||||
functionDominance->recalculate(function->getBody());
|
||||
dominanceInfos.try_emplace(&function->getBody(),
|
||||
std::move(functionDominance));
|
||||
functionDominance->recalculate(function.getBody());
|
||||
dominanceInfos.try_emplace(&function.getBody(), std::move(functionDominance));
|
||||
|
||||
/// Build the dominance for each of the operation regions.
|
||||
function->walk([&](Operation *op) {
|
||||
function.walk([&](Operation *op) {
|
||||
for (auto ®ion : op->getRegions()) {
|
||||
// Don't compute dominance if the region is empty.
|
||||
if (region.empty())
|
||||
|
|
|
@ -45,7 +45,7 @@ void PrintOpStatsPass::runOnModule() {
|
|||
opCount.clear();
|
||||
|
||||
// Compute the operation statistics for each function in the module.
|
||||
for (auto &fn : getModule())
|
||||
for (auto fn : getModule())
|
||||
fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
|
||||
printSummary();
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() {
|
|||
// Walks the function and emits a note for all 'affine.for' ops detected as
|
||||
// parallel.
|
||||
void TestParallelismDetection::runOnFunction() {
|
||||
Function &f = getFunction();
|
||||
Function f = getFunction();
|
||||
OpBuilder b(f.getBody());
|
||||
f.walk<AffineForOp>([&](AffineForOp forOp) {
|
||||
if (isLoopParallel(forOp))
|
||||
|
|
|
@ -53,7 +53,7 @@ public:
|
|||
: ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
|
||||
|
||||
/// Verify the body of the given function.
|
||||
LogicalResult verify(Function &fn);
|
||||
LogicalResult verify(Function fn);
|
||||
|
||||
/// Verify the given operation.
|
||||
LogicalResult verify(Operation &op);
|
||||
|
@ -104,7 +104,7 @@ private:
|
|||
} // end anonymous namespace
|
||||
|
||||
/// Verify the body of the given function.
|
||||
LogicalResult OperationVerifier::verify(Function &fn) {
|
||||
LogicalResult OperationVerifier::verify(Function fn) {
|
||||
// Verify the body first.
|
||||
if (failed(verifyRegion(fn.getBody())))
|
||||
return failure();
|
||||
|
@ -113,7 +113,7 @@ LogicalResult OperationVerifier::verify(Function &fn) {
|
|||
// check. We do this as a second pass since malformed CFG's can cause
|
||||
// dominator analysis constructure to crash and we want the verifier to be
|
||||
// resilient to malformed code.
|
||||
DominanceInfo theDomInfo(&fn);
|
||||
DominanceInfo theDomInfo(fn);
|
||||
domInfo = &theDomInfo;
|
||||
if (failed(verifyDominance(fn.getBody())))
|
||||
return failure();
|
||||
|
@ -313,7 +313,7 @@ LogicalResult Function::verify() {
|
|||
|
||||
// Verify this attribute with the defining dialect.
|
||||
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
|
||||
if (failed(dialect->verifyFunctionAttribute(this, attr)))
|
||||
if (failed(dialect->verifyFunctionAttribute(*this, attr)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -331,7 +331,7 @@ LogicalResult Function::verify() {
|
|||
|
||||
// Verify this attribute with the defining dialect.
|
||||
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
|
||||
if (failed(dialect->verifyFunctionArgAttribute(this, i, attr)))
|
||||
if (failed(dialect->verifyFunctionArgAttribute(*this, i, attr)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
@ -369,7 +369,7 @@ LogicalResult Operation::verify() {
|
|||
LogicalResult Module::verify() {
|
||||
// Check that all functions are uniquely named.
|
||||
llvm::StringMap<Location> nameToOrigLoc;
|
||||
for (auto &fn : *this) {
|
||||
for (auto fn : *this) {
|
||||
auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc());
|
||||
if (!it.second)
|
||||
return fn.emitError()
|
||||
|
@ -379,7 +379,7 @@ LogicalResult Module::verify() {
|
|||
}
|
||||
|
||||
// Check that each function is correct.
|
||||
for (auto &fn : *this)
|
||||
for (auto fn : *this)
|
||||
if (failed(fn.verify()))
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -64,8 +64,8 @@ public:
|
|||
LLVMInitializeNVPTXTargetMC();
|
||||
LLVMInitializeNVPTXAsmPrinter();
|
||||
|
||||
for (auto &function : getModule()) {
|
||||
if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) {
|
||||
for (auto function : getModule()) {
|
||||
if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) {
|
||||
continue;
|
||||
}
|
||||
if (failed(translateGpuKernelToCubinAnnotation(function)))
|
||||
|
@ -142,7 +142,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
|
|||
std::unique_ptr<Module> module(builder.createModule());
|
||||
|
||||
// TODO(herhut): Also handle called functions.
|
||||
module->getFunctions().push_back(function.clone());
|
||||
module->push_back(function.clone());
|
||||
|
||||
auto llvmModule = translateModuleToNVVMIR(*module);
|
||||
auto cubin = convertModuleToCubin(*llvmModule, function);
|
||||
|
|
|
@ -118,7 +118,7 @@ private:
|
|||
|
||||
void declareCudaFunctions(Location loc);
|
||||
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
|
||||
Value *generateKernelNameConstant(Function *kernelFunction, Location &loc,
|
||||
Value *generateKernelNameConstant(Function kernelFunction, Location &loc,
|
||||
OpBuilder &builder);
|
||||
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
|
||||
|
||||
|
@ -130,7 +130,7 @@ public:
|
|||
// Cache the used LLVM types.
|
||||
initializeCachedTypes();
|
||||
|
||||
for (auto &func : getModule()) {
|
||||
for (auto func : getModule()) {
|
||||
func.walk<mlir::gpu::LaunchFuncOp>(
|
||||
[this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
|
||||
}
|
||||
|
@ -155,66 +155,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
|||
Module &module = getModule();
|
||||
Builder builder(&module);
|
||||
if (!module.getNamedFunction(cuModuleLoadName)) {
|
||||
module.getFunctions().push_back(
|
||||
new Function(loc, cuModuleLoadName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerPointerType(), /* CUmodule *module */
|
||||
getPointerType() /* void *cubin */
|
||||
},
|
||||
getCUResultType())));
|
||||
module.push_back(
|
||||
Function::create(loc, cuModuleLoadName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerPointerType(), /* CUmodule *module */
|
||||
getPointerType() /* void *cubin */
|
||||
},
|
||||
getCUResultType())));
|
||||
}
|
||||
if (!module.getNamedFunction(cuModuleGetFunctionName)) {
|
||||
// The helper uses void* instead of CUDA's opaque CUmodule and
|
||||
// CUfunction.
|
||||
module.getFunctions().push_back(
|
||||
new Function(loc, cuModuleGetFunctionName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerPointerType(), /* void **function */
|
||||
getPointerType(), /* void *module */
|
||||
getPointerType() /* char *name */
|
||||
},
|
||||
getCUResultType())));
|
||||
module.push_back(
|
||||
Function::create(loc, cuModuleGetFunctionName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerPointerType(), /* void **function */
|
||||
getPointerType(), /* void *module */
|
||||
getPointerType() /* char *name */
|
||||
},
|
||||
getCUResultType())));
|
||||
}
|
||||
if (!module.getNamedFunction(cuLaunchKernelName)) {
|
||||
// Other than the CUDA api, the wrappers use uintptr_t to match the
|
||||
// LLVM type if MLIR's index type, which the GPU dialect uses.
|
||||
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
|
||||
// CUstream.
|
||||
module.getFunctions().push_back(
|
||||
new Function(loc, cuLaunchKernelName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerType(), /* void* f */
|
||||
getIntPtrType(), /* intptr_t gridXDim */
|
||||
getIntPtrType(), /* intptr_t gridyDim */
|
||||
getIntPtrType(), /* intptr_t gridZDim */
|
||||
getIntPtrType(), /* intptr_t blockXDim */
|
||||
getIntPtrType(), /* intptr_t blockYDim */
|
||||
getIntPtrType(), /* intptr_t blockZDim */
|
||||
getInt32Type(), /* unsigned int sharedMemBytes */
|
||||
getPointerType(), /* void *hstream */
|
||||
getPointerPointerType(), /* void **kernelParams */
|
||||
getPointerPointerType() /* void **extra */
|
||||
},
|
||||
getCUResultType())));
|
||||
module.push_back(Function::create(
|
||||
loc, cuLaunchKernelName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerType(), /* void* f */
|
||||
getIntPtrType(), /* intptr_t gridXDim */
|
||||
getIntPtrType(), /* intptr_t gridyDim */
|
||||
getIntPtrType(), /* intptr_t gridZDim */
|
||||
getIntPtrType(), /* intptr_t blockXDim */
|
||||
getIntPtrType(), /* intptr_t blockYDim */
|
||||
getIntPtrType(), /* intptr_t blockZDim */
|
||||
getInt32Type(), /* unsigned int sharedMemBytes */
|
||||
getPointerType(), /* void *hstream */
|
||||
getPointerPointerType(), /* void **kernelParams */
|
||||
getPointerPointerType() /* void **extra */
|
||||
},
|
||||
getCUResultType())));
|
||||
}
|
||||
if (!module.getNamedFunction(cuGetStreamHelperName)) {
|
||||
// Helper function to get the current CUDA stream. Uses void* instead of
|
||||
// CUDAs opaque CUstream.
|
||||
module.getFunctions().push_back(new Function(
|
||||
module.push_back(Function::create(
|
||||
loc, cuGetStreamHelperName,
|
||||
builder.getFunctionType({}, getPointerType() /* void *stream */)));
|
||||
}
|
||||
if (!module.getNamedFunction(cuStreamSynchronizeName)) {
|
||||
module.getFunctions().push_back(
|
||||
new Function(loc, cuStreamSynchronizeName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerType() /* CUstream stream */
|
||||
},
|
||||
getCUResultType())));
|
||||
module.push_back(
|
||||
Function::create(loc, cuStreamSynchronizeName,
|
||||
builder.getFunctionType(
|
||||
{
|
||||
getPointerType() /* CUstream stream */
|
||||
},
|
||||
getCUResultType())));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -264,14 +264,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
|
|||
// %0[n] = constant name[n]
|
||||
// %0[n+1] = 0
|
||||
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
|
||||
Function *kernelFunction, Location &loc, OpBuilder &builder) {
|
||||
Function kernelFunction, Location &loc, OpBuilder &builder) {
|
||||
// TODO(herhut): Make this a constant once this is supported.
|
||||
auto kernelNameSize = builder.create<LLVM::ConstantOp>(
|
||||
loc, getInt32Type(),
|
||||
builder.getI32IntegerAttr(kernelFunction->getName().size() + 1));
|
||||
builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
|
||||
auto kernelName =
|
||||
builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize);
|
||||
for (auto byte : llvm::enumerate(kernelFunction->getName())) {
|
||||
for (auto byte : llvm::enumerate(kernelFunction.getName())) {
|
||||
auto index = builder.create<LLVM::ConstantOp>(
|
||||
loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
|
||||
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
|
||||
|
@ -284,7 +284,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
|
|||
// Add trailing zero to terminate string.
|
||||
auto index = builder.create<LLVM::ConstantOp>(
|
||||
loc, getInt32Type(),
|
||||
builder.getI32IntegerAttr(kernelFunction->getName().size()));
|
||||
builder.getI32IntegerAttr(kernelFunction.getName().size()));
|
||||
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
|
||||
ArrayRef<Value *>{index});
|
||||
auto value = builder.create<LLVM::ConstantOp>(
|
||||
|
@ -326,9 +326,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// TODO(herhut): This should rather be a static global once supported.
|
||||
auto kernelFunction = getModule().getNamedFunction(launchOp.kernel());
|
||||
auto cubinGetter =
|
||||
kernelFunction->getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
|
||||
kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
|
||||
if (!cubinGetter) {
|
||||
kernelFunction->emitError("Missing ")
|
||||
kernelFunction.emitError("Missing ")
|
||||
<< kCubinGetterAnnotation << " attribute.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// Emit the load module call to load the module data. Error checking is done
|
||||
// in the called helper function.
|
||||
auto cuModule = allocatePointer(builder, loc);
|
||||
Function *cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
|
||||
Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuModuleLoad),
|
||||
ArrayRef<Value *>{cuModule, data.getResult(0)});
|
||||
|
@ -347,14 +347,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
|
||||
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
|
||||
auto cuFunction = allocatePointer(builder, loc);
|
||||
Function *cuModuleGetFunction =
|
||||
Function cuModuleGetFunction =
|
||||
getModule().getNamedFunction(cuModuleGetFunctionName);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuModuleGetFunction),
|
||||
ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
Function *cuGetStreamHelper =
|
||||
Function cuGetStreamHelper =
|
||||
getModule().getNamedFunction(cuGetStreamHelperName);
|
||||
auto cuStream = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
|
|
|
@ -53,15 +53,15 @@ constexpr const char *kMallocHelperName = "mcuMalloc";
|
|||
class GpuGenerateCubinAccessorsPass
|
||||
: public ModulePass<GpuGenerateCubinAccessorsPass> {
|
||||
private:
|
||||
Function *getMallocHelper(Location loc, Builder &builder) {
|
||||
Function *result = getModule().getNamedFunction(kMallocHelperName);
|
||||
Function getMallocHelper(Location loc, Builder &builder) {
|
||||
Function result = getModule().getNamedFunction(kMallocHelperName);
|
||||
if (!result) {
|
||||
result = new Function(
|
||||
result = Function::create(
|
||||
loc, kMallocHelperName,
|
||||
builder.getFunctionType(
|
||||
ArrayRef<Type>{LLVM::LLVMType::getInt32Ty(llvmDialect)},
|
||||
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
|
||||
getModule().getFunctions().push_back(result);
|
||||
getModule().push_back(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -70,18 +70,18 @@ private:
|
|||
// data from blob. As there are currently no global constants, this uses a
|
||||
// sequence of store operations.
|
||||
// TODO(herhut): Use global constants instead.
|
||||
Function *generateCubinAccessor(Builder &builder, Function &orig,
|
||||
StringAttr blob) {
|
||||
Function generateCubinAccessor(Builder &builder, Function &orig,
|
||||
StringAttr blob) {
|
||||
Location loc = orig.getLoc();
|
||||
SmallString<128> nameBuffer(orig.getName());
|
||||
nameBuffer.append(kCubinGetterSuffix);
|
||||
// Generate a function that returns void*.
|
||||
Function *result = new Function(
|
||||
Function result = Function::create(
|
||||
loc, mlir::Identifier::get(nameBuffer, &getContext()),
|
||||
builder.getFunctionType(ArrayRef<Type>{},
|
||||
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
|
||||
// Insert a body block that just returns the constant.
|
||||
OpBuilder ob(result->getBody());
|
||||
OpBuilder ob(result.getBody());
|
||||
ob.createBlock();
|
||||
auto sizeConstant = ob.create<LLVM::ConstantOp>(
|
||||
loc, LLVM::LLVMType::getInt32Ty(llvmDialect),
|
||||
|
@ -115,18 +115,18 @@ public:
|
|||
void runOnModule() override {
|
||||
llvmDialect =
|
||||
getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
Builder builder(getModule().getContext());
|
||||
auto &module = getModule();
|
||||
Builder builder(&getContext());
|
||||
|
||||
auto &functions = getModule().getFunctions();
|
||||
auto functions = module.getFunctions();
|
||||
for (auto it = functions.begin(); it != functions.end();) {
|
||||
// Move iterator to after the current function so that potential insertion
|
||||
// of the accessor is after the kernel with cubin iself.
|
||||
Function &orig = *it++;
|
||||
Function orig = *it++;
|
||||
StringAttr cubinBlob = orig.getAttrOfType<StringAttr>(kCubinAnnotation);
|
||||
if (!cubinBlob)
|
||||
continue;
|
||||
it =
|
||||
functions.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
|
||||
module.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -441,13 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
createIndexConstant(rewriter, op->getLoc(), elementSize)});
|
||||
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
Function *mallocFunc =
|
||||
op->getFunction()->getModule()->getNamedFunction("malloc");
|
||||
Function mallocFunc =
|
||||
op->getFunction().getModule()->getNamedFunction("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType =
|
||||
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
|
||||
mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
|
||||
mallocFunc =
|
||||
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
op->getFunction().getModule()->push_back(mallocFunc);
|
||||
}
|
||||
|
||||
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
||||
|
@ -502,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
OperandAdaptor<DeallocOp> transformed(operands);
|
||||
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
Function *freeFunc =
|
||||
op->getFunction()->getModule()->getNamedFunction("free");
|
||||
Function freeFunc = op->getFunction().getModule()->getNamedFunction("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
|
||||
freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
|
||||
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
|
||||
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
op->getFunction().getModule()->push_back(freeFunc);
|
||||
}
|
||||
|
||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
||||
|
@ -937,7 +937,7 @@ static void ensureDistinctSuccessors(Block &bb) {
|
|||
}
|
||||
|
||||
void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
|
||||
for (auto &f : *m) {
|
||||
for (auto f : *m) {
|
||||
for (auto &bb : f.getBlocks()) {
|
||||
::ensureDistinctSuccessors(bb);
|
||||
}
|
||||
|
|
|
@ -365,7 +365,7 @@ struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LowerUniformRealMathPass::runOnFunction() {
|
||||
auto &fn = getFunction();
|
||||
auto fn = getFunction();
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
|
||||
|
@ -386,7 +386,7 @@ static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LowerUniformCastsPass::runOnFunction() {
|
||||
auto &fn = getFunction();
|
||||
auto fn = getFunction();
|
||||
OwningRewritePatternList patterns;
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
|
||||
|
|
|
@ -106,7 +106,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
|
||||
void ConvertConstPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
|
|
|
@ -95,7 +95,7 @@ public:
|
|||
void ConvertSimulatedQuantPass::runOnFunction() {
|
||||
bool hadFailure = false;
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));
|
||||
|
|
|
@ -67,10 +67,10 @@ allocMemRefDescriptor(Type type, bool allocateData = true,
|
|||
}
|
||||
|
||||
llvm::Expected<SmallVector<void *, 8>>
|
||||
mlir::allocateMemRefArguments(Function *func, float initialValue) {
|
||||
mlir::allocateMemRefArguments(Function func, float initialValue) {
|
||||
SmallVector<void *, 8> args;
|
||||
args.reserve(func->getNumArguments());
|
||||
for (const auto &arg : func->getArguments()) {
|
||||
args.reserve(func.getNumArguments());
|
||||
for (const auto &arg : func.getArguments()) {
|
||||
auto descriptor =
|
||||
allocMemRefDescriptor(arg->getType(),
|
||||
/*allocateData=*/true, initialValue);
|
||||
|
@ -79,10 +79,10 @@ mlir::allocateMemRefArguments(Function *func, float initialValue) {
|
|||
args.push_back(*descriptor);
|
||||
}
|
||||
|
||||
if (func->getType().getNumResults() > 1)
|
||||
if (func.getType().getNumResults() > 1)
|
||||
return make_string_error("functions with more than 1 result not supported");
|
||||
|
||||
for (Type resType : func->getType().getResults()) {
|
||||
for (Type resType : func.getType().getResults()) {
|
||||
auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false);
|
||||
if (!descriptor)
|
||||
return descriptor.takeError();
|
||||
|
|
|
@ -30,9 +30,9 @@ using namespace mlir::gpu;
|
|||
|
||||
StringRef GPUDialect::getDialectName() { return "gpu"; }
|
||||
|
||||
bool GPUDialect::isKernel(Function *function) {
|
||||
bool GPUDialect::isKernel(Function function) {
|
||||
UnitAttr isKernelAttr =
|
||||
function->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||
function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||
return static_cast<bool>(isKernelAttr);
|
||||
}
|
||||
|
||||
|
@ -318,7 +318,7 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void LaunchFuncOp::build(Builder *builder, OperationState *result,
|
||||
Function *kernelFunc, Value *gridSizeX,
|
||||
Function kernelFunc, Value *gridSizeX,
|
||||
Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
|
||||
Value *blockSizeY, Value *blockSizeZ,
|
||||
ArrayRef<Value *> kernelOperands) {
|
||||
|
@ -331,7 +331,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
|
|||
}
|
||||
|
||||
void LaunchFuncOp::build(Builder *builder, OperationState *result,
|
||||
Function *kernelFunc, KernelDim3 gridSize,
|
||||
Function kernelFunc, KernelDim3 gridSize,
|
||||
KernelDim3 blockSize,
|
||||
ArrayRef<Value *> kernelOperands) {
|
||||
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
|
||||
|
@ -366,23 +366,23 @@ LogicalResult LaunchFuncOp::verify() {
|
|||
return emitOpError("attribute 'kernel' must be a function");
|
||||
}
|
||||
|
||||
auto *module = getOperation()->getFunction()->getModule();
|
||||
Function *kernelFunc = module->getNamedFunction(kernel());
|
||||
auto *module = getOperation()->getFunction().getModule();
|
||||
Function kernelFunc = module->getNamedFunction(kernel());
|
||||
if (!kernelFunc)
|
||||
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
|
||||
|
||||
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
|
||||
if (!kernelFunc.getAttrOfType<mlir::UnitAttr>(
|
||||
GPUDialect::getKernelFuncAttrName())) {
|
||||
return emitError("kernel function is missing the '")
|
||||
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
|
||||
}
|
||||
unsigned numKernelFuncArgs = kernelFunc->getNumArguments();
|
||||
unsigned numKernelFuncArgs = kernelFunc.getNumArguments();
|
||||
if (getNumKernelOperands() != numKernelFuncArgs) {
|
||||
return emitOpError("got ")
|
||||
<< getNumKernelOperands() << " kernel operands but expected "
|
||||
<< numKernelFuncArgs;
|
||||
}
|
||||
auto functionType = kernelFunc->getType();
|
||||
auto functionType = kernelFunc.getType();
|
||||
for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
|
||||
if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
|
||||
return emitOpError("type of function argument ")
|
||||
|
|
|
@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc,
|
|||
|
||||
// Add operations generating block/thread ids and gird/block dimensions at the
|
||||
// beginning of `kernelFunc` and replace uses of the respective function args.
|
||||
static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
|
||||
static void injectGpuIndexOperations(Location loc, Function kernelFunc) {
|
||||
OpBuilder OpBuilder(kernelFunc.getBody());
|
||||
SmallVector<Value *, 12> indexOps;
|
||||
createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
|
||||
|
@ -58,20 +58,20 @@ static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
|
|||
|
||||
// Outline the `gpu.launch` operation body into a kernel function. Replace
|
||||
// `gpu.return` operations by `std.return` in the generated functions.
|
||||
static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
|
||||
static Function outlineKernelFunc(gpu::LaunchOp launchOp) {
|
||||
Location loc = launchOp.getLoc();
|
||||
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
|
||||
FunctionType type =
|
||||
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
|
||||
std::string kernelFuncName =
|
||||
Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str();
|
||||
Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type);
|
||||
outlinedFunc->getBody().takeBody(launchOp.getBody());
|
||||
Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str();
|
||||
Function outlinedFunc = Function::create(loc, kernelFuncName, type);
|
||||
outlinedFunc.getBody().takeBody(launchOp.getBody());
|
||||
Builder builder(launchOp.getContext());
|
||||
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
builder.getUnitAttr());
|
||||
injectGpuIndexOperations(loc, *outlinedFunc);
|
||||
outlinedFunc->walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
|
||||
outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
builder.getUnitAttr());
|
||||
injectGpuIndexOperations(loc, outlinedFunc);
|
||||
outlinedFunc.walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
|
||||
OpBuilder replacer(op);
|
||||
replacer.create<ReturnOp>(op.getLoc());
|
||||
op.erase();
|
||||
|
@ -82,12 +82,12 @@ static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
|
|||
// Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
|
||||
// `kernelFunc`.
|
||||
static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
|
||||
Function &kernelFunc) {
|
||||
Function kernelFunc) {
|
||||
OpBuilder builder(launchOp);
|
||||
SmallVector<Value *, 4> kernelOperandValues(
|
||||
launchOp.getKernelOperandValues());
|
||||
builder.create<gpu::LaunchFuncOp>(
|
||||
launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(),
|
||||
launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
|
||||
launchOp.getBlockSizeOperandValues(), kernelOperandValues);
|
||||
launchOp.erase();
|
||||
}
|
||||
|
@ -98,11 +98,11 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
|
|||
public:
|
||||
void runOnModule() override {
|
||||
ModuleManager moduleManager(&getModule());
|
||||
for (auto &func : getModule()) {
|
||||
for (auto func : getModule()) {
|
||||
func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
|
||||
Function *outlinedFunc = outlineKernelFunc(op);
|
||||
Function outlinedFunc = outlineKernelFunc(op);
|
||||
moduleManager.insert(outlinedFunc);
|
||||
convertToLaunchFuncOp(op, *outlinedFunc);
|
||||
convertToLaunchFuncOp(op, outlinedFunc);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -306,7 +306,7 @@ void ModuleState::initialize(Module *module) {
|
|||
initializeSymbolAliases();
|
||||
|
||||
// Walk the module and visit each operation.
|
||||
for (auto &fn : *module) {
|
||||
for (auto fn : *module) {
|
||||
visitType(fn.getType());
|
||||
for (auto attr : fn.getAttrs())
|
||||
ModuleState::visitAttribute(attr.second);
|
||||
|
@ -342,7 +342,7 @@ public:
|
|||
void printAttribute(Attribute attr, bool mayElideType = false);
|
||||
|
||||
void printType(Type type);
|
||||
void print(Function *fn);
|
||||
void print(Function fn);
|
||||
void printLocation(LocationAttr loc);
|
||||
|
||||
void printAffineMap(AffineMap map);
|
||||
|
@ -460,8 +460,8 @@ void ModulePrinter::print(Module *module) {
|
|||
state.printTypeAliases(os);
|
||||
|
||||
// Print the module.
|
||||
for (auto &fn : *module)
|
||||
print(&fn);
|
||||
for (auto fn : *module)
|
||||
print(fn);
|
||||
}
|
||||
|
||||
/// Print a floating point value in a way that the parser will be able to
|
||||
|
@ -1186,7 +1186,7 @@ namespace {
|
|||
// CFG and ML functions.
|
||||
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
|
||||
public:
|
||||
FunctionPrinter(Function *function, ModulePrinter &other);
|
||||
FunctionPrinter(Function function, ModulePrinter &other);
|
||||
|
||||
// Prints the function as a whole.
|
||||
void print();
|
||||
|
@ -1275,7 +1275,7 @@ protected:
|
|||
void printValueID(Value *value, bool printResultNo = true) const;
|
||||
|
||||
private:
|
||||
Function *function;
|
||||
Function function;
|
||||
|
||||
/// This is the value ID for each SSA value in the current function. If this
|
||||
/// returns ~0, then the valueID has an entry in valueNames.
|
||||
|
@ -1305,10 +1305,10 @@ private:
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other)
|
||||
FunctionPrinter::FunctionPrinter(Function function, ModulePrinter &other)
|
||||
: ModulePrinter(other), function(function) {
|
||||
|
||||
for (auto &block : *function)
|
||||
for (auto &block : function)
|
||||
numberValuesInBlock(block);
|
||||
}
|
||||
|
||||
|
@ -1419,17 +1419,17 @@ void FunctionPrinter::print() {
|
|||
printFunctionSignature();
|
||||
|
||||
// Print out function attributes, if present.
|
||||
auto attrs = function->getAttrs();
|
||||
auto attrs = function.getAttrs();
|
||||
if (!attrs.empty()) {
|
||||
os << "\n attributes ";
|
||||
printOptionalAttrDict(attrs);
|
||||
}
|
||||
|
||||
// Print the trailing location.
|
||||
printTrailingLocation(function->getLoc());
|
||||
printTrailingLocation(function.getLoc());
|
||||
|
||||
if (!function->empty()) {
|
||||
printRegion(function->getBody(), /*printEntryBlockArgs=*/false,
|
||||
if (!function.empty()) {
|
||||
printRegion(function.getBody(), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/true);
|
||||
os << "\n";
|
||||
}
|
||||
|
@ -1437,24 +1437,24 @@ void FunctionPrinter::print() {
|
|||
}
|
||||
|
||||
void FunctionPrinter::printFunctionSignature() {
|
||||
os << "func @" << function->getName() << '(';
|
||||
os << "func @" << function.getName() << '(';
|
||||
|
||||
auto fnType = function->getType();
|
||||
bool isExternal = function->isExternal();
|
||||
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
|
||||
auto fnType = function.getType();
|
||||
bool isExternal = function.isExternal();
|
||||
for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) {
|
||||
if (i > 0)
|
||||
os << ", ";
|
||||
|
||||
// If this is an external function, don't print argument labels.
|
||||
if (!isExternal) {
|
||||
printOperand(function->getArgument(i));
|
||||
printOperand(function.getArgument(i));
|
||||
os << ": ";
|
||||
}
|
||||
|
||||
printType(fnType.getInput(i));
|
||||
|
||||
// Print the attributes for this argument.
|
||||
printOptionalAttrDict(function->getArgAttrs(i));
|
||||
printOptionalAttrDict(function.getArgAttrs(i));
|
||||
}
|
||||
os << ')';
|
||||
|
||||
|
@ -1662,7 +1662,7 @@ void FunctionPrinter::printSuccessorAndUseList(Operation *term,
|
|||
}
|
||||
|
||||
// Prints function with initialized module state.
|
||||
void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); }
|
||||
void ModulePrinter::print(Function fn) { FunctionPrinter(fn, *this).print(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// print and dump methods
|
||||
|
@ -1737,13 +1737,13 @@ void Value::print(raw_ostream &os) {
|
|||
void Value::dump() { print(llvm::errs()); }
|
||||
|
||||
void Operation::print(raw_ostream &os) {
|
||||
auto *function = getFunction();
|
||||
auto function = getFunction();
|
||||
if (!function) {
|
||||
os << "<<UNLINKED INSTRUCTION>>\n";
|
||||
return;
|
||||
}
|
||||
|
||||
ModuleState state(function->getContext());
|
||||
ModuleState state(function.getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
FunctionPrinter(function, modulePrinter).print(this);
|
||||
}
|
||||
|
@ -1754,13 +1754,13 @@ void Operation::dump() {
|
|||
}
|
||||
|
||||
void Block::print(raw_ostream &os) {
|
||||
auto *function = getFunction();
|
||||
auto function = getFunction();
|
||||
if (!function) {
|
||||
os << "<<UNLINKED BLOCK>>\n";
|
||||
return;
|
||||
}
|
||||
|
||||
ModuleState state(function->getContext());
|
||||
ModuleState state(function.getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
FunctionPrinter(function, modulePrinter).print(this);
|
||||
}
|
||||
|
@ -1773,14 +1773,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
|
|||
os << "<<UNLINKED BLOCK>>\n";
|
||||
return;
|
||||
}
|
||||
ModuleState state(getFunction()->getContext());
|
||||
ModuleState state(getFunction().getContext());
|
||||
ModulePrinter modulePrinter(os, state);
|
||||
FunctionPrinter(getFunction(), modulePrinter).printBlockName(this);
|
||||
}
|
||||
|
||||
void Function::print(raw_ostream &os) {
|
||||
ModuleState state(getContext());
|
||||
ModulePrinter(os, state).print(this);
|
||||
ModulePrinter(os, state).print(*this);
|
||||
}
|
||||
|
||||
void Function::dump() { print(llvm::errs()); }
|
||||
|
|
|
@ -249,11 +249,6 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
|||
// FunctionAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FunctionAttr FunctionAttr::get(Function *value) {
|
||||
assert(value && "Cannot get FunctionAttr for a null function");
|
||||
return get(value->getName(), value->getContext());
|
||||
}
|
||||
|
||||
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
|
||||
return Base::get(ctx, StandardAttributes::Function, value,
|
||||
NoneType::get(ctx));
|
||||
|
|
|
@ -50,7 +50,7 @@ Operation *Block::getContainingOp() {
|
|||
return getParent() ? getParent()->getContainingOp() : nullptr;
|
||||
}
|
||||
|
||||
Function *Block::getFunction() {
|
||||
Function Block::getFunction() {
|
||||
Block *block = this;
|
||||
while (auto *op = block->getContainingOp()) {
|
||||
block = op->getBlock();
|
||||
|
|
|
@ -177,8 +177,8 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
|
|||
|
||||
TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
|
||||
|
||||
FunctionAttr Builder::getFunctionAttr(Function *value) {
|
||||
return FunctionAttr::get(value);
|
||||
FunctionAttr Builder::getFunctionAttr(Function value) {
|
||||
return getFunctionAttr(value.getName());
|
||||
}
|
||||
FunctionAttr Builder::getFunctionAttr(StringRef value) {
|
||||
return FunctionAttr::get(value, getContext());
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/DialectHooks.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
|
@ -68,6 +69,20 @@ Dialect::Dialect(StringRef name, MLIRContext *context)
|
|||
|
||||
Dialect::~Dialect() {}
|
||||
|
||||
/// Verify an attribute from this dialect on the given function. Returns
|
||||
/// failure if the verification failed, success otherwise.
|
||||
LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Verify an attribute from this dialect on the argument at 'argIndex' for
|
||||
/// the given function. Returns failure if the verification failed, success
|
||||
/// otherwise.
|
||||
LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex,
|
||||
NamedAttribute) {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse an attribute registered to this dialect.
|
||||
Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const {
|
||||
emitError(loc) << "dialect '" << getNamespace()
|
||||
|
|
|
@ -27,45 +27,50 @@
|
|||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
Function::Function(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
FunctionStorage::FunctionStorage(Location location, StringRef name,
|
||||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: name(Identifier::get(name, type.getContext())), location(location),
|
||||
type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {}
|
||||
|
||||
Function::Function(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<NamedAttributeList> argAttrs)
|
||||
FunctionStorage::FunctionStorage(Location location, StringRef name,
|
||||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<NamedAttributeList> argAttrs)
|
||||
: name(Identifier::get(name, type.getContext())), location(location),
|
||||
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
|
||||
|
||||
MLIRContext *Function::getContext() { return getType().getContext(); }
|
||||
|
||||
Module *llvm::ilist_traits<Function>::getContainingModule() {
|
||||
Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
|
||||
size_t Offset(
|
||||
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
|
||||
iplist<Function> *Anchor(static_cast<iplist<Function> *>(this));
|
||||
iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
|
||||
return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when a Function is added to a Module. We
|
||||
/// keep the module pointer and module symbol table up to date.
|
||||
void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
|
||||
assert(!function->getModule() && "already in a module!");
|
||||
void llvm::ilist_traits<FunctionStorage>::addNodeToList(
|
||||
FunctionStorage *function) {
|
||||
assert(!function->module && "already in a module!");
|
||||
function->module = getContainingModule();
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when a Function is removed from a Module.
|
||||
/// We keep the module pointer up to date.
|
||||
void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
|
||||
void llvm::ilist_traits<FunctionStorage>::removeNodeFromList(
|
||||
FunctionStorage *function) {
|
||||
assert(function->module && "not already in a module!");
|
||||
function->module = nullptr;
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when an operation is moved from one block
|
||||
/// to another. We keep the block pointer up to date.
|
||||
void llvm::ilist_traits<Function>::transferNodesFromList(
|
||||
ilist_traits<Function> &otherList, function_iterator first,
|
||||
void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
|
||||
ilist_traits<FunctionStorage> &otherList, function_iterator first,
|
||||
function_iterator last) {
|
||||
// If we are transferring functions within the same module, the Module
|
||||
// pointer doesn't need to be updated.
|
||||
|
@ -82,8 +87,10 @@ void llvm::ilist_traits<Function>::transferNodesFromList(
|
|||
|
||||
/// Unlink this function from its Module and delete it.
|
||||
void Function::erase() {
|
||||
assert(getModule() && "Function has no parent");
|
||||
getModule()->getFunctions().erase(this);
|
||||
if (auto *module = getModule())
|
||||
getModule()->functions.erase(impl);
|
||||
else
|
||||
delete impl;
|
||||
}
|
||||
|
||||
/// Emit an error about fatal conditions with this function, reporting up to
|
||||
|
@ -111,10 +118,10 @@ InFlightDiagnostic Function::emitRemark(const Twine &message) {
|
|||
|
||||
/// Clone the internal blocks from this function into dest and all attributes
|
||||
/// from this function to dest.
|
||||
void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
|
||||
void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) {
|
||||
// Add the attributes of this function to dest.
|
||||
llvm::MapVector<Identifier, Attribute> newAttrs;
|
||||
for (auto &attr : dest->getAttrs())
|
||||
for (auto &attr : dest.getAttrs())
|
||||
newAttrs.insert(attr);
|
||||
for (auto &attr : getAttrs()) {
|
||||
auto insertPair = newAttrs.insert(attr);
|
||||
|
@ -125,10 +132,10 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
|
|||
assert((insertPair.second || insertPair.first->second == attr.second) &&
|
||||
"the two functions have incompatible attributes");
|
||||
}
|
||||
dest->setAttrs(newAttrs.takeVector());
|
||||
dest.setAttrs(newAttrs.takeVector());
|
||||
|
||||
// Clone the body.
|
||||
body.cloneInto(&dest->body, mapper);
|
||||
impl->body.cloneInto(&dest.impl->body, mapper);
|
||||
}
|
||||
|
||||
/// Create a deep copy of this function and all of its blocks, remapping
|
||||
|
@ -136,8 +143,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
|
|||
/// provided (leaving them alone if no entry is present). Replaces references
|
||||
/// to cloned sub-values with the corresponding value that is copied, and adds
|
||||
/// those mappings to the mapper.
|
||||
Function *Function::clone(BlockAndValueMapping &mapper) {
|
||||
FunctionType newType = type;
|
||||
Function Function::clone(BlockAndValueMapping &mapper) {
|
||||
FunctionType newType = impl->type;
|
||||
|
||||
// If the function has a body, then the user might be deleting arguments to
|
||||
// the function by specifying them in the mapper. If so, we don't add the
|
||||
|
@ -147,23 +154,23 @@ Function *Function::clone(BlockAndValueMapping &mapper) {
|
|||
SmallVector<Type, 4> inputTypes;
|
||||
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
|
||||
if (!mapper.contains(getArgument(i)))
|
||||
inputTypes.push_back(type.getInput(i));
|
||||
newType = FunctionType::get(inputTypes, type.getResults(), getContext());
|
||||
inputTypes.push_back(newType.getInput(i));
|
||||
newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
|
||||
}
|
||||
|
||||
// Create the new function.
|
||||
Function *newFunc = new Function(getLoc(), getName(), newType);
|
||||
Function newFunc = Function::create(getLoc(), getName(), newType);
|
||||
|
||||
/// Set the argument attributes for arguments that aren't being replaced.
|
||||
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
|
||||
if (isExternalFn || !mapper.contains(getArgument(i)))
|
||||
newFunc->setArgAttrs(destI++, getArgAttrs(i));
|
||||
newFunc.setArgAttrs(destI++, getArgAttrs(i));
|
||||
|
||||
/// Clone the current function into the new one and return it.
|
||||
cloneInto(newFunc, mapper);
|
||||
return newFunc;
|
||||
}
|
||||
Function *Function::clone() {
|
||||
Function Function::clone() {
|
||||
BlockAndValueMapping mapper;
|
||||
return clone(mapper);
|
||||
}
|
||||
|
@ -178,7 +185,7 @@ void Function::addEntryBlock() {
|
|||
assert(empty() && "function already has an entry block");
|
||||
auto *entry = new Block();
|
||||
push_back(entry);
|
||||
entry->addArguments(type.getInputs());
|
||||
entry->addArguments(impl->type.getInputs());
|
||||
}
|
||||
|
||||
void Function::walk(const std::function<void(Operation *)> &callback) {
|
||||
|
|
|
@ -281,7 +281,7 @@ Operation *Operation::getParentOp() {
|
|||
return block ? block->getContainingOp() : nullptr;
|
||||
}
|
||||
|
||||
Function *Operation::getFunction() {
|
||||
Function Operation::getFunction() {
|
||||
return block ? block->getFunction() : nullptr;
|
||||
}
|
||||
|
||||
|
@ -861,12 +861,13 @@ static LogicalResult verifyBBArguments(Operation::operand_range operands,
|
|||
}
|
||||
|
||||
static LogicalResult verifyTerminatorSuccessors(Operation *op) {
|
||||
auto *parent = op->getContainingRegion();
|
||||
|
||||
// Verify that the operands lines up with the BB arguments in the successor.
|
||||
Function *fn = op->getFunction();
|
||||
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
|
||||
auto *succ = op->getSuccessor(i);
|
||||
if (succ->getFunction() != fn)
|
||||
return op->emitError("reference to block defined in another function");
|
||||
if (succ->getParent() != parent)
|
||||
return op->emitError("reference to block defined in another region");
|
||||
if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op)))
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "mlir/IR/Operation.h"
|
||||
using namespace mlir;
|
||||
|
||||
Region::Region(Function *container) : container(container) {}
|
||||
Region::Region(Function container) : container(container.impl) {}
|
||||
|
||||
Region::Region(Operation *container) : container(container) {}
|
||||
|
||||
|
@ -38,7 +38,7 @@ MLIRContext *Region::getContext() {
|
|||
assert(!container.isNull() && "region is not attached to a container");
|
||||
if (auto *inst = getContainingOp())
|
||||
return inst->getContext();
|
||||
return getContainingFunction()->getContext();
|
||||
return getContainingFunction().getContext();
|
||||
}
|
||||
|
||||
/// Return a location for this region. This is the location attached to the
|
||||
|
@ -47,7 +47,7 @@ Location Region::getLoc() {
|
|||
assert(!container.isNull() && "region is not attached to a container");
|
||||
if (auto *inst = getContainingOp())
|
||||
return inst->getLoc();
|
||||
return getContainingFunction()->getLoc();
|
||||
return getContainingFunction().getLoc();
|
||||
}
|
||||
|
||||
Region *Region::getContainingRegion() {
|
||||
|
@ -60,8 +60,8 @@ Operation *Region::getContainingOp() {
|
|||
return container.dyn_cast<Operation *>();
|
||||
}
|
||||
|
||||
Function *Region::getContainingFunction() {
|
||||
return container.dyn_cast<Function *>();
|
||||
Function Region::getContainingFunction() {
|
||||
return container.dyn_cast<detail::FunctionStorage *>();
|
||||
}
|
||||
|
||||
bool Region::isProperAncestor(Region *other) {
|
||||
|
|
|
@ -22,8 +22,8 @@ using namespace mlir;
|
|||
|
||||
/// Build a symbol table with the symbols within the given module.
|
||||
SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
|
||||
for (auto &func : *module) {
|
||||
auto inserted = symbolTable.insert({func.getName(), &func});
|
||||
for (auto func : *module) {
|
||||
auto inserted = symbolTable.insert({func.getName(), func});
|
||||
(void)inserted;
|
||||
assert(inserted.second &&
|
||||
"expected module to contain uniquely named functions");
|
||||
|
@ -32,34 +32,34 @@ SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
|
|||
|
||||
/// Look up a symbol with the specified name, returning null if no such name
|
||||
/// exists. Names never include the @ on them.
|
||||
Function *SymbolTable::lookup(StringRef name) const {
|
||||
Function SymbolTable::lookup(StringRef name) const {
|
||||
return lookup(Identifier::get(name, context));
|
||||
}
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such name
|
||||
/// exists. Names never include the @ on them.
|
||||
Function *SymbolTable::lookup(Identifier name) const {
|
||||
Function SymbolTable::lookup(Identifier name) const {
|
||||
return symbolTable.lookup(name);
|
||||
}
|
||||
|
||||
/// Erase the given symbol from the table.
|
||||
void SymbolTable::erase(Function *symbol) {
|
||||
auto it = symbolTable.find(symbol->getName());
|
||||
void SymbolTable::erase(Function symbol) {
|
||||
auto it = symbolTable.find(symbol.getName());
|
||||
if (it != symbolTable.end() && it->second == symbol)
|
||||
symbolTable.erase(it);
|
||||
}
|
||||
|
||||
/// Insert a new symbol into the table, and rename it as necessary to avoid
|
||||
/// collisions.
|
||||
void SymbolTable::insert(Function *symbol) {
|
||||
void SymbolTable::insert(Function symbol) {
|
||||
// Add this symbol to the symbol table, uniquing the name if a conflict is
|
||||
// detected.
|
||||
if (symbolTable.insert({symbol->getName(), symbol}).second)
|
||||
if (symbolTable.insert({symbol.getName(), symbol}).second)
|
||||
return;
|
||||
|
||||
// If a conflict was detected, then the function will not have been added to
|
||||
// the symbol table. Try suffixes until we get to a unique name that works.
|
||||
SmallString<128> nameBuffer(symbol->getName());
|
||||
SmallString<128> nameBuffer(symbol.getName());
|
||||
unsigned originalLength = nameBuffer.size();
|
||||
|
||||
// Iteratively try suffixes until we find one that isn't used. We use a
|
||||
|
@ -68,6 +68,6 @@ void SymbolTable::insert(Function *symbol) {
|
|||
nameBuffer.resize(originalLength);
|
||||
nameBuffer += '_';
|
||||
nameBuffer += std::to_string(uniquingCounter++);
|
||||
symbol->setName(Identifier::get(nameBuffer, context));
|
||||
} while (!symbolTable.insert({symbol->getName(), symbol}).second);
|
||||
symbol.setName(Identifier::get(nameBuffer, context));
|
||||
} while (!symbolTable.insert({symbol.getName(), symbol}).second);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ Operation *Value::getDefiningOp() {
|
|||
}
|
||||
|
||||
/// Return the function that this Value is defined in.
|
||||
Function *Value::getFunction() {
|
||||
Function Value::getFunction() {
|
||||
switch (getKind()) {
|
||||
case Value::Kind::BlockArgument:
|
||||
return cast<BlockArgument>(this)->getFunction();
|
||||
|
@ -84,7 +84,7 @@ void IRObjectWithUseList::dropAllUses() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the function that this argument is defined in.
|
||||
Function *BlockArgument::getFunction() {
|
||||
Function BlockArgument::getFunction() {
|
||||
if (auto *owner = getOwner())
|
||||
return owner->getFunction();
|
||||
return nullptr;
|
||||
|
@ -92,6 +92,6 @@ Function *BlockArgument::getFunction() {
|
|||
|
||||
/// Returns if the current argument is a function argument.
|
||||
bool BlockArgument::isFunctionArgument() {
|
||||
auto *containingFn = getFunction();
|
||||
return containingFn && &containingFn->front() == getOwner();
|
||||
auto containingFn = getFunction();
|
||||
return containingFn && &containingFn.front() == getOwner();
|
||||
}
|
||||
|
|
|
@ -816,12 +816,12 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const {
|
|||
}
|
||||
|
||||
/// Verify LLVMIR function argument attributes.
|
||||
LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
|
||||
LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func,
|
||||
unsigned argIdx,
|
||||
NamedAttribute argAttr) {
|
||||
// Check that llvm.noalias is a boolean attribute.
|
||||
if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
|
||||
return func->emitError()
|
||||
return func.emitError()
|
||||
<< "llvm.noalias argument attribute of non boolean type";
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
|
|||
return true;
|
||||
}
|
||||
|
||||
static void fuseLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
|
||||
static void fuseLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
|
||||
OperationFolder state;
|
||||
DenseSet<Operation *> eraseSet;
|
||||
|
||||
|
|
|
@ -170,12 +170,13 @@ public:
|
|||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto *module = op->getFunction()->getModule();
|
||||
Function *mallocFunc = module->getNamedFunction("malloc");
|
||||
auto *module = op->getFunction().getModule();
|
||||
Function mallocFunc = module->getNamedFunction("malloc");
|
||||
if (!mallocFunc) {
|
||||
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
|
||||
mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
module->getFunctions().push_back(mallocFunc);
|
||||
mallocFunc =
|
||||
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
|
||||
module->push_back(mallocFunc);
|
||||
}
|
||||
|
||||
// Get MLIR types for injecting element pointer.
|
||||
|
@ -230,12 +231,12 @@ public:
|
|||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto *module = op->getFunction()->getModule();
|
||||
Function *freeFunc = module->getNamedFunction("free");
|
||||
auto *module = op->getFunction().getModule();
|
||||
Function freeFunc = module->getNamedFunction("free");
|
||||
if (!freeFunc) {
|
||||
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
|
||||
freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
|
||||
module->getFunctions().push_back(freeFunc);
|
||||
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
|
||||
module->push_back(freeFunc);
|
||||
}
|
||||
|
||||
// Get MLIR types for extracting element pointer.
|
||||
|
@ -572,37 +573,37 @@ public:
|
|||
|
||||
// Create a function definition which takes as argument pointers to the input
|
||||
// types and returns pointers to the output types.
|
||||
static Function *getLLVMLibraryCallImplDefinition(Function *libFn) {
|
||||
auto implFnName = (libFn->getName().str() + "_impl");
|
||||
auto module = libFn->getModule();
|
||||
if (auto *f = module->getNamedFunction(implFnName)) {
|
||||
static Function getLLVMLibraryCallImplDefinition(Function libFn) {
|
||||
auto implFnName = (libFn.getName().str() + "_impl");
|
||||
auto module = libFn.getModule();
|
||||
if (auto f = module->getNamedFunction(implFnName)) {
|
||||
return f;
|
||||
}
|
||||
SmallVector<Type, 4> fnArgTypes;
|
||||
for (auto t : libFn->getType().getInputs()) {
|
||||
for (auto t : libFn.getType().getInputs()) {
|
||||
assert(t.isa<LLVMType>() &&
|
||||
"Expected LLVM Type for argument while generating library Call "
|
||||
"Implementation Definition");
|
||||
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
|
||||
}
|
||||
auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext());
|
||||
auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext());
|
||||
|
||||
// Insert the implementation function definition.
|
||||
auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType);
|
||||
module->getFunctions().push_back(implFnDefn);
|
||||
auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
|
||||
module->push_back(implFnDefn);
|
||||
return implFnDefn;
|
||||
}
|
||||
|
||||
// Get function definition for the LinalgOp. If it doesn't exist, insert a
|
||||
// definition.
|
||||
template <typename LinalgOp>
|
||||
static Function *getLLVMLibraryCallDeclaration(Operation *op,
|
||||
LLVMTypeConverter &lowering,
|
||||
PatternRewriter &rewriter) {
|
||||
static Function getLLVMLibraryCallDeclaration(Operation *op,
|
||||
LLVMTypeConverter &lowering,
|
||||
PatternRewriter &rewriter) {
|
||||
assert(isa<LinalgOp>(op));
|
||||
auto fnName = LinalgOp::getLibraryCallName();
|
||||
auto module = op->getFunction()->getModule();
|
||||
if (auto *f = module->getNamedFunction(fnName)) {
|
||||
auto module = op->getFunction().getModule();
|
||||
if (auto f = module->getNamedFunction(fnName)) {
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op,
|
|||
"Library call for linalg operation can be generated only for ops that "
|
||||
"have void return types");
|
||||
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
|
||||
auto libFn = new Function(op->getLoc(), fnName, libFnType);
|
||||
module->getFunctions().push_back(libFn);
|
||||
auto libFn = Function::create(op->getLoc(), fnName, libFnType);
|
||||
module->push_back(libFn);
|
||||
// Return after creating the function definition. The body will be created
|
||||
// later.
|
||||
return libFn;
|
||||
}
|
||||
|
||||
static void getLLVMLibraryCallDefinition(Function *fn,
|
||||
static void getLLVMLibraryCallDefinition(Function fn,
|
||||
LLVMTypeConverter &lowering) {
|
||||
// Generate the implementation function definition.
|
||||
auto implFn = getLLVMLibraryCallImplDefinition(fn);
|
||||
|
||||
// Generate the function body.
|
||||
fn->addEntryBlock();
|
||||
fn.addEntryBlock();
|
||||
|
||||
OpBuilder builder(fn->getBody());
|
||||
edsc::ScopedContext scope(builder, fn->getLoc());
|
||||
OpBuilder builder(fn.getBody());
|
||||
edsc::ScopedContext scope(builder, fn.getLoc());
|
||||
SmallVector<Value *, 4> implFnArgs;
|
||||
|
||||
// Create a constant 1.
|
||||
auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
|
||||
IntegerAttr::get(IndexType::get(fn->getContext()), 1));
|
||||
for (auto arg : fn->getArguments()) {
|
||||
IntegerAttr::get(IndexType::get(fn.getContext()), 1));
|
||||
for (auto arg : fn.getArguments()) {
|
||||
// Allocate a stack for storing the argument value. The stack is passed to
|
||||
// the implementation function.
|
||||
auto alloca =
|
||||
|
@ -665,17 +666,17 @@ public:
|
|||
return convertLinalgType(t, *this);
|
||||
}
|
||||
|
||||
void addLibraryFnDeclaration(Function *fn) {
|
||||
void addLibraryFnDeclaration(Function fn) {
|
||||
libraryFnDeclarations.push_back(fn);
|
||||
}
|
||||
|
||||
ArrayRef<Function *> getLibraryFnDeclarations() {
|
||||
ArrayRef<Function> getLibraryFnDeclarations() {
|
||||
return libraryFnDeclarations;
|
||||
}
|
||||
|
||||
private:
|
||||
/// List of library functions declarations needed during dialect conversion
|
||||
SmallVector<Function *, 2> libraryFnDeclarations;
|
||||
SmallVector<Function, 2> libraryFnDeclarations;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -692,7 +693,7 @@ public:
|
|||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only emit library call declaration. Fill in the body later.
|
||||
auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
|
||||
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
|
||||
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
|
||||
|
||||
auto fAttr = rewriter.getFunctionAttr(f);
|
||||
|
@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) {
|
|||
void LowerLinalgToLLVMPass::runOnModule() {
|
||||
auto &module = getModule();
|
||||
|
||||
for (auto &f : module.getFunctions()) {
|
||||
for (auto f : module.getFunctions()) {
|
||||
lowerLinalgSubViewOps(f);
|
||||
lowerLinalgForToCFG(f);
|
||||
if (failed(lowerAffineConstructs(f)))
|
||||
|
|
|
@ -104,9 +104,8 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
|
|||
} // namespace
|
||||
|
||||
void LowerLinalgToLoopsPass::runOnFunction() {
|
||||
auto &f = getFunction();
|
||||
OperationFolder state;
|
||||
f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
|
||||
getFunction().walk<LinalgOp>([&state](LinalgOp linalgOp) {
|
||||
emitLinalgOpAsLoops(linalgOp, state);
|
||||
linalgOp.getOperation()->erase();
|
||||
});
|
||||
|
|
|
@ -259,7 +259,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
|
|||
return tileLinalgOp(op, tileSizeValues, state);
|
||||
}
|
||||
|
||||
static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
|
||||
static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
|
||||
OperationFolder state;
|
||||
f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) {
|
||||
auto opLoopsPair = tileLinalgOp(op, tileSizes, state);
|
||||
|
|
|
@ -254,7 +254,7 @@ public:
|
|||
/// trailing-location ::= location?
|
||||
///
|
||||
template <typename Owner>
|
||||
ParseResult parseOptionalTrailingLocation(Owner *owner) {
|
||||
ParseResult parseOptionalTrailingLocation(Owner &owner) {
|
||||
// If there is a 'loc' we parse a trailing location.
|
||||
if (!getToken().is(Token::kw_loc))
|
||||
return success();
|
||||
|
@ -263,7 +263,7 @@ public:
|
|||
LocationAttr directLoc;
|
||||
if (parseLocation(directLoc))
|
||||
return failure();
|
||||
owner->setLoc(directLoc);
|
||||
owner.setLoc(directLoc);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -2472,8 +2472,8 @@ namespace {
|
|||
/// operations.
|
||||
class OperationParser : public Parser {
|
||||
public:
|
||||
OperationParser(ParserState &state, Function *function)
|
||||
: Parser(state), function(function), opBuilder(function->getBody()) {}
|
||||
OperationParser(ParserState &state, Function function)
|
||||
: Parser(state), function(function), opBuilder(function.getBody()) {}
|
||||
|
||||
~OperationParser();
|
||||
|
||||
|
@ -2588,7 +2588,7 @@ public:
|
|||
Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
|
||||
|
||||
private:
|
||||
Function *function;
|
||||
Function function;
|
||||
|
||||
/// Returns the info for a block at the current scope for the given name.
|
||||
std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
|
||||
|
@ -2690,7 +2690,7 @@ ParseResult OperationParser::popSSANameScope() {
|
|||
for (auto entry : forwardRefInCurrentScope) {
|
||||
errors.push_back({entry.second.getPointer(), entry.first});
|
||||
// Add this block to the top-level region to allow for automatic cleanup.
|
||||
function->push_back(entry.first);
|
||||
function.push_back(entry.first);
|
||||
}
|
||||
llvm::array_pod_sort(errors.begin(), errors.end());
|
||||
|
||||
|
@ -2984,7 +2984,7 @@ ParseResult OperationParser::parseOperation() {
|
|||
}
|
||||
|
||||
// Try to parse the optional trailing location.
|
||||
if (parseOptionalTrailingLocation(op))
|
||||
if (parseOptionalTrailingLocation(*op))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
|
@ -4049,17 +4049,17 @@ ParseResult ModuleParser::parseFunc(Module *module) {
|
|||
}
|
||||
|
||||
// Okay, the function signature was parsed correctly, create the function now.
|
||||
auto *function =
|
||||
new Function(getEncodedSourceLocation(loc), name, type, attrs);
|
||||
module->getFunctions().push_back(function);
|
||||
auto function =
|
||||
Function::create(getEncodedSourceLocation(loc), name, type, attrs);
|
||||
module->push_back(function);
|
||||
|
||||
// Parse an optional trailing location.
|
||||
if (parseOptionalTrailingLocation(function))
|
||||
return failure();
|
||||
|
||||
// Add the attributes to the function arguments.
|
||||
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i)
|
||||
function->setArgAttrs(i, argAttrs[i]);
|
||||
for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i)
|
||||
function.setArgAttrs(i, argAttrs[i]);
|
||||
|
||||
// External functions have no body.
|
||||
if (getToken().isNot(Token::l_brace))
|
||||
|
@ -4076,11 +4076,11 @@ ParseResult ModuleParser::parseFunc(Module *module) {
|
|||
|
||||
// Parse the function body.
|
||||
auto parser = OperationParser(getState(), function);
|
||||
if (parser.parseRegion(function->getBody(), entryArgs))
|
||||
if (parser.parseRegion(function.getBody(), entryArgs))
|
||||
return failure();
|
||||
|
||||
// Verify that a valid function body was parsed.
|
||||
if (function->empty())
|
||||
if (function.empty())
|
||||
return emitError(braceLoc, "function must have a body");
|
||||
|
||||
return parser.finalize(braceLoc);
|
||||
|
|
|
@ -61,12 +61,12 @@ private:
|
|||
static void printIR(const llvm::Any &ir, bool printModuleScope,
|
||||
raw_ostream &out) {
|
||||
// Check for printing at module scope.
|
||||
if (printModuleScope && llvm::any_isa<Function *>(ir)) {
|
||||
Function *function = llvm::any_cast<Function *>(ir);
|
||||
if (printModuleScope && llvm::any_isa<Function>(ir)) {
|
||||
Function function = llvm::any_cast<Function>(ir);
|
||||
|
||||
// Print the function name and a newline before the Module.
|
||||
out << " (function: " << function->getName() << ")\n";
|
||||
function->getModule()->print(out);
|
||||
out << " (function: " << function.getName() << ")\n";
|
||||
function.getModule()->print(out);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
|
|||
out << "\n";
|
||||
|
||||
// Print the given function.
|
||||
if (llvm::any_isa<Function *>(ir)) {
|
||||
llvm::any_cast<Function *>(ir)->print(out);
|
||||
if (llvm::any_isa<Function>(ir)) {
|
||||
llvm::any_cast<Function>(ir).print(out);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -46,8 +46,7 @@ static llvm::cl::opt<bool>
|
|||
void Pass::anchor() {}
|
||||
|
||||
/// Forwarding function to execute this pass.
|
||||
LogicalResult FunctionPassBase::run(Function *fn,
|
||||
FunctionAnalysisManager &fam) {
|
||||
LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) {
|
||||
// Initialize the pass state.
|
||||
passState.emplace(fn, fam);
|
||||
|
||||
|
@ -115,7 +114,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
|
|||
}
|
||||
|
||||
/// Run all of the passes in this manager over the current function.
|
||||
LogicalResult detail::FunctionPassExecutor::run(Function *function,
|
||||
LogicalResult detail::FunctionPassExecutor::run(Function function,
|
||||
FunctionAnalysisManager &fam) {
|
||||
// Run each of the held passes.
|
||||
for (auto &pass : passes)
|
||||
|
@ -141,7 +140,7 @@ LogicalResult detail::ModulePassExecutor::run(Module *module,
|
|||
/// Utility to run the given function and analysis manager on a provided
|
||||
/// function pass executor.
|
||||
static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
|
||||
Function *func,
|
||||
Function func,
|
||||
FunctionAnalysisManager &fam) {
|
||||
// Run the function pipeline over the provided function.
|
||||
auto result = fpe.run(func, fam);
|
||||
|
@ -158,14 +157,14 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
|
|||
/// module.
|
||||
void ModuleToFunctionPassAdaptor::runOnModule() {
|
||||
ModuleAnalysisManager &mam = getAnalysisManager();
|
||||
for (auto &func : getModule()) {
|
||||
for (auto func : getModule()) {
|
||||
// Skip external functions.
|
||||
if (func.isExternal())
|
||||
continue;
|
||||
|
||||
// Run the held function pipeline over the current function.
|
||||
auto fam = mam.slice(&func);
|
||||
if (failed(runFunctionPipeline(fpe, &func, fam)))
|
||||
auto fam = mam.slice(func);
|
||||
if (failed(runFunctionPipeline(fpe, func, fam)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Clear out any computed function analyses. These analyses won't be used
|
||||
|
@ -189,10 +188,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
|
|||
// Run a prepass over the module to collect the functions to execute a over.
|
||||
// This ensures that an analysis manager exists for each function, as well as
|
||||
// providing a queue of functions to execute over.
|
||||
std::vector<std::pair<Function *, FunctionAnalysisManager>> funcAMPairs;
|
||||
for (auto &func : getModule())
|
||||
std::vector<std::pair<Function, FunctionAnalysisManager>> funcAMPairs;
|
||||
for (auto func : getModule())
|
||||
if (!func.isExternal())
|
||||
funcAMPairs.emplace_back(&func, mam.slice(&func));
|
||||
funcAMPairs.emplace_back(func, mam.slice(func));
|
||||
|
||||
// A parallel diagnostic handler that provides deterministic diagnostic
|
||||
// ordering.
|
||||
|
@ -340,8 +339,8 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const {
|
|||
}
|
||||
|
||||
/// Create an analysis slice for the given child function.
|
||||
FunctionAnalysisManager ModuleAnalysisManager::slice(Function *func) {
|
||||
assert(func->getModule() == moduleAnalyses.getIRUnit() &&
|
||||
FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) {
|
||||
assert(func.getModule() == moduleAnalyses.getIRUnit() &&
|
||||
"function has a different parent module");
|
||||
auto it = functionAnalyses.find(func);
|
||||
if (it == functionAnalyses.end()) {
|
||||
|
|
|
@ -48,7 +48,7 @@ public:
|
|||
FunctionPassExecutor(const FunctionPassExecutor &rhs);
|
||||
|
||||
/// Run the executor on the given function.
|
||||
LogicalResult run(Function *function, FunctionAnalysisManager &fam);
|
||||
LogicalResult run(Function function, FunctionAnalysisManager &fam);
|
||||
|
||||
/// Add a pass to the current executor. This takes ownership over the provided
|
||||
/// pass pointer.
|
||||
|
|
|
@ -71,7 +71,7 @@ void AddDefaultStatsPass::runOnFunction() {
|
|||
|
||||
void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) {
|
||||
auto &func = getFunction();
|
||||
auto func = getFunction();
|
||||
|
||||
// Insert stats for each argument.
|
||||
for (auto *arg : func.getArguments()) {
|
||||
|
|
|
@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() {
|
|||
void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
|
||||
const TargetConfiguration &config) {
|
||||
CAGSlice cag(solverContext);
|
||||
for (auto &f : getModule()) {
|
||||
for (auto f : getModule()) {
|
||||
f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
|
||||
}
|
||||
config.finalizeAnchors(cag);
|
||||
|
|
|
@ -58,7 +58,7 @@ public:
|
|||
|
||||
void RemoveInstrumentationPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
patterns.push_back(
|
||||
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
|
||||
|
|
|
@ -36,11 +36,11 @@ using namespace mlir;
|
|||
// block. The created block will be terminated by `std.return`.
|
||||
Block *createOneBlockFunction(Builder builder, Module *module) {
|
||||
auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
|
||||
auto *fn = new Function(builder.getUnknownLoc(), "spirv_module", fnType);
|
||||
module->getFunctions().push_back(fn);
|
||||
auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType);
|
||||
module->push_back(fn);
|
||||
|
||||
auto *block = new Block();
|
||||
fn->push_back(block);
|
||||
fn.push_back(block);
|
||||
|
||||
OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName());
|
||||
ReturnOp::build(&builder, &state);
|
||||
|
|
|
@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
|
|||
// wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
|
||||
// to take in the SPIR-V ModuleOp directly after module and function are
|
||||
// migrated to be general ops.
|
||||
for (auto &fn : *module) {
|
||||
for (auto fn : *module) {
|
||||
fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
|
||||
if (done) {
|
||||
spirvModule.emitError("found more than one 'spv.module' op");
|
||||
|
|
|
@ -42,7 +42,7 @@ class StdOpsToSPIRVConversionPass
|
|||
|
||||
void StdOpsToSPIRVConversionPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
auto func = getFunction();
|
||||
|
||||
populateWithGenerated(func.getContext(), &patterns);
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
|
|
|
@ -440,14 +440,14 @@ static LogicalResult verify(CallOp op) {
|
|||
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires a 'callee' function attribute");
|
||||
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
|
||||
auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << fnAttr.getValue()
|
||||
<< "' does not reference a valid function";
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto fnType = fn->getType();
|
||||
auto fnType = fn.getType();
|
||||
if (fnType.getNumInputs() != op.getNumOperands())
|
||||
return op.emitOpError("incorrect number of operands for callee");
|
||||
|
||||
|
@ -1107,13 +1107,13 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
return op.emitOpError("requires 'value' to be a function reference");
|
||||
|
||||
// Try to find the referenced function.
|
||||
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
|
||||
auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError("reference to undefined function 'bar'");
|
||||
|
||||
// Check that the referenced function has the correct type.
|
||||
if (fn->getType() != type)
|
||||
if (fn.getType() != type)
|
||||
return op.emitOpError("reference to function with mismatched type");
|
||||
|
||||
return success();
|
||||
|
@ -1876,10 +1876,10 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
|
|||
}
|
||||
|
||||
static LogicalResult verify(ReturnOp op) {
|
||||
auto *function = op.getOperation()->getFunction();
|
||||
auto function = op.getOperation()->getFunction();
|
||||
|
||||
// The operand number and types must match the function signature.
|
||||
const auto &results = function->getType().getResults();
|
||||
const auto &results = function.getType().getResults();
|
||||
if (op.getNumOperands() != results.size())
|
||||
return op.emitOpError("has ")
|
||||
<< op.getNumOperands()
|
||||
|
|
|
@ -69,7 +69,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
|
|||
|
||||
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
|
||||
// function as a kernel.
|
||||
for (Function &func : m) {
|
||||
for (Function func : m) {
|
||||
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
|
||||
continue;
|
||||
|
||||
|
@ -89,20 +89,21 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
|
|||
return llvmModule;
|
||||
}
|
||||
|
||||
static TranslateFromMLIRRegistration registration(
|
||||
"mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) {
|
||||
if (!module)
|
||||
return true;
|
||||
static TranslateFromMLIRRegistration
|
||||
registration("mlir-to-nvvmir",
|
||||
[](Module *module, llvm::StringRef outputFilename) {
|
||||
if (!module)
|
||||
return true;
|
||||
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*module);
|
||||
if (!llvmModule)
|
||||
return true;
|
||||
auto llvmModule = mlir::translateModuleToNVVMIR(*module);
|
||||
if (!llvmModule)
|
||||
return true;
|
||||
|
||||
auto file = openOutputFile(outputFilename);
|
||||
if (!file)
|
||||
return true;
|
||||
auto file = openOutputFile(outputFilename);
|
||||
if (!file)
|
||||
return true;
|
||||
|
||||
llvmModule->print(file->os(), nullptr);
|
||||
file->keep();
|
||||
return false;
|
||||
});
|
||||
llvmModule->print(file->os(), nullptr);
|
||||
file->keep();
|
||||
return false;
|
||||
});
|
||||
|
|
|
@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) {
|
|||
bool ModuleTranslation::convertFunctions() {
|
||||
// Declare all functions first because there may be function calls that form a
|
||||
// call graph with cycles.
|
||||
for (Function &function : mlirModule) {
|
||||
for (Function function : mlirModule) {
|
||||
mlir::BoolAttr isVarArgsAttr =
|
||||
function.getAttrOfType<BoolAttr>("std.varargs");
|
||||
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
|
||||
|
@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() {
|
|||
}
|
||||
|
||||
// Convert functions.
|
||||
for (Function &function : mlirModule) {
|
||||
for (Function function : mlirModule) {
|
||||
// Ignore external functions.
|
||||
if (function.isExternal())
|
||||
continue;
|
||||
|
|
|
@ -40,7 +40,7 @@ struct Canonicalizer : public FunctionPass<Canonicalizer> {
|
|||
|
||||
void Canonicalizer::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto &func = getFunction();
|
||||
auto func = getFunction();
|
||||
|
||||
// TODO: Instead of adding all known patterns from the whole system lazily add
|
||||
// and cache the canonicalization patterns for ops we see in practice when
|
||||
|
|
|
@ -849,7 +849,7 @@ struct FunctionConverter {
|
|||
/// error, success otherwise. If 'signatureConversion' is provided, the
|
||||
/// arguments of the entry block are updated accordingly.
|
||||
LogicalResult
|
||||
convertFunction(Function *f,
|
||||
convertFunction(Function f,
|
||||
TypeConverter::SignatureConversion *signatureConversion);
|
||||
|
||||
/// Converts the given region starting from the entry block and following the
|
||||
|
@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
|
|||
}
|
||||
|
||||
LogicalResult FunctionConverter::convertFunction(
|
||||
Function *f, TypeConverter::SignatureConversion *signatureConversion) {
|
||||
Function f, TypeConverter::SignatureConversion *signatureConversion) {
|
||||
// If this is an external function, there is nothing else to do.
|
||||
if (f->isExternal())
|
||||
if (f.isExternal())
|
||||
return success();
|
||||
|
||||
DialectConversionRewriter rewriter(f->getBody(), typeConverter);
|
||||
DialectConversionRewriter rewriter(f.getBody(), typeConverter);
|
||||
|
||||
// Update the signature of the entry block.
|
||||
if (signatureConversion) {
|
||||
rewriter.argConverter.convertSignature(
|
||||
&f->getBody().front(), *signatureConversion, rewriter.mapping);
|
||||
&f.getBody().front(), *signatureConversion, rewriter.mapping);
|
||||
}
|
||||
|
||||
// Rewrite the function body.
|
||||
if (failed(
|
||||
convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) {
|
||||
convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) {
|
||||
// Reset any of the generated rewrites.
|
||||
rewriter.discardRewrites();
|
||||
return failure();
|
||||
|
@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const
|
|||
// applyConversionPatterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class represents a function to be converted. It allows for converting
|
||||
/// the body of functions and the signature in two phases.
|
||||
struct ConvertedFunction {
|
||||
ConvertedFunction(Function *fn, FunctionType newType,
|
||||
ArrayRef<NamedAttributeList> newFunctionArgAttrs)
|
||||
: fn(fn), newType(newType),
|
||||
newFunctionArgAttrs(newFunctionArgAttrs.begin(),
|
||||
newFunctionArgAttrs.end()) {}
|
||||
|
||||
/// The function to convert.
|
||||
Function *fn;
|
||||
/// The new type and argument attributes for the function.
|
||||
FunctionType newType;
|
||||
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Convert the given module with the provided conversion patterns and type
|
||||
/// conversion object. If conversion fails for specific functions, those
|
||||
/// functions remains unmodified.
|
||||
|
@ -1149,37 +1131,33 @@ LogicalResult
|
|||
mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
|
||||
TypeConverter &converter,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
std::vector<Function *> allFunctions;
|
||||
allFunctions.reserve(module.getFunctions().size());
|
||||
for (auto &func : module)
|
||||
allFunctions.push_back(&func);
|
||||
SmallVector<Function, 32> allFunctions(module.getFunctions());
|
||||
return applyConversionPatterns(allFunctions, target, converter,
|
||||
std::move(patterns));
|
||||
}
|
||||
|
||||
/// Convert the given functions with the provided conversion patterns.
|
||||
LogicalResult mlir::applyConversionPatterns(
|
||||
ArrayRef<Function *> fns, ConversionTarget &target,
|
||||
MutableArrayRef<Function> fns, ConversionTarget &target,
|
||||
TypeConverter &converter, OwningRewritePatternList &&patterns) {
|
||||
if (fns.empty())
|
||||
return success();
|
||||
|
||||
// Build the function converter.
|
||||
FunctionConverter funcConverter(fns.front()->getContext(), target, patterns,
|
||||
&converter);
|
||||
auto *ctx = fns.front().getContext();
|
||||
FunctionConverter funcConverter(ctx, target, patterns, &converter);
|
||||
|
||||
// Try to convert each of the functions within the module.
|
||||
auto *ctx = fns.front()->getContext();
|
||||
for (auto *func : fns) {
|
||||
for (auto func : fns) {
|
||||
// Convert the function type using the type converter.
|
||||
auto conversion =
|
||||
converter.convertSignature(func->getType(), func->getAllArgAttrs());
|
||||
converter.convertSignature(func.getType(), func.getAllArgAttrs());
|
||||
if (!conversion)
|
||||
return failure();
|
||||
|
||||
// Update the function signature.
|
||||
func->setType(conversion->getConvertedType(ctx));
|
||||
func->setAllArgAttrs(conversion->getConvertedArgAttrs());
|
||||
func.setType(conversion->getConvertedType(ctx));
|
||||
func.setAllArgAttrs(conversion->getConvertedArgAttrs());
|
||||
|
||||
// Convert the body of this function.
|
||||
if (failed(funcConverter.convertFunction(func, &*conversion)))
|
||||
|
@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns(
|
|||
/// convert as many of the operations within 'fn' as possible given the set of
|
||||
/// patterns.
|
||||
LogicalResult
|
||||
mlir::applyConversionPatterns(Function &fn, ConversionTarget &target,
|
||||
mlir::applyConversionPatterns(Function fn, ConversionTarget &target,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
// Convert the body of this function.
|
||||
FunctionConverter converter(fn.getContext(), target, patterns);
|
||||
return converter.convertFunction(&fn, /*signatureConversion=*/nullptr);
|
||||
return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
|
||||
}
|
||||
|
|
|
@ -214,7 +214,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
|
|||
static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
|
||||
emitRemarkForBlock(Block &block) {
|
||||
auto *op = block.getContainingOp();
|
||||
return op ? op->emitRemark() : block.getFunction()->emitRemark();
|
||||
return op ? op->emitRemark() : block.getFunction().emitRemark();
|
||||
}
|
||||
|
||||
/// Creates a buffer in the faster memory space for the specified region;
|
||||
|
@ -246,8 +246,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block,
|
|||
OpBuilder &b = region.isWrite() ? epilogue : prologue;
|
||||
|
||||
// Builder to create constants at the top level.
|
||||
auto *func = block->getFunction();
|
||||
OpBuilder top(func->getBody());
|
||||
auto func = block->getFunction();
|
||||
OpBuilder top(func.getBody());
|
||||
|
||||
auto loc = region.loc;
|
||||
auto *memref = region.memref;
|
||||
|
@ -751,14 +751,14 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
|
|||
if (auto *op = block->getContainingOp())
|
||||
op->emitError(str);
|
||||
else
|
||||
block->getFunction()->emitError(str);
|
||||
block->getFunction().emitError(str);
|
||||
}
|
||||
|
||||
return totalDmaBuffersSizeInBytes;
|
||||
}
|
||||
|
||||
void DmaGeneration::runOnFunction() {
|
||||
Function &f = getFunction();
|
||||
Function f = getFunction();
|
||||
OpBuilder topBuilder(f.getBody());
|
||||
zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
|
||||
|
||||
|
|
|
@ -257,7 +257,7 @@ public:
|
|||
|
||||
// Initializes the dependence graph based on operations in 'f'.
|
||||
// Returns true on success, false otherwise.
|
||||
bool init(Function &f);
|
||||
bool init(Function f);
|
||||
|
||||
// Returns the graph node for 'id'.
|
||||
Node *getNode(unsigned id) {
|
||||
|
@ -637,7 +637,7 @@ public:
|
|||
// Assigns each node in the graph a node id based on program order in 'f'.
|
||||
// TODO(andydavis) Add support for taking a Block arg to construct the
|
||||
// dependence graph at a different depth.
|
||||
bool MemRefDependenceGraph::init(Function &f) {
|
||||
bool MemRefDependenceGraph::init(Function f) {
|
||||
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
|
||||
|
||||
// TODO: support multi-block functions.
|
||||
|
@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
|
|||
// Create builder to insert alloc op just before 'forOp'.
|
||||
OpBuilder b(forInst);
|
||||
// Builder to create constants at the top level.
|
||||
OpBuilder top(forInst->getFunction()->getBody());
|
||||
OpBuilder top(forInst->getFunction().getBody());
|
||||
// Create new memref type based on slice bounds.
|
||||
auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
|
||||
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
|
||||
|
@ -1750,9 +1750,9 @@ public:
|
|||
};
|
||||
|
||||
// Search for siblings which load the same memref function argument.
|
||||
auto *fn = dstNode->op->getFunction();
|
||||
for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
|
||||
for (auto *user : fn->getArgument(i)->getUsers()) {
|
||||
auto fn = dstNode->op->getFunction();
|
||||
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
|
||||
for (auto *user : fn.getArgument(i)->getUsers()) {
|
||||
if (auto loadOp = dyn_cast<LoadOp>(user)) {
|
||||
// Gather loops surrounding 'use'.
|
||||
SmallVector<AffineForOp, 4> loops;
|
||||
|
|
|
@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
|
|||
// Identify valid and profitable bands of loops to tile. This is currently just
|
||||
// a temporary placeholder to test the mechanics of tiled code generation.
|
||||
// Returns all maximal outermost perfect loop nests to tile.
|
||||
static void getTileableBands(Function &f,
|
||||
static void getTileableBands(Function f,
|
||||
std::vector<SmallVector<AffineForOp, 6>> *bands) {
|
||||
// Get maximal perfect nest of 'affine.for' insts starting from root
|
||||
// (inclusive).
|
||||
|
|
|
@ -92,8 +92,8 @@ void LoopUnroll::runOnFunction() {
|
|||
// Store innermost loops as we walk.
|
||||
std::vector<AffineForOp> loops;
|
||||
|
||||
void walkPostOrder(Function *f) {
|
||||
for (auto &b : *f)
|
||||
void walkPostOrder(Function f) {
|
||||
for (auto &b : f)
|
||||
walkPostOrder(b.begin(), b.end());
|
||||
}
|
||||
|
||||
|
@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() {
|
|||
? clUnrollNumRepetitions
|
||||
: 1;
|
||||
// If the call back is provided, we will recurse until no loops are found.
|
||||
Function &func = getFunction();
|
||||
Function func = getFunction();
|
||||
for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
|
||||
InnermostLoopGatherer ilg;
|
||||
ilg.walkPostOrder(&func);
|
||||
ilg.walkPostOrder(func);
|
||||
auto &loops = ilg.loops;
|
||||
if (loops.empty())
|
||||
break;
|
||||
|
|
|
@ -726,7 +726,7 @@ public:
|
|||
|
||||
} // end namespace
|
||||
|
||||
LogicalResult mlir::lowerAffineConstructs(Function &function) {
|
||||
LogicalResult mlir::lowerAffineConstructs(Function function) {
|
||||
OwningRewritePatternList patterns;
|
||||
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
|
||||
AffineDmaWaitLowering, AffineLoadLowering,
|
||||
|
|
|
@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state,
|
|||
}
|
||||
|
||||
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
|
||||
LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
|
||||
LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs()));
|
||||
|
||||
// slice are topologically sorted, we can just erase them in reverse
|
||||
// order. Reverse iterator does not just work simply with an operator*
|
||||
|
@ -667,7 +667,7 @@ static bool emitSlice(MaterializationState *state,
|
|||
/// because we currently disallow vectorization of defs that come from another
|
||||
/// scope.
|
||||
/// TODO(ntv): please document return value.
|
||||
static bool materialize(Function *f, const SetVector<Operation *> &terminators,
|
||||
static bool materialize(Function f, const SetVector<Operation *> &terminators,
|
||||
MaterializationState *state) {
|
||||
DenseSet<Operation *> seen;
|
||||
DominanceInfo domInfo(f);
|
||||
|
@ -721,7 +721,7 @@ static bool materialize(Function *f, const SetVector<Operation *> &terminators,
|
|||
return true;
|
||||
}
|
||||
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
|
||||
LLVM_DEBUG(f->print(dbgs()));
|
||||
LLVM_DEBUG(f.print(dbgs()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -731,13 +731,13 @@ void MaterializeVectorsPass::runOnFunction() {
|
|||
NestedPatternContext mlContext;
|
||||
|
||||
// TODO(ntv): Check to see if this supports arbitrary top-level code.
|
||||
Function *f = &getFunction();
|
||||
if (f->getBlocks().size() != 1)
|
||||
Function f = getFunction();
|
||||
if (f.getBlocks().size() != 1)
|
||||
return;
|
||||
|
||||
using matcher::Op;
|
||||
LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
|
||||
LLVM_DEBUG(f->print(dbgs()));
|
||||
LLVM_DEBUG(f.print(dbgs()));
|
||||
|
||||
MaterializationState state(hwVectorSize);
|
||||
// Get the hardware vector type.
|
||||
|
|
|
@ -212,7 +212,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) {
|
|||
|
||||
void MemRefDataFlowOpt::runOnFunction() {
|
||||
// Only supports single block functions at the moment.
|
||||
Function &f = getFunction();
|
||||
Function f = getFunction();
|
||||
if (f.getBlocks().size() != 1) {
|
||||
markAllAnalysesPreserved();
|
||||
return;
|
||||
|
|
|
@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
|
|||
} // end anonymous namespace
|
||||
|
||||
void StripDebugInfo::runOnFunction() {
|
||||
Function &func = getFunction();
|
||||
Function func = getFunction();
|
||||
auto unknownLoc = UnknownLoc::get(&getContext());
|
||||
|
||||
// Strip the debug info from the function and its operations.
|
||||
|
|
|
@ -44,7 +44,7 @@ namespace {
|
|||
/// applies the locally optimal patterns in a roughly "bottom up" way.
|
||||
class GreedyPatternRewriteDriver : public PatternRewriter {
|
||||
public:
|
||||
explicit GreedyPatternRewriteDriver(Function &fn,
|
||||
explicit GreedyPatternRewriteDriver(Function fn,
|
||||
OwningRewritePatternList &&patterns)
|
||||
: PatternRewriter(fn.getBody()), matcher(std::move(patterns)) {
|
||||
worklist.reserve(64);
|
||||
|
@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
|
|||
/// patterns in a greedy work-list driven manner. Return true if no more
|
||||
/// patterns can be matched in the result function.
|
||||
///
|
||||
bool mlir::applyPatternsGreedily(Function &fn,
|
||||
bool mlir::applyPatternsGreedily(Function fn,
|
||||
OwningRewritePatternList &&patterns) {
|
||||
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
|
||||
bool converged = driver.simplifyFunction(maxPatternMatchIterations);
|
||||
|
|
|
@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
|
|||
Operation *op = forOp.getOperation();
|
||||
if (!iv->use_empty()) {
|
||||
if (forOp.hasConstantLowerBound()) {
|
||||
OpBuilder topBuilder(op->getFunction()->getBody());
|
||||
OpBuilder topBuilder(op->getFunction().getBody());
|
||||
auto constOp = topBuilder.create<ConstantIndexOp>(
|
||||
forOp.getLoc(), forOp.getConstantLowerBound());
|
||||
iv->replaceAllUsesWith(constOp);
|
||||
|
|
|
@ -1194,7 +1194,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m,
|
|||
/// Applies vectorization to the current Function by searching over a bunch of
|
||||
/// predetermined patterns.
|
||||
void Vectorize::runOnFunction() {
|
||||
Function &f = getFunction();
|
||||
Function f = getFunction();
|
||||
if (!fastestVaryingPattern.empty() &&
|
||||
fastestVaryingPattern.size() != vectorSizes.size()) {
|
||||
f.emitRemark("Fastest varying pattern specified with different size than "
|
||||
|
@ -1220,7 +1220,7 @@ void Vectorize::runOnFunction() {
|
|||
unsigned patternDepth = pat.getDepth();
|
||||
|
||||
SmallVector<NestedMatch, 8> matches;
|
||||
pat.match(&f, &matches);
|
||||
pat.match(f, &matches);
|
||||
// Iterate over all the top-level matches and vectorize eagerly.
|
||||
// This automatically prunes intersecting matches.
|
||||
for (auto m : matches) {
|
||||
|
|
|
@ -53,13 +53,13 @@ std::string DOTGraphTraits<Function *>::getNodeLabel(Block *Block, Function *) {
|
|||
|
||||
} // end namespace llvm
|
||||
|
||||
void mlir::viewGraph(Function &function, const llvm::Twine &name,
|
||||
void mlir::viewGraph(Function function, const llvm::Twine &name,
|
||||
bool shortNames, const llvm::Twine &title,
|
||||
llvm::GraphProgram::Name program) {
|
||||
llvm::ViewGraph(&function, name, shortNames, title, program);
|
||||
}
|
||||
|
||||
llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function,
|
||||
llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function,
|
||||
bool shortNames, const llvm::Twine &title) {
|
||||
return llvm::WriteGraph(os, &function, shortNames, title);
|
||||
}
|
||||
|
|
|
@ -43,13 +43,12 @@ static MLIRContext &globalContext() {
|
|||
return context;
|
||||
}
|
||||
|
||||
static std::unique_ptr<Function> makeFunction(StringRef name,
|
||||
ArrayRef<Type> results = {},
|
||||
ArrayRef<Type> args = {}) {
|
||||
static Function makeFunction(StringRef name, ArrayRef<Type> results = {},
|
||||
ArrayRef<Type> args = {}) {
|
||||
auto &ctx = globalContext();
|
||||
auto function = llvm::make_unique<Function>(
|
||||
UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx));
|
||||
function->addEntryBlock();
|
||||
auto function = Function::create(UnknownLoc::get(&ctx), name,
|
||||
FunctionType::get(args, results, &ctx));
|
||||
function.addEntryBlock();
|
||||
return function;
|
||||
}
|
||||
|
||||
|
@ -62,10 +61,10 @@ TEST_FUNC(builder_dynamic_for_func_args) {
|
|||
auto f =
|
||||
makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)),
|
||||
ub(f->getArgument(1));
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)),
|
||||
ub(f.getArgument(1));
|
||||
ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type));
|
||||
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
|
||||
ValueHandle i7(constant_int(7, 32));
|
||||
|
@ -102,7 +101,8 @@ TEST_FUNC(builder_dynamic_for_func_args) {
|
|||
// CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32
|
||||
// CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(builder_dynamic_for) {
|
||||
|
@ -113,10 +113,10 @@ TEST_FUNC(builder_dynamic_for) {
|
|||
auto f = makeFunction("builder_dynamic_for", {},
|
||||
{indexType, indexType, indexType, indexType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
|
||||
c(f->getArgument(2)), d(f->getArgument(3));
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
|
||||
c(f.getArgument(2)), d(f.getArgument(3));
|
||||
LoopBuilder(&i, a - b, c + d, 2)();
|
||||
|
||||
// clang-format off
|
||||
|
@ -125,7 +125,8 @@ TEST_FUNC(builder_dynamic_for) {
|
|||
// CHECK-DAG: [[r1:%[0-9]+]] = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3]
|
||||
// CHECK-NEXT: affine.for %i0 = (d0) -> (d0)([[r0]]) to (d0) -> (d0)([[r1]]) step 2 {
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(builder_max_min_for) {
|
||||
|
@ -136,10 +137,10 @@ TEST_FUNC(builder_max_min_for) {
|
|||
auto f = makeFunction("builder_max_min_for", {},
|
||||
{indexType, indexType, indexType, indexType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
|
||||
ub1(f->getArgument(2)), ub2(f->getArgument(3));
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)),
|
||||
ub1(f.getArgument(2)), ub2(f.getArgument(3));
|
||||
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
|
||||
ret();
|
||||
|
||||
|
@ -148,7 +149,8 @@ TEST_FUNC(builder_max_min_for) {
|
|||
// CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) {
|
||||
// CHECK: return
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(builder_blocks) {
|
||||
|
@ -157,14 +159,14 @@ TEST_FUNC(builder_blocks) {
|
|||
using namespace edsc::op;
|
||||
auto f = makeFunction("builder_blocks");
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
|
||||
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
|
||||
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
|
||||
arg4(c1.getType()), r(c1.getType());
|
||||
|
||||
BlockHandle b1, b2, functionBlock(&f->front());
|
||||
BlockHandle b1, b2, functionBlock(&f.front());
|
||||
BlockBuilder(&b1, {&arg1, &arg2})(
|
||||
// b2 has not yet been constructed, need to come back later.
|
||||
// This is a byproduct of non-structured control-flow.
|
||||
|
@ -192,7 +194,8 @@ TEST_FUNC(builder_blocks) {
|
|||
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
|
||||
// CHECK-NEXT: }
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(builder_blocks_eager) {
|
||||
|
@ -201,8 +204,8 @@ TEST_FUNC(builder_blocks_eager) {
|
|||
using namespace edsc::op;
|
||||
auto f = makeFunction("builder_blocks_eager");
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
|
||||
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
|
||||
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
|
||||
|
@ -235,7 +238,8 @@ TEST_FUNC(builder_blocks_eager) {
|
|||
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
|
||||
// CHECK-NEXT: }
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(builder_cond_branch) {
|
||||
|
@ -244,15 +248,15 @@ TEST_FUNC(builder_cond_branch) {
|
|||
auto f = makeFunction("builder_cond_branch", {},
|
||||
{IntegerType::get(1, &globalContext())});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
ValueHandle funcArg(f->getArgument(0));
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle funcArg(f.getArgument(0));
|
||||
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
|
||||
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
|
||||
c42(ValueHandle::create<ConstantIntOp>(42, 32));
|
||||
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
|
||||
|
||||
BlockHandle b1, b2, functionBlock(&f->front());
|
||||
BlockHandle b1, b2, functionBlock(&f.front());
|
||||
BlockBuilder(&b1, {&arg1})([&] { ret(); });
|
||||
BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); });
|
||||
// Get back to entry block and add a conditional branch
|
||||
|
@ -271,7 +275,8 @@ TEST_FUNC(builder_cond_branch) {
|
|||
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
|
||||
// CHECK-NEXT: return
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(builder_cond_branch_eager) {
|
||||
|
@ -281,9 +286,9 @@ TEST_FUNC(builder_cond_branch_eager) {
|
|||
auto f = makeFunction("builder_cond_branch_eager", {},
|
||||
{IntegerType::get(1, &globalContext())});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
ValueHandle funcArg(f->getArgument(0));
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle funcArg(f.getArgument(0));
|
||||
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
|
||||
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
|
||||
c42(ValueHandle::create<ConstantIntOp>(42, 32));
|
||||
|
@ -309,7 +314,8 @@ TEST_FUNC(builder_cond_branch_eager) {
|
|||
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
|
||||
// CHECK-NEXT: return
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(builder_helpers) {
|
||||
|
@ -321,14 +327,14 @@ TEST_FUNC(builder_helpers) {
|
|||
auto f =
|
||||
makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle f7(
|
||||
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
|
||||
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
|
||||
vC(f->getArgument(2));
|
||||
IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
|
||||
MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)),
|
||||
vC(f.getArgument(2));
|
||||
IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2;
|
||||
int64_t step0, step1, step2;
|
||||
std::tie(lb0, ub0, step0) = vA.range(0);
|
||||
|
@ -363,7 +369,8 @@ TEST_FUNC(builder_helpers) {
|
|||
// CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32
|
||||
// CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref<?x?x?xf32>
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(custom_ops) {
|
||||
|
@ -373,8 +380,8 @@ TEST_FUNC(custom_ops) {
|
|||
auto indexType = IndexType::get(&globalContext());
|
||||
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
|
||||
CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
|
||||
CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
|
||||
|
@ -382,7 +389,7 @@ TEST_FUNC(custom_ops) {
|
|||
// clang-format off
|
||||
ValueHandle vh(indexType), vh20(indexType), vh21(indexType);
|
||||
OperationHandle ih0, ih2;
|
||||
IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
|
||||
IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1));
|
||||
IndexHandle ten(index_t(10)), twenty(index_t(20));
|
||||
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
|
||||
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
|
||||
|
@ -402,7 +409,8 @@ TEST_FUNC(custom_ops) {
|
|||
// CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
|
||||
// CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(insertion_in_block) {
|
||||
|
@ -412,8 +420,8 @@ TEST_FUNC(insertion_in_block) {
|
|||
auto indexType = IndexType::get(&globalContext());
|
||||
auto f = makeFunction("insertion_in_block", {}, {indexType, indexType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
BlockHandle b1;
|
||||
// clang-format off
|
||||
ValueHandle::create<ConstantIntOp>(0, 32);
|
||||
|
@ -427,7 +435,8 @@ TEST_FUNC(insertion_in_block) {
|
|||
// CHECK: ^bb1: // no predecessors
|
||||
// CHECK: {{.*}} = constant 1 : i32
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
TEST_FUNC(select_op) {
|
||||
|
@ -438,12 +447,12 @@ TEST_FUNC(select_op) {
|
|||
auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
|
||||
auto f = makeFunction("select_op", {}, {memrefType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
// clang-format off
|
||||
ValueHandle zero = constant_index(0), one = constant_index(1);
|
||||
MemRefView vA(f->getArgument(0));
|
||||
IndexedValue A(f->getArgument(0));
|
||||
MemRefView vA(f.getArgument(0));
|
||||
IndexedValue A(f.getArgument(0));
|
||||
IndexHandle i, j;
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
|
||||
// This test exercises IndexedValue::operator Value*.
|
||||
|
@ -461,7 +470,8 @@ TEST_FUNC(select_op) {
|
|||
// CHECK-DAG: {{.*}} = load
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
// Inject an EDSC-constructed computation to exercise imperfectly nested 2-d
|
||||
|
@ -474,12 +484,11 @@ TEST_FUNC(tile_2d) {
|
|||
MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
|
||||
auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType});
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle zero = constant_index(0);
|
||||
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
|
||||
vC(f->getArgument(2));
|
||||
IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
|
||||
MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
|
||||
IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
|
||||
|
||||
// clang-format off
|
||||
|
@ -531,7 +540,8 @@ TEST_FUNC(tile_2d) {
|
|||
// CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32
|
||||
// CHECK-NEXT: store {{.*}}, {{.*}}[%i8, %i9, %i7] : memref<?x?x?xf32>
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
// Inject an EDSC-constructed computation to exercise 2-d vectorization.
|
||||
|
@ -544,16 +554,15 @@ TEST_FUNC(vectorize_2d) {
|
|||
auto owningF =
|
||||
makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
|
||||
|
||||
mlir::Function *f = owningF.release();
|
||||
mlir::Function f = owningF;
|
||||
mlir::Module module(&globalContext());
|
||||
module.getFunctions().push_back(f);
|
||||
module.push_back(f);
|
||||
|
||||
OpBuilder builder(f->getBody());
|
||||
ScopedContext scope(builder, f->getLoc());
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
ValueHandle zero = constant_index(0);
|
||||
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
|
||||
vC(f->getArgument(2));
|
||||
IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
|
||||
MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
|
||||
IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
|
||||
IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2));
|
||||
|
||||
// clang-format off
|
||||
|
@ -580,9 +589,10 @@ TEST_FUNC(vectorize_2d) {
|
|||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
SmallVector<int64_t, 2> vectorSizes{4, 4};
|
||||
pm.addPass(mlir::createVectorizePass(vectorSizes));
|
||||
auto result = pm.run(f->getModule());
|
||||
auto result = pm.run(f.getModule());
|
||||
if (succeeded(result))
|
||||
f->print(llvm::outs());
|
||||
f.print(llvm::outs());
|
||||
f.erase();
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue