NFC: Refactor Function to be value typed.

Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed).

PiperOrigin-RevId: 255983022
This commit is contained in:
River Riddle 2019-07-01 10:29:09 -07:00 committed by jpienaar
parent 84bd67fc4f
commit 54cd6a7e97
103 changed files with 987 additions and 875 deletions

View File

@ -17,6 +17,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
@ -110,13 +111,14 @@ struct PythonValueHandle {
struct PythonFunction {
PythonFunction() : function{nullptr} {}
PythonFunction(mlir_func_t f) : function{f} {}
PythonFunction(mlir::Function *f) : function{f} {}
PythonFunction(mlir::Function f)
: function(const_cast<void *>(f.getAsOpaquePointer())) {}
operator mlir_func_t() { return function; }
std::string str() {
mlir::Function *f = reinterpret_cast<mlir::Function *>(function);
mlir::Function f = mlir::Function::getFromOpaquePointer(function);
std::string res;
llvm::raw_string_ostream os(res);
f->print(os);
f.print(os);
return res;
}
@ -124,18 +126,18 @@ struct PythonFunction {
// declaration, add the entry block, transforming the declaration into a
// definition. Return true if the block was added, false otherwise.
bool define() {
auto *f = reinterpret_cast<mlir::Function *>(function);
if (!f->getBlocks().empty())
auto f = mlir::Function::getFromOpaquePointer(function);
if (!f.getBlocks().empty())
return false;
f->addEntryBlock();
f.addEntryBlock();
return true;
}
PythonValueHandle arg(unsigned index) {
Function *f = static_cast<Function *>(function);
assert(index < f->getNumArguments() && "argument index out of bounds");
return PythonValueHandle(ValueHandle(f->getArgument(index)));
auto f = mlir::Function::getFromOpaquePointer(function);
assert(index < f.getNumArguments() && "argument index out of bounds");
return PythonValueHandle(ValueHandle(f.getArgument(index)));
}
mlir_func_t function;
@ -250,10 +252,9 @@ struct PythonFunctionContext {
PythonFunction enter() {
assert(function.function && "function is not set up");
auto *mlirFunc = static_cast<mlir::Function *>(function.function);
contextBuilder.emplace(mlirFunc->getBody());
context =
new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc->getLoc());
auto mlirFunc = mlir::Function::getFromOpaquePointer(function.function);
contextBuilder.emplace(mlirFunc.getBody());
context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc());
return function;
}
@ -594,7 +595,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
}
// Create the function itself.
auto *func = new mlir::Function(
auto func = mlir::Function::create(
UnknownLoc::get(&mlirContext), name,
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
inputAttrs);
@ -652,9 +653,9 @@ PYBIND11_MODULE(pybind, m) {
return ValueHandle::create<ConstantFloatOp>(value, floatType);
});
m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
auto *function = reinterpret_cast<Function *>(func.function);
auto attr = FunctionAttr::get(function);
return ValueHandle::create<ConstantOp>(function->getType(), attr);
auto function = Function::getFromOpaquePointer(func.function);
auto attr = FunctionAttr::get(function.getName(), function.getContext());
return ValueHandle::create<ConstantOp>(function.getType(), attr);
});
m.def("appendTo", [](const PythonBlockHandle &handle) {
return PythonBlockAppender(handle);

View File

@ -57,15 +57,15 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
}
/// A basic function builder
inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name,
llvm::ArrayRef<mlir::Type> types,
llvm::ArrayRef<mlir::Type> resultTypes) {
inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name,
llvm::ArrayRef<mlir::Type> types,
llvm::ArrayRef<mlir::Type> resultTypes) {
auto *context = module.getContext();
auto *function = new mlir::Function(
auto function = mlir::Function::create(
mlir::UnknownLoc::get(context), name,
mlir::FunctionType::get({types}, resultTypes, context));
function->addEntryBlock();
module.getFunctions().push_back(function);
function.addEntryBlock();
module.push_back(function);
return function;
}
@ -83,19 +83,19 @@ inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
/// llvm::outs() for FileCheck'ing.
/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs()
/// which will make the associated FileCheck test fail.
inline void cleanupAndPrintFunction(mlir::Function *f) {
inline void cleanupAndPrintFunction(mlir::Function f) {
bool printToOuts = true;
auto check = [f, &printToOuts](mlir::LogicalResult result) {
auto check = [&f, &printToOuts](mlir::LogicalResult result) {
if (failed(result)) {
f->emitError("Verification and cleanup passes failed");
f.emitError("Verification and cleanup passes failed");
printToOuts = false;
}
};
auto pm = cleanupPassManager();
check(f->getModule()->verify());
check(pm->run(f->getModule()));
check(f.getModule()->verify());
check(pm->run(f.getModule()));
if (printToOuts)
f->print(llvm::outs());
f.print(llvm::outs());
}
/// Helper class to sugar building loop nests from indexings that appear in

View File

@ -36,14 +36,14 @@ TEST_FUNC(linalg_ops) {
MLIRContext context;
Module module(&context);
auto indexType = mlir::IndexType::get(&context);
mlir::Function *f =
mlir::Function f =
makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
@ -75,14 +75,14 @@ TEST_FUNC(linalg_ops_folded_slices) {
MLIRContext context;
Module module(&context);
auto indexType = mlir::IndexType::get(&context);
mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices",
{indexType, indexType, indexType}, {});
mlir::Function f = makeFunction(module, "linalg_ops_folded_slices",
{indexType, indexType, indexType}, {});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)),
ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
@ -104,7 +104,7 @@ TEST_FUNC(linalg_ops_folded_slices) {
// CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32>
// clang-format on
f->walk<SliceOp>([](SliceOp slice) {
f.walk<SliceOp>([](SliceOp slice) {
auto *sliceResult = slice.getResult();
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
sliceResult->replaceAllUsesWith(viewOp.getResult());

View File

@ -37,26 +37,26 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function *f = linalg::common::makeFunction(
mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
N = dim(f->getArgument(2), 1),
K = dim(f->getArgument(0), 1),
M = dim(f.getArgument(0), 0),
N = dim(f.getArgument(2), 1),
K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
vA = view(f->getArgument(0), {rM, rK}),
vB = view(f->getArgument(1), {rK, rN}),
vC = view(f->getArgument(2), {rM, rN});
vA = view(f.getArgument(0), {rM, rK}),
vB = view(f.getArgument(1), {rK, rN}),
vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@ -67,7 +67,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
TEST_FUNC(foo) {
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
convertLinalg3ToLLVM(module);

View File

@ -34,26 +34,26 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function *f = linalg::common::makeFunction(
mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
mlir::OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
mlir::OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
N = dim(f->getArgument(2), 1),
K = dim(f->getArgument(0), 1),
M = dim(f.getArgument(0), 0),
N = dim(f.getArgument(2), 1),
K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
vA = view(f->getArgument(0), {rM, rK}),
vB = view(f->getArgument(1), {rK, rN}),
vC = view(f->getArgument(2), {rM, rN});
vA = view(f.getArgument(0), {rM, rK}),
vB = view(f.getArgument(1), {rK, rN}),
vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@ -64,7 +64,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
TEST_FUNC(matmul_as_matvec) {
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
// clang-format off
@ -82,7 +82,7 @@ TEST_FUNC(matmul_as_matvec) {
TEST_FUNC(matmul_as_dot) {
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
lowerToFinerGrainedTensorContraction(f);
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
@ -103,7 +103,7 @@ TEST_FUNC(matmul_as_dot) {
TEST_FUNC(matmul_as_loops) {
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
composeSliceOps(f);
// clang-format off
@ -135,7 +135,7 @@ TEST_FUNC(matmul_as_loops) {
TEST_FUNC(matmul_as_matvec_as_loops) {
MLIRContext context;
Module module(&context);
mlir::Function *f =
mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
lowerToFinerGrainedTensorContraction(f);
lowerToLoops(f);
@ -166,14 +166,14 @@ TEST_FUNC(matmul_as_matvec_as_loops) {
TEST_FUNC(matmul_as_matvec_as_affine) {
MLIRContext context;
Module module(&context);
mlir::Function *f =
mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine");
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
lowerToLoops(f);
PassManager pm;
pm.addPass(createLowerLinalgLoadStorePass());
if (succeeded(pm.run(f->getModule())))
if (succeeded(pm.run(f.getModule())))
cleanupAndPrintFunction(f);
// clang-format off

View File

@ -37,26 +37,26 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function *f = linalg::common::makeFunction(
mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
mlir::OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
mlir::OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
N = dim(f->getArgument(2), 1),
K = dim(f->getArgument(0), 1),
M = dim(f.getArgument(0), 0),
N = dim(f.getArgument(2), 1),
K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
vA = view(f->getArgument(0), {rM, rK}),
vB = view(f->getArgument(1), {rK, rN}),
vC = view(f->getArgument(2), {rM, rN});
vA = view(f.getArgument(0), {rM, rK}),
vB = view(f.getArgument(1), {rK, rN}),
vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@ -110,7 +110,7 @@ TEST_FUNC(execution) {
// dialect through partial conversions.
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
convertLinalg3ToLLVM(module);

View File

@ -55,11 +55,11 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps,
/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
/// to only use linalg.view operations.
void composeSliceOps(mlir::Function *f);
void composeSliceOps(mlir::Function f);
/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec)
/// as linalg.matvec(resp. linalg.dot).
void lowerToFinerGrainedTensorContraction(mlir::Function *f);
void lowerToFinerGrainedTensorContraction(mlir::Function f);
/// Operation-wise writing of linalg operations to loop form.
/// It is the caller's responsibility to erase the `op` if necessary.
@ -69,7 +69,7 @@ llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>>
writeAsLoops(mlir::Operation *op);
/// Traverses `f` and rewrites linalg operations in loop form.
void lowerToLoops(mlir::Function *f);
void lowerToLoops(mlir::Function f);
/// Creates a pass that rewrites linalg.load and linalg.store to affine.load and
/// affine.store operations.

View File

@ -148,7 +148,7 @@ static void populateLinalg3ToLLVMConversionPatterns(
void linalg::convertLinalg3ToLLVM(Module &module) {
// Remove affine constructs.
for (auto &func : module) {
for (auto func : module) {
auto rr = lowerAffineConstructs(func);
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");

View File

@ -35,8 +35,8 @@ using namespace mlir::edsc::intrinsics;
using namespace linalg;
using namespace linalg::intrinsics;
void linalg::composeSliceOps(mlir::Function *f) {
f->walk<SliceOp>([](SliceOp sliceOp) {
void linalg::composeSliceOps(mlir::Function f) {
f.walk<SliceOp>([](SliceOp sliceOp) {
auto *sliceResult = sliceOp.getResult();
auto viewOp = emitAndReturnFullyComposedView(sliceResult);
sliceResult->replaceAllUsesWith(viewOp.getResult());
@ -44,8 +44,8 @@ void linalg::composeSliceOps(mlir::Function *f) {
});
}
void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
f->walk([](Operation *op) {
void linalg::lowerToFinerGrainedTensorContraction(mlir::Function f) {
f.walk([](Operation *op) {
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
matmulOp.writeAsFinerGrainTensorContraction();
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {
@ -211,8 +211,8 @@ linalg::writeAsLoops(Operation *op) {
return llvm::None;
}
void linalg::lowerToLoops(mlir::Function *f) {
f->walk([](Operation *op) {
void linalg::lowerToLoops(mlir::Function f) {
f.walk([](Operation *op) {
if (writeAsLoops(op))
op->erase();
});

View File

@ -34,27 +34,27 @@ using namespace linalg;
using namespace linalg::common;
using namespace linalg::intrinsics;
Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function *f = linalg::common::makeFunction(
mlir::Function f = linalg::common::makeFunction(
module, name,
{dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle
M = dim(f->getArgument(0), 0),
N = dim(f->getArgument(2), 1),
K = dim(f->getArgument(0), 1),
M = dim(f.getArgument(0), 0),
N = dim(f.getArgument(2), 1),
K = dim(f.getArgument(0), 1),
rM = range(constant_index(0), M, constant_index(1)),
rN = range(constant_index(0), N, constant_index(1)),
rK = range(constant_index(0), K, constant_index(1)),
vA = view(f->getArgument(0), {rM, rK}),
vB = view(f->getArgument(1), {rK, rN}),
vC = view(f->getArgument(2), {rM, rN});
vA = view(f.getArgument(0), {rM, rK}),
vB = view(f.getArgument(1), {rK, rN}),
vC = view(f.getArgument(2), {rM, rN});
matmul(vA, vB, vC);
ret();
// clang-format on
@ -65,11 +65,11 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
TEST_FUNC(matmul_tiled_loops) {
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
lowerToTiledLoops(f, {8, 9});
PassManager pm;
pm.addPass(createLowerLinalgLoadStorePass());
if (succeeded(pm.run(f->getModule())))
if (succeeded(pm.run(f.getModule())))
cleanupAndPrintFunction(f);
// clang-format off
@ -96,10 +96,10 @@ TEST_FUNC(matmul_tiled_loops) {
TEST_FUNC(matmul_tiled_views) {
MLIRContext context;
Module module(&context);
mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
OpBuilder b(f->getBody());
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
b.create<ConstantIndexOp>(f->getLoc(), 9)});
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
OpBuilder b(f.getBody());
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
b.create<ConstantIndexOp>(f.getLoc(), 9)});
composeSliceOps(f);
// clang-format off
@ -125,11 +125,11 @@ TEST_FUNC(matmul_tiled_views) {
TEST_FUNC(matmul_tiled_views_as_loops) {
MLIRContext context;
Module module(&context);
mlir::Function *f =
mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops");
OpBuilder b(f->getBody());
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8),
b.create<ConstantIndexOp>(f->getLoc(), 9)});
OpBuilder b(f.getBody());
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
b.create<ConstantIndexOp>(f.getLoc(), 9)});
composeSliceOps(f);
lowerToLoops(f);
// This cannot lower below linalg.load and linalg.store due to lost

View File

@ -34,12 +34,12 @@ writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef<mlir::Value *> tileSizes);
/// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
/// in a function can generally not be specified.
void lowerToTiledLoops(mlir::Function *f, llvm::ArrayRef<uint64_t> tileSizes);
void lowerToTiledLoops(mlir::Function f, llvm::ArrayRef<uint64_t> tileSizes);
/// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function
/// and is not exposed as a pass because a fixed set of tile sizes for all ops
/// in a function can generally not be specified.
void lowerToTiledViews(mlir::Function *f,
void lowerToTiledViews(mlir::Function f,
llvm::ArrayRef<mlir::Value *> tileSizes);
} // namespace linalg

View File

@ -43,9 +43,8 @@ linalg::writeAsTiledLoops(Operation *op, ArrayRef<uint64_t> tileSizes) {
return llvm::None;
}
void linalg::lowerToTiledLoops(mlir::Function *f,
ArrayRef<uint64_t> tileSizes) {
f->walk([tileSizes](Operation *op) {
void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef<uint64_t> tileSizes) {
f.walk([tileSizes](Operation *op) {
if (writeAsTiledLoops(op, tileSizes).hasValue())
op->erase();
});
@ -185,8 +184,8 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) {
return llvm::None;
}
void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef<Value *> tileSizes) {
f->walk([tileSizes](Operation *op) {
void linalg::lowerToTiledViews(mlir::Function f, ArrayRef<Value *> tileSizes) {
f.walk([tileSizes](Operation *op) {
if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) {
writeAsTiledViews(matmulOp, tileSizes);
} else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) {

View File

@ -75,7 +75,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
theModule->getFunctions().push_back(func.release());
theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@ -129,40 +129,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::Function *mlirGen(PrototypeAST &proto) {
mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function->getNumArguments())
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
if (function.getNumArguments())
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
function->addEntryBlock();
function.addEntryBlock();
auto &entryBlock = function->front();
auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@ -172,16 +172,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
if (!mlirGen(*funcAST.getBody()))
if (!mlirGen(*funcAST.getBody())) {
function.erase();
return nullptr;
}
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
if (function->getBlocks().back().back().getName().getStringRef() !=
if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);

View File

@ -76,7 +76,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
theModule->getFunctions().push_back(func.release());
theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@ -130,40 +130,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::Function *mlirGen(PrototypeAST &proto) {
mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function->getNumArguments())
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
if (function.getNumArguments())
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
function->addEntryBlock();
function.addEntryBlock();
auto &entryBlock = function->front();
auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@ -173,16 +173,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
if (!mlirGen(*funcAST.getBody()))
if (!mlirGen(*funcAST.getBody())) {
function.erase();
return nullptr;
}
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
if (function->getBlocks().back().back().getName().getStringRef() !=
if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);

View File

@ -76,7 +76,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
theModule->getFunctions().push_back(func.release());
theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@ -130,40 +130,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::Function *mlirGen(PrototypeAST &proto) {
mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function->getNumArguments())
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
if (function.getNumArguments())
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
function->addEntryBlock();
function.addEntryBlock();
auto &entryBlock = function->front();
auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@ -173,16 +173,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
if (!mlirGen(*funcAST.getBody()))
if (!mlirGen(*funcAST.getBody())) {
function.erase();
return nullptr;
}
// Implicitly return void if no return statement was emited.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
if (function->getBlocks().back().back().getName().getStringRef() !=
if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);

View File

@ -113,14 +113,14 @@ public:
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
mlir::Function *function;
mlir::Function function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
void runOnModule() override {
auto &module = getModule();
auto *main = module.getNamedFunction("main");
auto main = module.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
@ -139,7 +139,7 @@ public:
// Delete any generic function left
// FIXME: we may want this as a separate pass.
for (mlir::Function &function : llvm::make_early_inc_range(module)) {
for (mlir::Function function : llvm::make_early_inc_range(module)) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
@ -153,7 +153,7 @@ public:
mlir::LogicalResult
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
mlir::Function *f = functionToSpecialize.function;
mlir::Function f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
@ -169,36 +169,36 @@ public:
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
auto *newFunction = new mlir::Function(
f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
getModule().getFunctions().push_back(newFunction);
auto newFunction = mlir::Function::create(
f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs());
getModule().push_back(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
f->cloneInto(newFunction, mapper);
f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
f->dump();
f.dump();
llvm::dbgs() << "====== Into : \n";
newFunction->dump();
newFunction.dump();
});
f = newFunction;
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
auto &entryBlock = f->getBlocks().front();
auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
assert(blockArgSize == static_cast<int>(f->getType().getInputs().size()));
entryBlock.addArguments(f->getType().getInputs());
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
assert(succeeded(f->verify()));
assert(succeeded(f.verify()));
}
LLVM_DEBUG(llvm::dbgs()
<< "Run shape inference on : '" << f->getName() << "'\n");
<< "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
@ -211,7 +211,7 @@ public:
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f->walk([&](mlir::Operation *op) {
f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
@ -292,9 +292,9 @@ public:
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
auto *callee = getModule().getNamedFunction(calleeName);
auto callee = getModule().getNamedFunction(calleeName);
if (!callee) {
f->emitError("Shape inference failed, call to unknown '")
f.emitError("Shape inference failed, call to unknown '")
<< calleeName << "'";
signalPassFailure();
return mlir::failure();
@ -302,7 +302,7 @@ public:
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
auto *mangledCallee = getModule().getNamedFunction(mangledName);
auto mangledCallee = getModule().getNamedFunction(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
@ -327,7 +327,7 @@ public:
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
@ -337,31 +337,31 @@ public:
<< " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
errorMsg << " - " << *ope << "\n";
f->emitError(errorMsg.str());
f.emitError(errorMsg.str());
signalPassFailure();
return mlir::failure();
}
// Finally, update the return type of the function based on the argument to
// the return operation.
for (auto &block : f->getBlocks()) {
for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
f->getType().getResult(0) == ret.getOperand()->getType())
f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
for (auto arg : f->getArguments())
for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
f->setType(newType);
assert(succeeded(f->verify()));
f.setType(newType);
assert(succeeded(f.verify()));
break;
}
return mlir::success();

View File

@ -136,14 +136,14 @@ public:
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
Function *printfFunc = getPrintf(*op->getFunction()->getModule());
Function printfFunc = getPrintf(*op->getFunction().getModule());
auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
// We will operate on a MemRef abstraction, we use a type.cast to get one
// if our operand is still a Toy array.
Value *operand = memRefTypeCast(rewriter, operands[0]);
Type retTy = printfFunc->getType().getResult(0);
Type retTy = printfFunc.getType().getResult(0);
// Create our loop nest now
using namespace edsc;
@ -205,8 +205,8 @@ private:
/// Return the prototype declaration for printf in the module, create it if
/// necessary.
Function *getPrintf(Module &module) const {
auto *printfFunc = module.getNamedFunction("printf");
Function getPrintf(Module &module) const {
auto printfFunc = module.getNamedFunction("printf");
if (printfFunc)
return printfFunc;
@ -218,10 +218,10 @@ private:
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
printfFunc = Function::create(builder.getUnknownLoc(), "printf", printfTy);
// It should be variadic, but we don't support it fully just yet.
printfFunc->setAttr("std.varargs", builder.getBoolAttr(true));
module.getFunctions().push_back(printfFunc);
printfFunc.setAttr("std.varargs", builder.getBoolAttr(true));
module.push_back(printfFunc);
return printfFunc;
}
};
@ -369,7 +369,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
// affine dialect: they already include conversion to the LLVM dialect.
// First patch calls type to return memref instead of ToyArray
for (auto &function : getModule()) {
for (auto function : getModule()) {
function.walk([&](Operation *op) {
auto callOp = dyn_cast<CallOp>(op);
if (!callOp)
@ -384,7 +384,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
});
}
for (auto &function : getModule()) {
for (auto function : getModule()) {
function.walk([&](Operation *op) {
// Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
if (auto allocOp = dyn_cast<toy::AllocOp>(op)) {
@ -403,8 +403,8 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> {
}
// Lower Linalg to affine
for (auto &function : getModule())
linalg::lowerToLoops(&function);
for (auto function : getModule())
linalg::lowerToLoops(function);
getModule().dump();

View File

@ -76,7 +76,7 @@ public:
auto func = mlirGen(F);
if (!func)
return nullptr;
theModule->getFunctions().push_back(func.release());
theModule->push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
@ -130,40 +130,40 @@ private:
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::Function *mlirGen(PrototypeAST &proto) {
mlir::Function mlirGen(PrototypeAST &proto) {
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
auto function = mlir::Function::create(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function->getNumArguments())
function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
if (function.getNumArguments())
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
return function;
}
/// Emit a new function and add it to the MLIR module.
std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
mlir::Function mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
// Create an MLIR function for the given prototype.
std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
mlir::Function function(mlirGen(*funcAST.getProto()));
if (!function)
return nullptr;
// Let's start the body of the function now!
// In MLIR the entry block of the function is special: it must have the same
// argument list as the function itself.
function->addEntryBlock();
function.addEntryBlock();
auto &entryBlock = function->front();
auto &entryBlock = function.front();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
@ -173,16 +173,18 @@ private:
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
// Emit the body of the function.
if (!mlirGen(*funcAST.getBody()))
if (!mlirGen(*funcAST.getBody())) {
function.erase();
return nullptr;
}
// Implicitly return void if no return statement was emited.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
if (function->getBlocks().back().back().getName().getStringRef() !=
if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);

View File

@ -113,7 +113,7 @@ public:
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
mlir::Function *function;
mlir::Function function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
@ -121,7 +121,7 @@ public:
void runOnModule() override {
auto &module = getModule();
mlir::ModuleManager moduleManager(&module);
auto *main = moduleManager.getNamedFunction("main");
auto main = moduleManager.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
@ -140,7 +140,7 @@ public:
// Delete any generic function left
// FIXME: we may want this as a separate pass.
for (mlir::Function &function : llvm::make_early_inc_range(module)) {
for (mlir::Function function : llvm::make_early_inc_range(module)) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
@ -155,7 +155,7 @@ public:
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
mlir::ModuleManager &moduleManager) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
mlir::Function *f = functionToSpecialize.function;
mlir::Function f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
@ -171,36 +171,36 @@ public:
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
auto *newFunction = new mlir::Function(
f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
auto newFunction = mlir::Function::create(
f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs());
moduleManager.insert(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
f->cloneInto(newFunction, mapper);
f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
f->dump();
f.dump();
llvm::dbgs() << "====== Into : \n";
newFunction->dump();
newFunction.dump();
});
f = newFunction;
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
auto &entryBlock = f->getBlocks().front();
auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
assert(blockArgSize == static_cast<int>(f->getType().getInputs().size()));
entryBlock.addArguments(f->getType().getInputs());
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
assert(succeeded(f->verify()));
assert(succeeded(f.verify()));
}
LLVM_DEBUG(llvm::dbgs()
<< "Run shape inference on : '" << f->getName() << "'\n");
<< "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
@ -212,7 +212,7 @@ public:
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f->walk([&](mlir::Operation *op) {
f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
@ -295,16 +295,16 @@ public:
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
auto *callee = moduleManager.getNamedFunction(calleeName);
auto callee = moduleManager.getNamedFunction(calleeName);
if (!callee) {
signalPassFailure();
return f->emitError("Shape inference failed, call to unknown '")
return f.emitError("Shape inference failed, call to unknown '")
<< calleeName << "'";
}
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
auto *mangledCallee = moduleManager.getNamedFunction(mangledName);
auto mangledCallee = moduleManager.getNamedFunction(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
@ -315,7 +315,7 @@ public:
// Found a specialized callee! Let's turn this into a normal call
// operation.
SmallVector<mlir::Value *, 8> operands(op->getOperands());
mlir::OpBuilder builder(f->getBody());
mlir::OpBuilder builder(f.getBody());
builder.setInsertionPoint(op);
auto newCall =
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
@ -330,12 +330,12 @@ public:
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
signalPassFailure();
auto diag = f->emitError("Shape inference failed, ")
auto diag = f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
diag << " - " << *ope << "\n";
@ -344,24 +344,24 @@ public:
// Finally, update the return type of the function based on the argument to
// the return operation.
for (auto &block : f->getBlocks()) {
for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
f->getType().getResult(0) == ret.getOperand()->getType())
f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
for (auto arg : f->getArguments())
for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
f->setType(newType);
assert(succeeded(f->verify()));
f.setType(newType);
assert(succeeded(f.verify()));
break;
}
return mlir::success();

View File

@ -34,7 +34,7 @@ template <bool IsPostDom> class DominanceInfoBase {
using base = llvm::DominatorTreeBase<Block, IsPostDom>;
public:
DominanceInfoBase(Function *function) { recalculate(function); }
DominanceInfoBase(Function function) { recalculate(function); }
DominanceInfoBase(Operation *op) { recalculate(op); }
DominanceInfoBase(DominanceInfoBase &&) = default;
DominanceInfoBase &operator=(DominanceInfoBase &&) = default;
@ -43,7 +43,7 @@ public:
DominanceInfoBase &operator=(const DominanceInfoBase &) = delete;
/// Recalculate the dominance info.
void recalculate(Function *function);
void recalculate(Function function);
void recalculate(Operation *op);
/// Get the root dominance node of the given region.

View File

@ -104,8 +104,8 @@ struct NestedPattern {
NestedPattern &operator=(const NestedPattern &) = default;
/// Returns all the top-level matches in `func`.
void match(Function *func, SmallVectorImpl<NestedMatch> *matches) {
func->walk([&](Operation *op) { matchOne(op, matches); });
void match(Function func, SmallVectorImpl<NestedMatch> *matches) {
func.walk([&](Operation *op) { matchOne(op, matches); });
}
/// Returns all the top-level matches in `op`.

View File

@ -44,7 +44,7 @@ struct StaticFloatMemRef {
/// each of the arguments, initialize the storage with `initialValue`, and
/// return a list of type-erased descriptor pointers.
llvm::Expected<SmallVector<void *, 8>>
allocateMemRefArguments(Function *func, float initialValue = 0.0);
allocateMemRefArguments(Function func, float initialValue = 0.0);
/// Free a list of type-erased descriptors to statically-shaped memrefs with
/// element type f32.

View File

@ -44,7 +44,7 @@ public:
/// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute.
static bool isKernel(Function *function);
static bool isKernel(Function function);
};
/// Utility class for the GPU dialect to represent triples of `Value`s
@ -122,12 +122,12 @@ public:
using Op::Op;
static void build(Builder *builder, OperationState *result,
Function *kernelFunc, Value *gridSizeX, Value *gridSizeY,
Function kernelFunc, Value *gridSizeX, Value *gridSizeY,
Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
Value *blockSizeZ, ArrayRef<Value *> kernelOperands);
static void build(Builder *builder, OperationState *result,
Function *kernelFunc, KernelDim3 gridSize,
Function kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize, ArrayRef<Value *> kernelOperands);
/// The kernel function specified by the operation's `kernel` attribute.

View File

@ -313,9 +313,8 @@ class FunctionAttr
detail::StringAttributeStorage> {
public:
using Base::Base;
using ValueType = Function *;
using ValueType = StringRef;
static FunctionAttr get(Function *value);
static FunctionAttr get(StringRef value, MLIRContext *ctx);
/// Returns the name of the held function reference.

View File

@ -101,7 +101,7 @@ public:
/// Returns the function that this block is part of, even if the block is
/// nested under an operation region.
Function *getFunction();
Function getFunction();
/// Insert this block (which must not already be in a function) right before
/// the specified block.

View File

@ -112,7 +112,7 @@ public:
AffineMapAttr getAffineMapAttr(AffineMap map);
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(Function *value);
FunctionAttr getFunctionAttr(Function value);
FunctionAttr getFunctionAttr(StringRef value);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);

View File

@ -145,17 +145,13 @@ public:
/// Verify an attribute from this dialect on the given function. Returns
/// failure if the verification failed, success otherwise.
virtual LogicalResult verifyFunctionAttribute(Function *, NamedAttribute) {
return success();
}
virtual LogicalResult verifyFunctionAttribute(Function, NamedAttribute);
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the given function. Returns failure if the verification failed, success
/// otherwise.
virtual LogicalResult
verifyFunctionArgAttribute(Function *, unsigned argIndex, NamedAttribute) {
return success();
}
virtual LogicalResult verifyFunctionArgAttribute(Function, unsigned argIndex,
NamedAttribute);
/// Verify an attribute from this dialect on the given operation. Returns
/// failure if the verification failed, success otherwise.

View File

@ -29,29 +29,79 @@
namespace mlir {
class BlockAndValueMapping;
class FunctionType;
class Function;
class MLIRContext;
class Module;
/// This is the base class for all of the MLIR function types.
class Function : public llvm::ilist_node_with_parent<Function, Module> {
public:
Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs);
namespace detail {
/// This class represents all of the internal state of a Function. This allows
/// for the Function class to be value typed.
class FunctionStorage
: public llvm::ilist_node_with_parent<FunctionStorage, Module> {
FunctionStorage(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
FunctionStorage(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs);
/// The name of the function.
Identifier name;
/// The module this function is embedded into.
Module *module = nullptr;
/// The source location the function was defined or derived from.
Location getLoc() { return location; }
Location location;
/// The type of the function.
FunctionType type;
/// This holds general named attributes for the function.
NamedAttributeList attrs;
/// The attributes lists for each of the function arguments.
std::vector<NamedAttributeList> argAttrs;
/// The body of the function.
Region body;
friend struct llvm::ilist_traits<FunctionStorage>;
friend Function;
};
} // namespace detail
/// This class represents an MLIR function, or the common unit of computation.
/// The region of a function is not allowed to implicitly capture global values,
/// and all external references must use Function arguments or attributes.
class Function {
public:
Function(detail::FunctionStorage *impl = nullptr) : impl(impl) {}
static Function create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {}) {
return new detail::FunctionStorage(location, name, type, attrs);
}
static Function create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs) {
return new detail::FunctionStorage(location, name, type, attrs, argAttrs);
}
/// Allow converting a Function to bool for null checks.
operator bool() const { return impl; }
bool operator==(Function other) const { return impl == other.impl; }
bool operator!=(Function other) const { return !(*this == other); }
/// The source location the function was defined or derived from.
Location getLoc() { return impl->location; }
/// Set the source location this function was defined or derived from.
void setLoc(Location loc) { location = loc; }
void setLoc(Location loc) { impl->location = loc; }
/// Return the name of this function, without the @.
Identifier getName() { return name; }
Identifier getName() { return impl->name; }
/// Return the type of this function.
FunctionType getType() { return type; }
FunctionType getType() { return impl->type; }
/// Change the type of this function in place. This is an extremely dangerous
/// operation and it is up to the caller to ensure that this is legal for this
@ -61,12 +111,12 @@ public:
/// parameters we drop the extra attributes, if there are more parameters
/// they won't have any attributes.
void setType(FunctionType newType) {
type = newType;
argAttrs.resize(type.getNumInputs());
impl->type = newType;
impl->argAttrs.resize(newType.getNumInputs());
}
MLIRContext *getContext();
Module *getModule() { return module; }
Module *getModule() { return impl->module; }
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function.
@ -82,28 +132,28 @@ public:
// Body Handling
//===--------------------------------------------------------------------===//
Region &getBody() { return body; }
void eraseBody() { body.getBlocks().clear(); }
Region &getBody() { return impl->body; }
void eraseBody() { getBody().getBlocks().clear(); }
/// This is the list of blocks in the function.
using RegionType = Region::RegionType;
RegionType &getBlocks() { return body.getBlocks(); }
RegionType &getBlocks() { return getBody().getBlocks(); }
// Iteration over the block in the function.
using iterator = RegionType::iterator;
using reverse_iterator = RegionType::reverse_iterator;
iterator begin() { return body.begin(); }
iterator end() { return body.end(); }
reverse_iterator rbegin() { return body.rbegin(); }
reverse_iterator rend() { return body.rend(); }
iterator begin() { return getBody().begin(); }
iterator end() { return getBody().end(); }
reverse_iterator rbegin() { return getBody().rbegin(); }
reverse_iterator rend() { return getBody().rend(); }
bool empty() { return body.empty(); }
void push_back(Block *block) { body.push_back(block); }
void push_front(Block *block) { body.push_front(block); }
bool empty() { return getBody().empty(); }
void push_back(Block *block) { getBody().push_back(block); }
void push_front(Block *block) { getBody().push_front(block); }
Block &back() { return body.back(); }
Block &front() { return body.front(); }
Block &back() { return getBody().back(); }
Block &front() { return getBody().front(); }
//===--------------------------------------------------------------------===//
// Operation Walkers
@ -150,53 +200,55 @@ public:
/// the lifetime of an function.
/// Return all of the attributes on this function.
ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
ArrayRef<NamedAttribute> getAttrs() { return impl->attrs.getAttrs(); }
/// Return the internal attribute list on this function.
NamedAttributeList &getAttrList() { return attrs; }
NamedAttributeList &getAttrList() { return impl->attrs; }
/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
assert(index < getNumArguments() && "invalid argument number");
return argAttrs[index].getAttrs();
return impl->argAttrs[index].getAttrs();
}
/// Set the attributes held by this function.
void setAttrs(ArrayRef<NamedAttribute> attributes) {
attrs.setAttrs(attributes);
impl->attrs.setAttrs(attributes);
}
/// Set the attributes held by the argument at 'index'.
void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
assert(index < getNumArguments() && "invalid argument number");
argAttrs[index].setAttrs(attributes);
impl->argAttrs[index].setAttrs(attributes);
}
void setArgAttrs(unsigned index, NamedAttributeList attributes) {
assert(index < getNumArguments() && "invalid argument number");
argAttrs[index] = attributes;
impl->argAttrs[index] = attributes;
}
void setAllArgAttrs(ArrayRef<NamedAttributeList> attributes) {
assert(attributes.size() == getNumArguments());
for (unsigned i = 0, e = attributes.size(); i != e; ++i)
argAttrs[i] = attributes[i];
impl->argAttrs[i] = attributes[i];
}
/// Return all argument attributes of this function.
MutableArrayRef<NamedAttributeList> getAllArgAttrs() { return argAttrs; }
MutableArrayRef<NamedAttributeList> getAllArgAttrs() {
return impl->argAttrs;
}
/// Return the specified attribute if present, null otherwise.
Attribute getAttr(Identifier name) { return attrs.get(name); }
Attribute getAttr(StringRef name) { return attrs.get(name); }
Attribute getAttr(Identifier name) { return impl->attrs.get(name); }
Attribute getAttr(StringRef name) { return impl->attrs.get(name); }
/// Return the specified attribute, if present, for the argument at 'index',
/// null otherwise.
Attribute getArgAttr(unsigned index, Identifier name) {
assert(index < getNumArguments() && "invalid argument number");
return argAttrs[index].get(name);
return impl->argAttrs[index].get(name);
}
Attribute getArgAttr(unsigned index, StringRef name) {
assert(index < getNumArguments() && "invalid argument number");
return argAttrs[index].get(name);
return impl->argAttrs[index].get(name);
}
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) {
@ -219,13 +271,15 @@ public:
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
void setAttr(Identifier name, Attribute value) {
impl->attrs.set(name, value);
}
void setAttr(StringRef name, Attribute value) {
setAttr(Identifier::get(name, getContext()), value);
}
void setArgAttr(unsigned index, Identifier name, Attribute value) {
assert(index < getNumArguments() && "invalid argument number");
argAttrs[index].set(name, value);
impl->argAttrs[index].set(name, value);
}
void setArgAttr(unsigned index, StringRef name, Attribute value) {
setArgAttr(index, Identifier::get(name, getContext()), value);
@ -234,12 +288,12 @@ public:
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
return attrs.remove(name);
return impl->attrs.remove(name);
}
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
Identifier name) {
assert(index < getNumArguments() && "invalid argument number");
return attrs.remove(name);
return impl->attrs.remove(name);
}
//===--------------------------------------------------------------------===//
@ -281,44 +335,37 @@ public:
/// contains entries for function arguments, these arguments are not included
/// in the new function. Replaces references to cloned sub-values with the
/// corresponding value that is copied, and adds those mappings to the mapper.
Function *clone(BlockAndValueMapping &mapper);
Function *clone();
Function clone(BlockAndValueMapping &mapper);
Function clone();
/// Clone the internal blocks and attributes from this function into dest. Any
/// cloned blocks are appended to the back of dest. This function asserts that
/// the attributes of the current function and dest are compatible.
void cloneInto(Function *dest, BlockAndValueMapping &mapper);
void cloneInto(Function dest, BlockAndValueMapping &mapper);
/// Methods for supporting PointerLikeTypeTraits.
const void *getAsOpaquePointer() const {
return static_cast<const void *>(impl);
}
static Function getFromOpaquePointer(const void *pointer) {
return reinterpret_cast<detail::FunctionStorage *>(
const_cast<void *>(pointer));
}
private:
/// Set the name of this function.
void setName(Identifier newName) { name = newName; }
void setName(Identifier newName) { impl->name = newName; }
/// The name of the function.
Identifier name;
/// The module this function is embedded into.
Module *module = nullptr;
/// The source location the function was defined or derived from.
Location location;
/// The type of the function.
FunctionType type;
/// This holds general named attributes for the function.
NamedAttributeList attrs;
/// The attributes lists for each of the function arguments.
std::vector<NamedAttributeList> argAttrs;
/// The body of the function.
Region body;
void operator=(Function &) = delete;
friend struct llvm::ilist_traits<Function>;
/// A pointer to the impl storage instance for this function. This allows for
/// 'Function' to be treated as a value type.
detail::FunctionStorage *impl = nullptr;
// Allow access to 'setName'.
friend class SymbolTable;
// Allow access to 'impl'.
friend class Module;
friend class Region;
};
//===--------------------------------------------------------------------===//
@ -487,21 +534,52 @@ private:
namespace llvm {
template <>
struct ilist_traits<::mlir::Function>
: public ilist_alloc_traits<::mlir::Function> {
using Function = ::mlir::Function;
using function_iterator = simple_ilist<Function>::iterator;
struct ilist_traits<::mlir::detail::FunctionStorage>
: public ilist_alloc_traits<::mlir::detail::FunctionStorage> {
using FunctionStorage = ::mlir::detail::FunctionStorage;
using function_iterator = simple_ilist<FunctionStorage>::iterator;
static void deleteNode(Function *function) { delete function; }
static void deleteNode(FunctionStorage *function) { delete function; }
void addNodeToList(Function *function);
void removeNodeFromList(Function *function);
void transferNodesFromList(ilist_traits<Function> &otherList,
void addNodeToList(FunctionStorage *function);
void removeNodeFromList(FunctionStorage *function);
void transferNodesFromList(ilist_traits<FunctionStorage> &otherList,
function_iterator first, function_iterator last);
private:
mlir::Module *getContainingModule();
};
} // end namespace llvm
// Functions hash just like pointers.
template <> struct DenseMapInfo<mlir::Function> {
static mlir::Function getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::Function::getFromOpaquePointer(pointer);
}
static mlir::Function getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::Function::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::Function val) {
return hash_value(val.getAsOpaquePointer());
}
static bool isEqual(mlir::Function LHS, mlir::Function RHS) {
return LHS == RHS;
}
};
/// Allow stealing the low bits of FunctionStorage.
template <> struct PointerLikeTypeTraits<mlir::Function> {
public:
static inline void *getAsVoidPointer(mlir::Function I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::Function getFromVoidPointer(void *P) {
return mlir::Function::getFromOpaquePointer(P);
}
enum { NumLowBitsAvailable = 3 };
};
} // namespace llvm
#endif // MLIR_IR_FUNCTION_H

View File

@ -34,34 +34,54 @@ public:
MLIRContext *getContext() { return context; }
/// An iterator class used to iterate over the held functions.
class iterator : public llvm::mapped_iterator<
llvm::iplist<detail::FunctionStorage>::iterator,
Function (*)(detail::FunctionStorage &)> {
static Function unwrap(detail::FunctionStorage &impl) { return &impl; }
public:
using reference = Function;
/// Initializes the operand type iterator to the specified operand iterator.
iterator(llvm::iplist<detail::FunctionStorage>::iterator it)
: llvm::mapped_iterator<llvm::iplist<detail::FunctionStorage>::iterator,
Function (*)(detail::FunctionStorage &)>(
it, &unwrap) {}
iterator(Function it)
: iterator(llvm::iplist<detail::FunctionStorage>::iterator(it.impl)) {}
};
/// This is the list of functions in the module.
using FunctionListType = llvm::iplist<Function>;
FunctionListType &getFunctions() { return functions; }
llvm::iterator_range<iterator> getFunctions() { return {begin(), end()}; }
// Iteration over the functions in the module.
using iterator = FunctionListType::iterator;
using reverse_iterator = FunctionListType::reverse_iterator;
iterator begin() { return functions.begin(); }
iterator end() { return functions.end(); }
reverse_iterator rbegin() { return functions.rbegin(); }
reverse_iterator rend() { return functions.rend(); }
Function front() { return &functions.front(); }
Function back() { return &functions.back(); }
void push_back(Function fn) { functions.push_back(fn.impl); }
void insert(iterator insertPt, Function fn) {
functions.insert(insertPt.getCurrent(), fn.impl);
}
// Interfaces for working with the symbol table.
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
Function *getNamedFunction(StringRef name) {
Function getNamedFunction(StringRef name) {
return getNamedFunction(Identifier::get(name, getContext()));
}
/// Look up a function with the specified name, returning null if no such
/// name exists. Function names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
Function *getNamedFunction(Identifier name) {
auto it = llvm::find_if(
functions, [name](Function &fn) { return fn.getName() == name; });
Function getNamedFunction(Identifier name) {
auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) {
return Function(&fn).getName() == name;
});
return it == functions.end() ? nullptr : &*it;
}
@ -74,11 +94,13 @@ public:
void dump();
private:
friend struct llvm::ilist_traits<Function>;
friend class Function;
friend struct llvm::ilist_traits<detail::FunctionStorage>;
friend detail::FunctionStorage;
friend Function;
/// getSublistAccess() - Returns pointer to member of function list
static FunctionListType Module::*getSublistAccess(Function *) {
static llvm::iplist<detail::FunctionStorage> Module::*
getSublistAccess(detail::FunctionStorage *) {
return &Module::functions;
}
@ -86,7 +108,7 @@ private:
MLIRContext *context;
/// This is the actual list of functions the module contains.
FunctionListType functions;
llvm::iplist<detail::FunctionStorage> functions;
};
/// A class used to manage the symbols held by a module. This class handles
@ -98,24 +120,24 @@ public:
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them.
template <typename NameTy> Function *getNamedFunction(NameTy &&name) const {
template <typename NameTy> Function getNamedFunction(NameTy &&name) const {
return symbolTable.lookup(name);
}
/// Insert a new symbol into the module, auto-renaming it as necessary.
void insert(Function *function) {
void insert(Function function) {
symbolTable.insert(function);
module->getFunctions().push_back(function);
module->push_back(function);
}
void insert(Module::iterator insertPt, Function *function) {
void insert(Module::iterator insertPt, Function function) {
symbolTable.insert(function);
module->getFunctions().insert(insertPt, function);
module->insert(insertPt, function);
}
/// Remove the given symbol from the module symbol table and then erase it.
void erase(Function *function) {
void erase(Function function) {
symbolTable.erase(function);
function->erase();
function.erase();
}
/// Return the internally held module.

View File

@ -128,7 +128,7 @@ public:
/// Returns the function that this operation is part of.
/// The function is determined by traversing the chain of parent operations.
/// Returns nullptr if the operation is unlinked.
Function *getFunction();
Function getFunction();
/// Replace any uses of 'from' with 'to' within this operation.
void replaceUsesOfWith(Value *from, Value *to);

View File

@ -420,7 +420,7 @@ private:
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
///
bool applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns);
bool applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns);
/// Helper class to create a list of rewrite patterns given a list of their
/// types and a list of attributes perfect-forwarded to each of the conversion

View File

@ -27,11 +27,16 @@
namespace mlir {
class BlockAndValueMapping;
namespace detail {
class FunctionStorage;
}
/// This class contains a list of basic blocks and has a notion of the object it
/// is part of - a Function or an Operation.
class Region {
public:
explicit Region(Function *container = nullptr);
Region() = default;
explicit Region(Function container);
explicit Region(Operation *container);
~Region();
@ -77,7 +82,7 @@ public:
/// A Region is either a function body or a part of an operation. If it is
/// a Function body, then return this function, otherwise return null.
Function *getContainingFunction();
Function getContainingFunction();
/// Return true if this region is a proper ancestor of the `other` region.
bool isProperAncestor(Region *other);
@ -118,7 +123,7 @@ private:
RegionType blocks;
/// This is the object we are part of.
llvm::PointerUnion<Function *, Operation *> container;
llvm::PointerUnion<detail::FunctionStorage *, Operation *> container;
};
} // end namespace mlir

View File

@ -18,7 +18,7 @@
#ifndef MLIR_IR_SYMBOLTABLE_H
#define MLIR_IR_SYMBOLTABLE_H
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Function.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@ -35,18 +35,18 @@ public:
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
Function *lookup(StringRef name) const;
Function lookup(StringRef name) const;
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
Function *lookup(Identifier name) const;
Function lookup(Identifier name) const;
/// Erase the given symbol from the table.
void erase(Function *symbol);
void erase(Function symbol);
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
void insert(Function *symbol);
void insert(Function symbol);
/// Returns the context held by this symbol table.
MLIRContext *getContext() const { return context; }
@ -55,7 +55,7 @@ private:
MLIRContext *context;
/// This is a mapping from a name to the function with that name.
llvm::DenseMap<Identifier, Function *> symbolTable;
llvm::DenseMap<Identifier, Function> symbolTable;
/// This is used when name conflicts are detected.
unsigned uniquingCounter = 0;

View File

@ -72,7 +72,7 @@ public:
}
/// Return the function that this Value is defined in.
Function *getFunction();
Function getFunction();
/// If this value is the result of an operation, return the operation that
/// defines it.
@ -128,7 +128,7 @@ public:
}
/// Return the function that this argument is defined in.
Function *getFunction();
Function getFunction();
Block *getOwner() { return owner; }

View File

@ -153,7 +153,7 @@ public:
/// Verify a function argument attribute registered to this dialect.
/// Returns failure if the verification failed, success otherwise.
LogicalResult verifyFunctionArgAttribute(Function *func, unsigned argIdx,
LogicalResult verifyFunctionArgAttribute(Function func, unsigned argIdx,
NamedAttribute argAttr) override;
private:

View File

@ -106,7 +106,7 @@ template <typename IRUnitT> class AnalysisMap {
}
public:
explicit AnalysisMap(IRUnitT *ir) : ir(ir) {}
explicit AnalysisMap(IRUnitT ir) : ir(ir) {}
/// Get an analysis for the current IR unit, computing it if necessary.
template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
@ -140,8 +140,8 @@ public:
}
/// Returns the IR unit that this analysis map represents.
IRUnitT *getIRUnit() { return ir; }
const IRUnitT *getIRUnit() const { return ir; }
IRUnitT getIRUnit() { return ir; }
const IRUnitT getIRUnit() const { return ir; }
/// Clear any held analyses.
void clear() { analyses.clear(); }
@ -158,7 +158,7 @@ public:
}
private:
IRUnitT *ir;
IRUnitT ir;
ConceptMap analyses;
};
@ -231,14 +231,14 @@ public:
/// Query for the analysis of a function. The analysis is computed if it does
/// not exist.
template <typename AnalysisT>
AnalysisT &getFunctionAnalysis(Function *function) {
AnalysisT &getFunctionAnalysis(Function function) {
return slice(function).getAnalysis<AnalysisT>();
}
/// Query for a cached analysis of a child function, or return null.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
getCachedFunctionAnalysis(Function *function) const {
getCachedFunctionAnalysis(Function function) const {
auto it = functionAnalyses.find(function);
if (it == functionAnalyses.end())
return llvm::None;
@ -258,7 +258,7 @@ public:
}
/// Create an analysis slice for the given child function.
FunctionAnalysisManager slice(Function *function);
FunctionAnalysisManager slice(Function function);
/// Invalidate any non preserved analyses.
void invalidate(const detail::PreservedAnalyses &pa);
@ -269,11 +269,11 @@ public:
private:
/// The cached analyses for functions within the current module.
llvm::DenseMap<Function *, std::unique_ptr<detail::AnalysisMap<Function>>>
llvm::DenseMap<Function, std::unique_ptr<detail::AnalysisMap<Function>>>
functionAnalyses;
/// The analyses for the owning module.
detail::AnalysisMap<Module> moduleAnalyses;
detail::AnalysisMap<Module *> moduleAnalyses;
/// An optional instrumentation object.
PassInstrumentor *passInstrumentor;

View File

@ -70,12 +70,12 @@ class ModulePassExecutor;
/// interface for accessing and initializing necessary state for pass execution.
template <typename IRUnitT, typename AnalysisManagerT>
struct PassExecutionState {
PassExecutionState(IRUnitT *ir, AnalysisManagerT &analysisManager)
PassExecutionState(IRUnitT ir, AnalysisManagerT &analysisManager)
: irAndPassFailed(ir, false), analysisManager(analysisManager) {}
/// The current IR unit being transformed and a bool for if the pass signaled
/// a failure.
llvm::PointerIntPair<IRUnitT *, 1, bool> irAndPassFailed;
llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
/// The analysis manager for the IR unit.
AnalysisManagerT &analysisManager;
@ -107,9 +107,7 @@ protected:
virtual FunctionPassBase *clone() const = 0;
/// Return the current function being transformed.
Function &getFunction() {
return *getPassState().irAndPassFailed.getPointer();
}
Function getFunction() { return getPassState().irAndPassFailed.getPointer(); }
/// Return the MLIR context for the current function being transformed.
MLIRContext &getContext() { return *getFunction().getContext(); }
@ -128,7 +126,7 @@ protected:
private:
/// Forwarding function to execute this pass.
LLVM_NODISCARD
LogicalResult run(Function *fn, FunctionAnalysisManager &fam);
LogicalResult run(Function fn, FunctionAnalysisManager &fam);
/// The current execution state for the pass.
llvm::Optional<PassStateT> passState;
@ -140,7 +138,8 @@ private:
/// Pass to transform a module. Derived passes should not inherit from this
/// class directly, and instead should use the CRTP ModulePass class.
class ModulePassBase : public Pass {
using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
using PassStateT =
detail::PassExecutionState<Module *, ModuleAnalysisManager>;
public:
static bool classof(const Pass *pass) {
@ -272,7 +271,7 @@ struct FunctionPass : public detail::PassModel<Function, T, FunctionPassBase> {
template <typename T>
struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
/// Returns the analysis for a child function.
template <typename AnalysisT> AnalysisT &getFunctionAnalysis(Function *f) {
template <typename AnalysisT> AnalysisT &getFunctionAnalysis(Function f) {
return this->getAnalysisManager().template getFunctionAnalysis<AnalysisT>(
f);
}
@ -280,7 +279,7 @@ struct ModulePass : public detail::PassModel<Module, T, ModulePassBase> {
/// Returns an existing analysis for a child function if it exists.
template <typename AnalysisT>
llvm::Optional<std::reference_wrapper<AnalysisT>>
getCachedFunctionAnalysis(Function *f) {
getCachedFunctionAnalysis(Function f) {
return this->getAnalysisManager()
.template getCachedFunctionAnalysis<AnalysisT>(f);
}

View File

@ -77,29 +77,29 @@ public:
~PassInstrumentor();
/// See PassInstrumentation::runBeforePass for details.
template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT *ir) {
template <typename IRUnitT> void runBeforePass(Pass *pass, IRUnitT ir) {
runBeforePass(pass, llvm::Any(ir));
}
/// See PassInstrumentation::runAfterPass for details.
template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT *ir) {
template <typename IRUnitT> void runAfterPass(Pass *pass, IRUnitT ir) {
runAfterPass(pass, llvm::Any(ir));
}
/// See PassInstrumentation::runAfterPassFailed for details.
template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT *ir) {
template <typename IRUnitT> void runAfterPassFailed(Pass *pass, IRUnitT ir) {
runAfterPassFailed(pass, llvm::Any(ir));
}
/// See PassInstrumentation::runBeforeAnalysis for details.
template <typename IRUnitT>
void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
void runBeforeAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
runBeforeAnalysis(name, id, llvm::Any(ir));
}
/// See PassInstrumentation::runAfterAnalysis for details.
template <typename IRUnitT>
void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT *ir) {
void runAfterAnalysis(llvm::StringRef name, AnalysisID *id, IRUnitT ir) {
runAfterAnalysis(name, id, llvm::Any(ir));
}

View File

@ -214,11 +214,11 @@ def CallOp : Std_Op<"call"> {
let results = (outs Variadic<AnyType>);
let builders = [OpBuilder<
"Builder *builder, OperationState *result, Function *callee,"
"Builder *builder, OperationState *result, Function callee,"
"ArrayRef<Value *> operands = {}", [{
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType().getResults());
result->addTypes(callee.getType().getResults());
}]>, OpBuilder<
"Builder *builder, OperationState *result, StringRef callee,"
"ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{

View File

@ -345,7 +345,7 @@ LLVM_NODISCARD LogicalResult applyConversionPatterns(
/// Convert the given functions with the provided conversion patterns. This
/// function returns failure if a type conversion failed.
LLVM_NODISCARD
LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
LogicalResult applyConversionPatterns(MutableArrayRef<Function> fns,
ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns);
@ -354,7 +354,7 @@ LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LLVM_NODISCARD
LogicalResult applyConversionPatterns(Function &fn, ConversionTarget &target,
LogicalResult applyConversionPatterns(Function fn, ConversionTarget &target,
OwningRewritePatternList &&patterns);
} // end namespace mlir

View File

@ -37,7 +37,7 @@ Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
/// Convert from the Affine dialect to the Standard dialect, in particular
/// convert structured affine control flow into CFG branch-based control flow.
LogicalResult lowerAffineConstructs(Function &function);
LogicalResult lowerAffineConstructs(Function function);
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.

View File

@ -33,11 +33,11 @@ class FunctionPassBase;
/// Displays the CFG in a window. This is for use from the debugger and
/// depends on Graphviz to generate the graph.
void viewGraph(Function &function, const Twine &name, bool shortNames = false,
void viewGraph(Function function, const Twine &name, bool shortNames = false,
const Twine &title = "",
llvm::GraphProgram::Name program = llvm::GraphProgram::DOT);
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function &function,
llvm::raw_ostream &writeGraph(llvm::raw_ostream &os, Function function,
bool shortNames = false, const Twine &title = "");
/// Creates a pass to print CFG graphs.

View File

@ -303,7 +303,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
if (inserted) {
reorderedDims.push_back(v);
}
return getAffineDimExpr(iterPos->second, v->getFunction()->getContext())
return getAffineDimExpr(iterPos->second, v->getFunction().getContext())
.cast<AffineDimExpr>();
}

View File

@ -37,17 +37,16 @@ template class llvm::DomTreeNodeBase<Block>;
/// Recalculate the dominance info.
template <bool IsPostDom>
void DominanceInfoBase<IsPostDom>::recalculate(Function *function) {
void DominanceInfoBase<IsPostDom>::recalculate(Function function) {
dominanceInfos.clear();
// Build the top level function dominance.
auto functionDominance = llvm::make_unique<base>();
functionDominance->recalculate(function->getBody());
dominanceInfos.try_emplace(&function->getBody(),
std::move(functionDominance));
functionDominance->recalculate(function.getBody());
dominanceInfos.try_emplace(&function.getBody(), std::move(functionDominance));
/// Build the dominance for each of the operation regions.
function->walk([&](Operation *op) {
function.walk([&](Operation *op) {
for (auto &region : op->getRegions()) {
// Don't compute dominance if the region is empty.
if (region.empty())

View File

@ -45,7 +45,7 @@ void PrintOpStatsPass::runOnModule() {
opCount.clear();
// Compute the operation statistics for each function in the module.
for (auto &fn : getModule())
for (auto fn : getModule())
fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
printSummary();
}

View File

@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() {
// Walks the function and emits a note for all 'affine.for' ops detected as
// parallel.
void TestParallelismDetection::runOnFunction() {
Function &f = getFunction();
Function f = getFunction();
OpBuilder b(f.getBody());
f.walk<AffineForOp>([&](AffineForOp forOp) {
if (isLoopParallel(forOp))

View File

@ -53,7 +53,7 @@ public:
: ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
/// Verify the body of the given function.
LogicalResult verify(Function &fn);
LogicalResult verify(Function fn);
/// Verify the given operation.
LogicalResult verify(Operation &op);
@ -104,7 +104,7 @@ private:
} // end anonymous namespace
/// Verify the body of the given function.
LogicalResult OperationVerifier::verify(Function &fn) {
LogicalResult OperationVerifier::verify(Function fn) {
// Verify the body first.
if (failed(verifyRegion(fn.getBody())))
return failure();
@ -113,7 +113,7 @@ LogicalResult OperationVerifier::verify(Function &fn) {
// check. We do this as a second pass since malformed CFG's can cause
// dominator analysis constructure to crash and we want the verifier to be
// resilient to malformed code.
DominanceInfo theDomInfo(&fn);
DominanceInfo theDomInfo(fn);
domInfo = &theDomInfo;
if (failed(verifyDominance(fn.getBody())))
return failure();
@ -313,7 +313,7 @@ LogicalResult Function::verify() {
// Verify this attribute with the defining dialect.
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
if (failed(dialect->verifyFunctionAttribute(this, attr)))
if (failed(dialect->verifyFunctionAttribute(*this, attr)))
return failure();
}
@ -331,7 +331,7 @@ LogicalResult Function::verify() {
// Verify this attribute with the defining dialect.
if (auto *dialect = opVerifier.getDialectForAttribute(attr))
if (failed(dialect->verifyFunctionArgAttribute(this, i, attr)))
if (failed(dialect->verifyFunctionArgAttribute(*this, i, attr)))
return failure();
}
}
@ -369,7 +369,7 @@ LogicalResult Operation::verify() {
LogicalResult Module::verify() {
// Check that all functions are uniquely named.
llvm::StringMap<Location> nameToOrigLoc;
for (auto &fn : *this) {
for (auto fn : *this) {
auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc());
if (!it.second)
return fn.emitError()
@ -379,7 +379,7 @@ LogicalResult Module::verify() {
}
// Check that each function is correct.
for (auto &fn : *this)
for (auto fn : *this)
if (failed(fn.verify()))
return failure();

View File

@ -64,8 +64,8 @@ public:
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
for (auto &function : getModule()) {
if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) {
for (auto function : getModule()) {
if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) {
continue;
}
if (failed(translateGpuKernelToCubinAnnotation(function)))
@ -142,7 +142,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
std::unique_ptr<Module> module(builder.createModule());
// TODO(herhut): Also handle called functions.
module->getFunctions().push_back(function.clone());
module->push_back(function.clone());
auto llvmModule = translateModuleToNVVMIR(*module);
auto cubin = convertModuleToCubin(*llvmModule, function);

View File

@ -118,7 +118,7 @@ private:
void declareCudaFunctions(Location loc);
Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
Value *generateKernelNameConstant(Function *kernelFunction, Location &loc,
Value *generateKernelNameConstant(Function kernelFunction, Location &loc,
OpBuilder &builder);
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
@ -130,7 +130,7 @@ public:
// Cache the used LLVM types.
initializeCachedTypes();
for (auto &func : getModule()) {
for (auto func : getModule()) {
func.walk<mlir::gpu::LaunchFuncOp>(
[this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
}
@ -155,66 +155,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
Module &module = getModule();
Builder builder(&module);
if (!module.getNamedFunction(cuModuleLoadName)) {
module.getFunctions().push_back(
new Function(loc, cuModuleLoadName,
builder.getFunctionType(
{
getPointerPointerType(), /* CUmodule *module */
getPointerType() /* void *cubin */
},
getCUResultType())));
module.push_back(
Function::create(loc, cuModuleLoadName,
builder.getFunctionType(
{
getPointerPointerType(), /* CUmodule *module */
getPointerType() /* void *cubin */
},
getCUResultType())));
}
if (!module.getNamedFunction(cuModuleGetFunctionName)) {
// The helper uses void* instead of CUDA's opaque CUmodule and
// CUfunction.
module.getFunctions().push_back(
new Function(loc, cuModuleGetFunctionName,
builder.getFunctionType(
{
getPointerPointerType(), /* void **function */
getPointerType(), /* void *module */
getPointerType() /* char *name */
},
getCUResultType())));
module.push_back(
Function::create(loc, cuModuleGetFunctionName,
builder.getFunctionType(
{
getPointerPointerType(), /* void **function */
getPointerType(), /* void *module */
getPointerType() /* char *name */
},
getCUResultType())));
}
if (!module.getNamedFunction(cuLaunchKernelName)) {
// Other than the CUDA api, the wrappers use uintptr_t to match the
// LLVM type if MLIR's index type, which the GPU dialect uses.
// Furthermore, they use void* instead of CUDA's opaque CUfunction and
// CUstream.
module.getFunctions().push_back(
new Function(loc, cuLaunchKernelName,
builder.getFunctionType(
{
getPointerType(), /* void* f */
getIntPtrType(), /* intptr_t gridXDim */
getIntPtrType(), /* intptr_t gridyDim */
getIntPtrType(), /* intptr_t gridZDim */
getIntPtrType(), /* intptr_t blockXDim */
getIntPtrType(), /* intptr_t blockYDim */
getIntPtrType(), /* intptr_t blockZDim */
getInt32Type(), /* unsigned int sharedMemBytes */
getPointerType(), /* void *hstream */
getPointerPointerType(), /* void **kernelParams */
getPointerPointerType() /* void **extra */
},
getCUResultType())));
module.push_back(Function::create(
loc, cuLaunchKernelName,
builder.getFunctionType(
{
getPointerType(), /* void* f */
getIntPtrType(), /* intptr_t gridXDim */
getIntPtrType(), /* intptr_t gridyDim */
getIntPtrType(), /* intptr_t gridZDim */
getIntPtrType(), /* intptr_t blockXDim */
getIntPtrType(), /* intptr_t blockYDim */
getIntPtrType(), /* intptr_t blockZDim */
getInt32Type(), /* unsigned int sharedMemBytes */
getPointerType(), /* void *hstream */
getPointerPointerType(), /* void **kernelParams */
getPointerPointerType() /* void **extra */
},
getCUResultType())));
}
if (!module.getNamedFunction(cuGetStreamHelperName)) {
// Helper function to get the current CUDA stream. Uses void* instead of
// CUDAs opaque CUstream.
module.getFunctions().push_back(new Function(
module.push_back(Function::create(
loc, cuGetStreamHelperName,
builder.getFunctionType({}, getPointerType() /* void *stream */)));
}
if (!module.getNamedFunction(cuStreamSynchronizeName)) {
module.getFunctions().push_back(
new Function(loc, cuStreamSynchronizeName,
builder.getFunctionType(
{
getPointerType() /* CUstream stream */
},
getCUResultType())));
module.push_back(
Function::create(loc, cuStreamSynchronizeName,
builder.getFunctionType(
{
getPointerType() /* CUstream stream */
},
getCUResultType())));
}
}
@ -264,14 +264,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
// %0[n] = constant name[n]
// %0[n+1] = 0
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
Function *kernelFunction, Location &loc, OpBuilder &builder) {
Function kernelFunction, Location &loc, OpBuilder &builder) {
// TODO(herhut): Make this a constant once this is supported.
auto kernelNameSize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(kernelFunction->getName().size() + 1));
builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
auto kernelName =
builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize);
for (auto byte : llvm::enumerate(kernelFunction->getName())) {
for (auto byte : llvm::enumerate(kernelFunction.getName())) {
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
@ -284,7 +284,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
// Add trailing zero to terminate string.
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(kernelFunction->getName().size()));
builder.getI32IntegerAttr(kernelFunction.getName().size()));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
ArrayRef<Value *>{index});
auto value = builder.create<LLVM::ConstantOp>(
@ -326,9 +326,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// TODO(herhut): This should rather be a static global once supported.
auto kernelFunction = getModule().getNamedFunction(launchOp.kernel());
auto cubinGetter =
kernelFunction->getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
if (!cubinGetter) {
kernelFunction->emitError("Missing ")
kernelFunction.emitError("Missing ")
<< kCubinGetterAnnotation << " attribute.";
return signalPassFailure();
}
@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Emit the load module call to load the module data. Error checking is done
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
Function *cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleLoad),
ArrayRef<Value *>{cuModule, data.getResult(0)});
@ -347,14 +347,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
auto cuFunction = allocatePointer(builder, loc);
Function *cuModuleGetFunction =
Function cuModuleGetFunction =
getModule().getNamedFunction(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleGetFunction),
ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName});
// Grab the global stream needed for execution.
Function *cuGetStreamHelper =
Function cuGetStreamHelper =
getModule().getNamedFunction(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},

View File

@ -53,15 +53,15 @@ constexpr const char *kMallocHelperName = "mcuMalloc";
class GpuGenerateCubinAccessorsPass
: public ModulePass<GpuGenerateCubinAccessorsPass> {
private:
Function *getMallocHelper(Location loc, Builder &builder) {
Function *result = getModule().getNamedFunction(kMallocHelperName);
Function getMallocHelper(Location loc, Builder &builder) {
Function result = getModule().getNamedFunction(kMallocHelperName);
if (!result) {
result = new Function(
result = Function::create(
loc, kMallocHelperName,
builder.getFunctionType(
ArrayRef<Type>{LLVM::LLVMType::getInt32Ty(llvmDialect)},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
getModule().getFunctions().push_back(result);
getModule().push_back(result);
}
return result;
}
@ -70,18 +70,18 @@ private:
// data from blob. As there are currently no global constants, this uses a
// sequence of store operations.
// TODO(herhut): Use global constants instead.
Function *generateCubinAccessor(Builder &builder, Function &orig,
StringAttr blob) {
Function generateCubinAccessor(Builder &builder, Function &orig,
StringAttr blob) {
Location loc = orig.getLoc();
SmallString<128> nameBuffer(orig.getName());
nameBuffer.append(kCubinGetterSuffix);
// Generate a function that returns void*.
Function *result = new Function(
Function result = Function::create(
loc, mlir::Identifier::get(nameBuffer, &getContext()),
builder.getFunctionType(ArrayRef<Type>{},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
// Insert a body block that just returns the constant.
OpBuilder ob(result->getBody());
OpBuilder ob(result.getBody());
ob.createBlock();
auto sizeConstant = ob.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt32Ty(llvmDialect),
@ -115,18 +115,18 @@ public:
void runOnModule() override {
llvmDialect =
getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
Builder builder(getModule().getContext());
auto &module = getModule();
Builder builder(&getContext());
auto &functions = getModule().getFunctions();
auto functions = module.getFunctions();
for (auto it = functions.begin(); it != functions.end();) {
// Move iterator to after the current function so that potential insertion
// of the accessor is after the kernel with cubin iself.
Function &orig = *it++;
Function orig = *it++;
StringAttr cubinBlob = orig.getAttrOfType<StringAttr>(kCubinAnnotation);
if (!cubinBlob)
continue;
it =
functions.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
module.insert(it, generateCubinAccessor(builder, orig, cubinBlob));
}
}

View File

@ -441,13 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
createIndexConstant(rewriter, op->getLoc(), elementSize)});
// Insert the `malloc` declaration if it is not already present.
Function *mallocFunc =
op->getFunction()->getModule()->getNamedFunction("malloc");
Function mallocFunc =
op->getFunction().getModule()->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
mallocFunc =
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
op->getFunction().getModule()->push_back(mallocFunc);
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
@ -502,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
Function *freeFunc =
op->getFunction()->getModule()->getNamedFunction("free");
Function freeFunc = op->getFunction().getModule()->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
op->getFunction().getModule()->push_back(freeFunc);
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
@ -937,7 +937,7 @@ static void ensureDistinctSuccessors(Block &bb) {
}
void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
for (auto &f : *m) {
for (auto f : *m) {
for (auto &bb : f.getBlocks()) {
::ensureDistinctSuccessors(bb);
}

View File

@ -365,7 +365,7 @@ struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
//===----------------------------------------------------------------------===//
void LowerUniformRealMathPass::runOnFunction() {
auto &fn = getFunction();
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
@ -386,7 +386,7 @@ static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
//===----------------------------------------------------------------------===//
void LowerUniformCastsPass::runOnFunction() {
auto &fn = getFunction();
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));

View File

@ -106,7 +106,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
void ConvertConstPass::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
auto func = getFunction();
auto *context = &getContext();
patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
applyPatternsGreedily(func, std::move(patterns));

View File

@ -95,7 +95,7 @@ public:
void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
OwningRewritePatternList patterns;
auto &func = getFunction();
auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure));

View File

@ -67,10 +67,10 @@ allocMemRefDescriptor(Type type, bool allocateData = true,
}
llvm::Expected<SmallVector<void *, 8>>
mlir::allocateMemRefArguments(Function *func, float initialValue) {
mlir::allocateMemRefArguments(Function func, float initialValue) {
SmallVector<void *, 8> args;
args.reserve(func->getNumArguments());
for (const auto &arg : func->getArguments()) {
args.reserve(func.getNumArguments());
for (const auto &arg : func.getArguments()) {
auto descriptor =
allocMemRefDescriptor(arg->getType(),
/*allocateData=*/true, initialValue);
@ -79,10 +79,10 @@ mlir::allocateMemRefArguments(Function *func, float initialValue) {
args.push_back(*descriptor);
}
if (func->getType().getNumResults() > 1)
if (func.getType().getNumResults() > 1)
return make_string_error("functions with more than 1 result not supported");
for (Type resType : func->getType().getResults()) {
for (Type resType : func.getType().getResults()) {
auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false);
if (!descriptor)
return descriptor.takeError();

View File

@ -30,9 +30,9 @@ using namespace mlir::gpu;
StringRef GPUDialect::getDialectName() { return "gpu"; }
bool GPUDialect::isKernel(Function *function) {
bool GPUDialect::isKernel(Function function) {
UnitAttr isKernelAttr =
function->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
}
@ -318,7 +318,7 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
//===----------------------------------------------------------------------===//
void LaunchFuncOp::build(Builder *builder, OperationState *result,
Function *kernelFunc, Value *gridSizeX,
Function kernelFunc, Value *gridSizeX,
Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
Value *blockSizeY, Value *blockSizeZ,
ArrayRef<Value *> kernelOperands) {
@ -331,7 +331,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
}
void LaunchFuncOp::build(Builder *builder, OperationState *result,
Function *kernelFunc, KernelDim3 gridSize,
Function kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands) {
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
@ -366,23 +366,23 @@ LogicalResult LaunchFuncOp::verify() {
return emitOpError("attribute 'kernel' must be a function");
}
auto *module = getOperation()->getFunction()->getModule();
Function *kernelFunc = module->getNamedFunction(kernel());
auto *module = getOperation()->getFunction().getModule();
Function kernelFunc = module->getNamedFunction(kernel());
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
if (!kernelFunc.getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName())) {
return emitError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";
}
unsigned numKernelFuncArgs = kernelFunc->getNumArguments();
unsigned numKernelFuncArgs = kernelFunc.getNumArguments();
if (getNumKernelOperands() != numKernelFuncArgs) {
return emitOpError("got ")
<< getNumKernelOperands() << " kernel operands but expected "
<< numKernelFuncArgs;
}
auto functionType = kernelFunc->getType();
auto functionType = kernelFunc.getType();
for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
return emitOpError("type of function argument ")

View File

@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc,
// Add operations generating block/thread ids and gird/block dimensions at the
// beginning of `kernelFunc` and replace uses of the respective function args.
static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
static void injectGpuIndexOperations(Location loc, Function kernelFunc) {
OpBuilder OpBuilder(kernelFunc.getBody());
SmallVector<Value *, 12> indexOps;
createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
@ -58,20 +58,20 @@ static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
// Outline the `gpu.launch` operation body into a kernel function. Replace
// `gpu.return` operations by `std.return` in the generated functions.
static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
static Function outlineKernelFunc(gpu::LaunchOp launchOp) {
Location loc = launchOp.getLoc();
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
FunctionType type =
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
std::string kernelFuncName =
Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str();
Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type);
outlinedFunc->getBody().takeBody(launchOp.getBody());
Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str();
Function outlinedFunc = Function::create(loc, kernelFuncName, type);
outlinedFunc.getBody().takeBody(launchOp.getBody());
Builder builder(launchOp.getContext());
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
injectGpuIndexOperations(loc, *outlinedFunc);
outlinedFunc->walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
injectGpuIndexOperations(loc, outlinedFunc);
outlinedFunc.walk<mlir::gpu::Return>([](mlir::gpu::Return op) {
OpBuilder replacer(op);
replacer.create<ReturnOp>(op.getLoc());
op.erase();
@ -82,12 +82,12 @@ static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
// Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
// `kernelFunc`.
static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
Function &kernelFunc) {
Function kernelFunc) {
OpBuilder builder(launchOp);
SmallVector<Value *, 4> kernelOperandValues(
launchOp.getKernelOperandValues());
builder.create<gpu::LaunchFuncOp>(
launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(),
launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
launchOp.getBlockSizeOperandValues(), kernelOperandValues);
launchOp.erase();
}
@ -98,11 +98,11 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
void runOnModule() override {
ModuleManager moduleManager(&getModule());
for (auto &func : getModule()) {
for (auto func : getModule()) {
func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
Function *outlinedFunc = outlineKernelFunc(op);
Function outlinedFunc = outlineKernelFunc(op);
moduleManager.insert(outlinedFunc);
convertToLaunchFuncOp(op, *outlinedFunc);
convertToLaunchFuncOp(op, outlinedFunc);
});
}
}

View File

@ -306,7 +306,7 @@ void ModuleState::initialize(Module *module) {
initializeSymbolAliases();
// Walk the module and visit each operation.
for (auto &fn : *module) {
for (auto fn : *module) {
visitType(fn.getType());
for (auto attr : fn.getAttrs())
ModuleState::visitAttribute(attr.second);
@ -342,7 +342,7 @@ public:
void printAttribute(Attribute attr, bool mayElideType = false);
void printType(Type type);
void print(Function *fn);
void print(Function fn);
void printLocation(LocationAttr loc);
void printAffineMap(AffineMap map);
@ -460,8 +460,8 @@ void ModulePrinter::print(Module *module) {
state.printTypeAliases(os);
// Print the module.
for (auto &fn : *module)
print(&fn);
for (auto fn : *module)
print(fn);
}
/// Print a floating point value in a way that the parser will be able to
@ -1186,7 +1186,7 @@ namespace {
// CFG and ML functions.
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
public:
FunctionPrinter(Function *function, ModulePrinter &other);
FunctionPrinter(Function function, ModulePrinter &other);
// Prints the function as a whole.
void print();
@ -1275,7 +1275,7 @@ protected:
void printValueID(Value *value, bool printResultNo = true) const;
private:
Function *function;
Function function;
/// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames.
@ -1305,10 +1305,10 @@ private:
};
} // end anonymous namespace
FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other)
FunctionPrinter::FunctionPrinter(Function function, ModulePrinter &other)
: ModulePrinter(other), function(function) {
for (auto &block : *function)
for (auto &block : function)
numberValuesInBlock(block);
}
@ -1419,17 +1419,17 @@ void FunctionPrinter::print() {
printFunctionSignature();
// Print out function attributes, if present.
auto attrs = function->getAttrs();
auto attrs = function.getAttrs();
if (!attrs.empty()) {
os << "\n attributes ";
printOptionalAttrDict(attrs);
}
// Print the trailing location.
printTrailingLocation(function->getLoc());
printTrailingLocation(function.getLoc());
if (!function->empty()) {
printRegion(function->getBody(), /*printEntryBlockArgs=*/false,
if (!function.empty()) {
printRegion(function.getBody(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
os << "\n";
}
@ -1437,24 +1437,24 @@ void FunctionPrinter::print() {
}
void FunctionPrinter::printFunctionSignature() {
os << "func @" << function->getName() << '(';
os << "func @" << function.getName() << '(';
auto fnType = function->getType();
bool isExternal = function->isExternal();
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
auto fnType = function.getType();
bool isExternal = function.isExternal();
for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
// If this is an external function, don't print argument labels.
if (!isExternal) {
printOperand(function->getArgument(i));
printOperand(function.getArgument(i));
os << ": ";
}
printType(fnType.getInput(i));
// Print the attributes for this argument.
printOptionalAttrDict(function->getArgAttrs(i));
printOptionalAttrDict(function.getArgAttrs(i));
}
os << ')';
@ -1662,7 +1662,7 @@ void FunctionPrinter::printSuccessorAndUseList(Operation *term,
}
// Prints function with initialized module state.
void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); }
void ModulePrinter::print(Function fn) { FunctionPrinter(fn, *this).print(); }
//===----------------------------------------------------------------------===//
// print and dump methods
@ -1737,13 +1737,13 @@ void Value::print(raw_ostream &os) {
void Value::dump() { print(llvm::errs()); }
void Operation::print(raw_ostream &os) {
auto *function = getFunction();
auto function = getFunction();
if (!function) {
os << "<<UNLINKED INSTRUCTION>>\n";
return;
}
ModuleState state(function->getContext());
ModuleState state(function.getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
@ -1754,13 +1754,13 @@ void Operation::dump() {
}
void Block::print(raw_ostream &os) {
auto *function = getFunction();
auto function = getFunction();
if (!function) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
ModuleState state(function->getContext());
ModuleState state(function.getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(function, modulePrinter).print(this);
}
@ -1773,14 +1773,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
ModuleState state(getFunction()->getContext());
ModuleState state(getFunction().getContext());
ModulePrinter modulePrinter(os, state);
FunctionPrinter(getFunction(), modulePrinter).printBlockName(this);
}
void Function::print(raw_ostream &os) {
ModuleState state(getContext());
ModulePrinter(os, state).print(this);
ModulePrinter(os, state).print(*this);
}
void Function::dump() { print(llvm::errs()); }

View File

@ -249,11 +249,6 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
// FunctionAttr
//===----------------------------------------------------------------------===//
FunctionAttr FunctionAttr::get(Function *value) {
assert(value && "Cannot get FunctionAttr for a null function");
return get(value->getName(), value->getContext());
}
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::Function, value,
NoneType::get(ctx));

View File

@ -50,7 +50,7 @@ Operation *Block::getContainingOp() {
return getParent() ? getParent()->getContainingOp() : nullptr;
}
Function *Block::getFunction() {
Function Block::getFunction() {
Block *block = this;
while (auto *op = block->getContainingOp()) {
block = op->getBlock();

View File

@ -177,8 +177,8 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
FunctionAttr Builder::getFunctionAttr(Function *value) {
return FunctionAttr::get(value);
FunctionAttr Builder::getFunctionAttr(Function value) {
return getFunctionAttr(value.getName());
}
FunctionAttr Builder::getFunctionAttr(StringRef value) {
return FunctionAttr::get(value, getContext());

View File

@ -18,6 +18,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectHooks.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ManagedStatic.h"
@ -68,6 +69,20 @@ Dialect::Dialect(StringRef name, MLIRContext *context)
Dialect::~Dialect() {}
/// Verify an attribute from this dialect on the given function. Returns
/// failure if the verification failed, success otherwise.
LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) {
return success();
}
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the given function. Returns failure if the verification failed, success
/// otherwise.
LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex,
NamedAttribute) {
return success();
}
/// Parse an attribute registered to this dialect.
Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const {
emitError(loc) << "dialect '" << getNamespace()

View File

@ -27,45 +27,50 @@
#include "llvm/ADT/Twine.h"
using namespace mlir;
using namespace mlir::detail;
Function::Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
FunctionStorage::FunctionStorage(Location location, StringRef name,
FunctionType type,
ArrayRef<NamedAttribute> attrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {}
Function::Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs)
FunctionStorage::FunctionStorage(Location location, StringRef name,
FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
MLIRContext *Function::getContext() { return getType().getContext(); }
Module *llvm::ilist_traits<Function>::getContainingModule() {
Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
size_t Offset(
size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
iplist<Function> *Anchor(static_cast<iplist<Function> *>(this));
iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
}
/// This is a trait method invoked when a Function is added to a Module. We
/// keep the module pointer and module symbol table up to date.
void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
assert(!function->getModule() && "already in a module!");
void llvm::ilist_traits<FunctionStorage>::addNodeToList(
FunctionStorage *function) {
assert(!function->module && "already in a module!");
function->module = getContainingModule();
}
/// This is a trait method invoked when a Function is removed from a Module.
/// We keep the module pointer up to date.
void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
void llvm::ilist_traits<FunctionStorage>::removeNodeFromList(
FunctionStorage *function) {
assert(function->module && "not already in a module!");
function->module = nullptr;
}
/// This is a trait method invoked when an operation is moved from one block
/// to another. We keep the block pointer up to date.
void llvm::ilist_traits<Function>::transferNodesFromList(
ilist_traits<Function> &otherList, function_iterator first,
void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
ilist_traits<FunctionStorage> &otherList, function_iterator first,
function_iterator last) {
// If we are transferring functions within the same module, the Module
// pointer doesn't need to be updated.
@ -82,8 +87,10 @@ void llvm::ilist_traits<Function>::transferNodesFromList(
/// Unlink this function from its Module and delete it.
void Function::erase() {
assert(getModule() && "Function has no parent");
getModule()->getFunctions().erase(this);
if (auto *module = getModule())
getModule()->functions.erase(impl);
else
delete impl;
}
/// Emit an error about fatal conditions with this function, reporting up to
@ -111,10 +118,10 @@ InFlightDiagnostic Function::emitRemark(const Twine &message) {
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) {
// Add the attributes of this function to dest.
llvm::MapVector<Identifier, Attribute> newAttrs;
for (auto &attr : dest->getAttrs())
for (auto &attr : dest.getAttrs())
newAttrs.insert(attr);
for (auto &attr : getAttrs()) {
auto insertPair = newAttrs.insert(attr);
@ -125,10 +132,10 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
assert((insertPair.second || insertPair.first->second == attr.second) &&
"the two functions have incompatible attributes");
}
dest->setAttrs(newAttrs.takeVector());
dest.setAttrs(newAttrs.takeVector());
// Clone the body.
body.cloneInto(&dest->body, mapper);
impl->body.cloneInto(&dest.impl->body, mapper);
}
/// Create a deep copy of this function and all of its blocks, remapping
@ -136,8 +143,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) {
/// provided (leaving them alone if no entry is present). Replaces references
/// to cloned sub-values with the corresponding value that is copied, and adds
/// those mappings to the mapper.
Function *Function::clone(BlockAndValueMapping &mapper) {
FunctionType newType = type;
Function Function::clone(BlockAndValueMapping &mapper) {
FunctionType newType = impl->type;
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
@ -147,23 +154,23 @@ Function *Function::clone(BlockAndValueMapping &mapper) {
SmallVector<Type, 4> inputTypes;
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
inputTypes.push_back(type.getInput(i));
newType = FunctionType::get(inputTypes, type.getResults(), getContext());
inputTypes.push_back(newType.getInput(i));
newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
}
// Create the new function.
Function *newFunc = new Function(getLoc(), getName(), newType);
Function newFunc = Function::create(getLoc(), getName(), newType);
/// Set the argument attributes for arguments that aren't being replaced.
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
if (isExternalFn || !mapper.contains(getArgument(i)))
newFunc->setArgAttrs(destI++, getArgAttrs(i));
newFunc.setArgAttrs(destI++, getArgAttrs(i));
/// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}
Function *Function::clone() {
Function Function::clone() {
BlockAndValueMapping mapper;
return clone(mapper);
}
@ -178,7 +185,7 @@ void Function::addEntryBlock() {
assert(empty() && "function already has an entry block");
auto *entry = new Block();
push_back(entry);
entry->addArguments(type.getInputs());
entry->addArguments(impl->type.getInputs());
}
void Function::walk(const std::function<void(Operation *)> &callback) {

View File

@ -281,7 +281,7 @@ Operation *Operation::getParentOp() {
return block ? block->getContainingOp() : nullptr;
}
Function *Operation::getFunction() {
Function Operation::getFunction() {
return block ? block->getFunction() : nullptr;
}
@ -861,12 +861,13 @@ static LogicalResult verifyBBArguments(Operation::operand_range operands,
}
static LogicalResult verifyTerminatorSuccessors(Operation *op) {
auto *parent = op->getContainingRegion();
// Verify that the operands lines up with the BB arguments in the successor.
Function *fn = op->getFunction();
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
auto *succ = op->getSuccessor(i);
if (succ->getFunction() != fn)
return op->emitError("reference to block defined in another function");
if (succ->getParent() != parent)
return op->emitError("reference to block defined in another region");
if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op)))
return failure();
}

View File

@ -21,7 +21,7 @@
#include "mlir/IR/Operation.h"
using namespace mlir;
Region::Region(Function *container) : container(container) {}
Region::Region(Function container) : container(container.impl) {}
Region::Region(Operation *container) : container(container) {}
@ -38,7 +38,7 @@ MLIRContext *Region::getContext() {
assert(!container.isNull() && "region is not attached to a container");
if (auto *inst = getContainingOp())
return inst->getContext();
return getContainingFunction()->getContext();
return getContainingFunction().getContext();
}
/// Return a location for this region. This is the location attached to the
@ -47,7 +47,7 @@ Location Region::getLoc() {
assert(!container.isNull() && "region is not attached to a container");
if (auto *inst = getContainingOp())
return inst->getLoc();
return getContainingFunction()->getLoc();
return getContainingFunction().getLoc();
}
Region *Region::getContainingRegion() {
@ -60,8 +60,8 @@ Operation *Region::getContainingOp() {
return container.dyn_cast<Operation *>();
}
Function *Region::getContainingFunction() {
return container.dyn_cast<Function *>();
Function Region::getContainingFunction() {
return container.dyn_cast<detail::FunctionStorage *>();
}
bool Region::isProperAncestor(Region *other) {

View File

@ -22,8 +22,8 @@ using namespace mlir;
/// Build a symbol table with the symbols within the given module.
SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
for (auto &func : *module) {
auto inserted = symbolTable.insert({func.getName(), &func});
for (auto func : *module) {
auto inserted = symbolTable.insert({func.getName(), func});
(void)inserted;
assert(inserted.second &&
"expected module to contain uniquely named functions");
@ -32,34 +32,34 @@ SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
Function *SymbolTable::lookup(StringRef name) const {
Function SymbolTable::lookup(StringRef name) const {
return lookup(Identifier::get(name, context));
}
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
Function *SymbolTable::lookup(Identifier name) const {
Function SymbolTable::lookup(Identifier name) const {
return symbolTable.lookup(name);
}
/// Erase the given symbol from the table.
void SymbolTable::erase(Function *symbol) {
auto it = symbolTable.find(symbol->getName());
void SymbolTable::erase(Function symbol) {
auto it = symbolTable.find(symbol.getName());
if (it != symbolTable.end() && it->second == symbol)
symbolTable.erase(it);
}
/// Insert a new symbol into the table, and rename it as necessary to avoid
/// collisions.
void SymbolTable::insert(Function *symbol) {
void SymbolTable::insert(Function symbol) {
// Add this symbol to the symbol table, uniquing the name if a conflict is
// detected.
if (symbolTable.insert({symbol->getName(), symbol}).second)
if (symbolTable.insert({symbol.getName(), symbol}).second)
return;
// If a conflict was detected, then the function will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
SmallString<128> nameBuffer(symbol->getName());
SmallString<128> nameBuffer(symbol.getName());
unsigned originalLength = nameBuffer.size();
// Iteratively try suffixes until we find one that isn't used. We use a
@ -68,6 +68,6 @@ void SymbolTable::insert(Function *symbol) {
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
symbol->setName(Identifier::get(nameBuffer, context));
} while (!symbolTable.insert({symbol->getName(), symbol}).second);
symbol.setName(Identifier::get(nameBuffer, context));
} while (!symbolTable.insert({symbol.getName(), symbol}).second);
}

View File

@ -30,7 +30,7 @@ Operation *Value::getDefiningOp() {
}
/// Return the function that this Value is defined in.
Function *Value::getFunction() {
Function Value::getFunction() {
switch (getKind()) {
case Value::Kind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
@ -84,7 +84,7 @@ void IRObjectWithUseList::dropAllUses() {
//===----------------------------------------------------------------------===//
/// Return the function that this argument is defined in.
Function *BlockArgument::getFunction() {
Function BlockArgument::getFunction() {
if (auto *owner = getOwner())
return owner->getFunction();
return nullptr;
@ -92,6 +92,6 @@ Function *BlockArgument::getFunction() {
/// Returns if the current argument is a function argument.
bool BlockArgument::isFunctionArgument() {
auto *containingFn = getFunction();
return containingFn && &containingFn->front() == getOwner();
auto containingFn = getFunction();
return containingFn && &containingFn.front() == getOwner();
}

View File

@ -816,12 +816,12 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const {
}
/// Verify LLVMIR function argument attributes.
LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func,
LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func,
unsigned argIdx,
NamedAttribute argAttr) {
// Check that llvm.noalias is a boolean attribute.
if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
return func->emitError()
return func.emitError()
<< "llvm.noalias argument attribute of non boolean type";
return success();
}

View File

@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
return true;
}
static void fuseLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
static void fuseLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
OperationFolder state;
DenseSet<Operation *> eraseSet;

View File

@ -170,12 +170,13 @@ public:
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
Function *mallocFunc = module->getNamedFunction("malloc");
auto *module = op->getFunction().getModule();
Function mallocFunc = module->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
module->getFunctions().push_back(mallocFunc);
mallocFunc =
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
module->push_back(mallocFunc);
}
// Get MLIR types for injecting element pointer.
@ -230,12 +231,12 @@ public:
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
Function *freeFunc = module->getNamedFunction("free");
auto *module = op->getFunction().getModule();
Function freeFunc = module->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
module->getFunctions().push_back(freeFunc);
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
module->push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.
@ -572,37 +573,37 @@ public:
// Create a function definition which takes as argument pointers to the input
// types and returns pointers to the output types.
static Function *getLLVMLibraryCallImplDefinition(Function *libFn) {
auto implFnName = (libFn->getName().str() + "_impl");
auto module = libFn->getModule();
if (auto *f = module->getNamedFunction(implFnName)) {
static Function getLLVMLibraryCallImplDefinition(Function libFn) {
auto implFnName = (libFn.getName().str() + "_impl");
auto module = libFn.getModule();
if (auto f = module->getNamedFunction(implFnName)) {
return f;
}
SmallVector<Type, 4> fnArgTypes;
for (auto t : libFn->getType().getInputs()) {
for (auto t : libFn.getType().getInputs()) {
assert(t.isa<LLVMType>() &&
"Expected LLVM Type for argument while generating library Call "
"Implementation Definition");
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
}
auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext());
auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext());
// Insert the implementation function definition.
auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType);
module->getFunctions().push_back(implFnDefn);
auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
module->push_back(implFnDefn);
return implFnDefn;
}
// Get function definition for the LinalgOp. If it doesn't exist, insert a
// definition.
template <typename LinalgOp>
static Function *getLLVMLibraryCallDeclaration(Operation *op,
LLVMTypeConverter &lowering,
PatternRewriter &rewriter) {
static Function getLLVMLibraryCallDeclaration(Operation *op,
LLVMTypeConverter &lowering,
PatternRewriter &rewriter) {
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
auto module = op->getFunction()->getModule();
if (auto *f = module->getNamedFunction(fnName)) {
auto module = op->getFunction().getModule();
if (auto f = module->getNamedFunction(fnName)) {
return f;
}
@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op,
"Library call for linalg operation can be generated only for ops that "
"have void return types");
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
auto libFn = new Function(op->getLoc(), fnName, libFnType);
module->getFunctions().push_back(libFn);
auto libFn = Function::create(op->getLoc(), fnName, libFnType);
module->push_back(libFn);
// Return after creating the function definition. The body will be created
// later.
return libFn;
}
static void getLLVMLibraryCallDefinition(Function *fn,
static void getLLVMLibraryCallDefinition(Function fn,
LLVMTypeConverter &lowering) {
// Generate the implementation function definition.
auto implFn = getLLVMLibraryCallImplDefinition(fn);
// Generate the function body.
fn->addEntryBlock();
fn.addEntryBlock();
OpBuilder builder(fn->getBody());
edsc::ScopedContext scope(builder, fn->getLoc());
OpBuilder builder(fn.getBody());
edsc::ScopedContext scope(builder, fn.getLoc());
SmallVector<Value *, 4> implFnArgs;
// Create a constant 1.
auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
IntegerAttr::get(IndexType::get(fn->getContext()), 1));
for (auto arg : fn->getArguments()) {
IntegerAttr::get(IndexType::get(fn.getContext()), 1));
for (auto arg : fn.getArguments()) {
// Allocate a stack for storing the argument value. The stack is passed to
// the implementation function.
auto alloca =
@ -665,17 +666,17 @@ public:
return convertLinalgType(t, *this);
}
void addLibraryFnDeclaration(Function *fn) {
void addLibraryFnDeclaration(Function fn) {
libraryFnDeclarations.push_back(fn);
}
ArrayRef<Function *> getLibraryFnDeclarations() {
ArrayRef<Function> getLibraryFnDeclarations() {
return libraryFnDeclarations;
}
private:
/// List of library functions declarations needed during dialect conversion
SmallVector<Function *, 2> libraryFnDeclarations;
SmallVector<Function, 2> libraryFnDeclarations;
};
} // end anonymous namespace
@ -692,7 +693,7 @@ public:
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Only emit library call declaration. Fill in the body later.
auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getFunctionAttr(f);
@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) {
void LowerLinalgToLLVMPass::runOnModule() {
auto &module = getModule();
for (auto &f : module.getFunctions()) {
for (auto f : module.getFunctions()) {
lowerLinalgSubViewOps(f);
lowerLinalgForToCFG(f);
if (failed(lowerAffineConstructs(f)))

View File

@ -104,9 +104,8 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
} // namespace
void LowerLinalgToLoopsPass::runOnFunction() {
auto &f = getFunction();
OperationFolder state;
f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
getFunction().walk<LinalgOp>([&state](LinalgOp linalgOp) {
emitLinalgOpAsLoops(linalgOp, state);
linalgOp.getOperation()->erase();
});

View File

@ -259,7 +259,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
return tileLinalgOp(op, tileSizeValues, state);
}
static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
OperationFolder state;
f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) {
auto opLoopsPair = tileLinalgOp(op, tileSizes, state);

View File

@ -254,7 +254,7 @@ public:
/// trailing-location ::= location?
///
template <typename Owner>
ParseResult parseOptionalTrailingLocation(Owner *owner) {
ParseResult parseOptionalTrailingLocation(Owner &owner) {
// If there is a 'loc' we parse a trailing location.
if (!getToken().is(Token::kw_loc))
return success();
@ -263,7 +263,7 @@ public:
LocationAttr directLoc;
if (parseLocation(directLoc))
return failure();
owner->setLoc(directLoc);
owner.setLoc(directLoc);
return success();
}
@ -2472,8 +2472,8 @@ namespace {
/// operations.
class OperationParser : public Parser {
public:
OperationParser(ParserState &state, Function *function)
: Parser(state), function(function), opBuilder(function->getBody()) {}
OperationParser(ParserState &state, Function function)
: Parser(state), function(function), opBuilder(function.getBody()) {}
~OperationParser();
@ -2588,7 +2588,7 @@ public:
Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
private:
Function *function;
Function function;
/// Returns the info for a block at the current scope for the given name.
std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
@ -2690,7 +2690,7 @@ ParseResult OperationParser::popSSANameScope() {
for (auto entry : forwardRefInCurrentScope) {
errors.push_back({entry.second.getPointer(), entry.first});
// Add this block to the top-level region to allow for automatic cleanup.
function->push_back(entry.first);
function.push_back(entry.first);
}
llvm::array_pod_sort(errors.begin(), errors.end());
@ -2984,7 +2984,7 @@ ParseResult OperationParser::parseOperation() {
}
// Try to parse the optional trailing location.
if (parseOptionalTrailingLocation(op))
if (parseOptionalTrailingLocation(*op))
return failure();
return success();
@ -4049,17 +4049,17 @@ ParseResult ModuleParser::parseFunc(Module *module) {
}
// Okay, the function signature was parsed correctly, create the function now.
auto *function =
new Function(getEncodedSourceLocation(loc), name, type, attrs);
module->getFunctions().push_back(function);
auto function =
Function::create(getEncodedSourceLocation(loc), name, type, attrs);
module->push_back(function);
// Parse an optional trailing location.
if (parseOptionalTrailingLocation(function))
return failure();
// Add the attributes to the function arguments.
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i)
function->setArgAttrs(i, argAttrs[i]);
for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i)
function.setArgAttrs(i, argAttrs[i]);
// External functions have no body.
if (getToken().isNot(Token::l_brace))
@ -4076,11 +4076,11 @@ ParseResult ModuleParser::parseFunc(Module *module) {
// Parse the function body.
auto parser = OperationParser(getState(), function);
if (parser.parseRegion(function->getBody(), entryArgs))
if (parser.parseRegion(function.getBody(), entryArgs))
return failure();
// Verify that a valid function body was parsed.
if (function->empty())
if (function.empty())
return emitError(braceLoc, "function must have a body");
return parser.finalize(braceLoc);

View File

@ -61,12 +61,12 @@ private:
static void printIR(const llvm::Any &ir, bool printModuleScope,
raw_ostream &out) {
// Check for printing at module scope.
if (printModuleScope && llvm::any_isa<Function *>(ir)) {
Function *function = llvm::any_cast<Function *>(ir);
if (printModuleScope && llvm::any_isa<Function>(ir)) {
Function function = llvm::any_cast<Function>(ir);
// Print the function name and a newline before the Module.
out << " (function: " << function->getName() << ")\n";
function->getModule()->print(out);
out << " (function: " << function.getName() << ")\n";
function.getModule()->print(out);
return;
}
@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
out << "\n";
// Print the given function.
if (llvm::any_isa<Function *>(ir)) {
llvm::any_cast<Function *>(ir)->print(out);
if (llvm::any_isa<Function>(ir)) {
llvm::any_cast<Function>(ir).print(out);
return;
}

View File

@ -46,8 +46,7 @@ static llvm::cl::opt<bool>
void Pass::anchor() {}
/// Forwarding function to execute this pass.
LogicalResult FunctionPassBase::run(Function *fn,
FunctionAnalysisManager &fam) {
LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) {
// Initialize the pass state.
passState.emplace(fn, fam);
@ -115,7 +114,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
}
/// Run all of the passes in this manager over the current function.
LogicalResult detail::FunctionPassExecutor::run(Function *function,
LogicalResult detail::FunctionPassExecutor::run(Function function,
FunctionAnalysisManager &fam) {
// Run each of the held passes.
for (auto &pass : passes)
@ -141,7 +140,7 @@ LogicalResult detail::ModulePassExecutor::run(Module *module,
/// Utility to run the given function and analysis manager on a provided
/// function pass executor.
static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
Function *func,
Function func,
FunctionAnalysisManager &fam) {
// Run the function pipeline over the provided function.
auto result = fpe.run(func, fam);
@ -158,14 +157,14 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe,
/// module.
void ModuleToFunctionPassAdaptor::runOnModule() {
ModuleAnalysisManager &mam = getAnalysisManager();
for (auto &func : getModule()) {
for (auto func : getModule()) {
// Skip external functions.
if (func.isExternal())
continue;
// Run the held function pipeline over the current function.
auto fam = mam.slice(&func);
if (failed(runFunctionPipeline(fpe, &func, fam)))
auto fam = mam.slice(func);
if (failed(runFunctionPipeline(fpe, func, fam)))
return signalPassFailure();
// Clear out any computed function analyses. These analyses won't be used
@ -189,10 +188,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
// Run a prepass over the module to collect the functions to execute a over.
// This ensures that an analysis manager exists for each function, as well as
// providing a queue of functions to execute over.
std::vector<std::pair<Function *, FunctionAnalysisManager>> funcAMPairs;
for (auto &func : getModule())
std::vector<std::pair<Function, FunctionAnalysisManager>> funcAMPairs;
for (auto func : getModule())
if (!func.isExternal())
funcAMPairs.emplace_back(&func, mam.slice(&func));
funcAMPairs.emplace_back(func, mam.slice(func));
// A parallel diagnostic handler that provides deterministic diagnostic
// ordering.
@ -340,8 +339,8 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const {
}
/// Create an analysis slice for the given child function.
FunctionAnalysisManager ModuleAnalysisManager::slice(Function *func) {
assert(func->getModule() == moduleAnalyses.getIRUnit() &&
FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) {
assert(func.getModule() == moduleAnalyses.getIRUnit() &&
"function has a different parent module");
auto it = functionAnalyses.find(func);
if (it == functionAnalyses.end()) {

View File

@ -48,7 +48,7 @@ public:
FunctionPassExecutor(const FunctionPassExecutor &rhs);
/// Run the executor on the given function.
LogicalResult run(Function *function, FunctionAnalysisManager &fam);
LogicalResult run(Function function, FunctionAnalysisManager &fam);
/// Add a pass to the current executor. This takes ownership over the provided
/// pass pointer.

View File

@ -71,7 +71,7 @@ void AddDefaultStatsPass::runOnFunction() {
void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
auto &func = getFunction();
auto func = getFunction();
// Insert stats for each argument.
for (auto *arg : func.getArguments()) {

View File

@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() {
void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
const TargetConfiguration &config) {
CAGSlice cag(solverContext);
for (auto &f : getModule()) {
for (auto f : getModule()) {
f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
}
config.finalizeAnchors(cag);

View File

@ -58,7 +58,7 @@ public:
void RemoveInstrumentationPass::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
auto func = getFunction();
auto *context = &getContext();
patterns.push_back(
llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));

View File

@ -36,11 +36,11 @@ using namespace mlir;
// block. The created block will be terminated by `std.return`.
Block *createOneBlockFunction(Builder builder, Module *module) {
auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
auto *fn = new Function(builder.getUnknownLoc(), "spirv_module", fnType);
module->getFunctions().push_back(fn);
auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType);
module->push_back(fn);
auto *block = new Block();
fn->push_back(block);
fn.push_back(block);
OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName());
ReturnOp::build(&builder, &state);

View File

@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
// wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
// to take in the SPIR-V ModuleOp directly after module and function are
// migrated to be general ops.
for (auto &fn : *module) {
for (auto fn : *module) {
fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
if (done) {
spirvModule.emitError("found more than one 'spv.module' op");

View File

@ -42,7 +42,7 @@ class StdOpsToSPIRVConversionPass
void StdOpsToSPIRVConversionPass::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
auto func = getFunction();
populateWithGenerated(func.getContext(), &patterns);
applyPatternsGreedily(func, std::move(patterns));

View File

@ -440,14 +440,14 @@ static LogicalResult verify(CallOp op) {
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
// Verify that the operand and result types match the callee.
auto fnType = fn->getType();
auto fnType = fn.getType();
if (fnType.getNumInputs() != op.getNumOperands())
return op.emitOpError("incorrect number of operands for callee");
@ -1107,13 +1107,13 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function.
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError("reference to undefined function 'bar'");
// Check that the referenced function has the correct type.
if (fn->getType() != type)
if (fn.getType() != type)
return op.emitOpError("reference to function with mismatched type");
return success();
@ -1876,10 +1876,10 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
}
static LogicalResult verify(ReturnOp op) {
auto *function = op.getOperation()->getFunction();
auto function = op.getOperation()->getFunction();
// The operand number and types must match the function signature.
const auto &results = function->getType().getResults();
const auto &results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError("has ")
<< op.getNumOperands()

View File

@ -69,7 +69,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
// Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
// function as a kernel.
for (Function &func : m) {
for (Function func : m) {
if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName()))
continue;
@ -89,20 +89,21 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
return llvmModule;
}
static TranslateFromMLIRRegistration registration(
"mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) {
if (!module)
return true;
static TranslateFromMLIRRegistration
registration("mlir-to-nvvmir",
[](Module *module, llvm::StringRef outputFilename) {
if (!module)
return true;
auto llvmModule = mlir::translateModuleToNVVMIR(*module);
if (!llvmModule)
return true;
auto llvmModule = mlir::translateModuleToNVVMIR(*module);
if (!llvmModule)
return true;
auto file = openOutputFile(outputFilename);
if (!file)
return true;
auto file = openOutputFile(outputFilename);
if (!file)
return true;
llvmModule->print(file->os(), nullptr);
file->keep();
return false;
});
llvmModule->print(file->os(), nullptr);
file->keep();
return false;
});

View File

@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) {
bool ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
for (Function &function : mlirModule) {
for (Function function : mlirModule) {
mlir::BoolAttr isVarArgsAttr =
function.getAttrOfType<BoolAttr>("std.varargs");
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() {
}
// Convert functions.
for (Function &function : mlirModule) {
for (Function function : mlirModule) {
// Ignore external functions.
if (function.isExternal())
continue;

View File

@ -40,7 +40,7 @@ struct Canonicalizer : public FunctionPass<Canonicalizer> {
void Canonicalizer::runOnFunction() {
OwningRewritePatternList patterns;
auto &func = getFunction();
auto func = getFunction();
// TODO: Instead of adding all known patterns from the whole system lazily add
// and cache the canonicalization patterns for ops we see in practice when

View File

@ -849,7 +849,7 @@ struct FunctionConverter {
/// error, success otherwise. If 'signatureConversion' is provided, the
/// arguments of the entry block are updated accordingly.
LogicalResult
convertFunction(Function *f,
convertFunction(Function f,
TypeConverter::SignatureConversion *signatureConversion);
/// Converts the given region starting from the entry block and following the
@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
}
LogicalResult FunctionConverter::convertFunction(
Function *f, TypeConverter::SignatureConversion *signatureConversion) {
Function f, TypeConverter::SignatureConversion *signatureConversion) {
// If this is an external function, there is nothing else to do.
if (f->isExternal())
if (f.isExternal())
return success();
DialectConversionRewriter rewriter(f->getBody(), typeConverter);
DialectConversionRewriter rewriter(f.getBody(), typeConverter);
// Update the signature of the entry block.
if (signatureConversion) {
rewriter.argConverter.convertSignature(
&f->getBody().front(), *signatureConversion, rewriter.mapping);
&f.getBody().front(), *signatureConversion, rewriter.mapping);
}
// Rewrite the function body.
if (failed(
convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) {
convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) {
// Reset any of the generated rewrites.
rewriter.discardRewrites();
return failure();
@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const
// applyConversionPatterns
//===----------------------------------------------------------------------===//
namespace {
/// This class represents a function to be converted. It allows for converting
/// the body of functions and the signature in two phases.
struct ConvertedFunction {
ConvertedFunction(Function *fn, FunctionType newType,
ArrayRef<NamedAttributeList> newFunctionArgAttrs)
: fn(fn), newType(newType),
newFunctionArgAttrs(newFunctionArgAttrs.begin(),
newFunctionArgAttrs.end()) {}
/// The function to convert.
Function *fn;
/// The new type and argument attributes for the function.
FunctionType newType;
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
};
} // end anonymous namespace
/// Convert the given module with the provided conversion patterns and type
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
@ -1149,37 +1131,33 @@ LogicalResult
mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
std::vector<Function *> allFunctions;
allFunctions.reserve(module.getFunctions().size());
for (auto &func : module)
allFunctions.push_back(&func);
SmallVector<Function, 32> allFunctions(module.getFunctions());
return applyConversionPatterns(allFunctions, target, converter,
std::move(patterns));
}
/// Convert the given functions with the provided conversion patterns.
LogicalResult mlir::applyConversionPatterns(
ArrayRef<Function *> fns, ConversionTarget &target,
MutableArrayRef<Function> fns, ConversionTarget &target,
TypeConverter &converter, OwningRewritePatternList &&patterns) {
if (fns.empty())
return success();
// Build the function converter.
FunctionConverter funcConverter(fns.front()->getContext(), target, patterns,
&converter);
auto *ctx = fns.front().getContext();
FunctionConverter funcConverter(ctx, target, patterns, &converter);
// Try to convert each of the functions within the module.
auto *ctx = fns.front()->getContext();
for (auto *func : fns) {
for (auto func : fns) {
// Convert the function type using the type converter.
auto conversion =
converter.convertSignature(func->getType(), func->getAllArgAttrs());
converter.convertSignature(func.getType(), func.getAllArgAttrs());
if (!conversion)
return failure();
// Update the function signature.
func->setType(conversion->getConvertedType(ctx));
func->setAllArgAttrs(conversion->getConvertedArgAttrs());
func.setType(conversion->getConvertedType(ctx));
func.setAllArgAttrs(conversion->getConvertedArgAttrs());
// Convert the body of this function.
if (failed(funcConverter.convertFunction(func, &*conversion)))
@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns(
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LogicalResult
mlir::applyConversionPatterns(Function &fn, ConversionTarget &target,
mlir::applyConversionPatterns(Function fn, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
FunctionConverter converter(fn.getContext(), target, patterns);
return converter.convertFunction(&fn, /*signatureConversion=*/nullptr);
return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
}

View File

@ -214,7 +214,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
emitRemarkForBlock(Block &block) {
auto *op = block.getContainingOp();
return op ? op->emitRemark() : block.getFunction()->emitRemark();
return op ? op->emitRemark() : block.getFunction().emitRemark();
}
/// Creates a buffer in the faster memory space for the specified region;
@ -246,8 +246,8 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, Block *block,
OpBuilder &b = region.isWrite() ? epilogue : prologue;
// Builder to create constants at the top level.
auto *func = block->getFunction();
OpBuilder top(func->getBody());
auto func = block->getFunction();
OpBuilder top(func.getBody());
auto loc = region.loc;
auto *memref = region.memref;
@ -751,14 +751,14 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) {
if (auto *op = block->getContainingOp())
op->emitError(str);
else
block->getFunction()->emitError(str);
block->getFunction().emitError(str);
}
return totalDmaBuffersSizeInBytes;
}
void DmaGeneration::runOnFunction() {
Function &f = getFunction();
Function f = getFunction();
OpBuilder topBuilder(f.getBody());
zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);

View File

@ -257,7 +257,7 @@ public:
// Initializes the dependence graph based on operations in 'f'.
// Returns true on success, false otherwise.
bool init(Function &f);
bool init(Function f);
// Returns the graph node for 'id'.
Node *getNode(unsigned id) {
@ -637,7 +637,7 @@ public:
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(Function &f) {
bool MemRefDependenceGraph::init(Function f) {
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
// Builder to create constants at the top level.
OpBuilder top(forInst->getFunction()->getBody());
OpBuilder top(forInst->getFunction().getBody());
// Create new memref type based on slice bounds.
auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
@ -1750,9 +1750,9 @@ public:
};
// Search for siblings which load the same memref function argument.
auto *fn = dstNode->op->getFunction();
for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
for (auto *user : fn->getArgument(i)->getUsers()) {
auto fn = dstNode->op->getFunction();
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
for (auto *user : fn.getArgument(i)->getUsers()) {
if (auto loadOp = dyn_cast<LoadOp>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;

View File

@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
// Identify valid and profitable bands of loops to tile. This is currently just
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
static void getTileableBands(Function &f,
static void getTileableBands(Function f,
std::vector<SmallVector<AffineForOp, 6>> *bands) {
// Get maximal perfect nest of 'affine.for' insts starting from root
// (inclusive).

View File

@ -92,8 +92,8 @@ void LoopUnroll::runOnFunction() {
// Store innermost loops as we walk.
std::vector<AffineForOp> loops;
void walkPostOrder(Function *f) {
for (auto &b : *f)
void walkPostOrder(Function f) {
for (auto &b : f)
walkPostOrder(b.begin(), b.end());
}
@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() {
? clUnrollNumRepetitions
: 1;
// If the call back is provided, we will recurse until no loops are found.
Function &func = getFunction();
Function func = getFunction();
for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
InnermostLoopGatherer ilg;
ilg.walkPostOrder(&func);
ilg.walkPostOrder(func);
auto &loops = ilg.loops;
if (loops.empty())
break;

View File

@ -726,7 +726,7 @@ public:
} // end namespace
LogicalResult mlir::lowerAffineConstructs(Function &function) {
LogicalResult mlir::lowerAffineConstructs(Function function) {
OwningRewritePatternList patterns;
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,

View File

@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state,
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
@ -667,7 +667,7 @@ static bool emitSlice(MaterializationState *state,
/// because we currently disallow vectorization of defs that come from another
/// scope.
/// TODO(ntv): please document return value.
static bool materialize(Function *f, const SetVector<Operation *> &terminators,
static bool materialize(Function f, const SetVector<Operation *> &terminators,
MaterializationState *state) {
DenseSet<Operation *> seen;
DominanceInfo domInfo(f);
@ -721,7 +721,7 @@ static bool materialize(Function *f, const SetVector<Operation *> &terminators,
return true;
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
LLVM_DEBUG(f->print(dbgs()));
LLVM_DEBUG(f.print(dbgs()));
}
return false;
}
@ -731,13 +731,13 @@ void MaterializeVectorsPass::runOnFunction() {
NestedPatternContext mlContext;
// TODO(ntv): Check to see if this supports arbitrary top-level code.
Function *f = &getFunction();
if (f->getBlocks().size() != 1)
Function f = getFunction();
if (f.getBlocks().size() != 1)
return;
using matcher::Op;
LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
LLVM_DEBUG(f->print(dbgs()));
LLVM_DEBUG(f.print(dbgs()));
MaterializationState state(hwVectorSize);
// Get the hardware vector type.

View File

@ -212,7 +212,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) {
void MemRefDataFlowOpt::runOnFunction() {
// Only supports single block functions at the moment.
Function &f = getFunction();
Function f = getFunction();
if (f.getBlocks().size() != 1) {
markAllAnalysesPreserved();
return;

View File

@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass<StripDebugInfo> {
} // end anonymous namespace
void StripDebugInfo::runOnFunction() {
Function &func = getFunction();
Function func = getFunction();
auto unknownLoc = UnknownLoc::get(&getContext());
// Strip the debug info from the function and its operations.

View File

@ -44,7 +44,7 @@ namespace {
/// applies the locally optimal patterns in a roughly "bottom up" way.
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(Function &fn,
explicit GreedyPatternRewriteDriver(Function fn,
OwningRewritePatternList &&patterns)
: PatternRewriter(fn.getBody()), matcher(std::move(patterns)) {
worklist.reserve(64);
@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) {
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
///
bool mlir::applyPatternsGreedily(Function &fn,
bool mlir::applyPatternsGreedily(Function fn,
OwningRewritePatternList &&patterns) {
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
bool converged = driver.simplifyFunction(maxPatternMatchIterations);

View File

@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
Operation *op = forOp.getOperation();
if (!iv->use_empty()) {
if (forOp.hasConstantLowerBound()) {
OpBuilder topBuilder(op->getFunction()->getBody());
OpBuilder topBuilder(op->getFunction().getBody());
auto constOp = topBuilder.create<ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv->replaceAllUsesWith(constOp);

View File

@ -1194,7 +1194,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m,
/// Applies vectorization to the current Function by searching over a bunch of
/// predetermined patterns.
void Vectorize::runOnFunction() {
Function &f = getFunction();
Function f = getFunction();
if (!fastestVaryingPattern.empty() &&
fastestVaryingPattern.size() != vectorSizes.size()) {
f.emitRemark("Fastest varying pattern specified with different size than "
@ -1220,7 +1220,7 @@ void Vectorize::runOnFunction() {
unsigned patternDepth = pat.getDepth();
SmallVector<NestedMatch, 8> matches;
pat.match(&f, &matches);
pat.match(f, &matches);
// Iterate over all the top-level matches and vectorize eagerly.
// This automatically prunes intersecting matches.
for (auto m : matches) {

View File

@ -53,13 +53,13 @@ std::string DOTGraphTraits<Function *>::getNodeLabel(Block *Block, Function *) {
} // end namespace llvm
void mlir::viewGraph(Function &function, const llvm::Twine &name,
void mlir::viewGraph(Function function, const llvm::Twine &name,
bool shortNames, const llvm::Twine &title,
llvm::GraphProgram::Name program) {
llvm::ViewGraph(&function, name, shortNames, title, program);
}
llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function,
llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function,
bool shortNames, const llvm::Twine &title) {
return llvm::WriteGraph(os, &function, shortNames, title);
}

View File

@ -43,13 +43,12 @@ static MLIRContext &globalContext() {
return context;
}
static std::unique_ptr<Function> makeFunction(StringRef name,
ArrayRef<Type> results = {},
ArrayRef<Type> args = {}) {
static Function makeFunction(StringRef name, ArrayRef<Type> results = {},
ArrayRef<Type> args = {}) {
auto &ctx = globalContext();
auto function = llvm::make_unique<Function>(
UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx));
function->addEntryBlock();
auto function = Function::create(UnknownLoc::get(&ctx), name,
FunctionType::get(args, results, &ctx));
function.addEntryBlock();
return function;
}
@ -62,10 +61,10 @@ TEST_FUNC(builder_dynamic_for_func_args) {
auto f =
makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)),
ub(f->getArgument(1));
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)),
ub(f.getArgument(1));
ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type));
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
ValueHandle i7(constant_int(7, 32));
@ -102,7 +101,8 @@ TEST_FUNC(builder_dynamic_for_func_args) {
// CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32
// CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(builder_dynamic_for) {
@ -113,10 +113,10 @@ TEST_FUNC(builder_dynamic_for) {
auto f = makeFunction("builder_dynamic_for", {},
{indexType, indexType, indexType, indexType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
c(f->getArgument(2)), d(f->getArgument(3));
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
c(f.getArgument(2)), d(f.getArgument(3));
LoopBuilder(&i, a - b, c + d, 2)();
// clang-format off
@ -125,7 +125,8 @@ TEST_FUNC(builder_dynamic_for) {
// CHECK-DAG: [[r1:%[0-9]+]] = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3]
// CHECK-NEXT: affine.for %i0 = (d0) -> (d0)([[r0]]) to (d0) -> (d0)([[r1]]) step 2 {
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(builder_max_min_for) {
@ -136,10 +137,10 @@ TEST_FUNC(builder_max_min_for) {
auto f = makeFunction("builder_max_min_for", {},
{indexType, indexType, indexType, indexType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
ub1(f->getArgument(2)), ub2(f->getArgument(3));
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)),
ub1(f.getArgument(2)), ub2(f.getArgument(3));
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
ret();
@ -148,7 +149,8 @@ TEST_FUNC(builder_max_min_for) {
// CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) {
// CHECK: return
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(builder_blocks) {
@ -157,14 +159,14 @@ TEST_FUNC(builder_blocks) {
using namespace edsc::op;
auto f = makeFunction("builder_blocks");
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
arg4(c1.getType()), r(c1.getType());
BlockHandle b1, b2, functionBlock(&f->front());
BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1, &arg2})(
// b2 has not yet been constructed, need to come back later.
// This is a byproduct of non-structured control-flow.
@ -192,7 +194,8 @@ TEST_FUNC(builder_blocks) {
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
// CHECK-NEXT: }
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(builder_blocks_eager) {
@ -201,8 +204,8 @@ TEST_FUNC(builder_blocks_eager) {
using namespace edsc::op;
auto f = makeFunction("builder_blocks_eager");
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
@ -235,7 +238,8 @@ TEST_FUNC(builder_blocks_eager) {
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
// CHECK-NEXT: }
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(builder_cond_branch) {
@ -244,15 +248,15 @@ TEST_FUNC(builder_cond_branch) {
auto f = makeFunction("builder_cond_branch", {},
{IntegerType::get(1, &globalContext())});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
ValueHandle funcArg(f->getArgument(0));
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
BlockHandle b1, b2, functionBlock(&f->front());
BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1})([&] { ret(); });
BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); });
// Get back to entry block and add a conditional branch
@ -271,7 +275,8 @@ TEST_FUNC(builder_cond_branch) {
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
// CHECK-NEXT: return
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(builder_cond_branch_eager) {
@ -281,9 +286,9 @@ TEST_FUNC(builder_cond_branch_eager) {
auto f = makeFunction("builder_cond_branch_eager", {},
{IntegerType::get(1, &globalContext())});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
ValueHandle funcArg(f->getArgument(0));
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
@ -309,7 +314,8 @@ TEST_FUNC(builder_cond_branch_eager) {
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
// CHECK-NEXT: return
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(builder_helpers) {
@ -321,14 +327,14 @@ TEST_FUNC(builder_helpers) {
auto f =
makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle f7(
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
vC(f->getArgument(2));
IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)),
vC(f.getArgument(2));
IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2;
int64_t step0, step1, step2;
std::tie(lb0, ub0, step0) = vA.range(0);
@ -363,7 +369,8 @@ TEST_FUNC(builder_helpers) {
// CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32
// CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref<?x?x?xf32>
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(custom_ops) {
@ -373,8 +380,8 @@ TEST_FUNC(custom_ops) {
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
@ -382,7 +389,7 @@ TEST_FUNC(custom_ops) {
// clang-format off
ValueHandle vh(indexType), vh20(indexType), vh21(indexType);
OperationHandle ih0, ih2;
IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1));
IndexHandle ten(index_t(10)), twenty(index_t(20));
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
@ -402,7 +409,8 @@ TEST_FUNC(custom_ops) {
// CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
// CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(insertion_in_block) {
@ -412,8 +420,8 @@ TEST_FUNC(insertion_in_block) {
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("insertion_in_block", {}, {indexType, indexType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
BlockHandle b1;
// clang-format off
ValueHandle::create<ConstantIntOp>(0, 32);
@ -427,7 +435,8 @@ TEST_FUNC(insertion_in_block) {
// CHECK: ^bb1: // no predecessors
// CHECK: {{.*}} = constant 1 : i32
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
TEST_FUNC(select_op) {
@ -438,12 +447,12 @@ TEST_FUNC(select_op) {
auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle zero = constant_index(0), one = constant_index(1);
MemRefView vA(f->getArgument(0));
IndexedValue A(f->getArgument(0));
MemRefView vA(f.getArgument(0));
IndexedValue A(f.getArgument(0));
IndexHandle i, j;
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
// This test exercises IndexedValue::operator Value*.
@ -461,7 +470,8 @@ TEST_FUNC(select_op) {
// CHECK-DAG: {{.*}} = load
// CHECK-NEXT: {{.*}} = select
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
// Inject an EDSC-constructed computation to exercise imperfectly nested 2-d
@ -474,12 +484,11 @@ TEST_FUNC(tile_2d) {
MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType});
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle zero = constant_index(0);
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
vC(f->getArgument(2));
IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
// clang-format off
@ -531,7 +540,8 @@ TEST_FUNC(tile_2d) {
// CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32
// CHECK-NEXT: store {{.*}}, {{.*}}[%i8, %i9, %i7] : memref<?x?x?xf32>
// clang-format on
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
// Inject an EDSC-constructed computation to exercise 2-d vectorization.
@ -544,16 +554,15 @@ TEST_FUNC(vectorize_2d) {
auto owningF =
makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
mlir::Function *f = owningF.release();
mlir::Function f = owningF;
mlir::Module module(&globalContext());
module.getFunctions().push_back(f);
module.push_back(f);
OpBuilder builder(f->getBody());
ScopedContext scope(builder, f->getLoc());
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
ValueHandle zero = constant_index(0);
MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
vC(f->getArgument(2));
IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2));
// clang-format off
@ -580,9 +589,10 @@ TEST_FUNC(vectorize_2d) {
pm.addPass(mlir::createCanonicalizerPass());
SmallVector<int64_t, 2> vectorSizes{4, 4};
pm.addPass(mlir::createVectorizePass(vectorSizes));
auto result = pm.run(f->getModule());
auto result = pm.run(f.getModule());
if (succeeded(result))
f->print(llvm::outs());
f.print(llvm::outs());
f.erase();
}
int main() {

Some files were not shown because too many files have changed in this diff Show More