forked from OSchip/llvm-project
Use lambdas for nesting edsc constructs.
Using ArrayRef introduces issues with the order of evaluation between a constructor and the arguments of the subsequent calls to the `operator()`. As a consequence the order of captures is not well-defined can go wrong with certain compilers (e.g. gcc-6.4). This CL fixes the issue by using lambdas in lieu of ArrayRef. -- PiperOrigin-RevId: 249114775
This commit is contained in:
parent
70f85c0bbf
commit
fdbbb3c274
|
@ -107,8 +107,7 @@ public:
|
|||
llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
|
||||
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
|
||||
llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::edsc::ValueHandle
|
||||
operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts);
|
||||
mlir::edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||
|
||||
private:
|
||||
llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
|
||||
|
|
|
@ -59,7 +59,9 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
|||
indexings.begin(), indexings.end())) {}
|
||||
|
||||
ValueHandle linalg::common::LoopNestRangeBuilder::operator()(
|
||||
llvm::ArrayRef<CapturableHandle> stmts) {
|
||||
std::function<void(void)> fun) {
|
||||
if (fun)
|
||||
fun();
|
||||
for (auto &lit : llvm::reverse(loops)) {
|
||||
lit({});
|
||||
}
|
||||
|
|
|
@ -112,14 +112,11 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
|
|||
ScopedContext scope(builder, op->getLoc());
|
||||
IndexHandle i;
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
|
||||
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))(
|
||||
[&i, &vA, &vB, &vC]() {
|
||||
ValueHandle sliceA = slice(vA, i, 0);
|
||||
ValueHandle sliceC = slice(vC, i, 0);
|
||||
dot(sliceA, vB, sliceC);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -188,14 +185,11 @@ void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
|
|||
FuncBuilder builder(op);
|
||||
ScopedContext scope(builder, op->getLoc());
|
||||
IndexHandle j;
|
||||
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
|
||||
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))(
|
||||
[&j, &vA, &vB, &vC]() {
|
||||
ValueHandle sliceB = slice(vB, j, 1);
|
||||
ValueHandle sliceC = slice(vC, j, 1);
|
||||
matvec(vA, sliceB, sliceC);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -177,18 +177,15 @@ writeContractionAsLoops(ContractionOp contraction) {
|
|||
// clang-format off
|
||||
using linalg::common::LoopNestRangeBuilder;
|
||||
ArrayRef<Value *> ranges(loopRanges);
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
|
||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
|
||||
[&contraction, ¶llelIvs, &reductionIvs]() {
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&]{
|
||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
|
||||
[&contraction, ¶llelIvs, &reductionIvs] {
|
||||
SmallVector<mlir::Value *, 4> parallel(
|
||||
parallelIvs.begin(), parallelIvs.end());
|
||||
SmallVector<mlir::Value *, 4> reduction(
|
||||
reductionIvs.begin(), reductionIvs.end());
|
||||
contraction.emitScalarImplementation(parallel, reduction);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
})
|
||||
});
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -158,15 +158,12 @@ writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
|
|||
using linalg::common::LoopNestRangeBuilder;
|
||||
auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
|
||||
getRanges(contraction), tileSizes);
|
||||
linalg::common::LoopNestRangeBuilder(pivs, ranges)({
|
||||
linalg::common::LoopNestRangeBuilder(pivs, ranges)(
|
||||
[&contraction, &tileSizes, &ivs]() {
|
||||
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
||||
auto views = makeTiledViews(contraction, ivValues, tileSizes);
|
||||
ScopedContext::getBuilder()->create<ConcreteOp>(
|
||||
ScopedContext::getLocation(), views);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -109,12 +109,13 @@ public:
|
|||
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
||||
IndexHandle i, j, M(vRes.ub(0));
|
||||
if (vRes.rank() == 1) {
|
||||
LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)});
|
||||
LoopNestBuilder({&i}, {zero}, {M},
|
||||
{1})([&] { iRes(i) = iLHS(i) + iRHS(i); });
|
||||
} else {
|
||||
assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
|
||||
IndexHandle N(vRes.ub(1));
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
|
||||
{1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)});
|
||||
{1, 1})([&] { iRes(i, j) = iLHS(i, j) + iRHS(i, j); });
|
||||
}
|
||||
|
||||
// Return the newly allocated buffer, with a type.cast to preserve the
|
||||
|
@ -156,23 +157,23 @@ public:
|
|||
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
|
||||
if (vOp.rank() == 1) {
|
||||
// clang-format off
|
||||
LoopBuilder(&i, zero, M, 1)({
|
||||
LoopBuilder(&i, zero, M, 1)([&]{
|
||||
llvmCall(retTy,
|
||||
rewriter.getFunctionAttr(printfFunc),
|
||||
{fmtCst, iOp(i)})
|
||||
{fmtCst, iOp(i)});
|
||||
});
|
||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
|
||||
// clang-format on
|
||||
} else {
|
||||
IndexHandle N(vOp.ub(1));
|
||||
// clang-format off
|
||||
LoopBuilder(&i, zero, M, 1)({
|
||||
LoopBuilder(&j, zero, N, 1)({
|
||||
LoopBuilder(&i, zero, M, 1)([&]{
|
||||
LoopBuilder(&j, zero, N, 1)([&]{
|
||||
llvmCall(retTy,
|
||||
rewriter.getFunctionAttr(printfFunc),
|
||||
{fmtCst, iOp(i, j)})
|
||||
}),
|
||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol})
|
||||
{fmtCst, iOp(i, j)});
|
||||
});
|
||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -295,8 +296,8 @@ public:
|
|||
IndexedValue iRes(result), iOperand(operand);
|
||||
IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
|
||||
// clang-format off
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({
|
||||
iRes(i, j) = iOperand(j, i)
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
|
||||
iRes(i, j) = iOperand(j, i);
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -95,4 +95,4 @@ The
|
|||
demonstrates how to construct some simple IR snippets that pass through the
|
||||
verifier checks. The example demonstrate how to allocate three memref buffers
|
||||
from `index` function arguments and use those buffers as backing data structures
|
||||
for views that get passed to
|
||||
for views that get passed to `dot`, `matvec` and `matmul` operations.
|
||||
|
|
|
@ -53,15 +53,15 @@ structured loop nests.
|
|||
f13(constant_float(llvm::APFloat(13.0f), f32Type)),
|
||||
i7(constant_int(7, 32)),
|
||||
i13(constant_int(13, 32));
|
||||
LoopBuilder(&i, lb, ub, 3)({
|
||||
lb * index_t(3) + ub,
|
||||
lb + index_t(3),
|
||||
LoopBuilder(&j, lb, ub, 2)({
|
||||
LoopBuilder(&i, lb, ub, 3)([&]{
|
||||
lb * index_t(3) + ub;
|
||||
lb + index_t(3);
|
||||
LoopBuilder(&j, lb, ub, 2)([&]{
|
||||
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
|
||||
index_t(32)),
|
||||
((f7 + f13) / f7) % f13 - f7 * f13,
|
||||
((i7 + i13) / i7) % i13 - i7 * i13,
|
||||
}),
|
||||
index_t(32));
|
||||
((f7 + f13) / f7) % f13 - f7 * f13;
|
||||
((i7 + i13) / i7) % i13 - i7 * i13;
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
|
@ -86,7 +86,8 @@ def AddOp : Op<"x.add">,
|
|||
auto ivs = IndexHandle::makeIndexHandles(view_A.rank());
|
||||
auto pivs = IndexHandle::makePIndexHandles(ivs);
|
||||
IndexedValue A(arg_A), B(arg_B), C(arg_C);
|
||||
LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())({
|
||||
LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())(
|
||||
[&]{
|
||||
C(ivs) = A(ivs) + B(ivs)
|
||||
});
|
||||
}];
|
||||
|
|
|
@ -169,12 +169,9 @@ public:
|
|||
LoopBuilder &operator=(LoopBuilder &&) = default;
|
||||
|
||||
/// The only purpose of this operator is to serve as a sequence point so that
|
||||
/// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is
|
||||
/// sequenced strictly after the constructor of LoopBuilder.
|
||||
/// In order to be admissible in a nested ArrayRef<ValueHandle>, operator()
|
||||
/// returns a ValueHandle::null() that cannot be captured.
|
||||
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
|
||||
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
|
||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||
/// scoped within a LoopBuilder.
|
||||
ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||
};
|
||||
|
||||
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
|
||||
|
@ -184,15 +181,16 @@ public:
|
|||
/// Usage:
|
||||
///
|
||||
/// ```c++
|
||||
/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})({
|
||||
/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})(
|
||||
/// [&](){
|
||||
/// ...
|
||||
/// });
|
||||
/// ```
|
||||
///
|
||||
/// ```c++
|
||||
/// LoopNestBuilder({&i}, {lb}, {ub}, {1})({
|
||||
/// LoopNestBuilder({&j}, {lb}, {ub}, {1})({
|
||||
/// LoopNestBuilder({&k}, {lb}, {ub}, {1})({
|
||||
/// LoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){
|
||||
/// LoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){
|
||||
/// LoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){
|
||||
/// ...
|
||||
/// }),
|
||||
/// }),
|
||||
|
@ -203,8 +201,7 @@ public:
|
|||
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
||||
|
||||
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
|
||||
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
|
||||
ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||
|
||||
private:
|
||||
SmallVector<LoopBuilder, 4> loops;
|
||||
|
@ -235,9 +232,9 @@ public:
|
|||
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
|
||||
|
||||
/// The only purpose of this operator is to serve as a sequence point so that
|
||||
/// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is
|
||||
/// sequenced strictly after the constructor of BlockBuilder.
|
||||
void operator()(ArrayRef<CapturableHandle> stmts);
|
||||
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||
/// scoped within a BlockBuilder.
|
||||
void operator()(std::function<void(void)> fun = nullptr);
|
||||
|
||||
private:
|
||||
BlockBuilder(BlockBuilder &) = delete;
|
||||
|
|
|
@ -34,7 +34,7 @@ public:
|
|||
llvm::ArrayRef<edsc::ValueHandle> ranges);
|
||||
LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
||||
llvm::ArrayRef<Value *> ranges);
|
||||
edsc::ValueHandle operator()(llvm::ArrayRef<edsc::CapturableHandle> stmts);
|
||||
edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||
|
||||
private:
|
||||
llvm::SmallVector<edsc::LoopBuilder, 4> loops;
|
||||
|
|
|
@ -184,8 +184,7 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
|
|||
enter(body, /*prev=*/1);
|
||||
}
|
||||
|
||||
ValueHandle
|
||||
mlir::edsc::LoopBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
||||
ValueHandle mlir::edsc::LoopBuilder::operator()(std::function<void(void)> fun) {
|
||||
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
||||
// destructor) because of ordering wrt comma operator.
|
||||
/// The particular use case concerns nested blocks:
|
||||
|
@ -204,6 +203,8 @@ mlir::edsc::LoopBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
|||
/// }),
|
||||
/// });
|
||||
/// ```
|
||||
if (fun)
|
||||
fun();
|
||||
exit();
|
||||
return ValueHandle::null();
|
||||
}
|
||||
|
@ -222,14 +223,16 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
|
|||
}
|
||||
|
||||
ValueHandle
|
||||
mlir::edsc::LoopNestBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
||||
mlir::edsc::LoopNestBuilder::operator()(std::function<void(void)> fun) {
|
||||
if (fun)
|
||||
fun();
|
||||
// Iterate on the calling operator() on all the loops in the nest.
|
||||
// The iteration order is from innermost to outermost because enter/exit needs
|
||||
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
|
||||
// occurs on calling operator()). The asymmetry is required for properly
|
||||
// nesting imperfectly nested regions (see LoopBuilder::operator()).
|
||||
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
|
||||
(*lit)({});
|
||||
(*lit)();
|
||||
}
|
||||
return ValueHandle::null();
|
||||
}
|
||||
|
@ -258,9 +261,11 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
|
|||
|
||||
/// Only serves as an ordering point between entering nested block and creating
|
||||
/// stmts.
|
||||
void mlir::edsc::BlockBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
||||
void mlir::edsc::BlockBuilder::operator()(std::function<void(void)> fun) {
|
||||
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
||||
// destructor) because of ordering wrt comma operator.
|
||||
if (fun)
|
||||
fun();
|
||||
exit();
|
||||
}
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ OperationHandle mlir::edsc::intrinsics::br(BlockHandle *bh,
|
|||
ArrayRef<ValueHandle> operands) {
|
||||
assert(!*bh && "Unexpected already captured BlockHandle");
|
||||
enforceEmptyCapturesMatchOperands(captures, operands);
|
||||
BlockBuilder(bh, captures)({/* no body */});
|
||||
BlockBuilder(bh, captures)(/* no body */);
|
||||
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
|
||||
return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
|
||||
}
|
||||
|
@ -77,8 +77,8 @@ OperationHandle mlir::edsc::intrinsics::cond_br(
|
|||
assert(!*falseBranch && "Unexpected already captured BlockHandle");
|
||||
enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands);
|
||||
enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
|
||||
BlockBuilder(trueBranch, trueCaptures)({/* no body */});
|
||||
BlockBuilder(falseBranch, falseCaptures)({/* no body */});
|
||||
BlockBuilder(trueBranch, trueCaptures)(/* no body */);
|
||||
BlockBuilder(falseBranch, falseCaptures)(/* no body */);
|
||||
SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
|
||||
SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
|
||||
return OperationHandle::create<CondBranchOp>(
|
||||
|
|
|
@ -109,6 +109,10 @@ ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
|
|||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// ForOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// LoadOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -632,6 +636,7 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
|
|||
AffineMap::get(3, 0, {i, j}, {})};
|
||||
llvm_unreachable("Missing loopToOperandRangesMaps for op");
|
||||
}
|
||||
|
||||
// Ideally this should all be Tablegen'd but there is no good story for op
|
||||
// expansion directly in MLIR for now.
|
||||
void mlir::linalg::emitScalarImplementation(
|
||||
|
@ -641,7 +646,7 @@ void mlir::linalg::emitScalarImplementation(
|
|||
using linalg_store = OperationBuilder<linalg::StoreOp>;
|
||||
using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
|
||||
assert(reductionIvs.size() == 1);
|
||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
||||
auto innermostLoop = mlir::getForInductionVarOwner(reductionIvs.back());
|
||||
auto *body = innermostLoop.getBody();
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator*;
|
||||
|
|
|
@ -76,18 +76,15 @@ static void emitLinalgOpAsLoops(LinalgOp &linalgOp, FunctionConstants &state) {
|
|||
|
||||
// clang-format off
|
||||
ArrayRef<Value *> ranges(loopRanges);
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
|
||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
|
||||
[&linalgOp, ¶llelIvs, &reductionIvs]() {
|
||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&] {
|
||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
|
||||
[&linalgOp, ¶llelIvs, &reductionIvs] {
|
||||
SmallVector<mlir::Value *, 4> parallel(
|
||||
parallelIvs.begin(), parallelIvs.end());
|
||||
SmallVector<mlir::Value *, 4> reduction(
|
||||
reductionIvs.begin(), reductionIvs.end());
|
||||
emitScalarImplementation(parallel, reduction, linalgOp);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()
|
||||
})
|
||||
mlir::linalg::emitScalarImplementation(parallel, reduction, linalgOp);
|
||||
});
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -101,11 +98,9 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
|
|||
void LowerLinalgToLoopsPass::runOnFunction() {
|
||||
auto &f = getFunction();
|
||||
FunctionConstants state(f);
|
||||
f.walk([&state](Operation *op) {
|
||||
if (auto linalgOp = dyn_cast<LinalgOp>(op)) {
|
||||
f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
|
||||
emitLinalgOpAsLoops(linalgOp, state);
|
||||
op->erase();
|
||||
}
|
||||
linalgOp.getOperation()->erase();
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
|
|||
|
||||
SmallVector<IndexHandle, 4> ivs(loopRanges.size());
|
||||
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
|
||||
LoopNestRangeBuilder(pivs, loopRanges)({[&op, &tileSizes, &ivs, &state]() {
|
||||
LoopNestRangeBuilder(pivs, loopRanges)([&op, &tileSizes, &ivs, &state] {
|
||||
auto *b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
||||
|
@ -214,7 +214,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
|
|||
op.create(*b, loc, views);
|
||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||
return IndexHandle();
|
||||
}()});
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -59,7 +59,9 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
|||
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
|
||||
|
||||
ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
|
||||
ArrayRef<CapturableHandle> stmts) {
|
||||
std::function<void(void)> fun) {
|
||||
if (fun)
|
||||
fun();
|
||||
for (auto &lit : reverse(loops)) {
|
||||
lit({});
|
||||
}
|
||||
|
|
|
@ -285,9 +285,9 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
|
|||
ValueHandle tmp = alloc(tmpMemRefType(transfer));
|
||||
IndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
local(ivs) = remote(clip(transfer, view, ivs)),
|
||||
local(ivs) = remote(clip(transfer, view, ivs));
|
||||
});
|
||||
ValueHandle vectorValue = load(vec, {constant_index(0)});
|
||||
(dealloc(tmp)); // vexing parse
|
||||
|
@ -346,9 +346,9 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
|
|||
IndexedValue local(tmp);
|
||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
||||
store(vectorValue, vec, {constant_index(0)});
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
||||
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
|
||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||
remote(clip(transfer, view, ivs)) = local(ivs),
|
||||
remote(clip(transfer, view, ivs)) = local(ivs);
|
||||
});
|
||||
(dealloc(tmp)); // vexing parse...
|
||||
|
||||
|
|
|
@ -70,15 +70,15 @@ TEST_FUNC(builder_dynamic_for_func_args) {
|
|||
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
|
||||
ValueHandle i7(constant_int(7, 32));
|
||||
ValueHandle i13(constant_int(13, 32));
|
||||
LoopBuilder(&i, lb, ub, 3)({
|
||||
lb * index_t(3) + ub,
|
||||
lb + index_t(3),
|
||||
LoopBuilder(&j, lb, ub, 2)({
|
||||
LoopBuilder(&i, lb, ub, 3)([&] {
|
||||
lb *index_t(3) + ub;
|
||||
lb + index_t(3);
|
||||
LoopBuilder(&j, lb, ub, 2)([&] {
|
||||
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
|
||||
index_t(32)),
|
||||
((f7 + f13) / f7) % f13 - f7 * f13,
|
||||
((i7 + i13) / i7) % i13 - i7 * i13,
|
||||
}),
|
||||
index_t(32));
|
||||
((f7 + f13) / f7) % f13 - f7 *f13;
|
||||
((i7 + i13) / i7) % i13 - i7 *i13;
|
||||
});
|
||||
});
|
||||
|
||||
// clang-format off
|
||||
|
@ -117,7 +117,7 @@ TEST_FUNC(builder_dynamic_for) {
|
|||
ScopedContext scope(builder, f->getLoc());
|
||||
ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
|
||||
c(f->getArgument(2)), d(f->getArgument(3));
|
||||
LoopBuilder(&i, a - b, c + d, 2)({});
|
||||
LoopBuilder(&i, a - b, c + d, 2)();
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @builder_dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||
|
@ -140,7 +140,7 @@ TEST_FUNC(builder_max_min_for) {
|
|||
ScopedContext scope(builder, f->getLoc());
|
||||
ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
|
||||
ub1(f->getArgument(2)), ub2(f->getArgument(3));
|
||||
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)({});
|
||||
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
|
||||
ret();
|
||||
|
||||
// clang-format off
|
||||
|
@ -165,24 +165,20 @@ TEST_FUNC(builder_blocks) {
|
|||
arg4(c1.getType()), r(c1.getType());
|
||||
|
||||
BlockHandle b1, b2, functionBlock(&f->front());
|
||||
BlockBuilder(&b1, {&arg1, &arg2})({
|
||||
BlockBuilder(&b1, {&arg1, &arg2})(
|
||||
// b2 has not yet been constructed, need to come back later.
|
||||
// This is a byproduct of non-structured control-flow.
|
||||
});
|
||||
BlockBuilder(&b2, {&arg3, &arg4})({
|
||||
br(b1, {arg3, arg4}),
|
||||
});
|
||||
);
|
||||
BlockBuilder(&b2, {&arg3, &arg4})([&] { br(b1, {arg3, arg4}); });
|
||||
// The insertion point within the toplevel function is now past b2, we will
|
||||
// need to get back the entry block.
|
||||
// This is what happens with unstructured control-flow..
|
||||
BlockBuilder(b1, Append())({
|
||||
r = arg1 + arg2,
|
||||
br(b2, {arg1, r}),
|
||||
BlockBuilder(b1, Append())([&] {
|
||||
r = arg1 + arg2;
|
||||
br(b2, {arg1, r});
|
||||
});
|
||||
// Get back to entry block and add a branch into b1
|
||||
BlockBuilder(functionBlock, Append())({
|
||||
br(b1, {c1, c2}),
|
||||
});
|
||||
BlockBuilder(functionBlock, Append())([&] { br(b1, {c1, c2}); });
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: @builder_blocks
|
||||
|
@ -218,13 +214,13 @@ TEST_FUNC(builder_blocks_eager) {
|
|||
// Build a new block for b1 eagerly.
|
||||
br(&b1, {&arg1, &arg2}, {c1, c2});
|
||||
// Construct a new block b2 explicitly with a branch into b1.
|
||||
BlockBuilder(&b2, {&arg3, &arg4})({
|
||||
br(b1, {arg3, arg4}),
|
||||
BlockBuilder(&b2, {&arg3, &arg4})([&]{
|
||||
br(b1, {arg3, arg4});
|
||||
});
|
||||
/// And come back to append into b1 once b2 exists.
|
||||
BlockBuilder(b1, Append())({
|
||||
r = arg1 + arg2,
|
||||
br(b2, {arg1, r}),
|
||||
BlockBuilder(b1, Append())([&]{
|
||||
r = arg1 + arg2;
|
||||
br(b2, {arg1, r});
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -257,15 +253,11 @@ TEST_FUNC(builder_cond_branch) {
|
|||
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
|
||||
|
||||
BlockHandle b1, b2, functionBlock(&f->front());
|
||||
BlockBuilder(&b1, {&arg1})({
|
||||
ret(),
|
||||
});
|
||||
BlockBuilder(&b2, {&arg2, &arg3})({
|
||||
ret(),
|
||||
});
|
||||
BlockBuilder(&b1, {&arg1})([&] { ret(); });
|
||||
BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); });
|
||||
// Get back to entry block and add a conditional branch
|
||||
BlockBuilder(functionBlock, Append())({
|
||||
cond_br(funcArg, b1, {c32}, b2, {c64, c42}),
|
||||
BlockBuilder(functionBlock, Append())([&] {
|
||||
cond_br(funcArg, b1, {c32}, b2, {c64, c42});
|
||||
});
|
||||
|
||||
// clang-format off
|
||||
|
@ -300,11 +292,11 @@ TEST_FUNC(builder_cond_branch_eager) {
|
|||
// clang-format off
|
||||
BlockHandle b1, b2;
|
||||
cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42});
|
||||
BlockBuilder(b1, Append())({
|
||||
ret(),
|
||||
BlockBuilder(b1, Append())([]{
|
||||
ret();
|
||||
});
|
||||
BlockBuilder(b2, Append())({
|
||||
ret(),
|
||||
BlockBuilder(b2, Append())([]{
|
||||
ret();
|
||||
});
|
||||
|
||||
// CHECK-LABEL: @builder_cond_branch_eager
|
||||
|
@ -344,13 +336,13 @@ TEST_FUNC(builder_helpers) {
|
|||
lb2 = vA.lb(2);
|
||||
ub2 = vA.ub(2);
|
||||
step2 = vA.step(2);
|
||||
LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})({
|
||||
LoopBuilder(&k1, lb2, ub2, step2)({
|
||||
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1),
|
||||
}),
|
||||
LoopBuilder(&k2, lb2, ub2, step2)({
|
||||
C(i, j, k2) += A(i, j, k2) + B(i, j, k2),
|
||||
}),
|
||||
LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{
|
||||
LoopBuilder(&k1, lb2, ub2, step2)([&]{
|
||||
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1);
|
||||
});
|
||||
LoopBuilder(&k2, lb2, ub2, step2)([&]{
|
||||
C(i, j, k2) += A(i, j, k2) + B(i, j, k2);
|
||||
});
|
||||
});
|
||||
|
||||
// CHECK-LABEL: @builder_helpers
|
||||
|
@ -392,14 +384,14 @@ TEST_FUNC(custom_ops) {
|
|||
OperationHandle ih0, ih2;
|
||||
IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
|
||||
IndexHandle ten(index_t(10)), twenty(index_t(20));
|
||||
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})({
|
||||
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}),
|
||||
ih0 = MY_CUSTOM_OP_0({m, m + n}, {}),
|
||||
ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType}),
|
||||
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
|
||||
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
|
||||
ih0 = MY_CUSTOM_OP_0({m, m + n}, {});
|
||||
ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType});
|
||||
// These captures are verbose for now, can improve when used in practice.
|
||||
vh20 = ValueHandle(ih2.getOperation()->getResult(0)),
|
||||
vh21 = ValueHandle(ih2.getOperation()->getResult(1)),
|
||||
MY_CUSTOM_OP({vh20, vh21}, {indexType}, {}),
|
||||
vh20 = ValueHandle(ih2.getOperation()->getResult(0));
|
||||
vh21 = ValueHandle(ih2.getOperation()->getResult(1));
|
||||
MY_CUSTOM_OP({vh20, vh21}, {indexType}, {});
|
||||
});
|
||||
|
||||
// CHECK-LABEL: @custom_ops
|
||||
|
@ -425,8 +417,8 @@ TEST_FUNC(insertion_in_block) {
|
|||
BlockHandle b1;
|
||||
// clang-format off
|
||||
ValueHandle::create<ConstantIntOp>(0, 32);
|
||||
BlockBuilder(&b1, {})({
|
||||
ValueHandle::create<ConstantIntOp>(1, 32)
|
||||
BlockBuilder(&b1, {})([]{
|
||||
ValueHandle::create<ConstantIntOp>(1, 32);
|
||||
});
|
||||
ValueHandle::create<ConstantIntOp>(2, 32);
|
||||
// CHECK-LABEL: @insertion_in_block
|
||||
|
@ -453,12 +445,12 @@ TEST_FUNC(select_op) {
|
|||
MemRefView vA(f->getArgument(0));
|
||||
IndexedValue A(f->getArgument(0));
|
||||
IndexHandle i, j;
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})({
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
|
||||
// This test exercises IndexedValue::operator Value*.
|
||||
// Without it, one must force conversion to ValueHandle as such:
|
||||
// edsc::intrinsics::select(
|
||||
// i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j)))
|
||||
edsc::intrinsics::select(i == zero, *A(zero, zero), *A(i, j))
|
||||
edsc::intrinsics::select(i == zero, *A(zero, zero), *A(i, j));
|
||||
});
|
||||
|
||||
// CHECK-LABEL: @select_op
|
||||
|
@ -491,13 +483,13 @@ TEST_FUNC(tile_2d) {
|
|||
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
|
||||
|
||||
// clang-format off
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({
|
||||
LoopNestBuilder(&k1, zero, O, 1)({
|
||||
C(i, j, k1) = A(i, j, k1) + B(i, j, k1)
|
||||
}),
|
||||
LoopNestBuilder(&k2, zero, O, 1)({
|
||||
C(i, j, k2) = A(i, j, k2) + B(i, j, k2)
|
||||
}),
|
||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
|
||||
LoopNestBuilder(&k1, zero, O, 1)([&]{
|
||||
C(i, j, k1) = A(i, j, k1) + B(i, j, k1);
|
||||
});
|
||||
LoopNestBuilder(&k2, zero, O, 1)([&]{
|
||||
C(i, j, k2) = A(i, j, k2) + B(i, j, k2);
|
||||
});
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
|
@ -566,8 +558,8 @@ TEST_FUNC(vectorize_2d) {
|
|||
|
||||
// clang-format off
|
||||
IndexHandle i, j, k;
|
||||
LoopNestBuilder({&i, &j, &k}, {zero, zero, zero}, {M, N, P}, {1, 1, 1})({
|
||||
C(i, j, k) = A(i, j, k) + B(i, j, k)
|
||||
LoopNestBuilder({&i, &j, &k}, {zero, zero, zero}, {M, N, P}, {1, 1, 1})([&]{
|
||||
C(i, j, k) = A(i, j, k) + B(i, j, k);
|
||||
});
|
||||
ret();
|
||||
|
||||
|
|
Loading…
Reference in New Issue