Use lambdas for nesting edsc constructs.

Using ArrayRef introduces issues with the order of evaluation between a constructor and
    the arguments of the subsequent calls to the `operator()`.
    As a consequence the order of captures is not well-defined can go wrong with certain compilers (e.g. gcc-6.4).
    This CL fixes the issue by using lambdas in lieu of ArrayRef.

--

PiperOrigin-RevId: 249114775
This commit is contained in:
Nicolas Vasilache 2019-05-20 13:32:35 -07:00 committed by Mehdi Amini
parent 70f85c0bbf
commit fdbbb3c274
18 changed files with 150 additions and 163 deletions

View File

@ -107,8 +107,7 @@ public:
llvm::ArrayRef<mlir::edsc::ValueHandle> indexings); llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs, LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
llvm::ArrayRef<mlir::Value *> indexings); llvm::ArrayRef<mlir::Value *> indexings);
mlir::edsc::ValueHandle mlir::edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts);
private: private:
llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops; llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;

View File

@ -59,7 +59,9 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
indexings.begin(), indexings.end())) {} indexings.begin(), indexings.end())) {}
ValueHandle linalg::common::LoopNestRangeBuilder::operator()( ValueHandle linalg::common::LoopNestRangeBuilder::operator()(
llvm::ArrayRef<CapturableHandle> stmts) { std::function<void(void)> fun) {
if (fun)
fun();
for (auto &lit : llvm::reverse(loops)) { for (auto &lit : llvm::reverse(loops)) {
lit({}); lit({});
} }

View File

@ -112,14 +112,11 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
ScopedContext scope(builder, op->getLoc()); ScopedContext scope(builder, op->getLoc());
IndexHandle i; IndexHandle i;
using linalg::common::LoopNestRangeBuilder; using linalg::common::LoopNestRangeBuilder;
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({ LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))(
[&i, &vA, &vB, &vC]() { [&i, &vA, &vB, &vC]() {
ValueHandle sliceA = slice(vA, i, 0); ValueHandle sliceA = slice(vA, i, 0);
ValueHandle sliceC = slice(vC, i, 0); ValueHandle sliceC = slice(vC, i, 0);
dot(sliceA, vB, sliceC); dot(sliceA, vB, sliceC);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()
}); });
// clang-format on // clang-format on
} }
@ -188,14 +185,11 @@ void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
FuncBuilder builder(op); FuncBuilder builder(op);
ScopedContext scope(builder, op->getLoc()); ScopedContext scope(builder, op->getLoc());
IndexHandle j; IndexHandle j;
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({ LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))(
[&j, &vA, &vB, &vC]() { [&j, &vA, &vB, &vC]() {
ValueHandle sliceB = slice(vB, j, 1); ValueHandle sliceB = slice(vB, j, 1);
ValueHandle sliceC = slice(vC, j, 1); ValueHandle sliceC = slice(vC, j, 1);
matvec(vA, sliceB, sliceC); matvec(vA, sliceB, sliceC);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()
}); });
// clang-format on // clang-format on
} }

View File

@ -177,18 +177,15 @@ writeContractionAsLoops(ContractionOp contraction) {
// clang-format off // clang-format off
using linalg::common::LoopNestRangeBuilder; using linalg::common::LoopNestRangeBuilder;
ArrayRef<Value *> ranges(loopRanges); ArrayRef<Value *> ranges(loopRanges);
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({ LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&]{
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({ LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
[&contraction, &parallelIvs, &reductionIvs]() { [&contraction, &parallelIvs, &reductionIvs] {
SmallVector<mlir::Value *, 4> parallel( SmallVector<mlir::Value *, 4> parallel(
parallelIvs.begin(), parallelIvs.end()); parallelIvs.begin(), parallelIvs.end());
SmallVector<mlir::Value *, 4> reduction( SmallVector<mlir::Value *, 4> reduction(
reductionIvs.begin(), reductionIvs.end()); reductionIvs.begin(), reductionIvs.end());
contraction.emitScalarImplementation(parallel, reduction); contraction.emitScalarImplementation(parallel, reduction);
/// NestedBuilders expect handles, we thus return an IndexHandle. });
return IndexHandle();
}()
})
}); });
// clang-format on // clang-format on

View File

@ -158,15 +158,12 @@ writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
using linalg::common::LoopNestRangeBuilder; using linalg::common::LoopNestRangeBuilder;
auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction), auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
getRanges(contraction), tileSizes); getRanges(contraction), tileSizes);
linalg::common::LoopNestRangeBuilder(pivs, ranges)({ linalg::common::LoopNestRangeBuilder(pivs, ranges)(
[&contraction, &tileSizes, &ivs]() { [&contraction, &tileSizes, &ivs]() {
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end()); SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
auto views = makeTiledViews(contraction, ivValues, tileSizes); auto views = makeTiledViews(contraction, ivValues, tileSizes);
ScopedContext::getBuilder()->create<ConcreteOp>( ScopedContext::getBuilder()->create<ConcreteOp>(
ScopedContext::getLocation(), views); ScopedContext::getLocation(), views);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()
}); });
// clang-format on // clang-format on

View File

@ -109,12 +109,13 @@ public:
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
IndexHandle i, j, M(vRes.ub(0)); IndexHandle i, j, M(vRes.ub(0));
if (vRes.rank() == 1) { if (vRes.rank() == 1) {
LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)}); LoopNestBuilder({&i}, {zero}, {M},
{1})([&] { iRes(i) = iLHS(i) + iRHS(i); });
} else { } else {
assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now"); assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
IndexHandle N(vRes.ub(1)); IndexHandle N(vRes.ub(1));
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
{1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)}); {1, 1})([&] { iRes(i, j) = iLHS(i, j) + iRHS(i, j); });
} }
// Return the newly allocated buffer, with a type.cast to preserve the // Return the newly allocated buffer, with a type.cast to preserve the
@ -156,23 +157,23 @@ public:
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
if (vOp.rank() == 1) { if (vOp.rank() == 1) {
// clang-format off // clang-format off
LoopBuilder(&i, zero, M, 1)({ LoopBuilder(&i, zero, M, 1)([&]{
llvmCall(retTy, llvmCall(retTy,
rewriter.getFunctionAttr(printfFunc), rewriter.getFunctionAttr(printfFunc),
{fmtCst, iOp(i)}) {fmtCst, iOp(i)});
}); });
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}); llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
// clang-format on // clang-format on
} else { } else {
IndexHandle N(vOp.ub(1)); IndexHandle N(vOp.ub(1));
// clang-format off // clang-format off
LoopBuilder(&i, zero, M, 1)({ LoopBuilder(&i, zero, M, 1)([&]{
LoopBuilder(&j, zero, N, 1)({ LoopBuilder(&j, zero, N, 1)([&]{
llvmCall(retTy, llvmCall(retTy,
rewriter.getFunctionAttr(printfFunc), rewriter.getFunctionAttr(printfFunc),
{fmtCst, iOp(i, j)}) {fmtCst, iOp(i, j)});
}), });
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol}) llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
}); });
// clang-format on // clang-format on
} }
@ -295,8 +296,8 @@ public:
IndexedValue iRes(result), iOperand(operand); IndexedValue iRes(result), iOperand(operand);
IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1)); IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
// clang-format off // clang-format off
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({ LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
iRes(i, j) = iOperand(j, i) iRes(i, j) = iOperand(j, i);
}); });
// clang-format on // clang-format on

View File

@ -95,4 +95,4 @@ The
demonstrates how to construct some simple IR snippets that pass through the demonstrates how to construct some simple IR snippets that pass through the
verifier checks. The example demonstrate how to allocate three memref buffers verifier checks. The example demonstrate how to allocate three memref buffers
from `index` function arguments and use those buffers as backing data structures from `index` function arguments and use those buffers as backing data structures
for views that get passed to for views that get passed to `dot`, `matvec` and `matmul` operations.

View File

@ -53,15 +53,15 @@ structured loop nests.
f13(constant_float(llvm::APFloat(13.0f), f32Type)), f13(constant_float(llvm::APFloat(13.0f), f32Type)),
i7(constant_int(7, 32)), i7(constant_int(7, 32)),
i13(constant_int(13, 32)); i13(constant_int(13, 32));
LoopBuilder(&i, lb, ub, 3)({ LoopBuilder(&i, lb, ub, 3)([&]{
lb * index_t(3) + ub, lb * index_t(3) + ub;
lb + index_t(3), lb + index_t(3);
LoopBuilder(&j, lb, ub, 2)({ LoopBuilder(&j, lb, ub, 2)([&]{
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)), ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
index_t(32)), index_t(32));
((f7 + f13) / f7) % f13 - f7 * f13, ((f7 + f13) / f7) % f13 - f7 * f13;
((i7 + i13) / i7) % i13 - i7 * i13, ((i7 + i13) / i7) % i13 - i7 * i13;
}), });
}); });
``` ```
@ -86,7 +86,8 @@ def AddOp : Op<"x.add">,
auto ivs = IndexHandle::makeIndexHandles(view_A.rank()); auto ivs = IndexHandle::makeIndexHandles(view_A.rank());
auto pivs = IndexHandle::makePIndexHandles(ivs); auto pivs = IndexHandle::makePIndexHandles(ivs);
IndexedValue A(arg_A), B(arg_B), C(arg_C); IndexedValue A(arg_A), B(arg_B), C(arg_C);
LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())({ LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())(
[&]{
C(ivs) = A(ivs) + B(ivs) C(ivs) = A(ivs) + B(ivs)
}); });
}]; }];

View File

@ -169,12 +169,9 @@ public:
LoopBuilder &operator=(LoopBuilder &&) = default; LoopBuilder &operator=(LoopBuilder &&) = default;
/// The only purpose of this operator is to serve as a sequence point so that /// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// sequenced strictly after the constructor of LoopBuilder. /// scoped within a LoopBuilder.
/// In order to be admissible in a nested ArrayRef<ValueHandle>, operator() ValueHandle operator()(std::function<void(void)> fun = nullptr);
/// returns a ValueHandle::null() that cannot be captured.
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
}; };
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid /// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
@ -184,15 +181,16 @@ public:
/// Usage: /// Usage:
/// ///
/// ```c++ /// ```c++
/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})({ /// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})(
/// [&](){
/// ... /// ...
/// }); /// });
/// ``` /// ```
/// ///
/// ```c++ /// ```c++
/// LoopNestBuilder({&i}, {lb}, {ub}, {1})({ /// LoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){
/// LoopNestBuilder({&j}, {lb}, {ub}, {1})({ /// LoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){
/// LoopNestBuilder({&k}, {lb}, {ub}, {1})({ /// LoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){
/// ... /// ...
/// }), /// }),
/// }), /// }),
@ -203,8 +201,7 @@ public:
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs, LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps); ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
// TODO(ntv): when loops return escaping ssa-values, this should be adapted. ValueHandle operator()(std::function<void(void)> fun = nullptr);
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
private: private:
SmallVector<LoopBuilder, 4> loops; SmallVector<LoopBuilder, 4> loops;
@ -235,9 +232,9 @@ public:
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args); BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
/// The only purpose of this operator is to serve as a sequence point so that /// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
/// sequenced strictly after the constructor of BlockBuilder. /// scoped within a BlockBuilder.
void operator()(ArrayRef<CapturableHandle> stmts); void operator()(std::function<void(void)> fun = nullptr);
private: private:
BlockBuilder(BlockBuilder &) = delete; BlockBuilder(BlockBuilder &) = delete;

View File

@ -34,7 +34,7 @@ public:
llvm::ArrayRef<edsc::ValueHandle> ranges); llvm::ArrayRef<edsc::ValueHandle> ranges);
LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs, LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
llvm::ArrayRef<Value *> ranges); llvm::ArrayRef<Value *> ranges);
edsc::ValueHandle operator()(llvm::ArrayRef<edsc::CapturableHandle> stmts); edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
private: private:
llvm::SmallVector<edsc::LoopBuilder, 4> loops; llvm::SmallVector<edsc::LoopBuilder, 4> loops;

View File

@ -184,8 +184,7 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
enter(body, /*prev=*/1); enter(body, /*prev=*/1);
} }
ValueHandle ValueHandle mlir::edsc::LoopBuilder::operator()(std::function<void(void)> fun) {
mlir::edsc::LoopBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
// Call to `exit` must be explicit and asymmetric (cannot happen in the // Call to `exit` must be explicit and asymmetric (cannot happen in the
// destructor) because of ordering wrt comma operator. // destructor) because of ordering wrt comma operator.
/// The particular use case concerns nested blocks: /// The particular use case concerns nested blocks:
@ -204,6 +203,8 @@ mlir::edsc::LoopBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
/// }), /// }),
/// }); /// });
/// ``` /// ```
if (fun)
fun();
exit(); exit();
return ValueHandle::null(); return ValueHandle::null();
} }
@ -222,14 +223,16 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
} }
ValueHandle ValueHandle
mlir::edsc::LoopNestBuilder::operator()(ArrayRef<CapturableHandle> stmts) { mlir::edsc::LoopNestBuilder::operator()(std::function<void(void)> fun) {
if (fun)
fun();
// Iterate on the calling operator() on all the loops in the nest. // Iterate on the calling operator() on all the loops in the nest.
// The iteration order is from innermost to outermost because enter/exit needs // The iteration order is from innermost to outermost because enter/exit needs
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
// occurs on calling operator()). The asymmetry is required for properly // occurs on calling operator()). The asymmetry is required for properly
// nesting imperfectly nested regions (see LoopBuilder::operator()). // nesting imperfectly nested regions (see LoopBuilder::operator()).
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) { for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
(*lit)({}); (*lit)();
} }
return ValueHandle::null(); return ValueHandle::null();
} }
@ -258,9 +261,11 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
/// Only serves as an ordering point between entering nested block and creating /// Only serves as an ordering point between entering nested block and creating
/// stmts. /// stmts.
void mlir::edsc::BlockBuilder::operator()(ArrayRef<CapturableHandle> stmts) { void mlir::edsc::BlockBuilder::operator()(std::function<void(void)> fun) {
// Call to `exit` must be explicit and asymmetric (cannot happen in the // Call to `exit` must be explicit and asymmetric (cannot happen in the
// destructor) because of ordering wrt comma operator. // destructor) because of ordering wrt comma operator.
if (fun)
fun();
exit(); exit();
} }

View File

@ -52,7 +52,7 @@ OperationHandle mlir::edsc::intrinsics::br(BlockHandle *bh,
ArrayRef<ValueHandle> operands) { ArrayRef<ValueHandle> operands) {
assert(!*bh && "Unexpected already captured BlockHandle"); assert(!*bh && "Unexpected already captured BlockHandle");
enforceEmptyCapturesMatchOperands(captures, operands); enforceEmptyCapturesMatchOperands(captures, operands);
BlockBuilder(bh, captures)({/* no body */}); BlockBuilder(bh, captures)(/* no body */);
SmallVector<Value *, 4> ops(operands.begin(), operands.end()); SmallVector<Value *, 4> ops(operands.begin(), operands.end());
return OperationHandle::create<BranchOp>(bh->getBlock(), ops); return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
} }
@ -77,8 +77,8 @@ OperationHandle mlir::edsc::intrinsics::cond_br(
assert(!*falseBranch && "Unexpected already captured BlockHandle"); assert(!*falseBranch && "Unexpected already captured BlockHandle");
enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands); enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands);
enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands); enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
BlockBuilder(trueBranch, trueCaptures)({/* no body */}); BlockBuilder(trueBranch, trueCaptures)(/* no body */);
BlockBuilder(falseBranch, falseCaptures)({/* no body */}); BlockBuilder(falseBranch, falseCaptures)(/* no body */);
SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end()); SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end()); SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
return OperationHandle::create<CondBranchOp>( return OperationHandle::create<CondBranchOp>(

View File

@ -109,6 +109,10 @@ ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
parser->resolveOperands(sizeInfo, bufferType, result->operands)); parser->resolveOperands(sizeInfo, bufferType, result->operands));
} }
////////////////////////////////////////////////////////////////////////////////
// ForOp.
////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// LoadOp. // LoadOp.
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -632,6 +636,7 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
AffineMap::get(3, 0, {i, j}, {})}; AffineMap::get(3, 0, {i, j}, {})};
llvm_unreachable("Missing loopToOperandRangesMaps for op"); llvm_unreachable("Missing loopToOperandRangesMaps for op");
} }
// Ideally this should all be Tablegen'd but there is no good story for op // Ideally this should all be Tablegen'd but there is no good story for op
// expansion directly in MLIR for now. // expansion directly in MLIR for now.
void mlir::linalg::emitScalarImplementation( void mlir::linalg::emitScalarImplementation(
@ -641,7 +646,7 @@ void mlir::linalg::emitScalarImplementation(
using linalg_store = OperationBuilder<linalg::StoreOp>; using linalg_store = OperationBuilder<linalg::StoreOp>;
using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>; using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
assert(reductionIvs.size() == 1); assert(reductionIvs.size() == 1);
auto innermostLoop = getForInductionVarOwner(reductionIvs.back()); auto innermostLoop = mlir::getForInductionVarOwner(reductionIvs.back());
auto *body = innermostLoop.getBody(); auto *body = innermostLoop.getBody();
using edsc::op::operator+; using edsc::op::operator+;
using edsc::op::operator*; using edsc::op::operator*;

View File

@ -76,18 +76,15 @@ static void emitLinalgOpAsLoops(LinalgOp &linalgOp, FunctionConstants &state) {
// clang-format off // clang-format off
ArrayRef<Value *> ranges(loopRanges); ArrayRef<Value *> ranges(loopRanges);
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({ LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&] {
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({ LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
[&linalgOp, &parallelIvs, &reductionIvs]() { [&linalgOp, &parallelIvs, &reductionIvs] {
SmallVector<mlir::Value *, 4> parallel( SmallVector<mlir::Value *, 4> parallel(
parallelIvs.begin(), parallelIvs.end()); parallelIvs.begin(), parallelIvs.end());
SmallVector<mlir::Value *, 4> reduction( SmallVector<mlir::Value *, 4> reduction(
reductionIvs.begin(), reductionIvs.end()); reductionIvs.begin(), reductionIvs.end());
emitScalarImplementation(parallel, reduction, linalgOp); mlir::linalg::emitScalarImplementation(parallel, reduction, linalgOp);
/// NestedBuilders expect handles, we thus return an IndexHandle. });
return IndexHandle();
}()
})
}); });
// clang-format on // clang-format on
} }
@ -101,11 +98,9 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
void LowerLinalgToLoopsPass::runOnFunction() { void LowerLinalgToLoopsPass::runOnFunction() {
auto &f = getFunction(); auto &f = getFunction();
FunctionConstants state(f); FunctionConstants state(f);
f.walk([&state](Operation *op) { f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
if (auto linalgOp = dyn_cast<LinalgOp>(op)) {
emitLinalgOpAsLoops(linalgOp, state); emitLinalgOpAsLoops(linalgOp, state);
op->erase(); linalgOp.getOperation()->erase();
}
}); });
} }

View File

@ -202,7 +202,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
SmallVector<IndexHandle, 4> ivs(loopRanges.size()); SmallVector<IndexHandle, 4> ivs(loopRanges.size());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs); auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
LoopNestRangeBuilder(pivs, loopRanges)({[&op, &tileSizes, &ivs, &state]() { LoopNestRangeBuilder(pivs, loopRanges)([&op, &tileSizes, &ivs, &state] {
auto *b = ScopedContext::getBuilder(); auto *b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation(); auto loc = ScopedContext::getLocation();
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end()); SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
@ -214,7 +214,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
op.create(*b, loc, views); op.create(*b, loc, views);
/// NestedBuilders expect handles, we thus return an IndexHandle. /// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle(); return IndexHandle();
}()}); });
return success(); return success();
} }

View File

@ -59,7 +59,9 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {} ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
ArrayRef<CapturableHandle> stmts) { std::function<void(void)> fun) {
if (fun)
fun();
for (auto &lit : reverse(loops)) { for (auto &lit : reverse(loops)) {
lit({}); lit({});
} }

View File

@ -285,9 +285,9 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
ValueHandle tmp = alloc(tmpMemRefType(transfer)); ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp); IndexedValue local(tmp);
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer)); ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
LoopNestBuilder(pivs, lbs, ubs, steps)({ LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
local(ivs) = remote(clip(transfer, view, ivs)), local(ivs) = remote(clip(transfer, view, ivs));
}); });
ValueHandle vectorValue = load(vec, {constant_index(0)}); ValueHandle vectorValue = load(vec, {constant_index(0)});
(dealloc(tmp)); // vexing parse (dealloc(tmp)); // vexing parse
@ -346,9 +346,9 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
IndexedValue local(tmp); IndexedValue local(tmp);
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer)); ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
store(vectorValue, vec, {constant_index(0)}); store(vectorValue, vec, {constant_index(0)});
LoopNestBuilder(pivs, lbs, ubs, steps)({ LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
remote(clip(transfer, view, ivs)) = local(ivs), remote(clip(transfer, view, ivs)) = local(ivs);
}); });
(dealloc(tmp)); // vexing parse... (dealloc(tmp)); // vexing parse...

View File

@ -70,15 +70,15 @@ TEST_FUNC(builder_dynamic_for_func_args) {
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type)); ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
ValueHandle i7(constant_int(7, 32)); ValueHandle i7(constant_int(7, 32));
ValueHandle i13(constant_int(13, 32)); ValueHandle i13(constant_int(13, 32));
LoopBuilder(&i, lb, ub, 3)({ LoopBuilder(&i, lb, ub, 3)([&] {
lb * index_t(3) + ub, lb *index_t(3) + ub;
lb + index_t(3), lb + index_t(3);
LoopBuilder(&j, lb, ub, 2)({ LoopBuilder(&j, lb, ub, 2)([&] {
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)), ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
index_t(32)), index_t(32));
((f7 + f13) / f7) % f13 - f7 * f13, ((f7 + f13) / f7) % f13 - f7 *f13;
((i7 + i13) / i7) % i13 - i7 * i13, ((i7 + i13) / i7) % i13 - i7 *i13;
}), });
}); });
// clang-format off // clang-format off
@ -117,7 +117,7 @@ TEST_FUNC(builder_dynamic_for) {
ScopedContext scope(builder, f->getLoc()); ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)), ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
c(f->getArgument(2)), d(f->getArgument(3)); c(f->getArgument(2)), d(f->getArgument(3));
LoopBuilder(&i, a - b, c + d, 2)({}); LoopBuilder(&i, a - b, c + d, 2)();
// clang-format off // clang-format off
// CHECK-LABEL: func @builder_dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { // CHECK-LABEL: func @builder_dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
@ -140,7 +140,7 @@ TEST_FUNC(builder_max_min_for) {
ScopedContext scope(builder, f->getLoc()); ScopedContext scope(builder, f->getLoc());
ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)), ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
ub1(f->getArgument(2)), ub2(f->getArgument(3)); ub1(f->getArgument(2)), ub2(f->getArgument(3));
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)({}); LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
ret(); ret();
// clang-format off // clang-format off
@ -165,24 +165,20 @@ TEST_FUNC(builder_blocks) {
arg4(c1.getType()), r(c1.getType()); arg4(c1.getType()), r(c1.getType());
BlockHandle b1, b2, functionBlock(&f->front()); BlockHandle b1, b2, functionBlock(&f->front());
BlockBuilder(&b1, {&arg1, &arg2})({ BlockBuilder(&b1, {&arg1, &arg2})(
// b2 has not yet been constructed, need to come back later. // b2 has not yet been constructed, need to come back later.
// This is a byproduct of non-structured control-flow. // This is a byproduct of non-structured control-flow.
}); );
BlockBuilder(&b2, {&arg3, &arg4})({ BlockBuilder(&b2, {&arg3, &arg4})([&] { br(b1, {arg3, arg4}); });
br(b1, {arg3, arg4}),
});
// The insertion point within the toplevel function is now past b2, we will // The insertion point within the toplevel function is now past b2, we will
// need to get back the entry block. // need to get back the entry block.
// This is what happens with unstructured control-flow.. // This is what happens with unstructured control-flow..
BlockBuilder(b1, Append())({ BlockBuilder(b1, Append())([&] {
r = arg1 + arg2, r = arg1 + arg2;
br(b2, {arg1, r}), br(b2, {arg1, r});
}); });
// Get back to entry block and add a branch into b1 // Get back to entry block and add a branch into b1
BlockBuilder(functionBlock, Append())({ BlockBuilder(functionBlock, Append())([&] { br(b1, {c1, c2}); });
br(b1, {c1, c2}),
});
// clang-format off // clang-format off
// CHECK-LABEL: @builder_blocks // CHECK-LABEL: @builder_blocks
@ -218,13 +214,13 @@ TEST_FUNC(builder_blocks_eager) {
// Build a new block for b1 eagerly. // Build a new block for b1 eagerly.
br(&b1, {&arg1, &arg2}, {c1, c2}); br(&b1, {&arg1, &arg2}, {c1, c2});
// Construct a new block b2 explicitly with a branch into b1. // Construct a new block b2 explicitly with a branch into b1.
BlockBuilder(&b2, {&arg3, &arg4})({ BlockBuilder(&b2, {&arg3, &arg4})([&]{
br(b1, {arg3, arg4}), br(b1, {arg3, arg4});
}); });
/// And come back to append into b1 once b2 exists. /// And come back to append into b1 once b2 exists.
BlockBuilder(b1, Append())({ BlockBuilder(b1, Append())([&]{
r = arg1 + arg2, r = arg1 + arg2;
br(b2, {arg1, r}), br(b2, {arg1, r});
}); });
} }
@ -257,15 +253,11 @@ TEST_FUNC(builder_cond_branch) {
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
BlockHandle b1, b2, functionBlock(&f->front()); BlockHandle b1, b2, functionBlock(&f->front());
BlockBuilder(&b1, {&arg1})({ BlockBuilder(&b1, {&arg1})([&] { ret(); });
ret(), BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); });
});
BlockBuilder(&b2, {&arg2, &arg3})({
ret(),
});
// Get back to entry block and add a conditional branch // Get back to entry block and add a conditional branch
BlockBuilder(functionBlock, Append())({ BlockBuilder(functionBlock, Append())([&] {
cond_br(funcArg, b1, {c32}, b2, {c64, c42}), cond_br(funcArg, b1, {c32}, b2, {c64, c42});
}); });
// clang-format off // clang-format off
@ -300,11 +292,11 @@ TEST_FUNC(builder_cond_branch_eager) {
// clang-format off // clang-format off
BlockHandle b1, b2; BlockHandle b1, b2;
cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42}); cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42});
BlockBuilder(b1, Append())({ BlockBuilder(b1, Append())([]{
ret(), ret();
}); });
BlockBuilder(b2, Append())({ BlockBuilder(b2, Append())([]{
ret(), ret();
}); });
// CHECK-LABEL: @builder_cond_branch_eager // CHECK-LABEL: @builder_cond_branch_eager
@ -344,13 +336,13 @@ TEST_FUNC(builder_helpers) {
lb2 = vA.lb(2); lb2 = vA.lb(2);
ub2 = vA.ub(2); ub2 = vA.ub(2);
step2 = vA.step(2); step2 = vA.step(2);
LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})({ LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{
LoopBuilder(&k1, lb2, ub2, step2)({ LoopBuilder(&k1, lb2, ub2, step2)([&]{
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1), C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1);
}), });
LoopBuilder(&k2, lb2, ub2, step2)({ LoopBuilder(&k2, lb2, ub2, step2)([&]{
C(i, j, k2) += A(i, j, k2) + B(i, j, k2), C(i, j, k2) += A(i, j, k2) + B(i, j, k2);
}), });
}); });
// CHECK-LABEL: @builder_helpers // CHECK-LABEL: @builder_helpers
@ -392,14 +384,14 @@ TEST_FUNC(custom_ops) {
OperationHandle ih0, ih2; OperationHandle ih0, ih2;
IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1)); IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
IndexHandle ten(index_t(10)), twenty(index_t(20)); IndexHandle ten(index_t(10)), twenty(index_t(20));
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})({ LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}), vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
ih0 = MY_CUSTOM_OP_0({m, m + n}, {}), ih0 = MY_CUSTOM_OP_0({m, m + n}, {});
ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType}), ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType});
// These captures are verbose for now, can improve when used in practice. // These captures are verbose for now, can improve when used in practice.
vh20 = ValueHandle(ih2.getOperation()->getResult(0)), vh20 = ValueHandle(ih2.getOperation()->getResult(0));
vh21 = ValueHandle(ih2.getOperation()->getResult(1)), vh21 = ValueHandle(ih2.getOperation()->getResult(1));
MY_CUSTOM_OP({vh20, vh21}, {indexType}, {}), MY_CUSTOM_OP({vh20, vh21}, {indexType}, {});
}); });
// CHECK-LABEL: @custom_ops // CHECK-LABEL: @custom_ops
@ -425,8 +417,8 @@ TEST_FUNC(insertion_in_block) {
BlockHandle b1; BlockHandle b1;
// clang-format off // clang-format off
ValueHandle::create<ConstantIntOp>(0, 32); ValueHandle::create<ConstantIntOp>(0, 32);
BlockBuilder(&b1, {})({ BlockBuilder(&b1, {})([]{
ValueHandle::create<ConstantIntOp>(1, 32) ValueHandle::create<ConstantIntOp>(1, 32);
}); });
ValueHandle::create<ConstantIntOp>(2, 32); ValueHandle::create<ConstantIntOp>(2, 32);
// CHECK-LABEL: @insertion_in_block // CHECK-LABEL: @insertion_in_block
@ -453,12 +445,12 @@ TEST_FUNC(select_op) {
MemRefView vA(f->getArgument(0)); MemRefView vA(f->getArgument(0));
IndexedValue A(f->getArgument(0)); IndexedValue A(f->getArgument(0));
IndexHandle i, j; IndexHandle i, j;
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})({ LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
// This test exercises IndexedValue::operator Value*. // This test exercises IndexedValue::operator Value*.
// Without it, one must force conversion to ValueHandle as such: // Without it, one must force conversion to ValueHandle as such:
// edsc::intrinsics::select( // edsc::intrinsics::select(
// i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j))) // i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j)))
edsc::intrinsics::select(i == zero, *A(zero, zero), *A(i, j)) edsc::intrinsics::select(i == zero, *A(zero, zero), *A(i, j));
}); });
// CHECK-LABEL: @select_op // CHECK-LABEL: @select_op
@ -491,13 +483,13 @@ TEST_FUNC(tile_2d) {
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
// clang-format off // clang-format off
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({ LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
LoopNestBuilder(&k1, zero, O, 1)({ LoopNestBuilder(&k1, zero, O, 1)([&]{
C(i, j, k1) = A(i, j, k1) + B(i, j, k1) C(i, j, k1) = A(i, j, k1) + B(i, j, k1);
}), });
LoopNestBuilder(&k2, zero, O, 1)({ LoopNestBuilder(&k2, zero, O, 1)([&]{
C(i, j, k2) = A(i, j, k2) + B(i, j, k2) C(i, j, k2) = A(i, j, k2) + B(i, j, k2);
}), });
}); });
// clang-format on // clang-format on
@ -566,8 +558,8 @@ TEST_FUNC(vectorize_2d) {
// clang-format off // clang-format off
IndexHandle i, j, k; IndexHandle i, j, k;
LoopNestBuilder({&i, &j, &k}, {zero, zero, zero}, {M, N, P}, {1, 1, 1})({ LoopNestBuilder({&i, &j, &k}, {zero, zero, zero}, {M, N, P}, {1, 1, 1})([&]{
C(i, j, k) = A(i, j, k) + B(i, j, k) C(i, j, k) = A(i, j, k) + B(i, j, k);
}); });
ret(); ret();