forked from OSchip/llvm-project
Use lambdas for nesting edsc constructs.
Using ArrayRef introduces issues with the order of evaluation between a constructor and the arguments of the subsequent calls to the `operator()`. As a consequence the order of captures is not well-defined can go wrong with certain compilers (e.g. gcc-6.4). This CL fixes the issue by using lambdas in lieu of ArrayRef. -- PiperOrigin-RevId: 249114775
This commit is contained in:
parent
70f85c0bbf
commit
fdbbb3c274
|
@ -107,8 +107,7 @@ public:
|
||||||
llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
|
llvm::ArrayRef<mlir::edsc::ValueHandle> indexings);
|
||||||
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
|
LoopNestRangeBuilder(llvm::ArrayRef<mlir::edsc::ValueHandle *> ivs,
|
||||||
llvm::ArrayRef<mlir::Value *> indexings);
|
llvm::ArrayRef<mlir::Value *> indexings);
|
||||||
mlir::edsc::ValueHandle
|
mlir::edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||||
operator()(llvm::ArrayRef<mlir::edsc::CapturableHandle> stmts);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
|
llvm::SmallVector<mlir::edsc::LoopBuilder, 4> loops;
|
||||||
|
|
|
@ -59,7 +59,9 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||||
indexings.begin(), indexings.end())) {}
|
indexings.begin(), indexings.end())) {}
|
||||||
|
|
||||||
ValueHandle linalg::common::LoopNestRangeBuilder::operator()(
|
ValueHandle linalg::common::LoopNestRangeBuilder::operator()(
|
||||||
llvm::ArrayRef<CapturableHandle> stmts) {
|
std::function<void(void)> fun) {
|
||||||
|
if (fun)
|
||||||
|
fun();
|
||||||
for (auto &lit : llvm::reverse(loops)) {
|
for (auto &lit : llvm::reverse(loops)) {
|
||||||
lit({});
|
lit({});
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,14 +112,11 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
|
||||||
ScopedContext scope(builder, op->getLoc());
|
ScopedContext scope(builder, op->getLoc());
|
||||||
IndexHandle i;
|
IndexHandle i;
|
||||||
using linalg::common::LoopNestRangeBuilder;
|
using linalg::common::LoopNestRangeBuilder;
|
||||||
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
|
LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))(
|
||||||
[&i, &vA, &vB, &vC]() {
|
[&i, &vA, &vB, &vC]() {
|
||||||
ValueHandle sliceA = slice(vA, i, 0);
|
ValueHandle sliceA = slice(vA, i, 0);
|
||||||
ValueHandle sliceC = slice(vC, i, 0);
|
ValueHandle sliceC = slice(vC, i, 0);
|
||||||
dot(sliceA, vB, sliceC);
|
dot(sliceA, vB, sliceC);
|
||||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
|
||||||
return IndexHandle();
|
|
||||||
}()
|
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
@ -188,14 +185,11 @@ void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
|
||||||
FuncBuilder builder(op);
|
FuncBuilder builder(op);
|
||||||
ScopedContext scope(builder, op->getLoc());
|
ScopedContext scope(builder, op->getLoc());
|
||||||
IndexHandle j;
|
IndexHandle j;
|
||||||
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
|
LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))(
|
||||||
[&j, &vA, &vB, &vC]() {
|
[&j, &vA, &vB, &vC]() {
|
||||||
ValueHandle sliceB = slice(vB, j, 1);
|
ValueHandle sliceB = slice(vB, j, 1);
|
||||||
ValueHandle sliceC = slice(vC, j, 1);
|
ValueHandle sliceC = slice(vC, j, 1);
|
||||||
matvec(vA, sliceB, sliceC);
|
matvec(vA, sliceB, sliceC);
|
||||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
|
||||||
return IndexHandle();
|
|
||||||
}()
|
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
|
@ -177,18 +177,15 @@ writeContractionAsLoops(ContractionOp contraction) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
using linalg::common::LoopNestRangeBuilder;
|
using linalg::common::LoopNestRangeBuilder;
|
||||||
ArrayRef<Value *> ranges(loopRanges);
|
ArrayRef<Value *> ranges(loopRanges);
|
||||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
|
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&]{
|
||||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
|
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
|
||||||
[&contraction, ¶llelIvs, &reductionIvs]() {
|
[&contraction, ¶llelIvs, &reductionIvs] {
|
||||||
SmallVector<mlir::Value *, 4> parallel(
|
SmallVector<mlir::Value *, 4> parallel(
|
||||||
parallelIvs.begin(), parallelIvs.end());
|
parallelIvs.begin(), parallelIvs.end());
|
||||||
SmallVector<mlir::Value *, 4> reduction(
|
SmallVector<mlir::Value *, 4> reduction(
|
||||||
reductionIvs.begin(), reductionIvs.end());
|
reductionIvs.begin(), reductionIvs.end());
|
||||||
contraction.emitScalarImplementation(parallel, reduction);
|
contraction.emitScalarImplementation(parallel, reduction);
|
||||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
});
|
||||||
return IndexHandle();
|
|
||||||
}()
|
|
||||||
})
|
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
|
|
@ -158,15 +158,12 @@ writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
|
||||||
using linalg::common::LoopNestRangeBuilder;
|
using linalg::common::LoopNestRangeBuilder;
|
||||||
auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
|
auto ranges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
|
||||||
getRanges(contraction), tileSizes);
|
getRanges(contraction), tileSizes);
|
||||||
linalg::common::LoopNestRangeBuilder(pivs, ranges)({
|
linalg::common::LoopNestRangeBuilder(pivs, ranges)(
|
||||||
[&contraction, &tileSizes, &ivs]() {
|
[&contraction, &tileSizes, &ivs]() {
|
||||||
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
||||||
auto views = makeTiledViews(contraction, ivValues, tileSizes);
|
auto views = makeTiledViews(contraction, ivValues, tileSizes);
|
||||||
ScopedContext::getBuilder()->create<ConcreteOp>(
|
ScopedContext::getBuilder()->create<ConcreteOp>(
|
||||||
ScopedContext::getLocation(), views);
|
ScopedContext::getLocation(), views);
|
||||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
|
||||||
return IndexHandle();
|
|
||||||
}()
|
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
|
|
@ -109,12 +109,13 @@ public:
|
||||||
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
||||||
IndexHandle i, j, M(vRes.ub(0));
|
IndexHandle i, j, M(vRes.ub(0));
|
||||||
if (vRes.rank() == 1) {
|
if (vRes.rank() == 1) {
|
||||||
LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)});
|
LoopNestBuilder({&i}, {zero}, {M},
|
||||||
|
{1})([&] { iRes(i) = iLHS(i) + iRHS(i); });
|
||||||
} else {
|
} else {
|
||||||
assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
|
assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
|
||||||
IndexHandle N(vRes.ub(1));
|
IndexHandle N(vRes.ub(1));
|
||||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
|
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
|
||||||
{1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)});
|
{1, 1})([&] { iRes(i, j) = iLHS(i, j) + iRHS(i, j); });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the newly allocated buffer, with a type.cast to preserve the
|
// Return the newly allocated buffer, with a type.cast to preserve the
|
||||||
|
@ -156,23 +157,23 @@ public:
|
||||||
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
|
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
|
||||||
if (vOp.rank() == 1) {
|
if (vOp.rank() == 1) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
LoopBuilder(&i, zero, M, 1)({
|
LoopBuilder(&i, zero, M, 1)([&]{
|
||||||
llvmCall(retTy,
|
llvmCall(retTy,
|
||||||
rewriter.getFunctionAttr(printfFunc),
|
rewriter.getFunctionAttr(printfFunc),
|
||||||
{fmtCst, iOp(i)})
|
{fmtCst, iOp(i)});
|
||||||
});
|
});
|
||||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
|
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
} else {
|
} else {
|
||||||
IndexHandle N(vOp.ub(1));
|
IndexHandle N(vOp.ub(1));
|
||||||
// clang-format off
|
// clang-format off
|
||||||
LoopBuilder(&i, zero, M, 1)({
|
LoopBuilder(&i, zero, M, 1)([&]{
|
||||||
LoopBuilder(&j, zero, N, 1)({
|
LoopBuilder(&j, zero, N, 1)([&]{
|
||||||
llvmCall(retTy,
|
llvmCall(retTy,
|
||||||
rewriter.getFunctionAttr(printfFunc),
|
rewriter.getFunctionAttr(printfFunc),
|
||||||
{fmtCst, iOp(i, j)})
|
{fmtCst, iOp(i, j)});
|
||||||
}),
|
});
|
||||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol})
|
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
@ -295,8 +296,8 @@ public:
|
||||||
IndexedValue iRes(result), iOperand(operand);
|
IndexedValue iRes(result), iOperand(operand);
|
||||||
IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
|
IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
|
||||||
// clang-format off
|
// clang-format off
|
||||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({
|
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
|
||||||
iRes(i, j) = iOperand(j, i)
|
iRes(i, j) = iOperand(j, i);
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
|
|
@ -95,4 +95,4 @@ The
|
||||||
demonstrates how to construct some simple IR snippets that pass through the
|
demonstrates how to construct some simple IR snippets that pass through the
|
||||||
verifier checks. The example demonstrate how to allocate three memref buffers
|
verifier checks. The example demonstrate how to allocate three memref buffers
|
||||||
from `index` function arguments and use those buffers as backing data structures
|
from `index` function arguments and use those buffers as backing data structures
|
||||||
for views that get passed to
|
for views that get passed to `dot`, `matvec` and `matmul` operations.
|
||||||
|
|
|
@ -53,15 +53,15 @@ structured loop nests.
|
||||||
f13(constant_float(llvm::APFloat(13.0f), f32Type)),
|
f13(constant_float(llvm::APFloat(13.0f), f32Type)),
|
||||||
i7(constant_int(7, 32)),
|
i7(constant_int(7, 32)),
|
||||||
i13(constant_int(13, 32));
|
i13(constant_int(13, 32));
|
||||||
LoopBuilder(&i, lb, ub, 3)({
|
LoopBuilder(&i, lb, ub, 3)([&]{
|
||||||
lb * index_t(3) + ub,
|
lb * index_t(3) + ub;
|
||||||
lb + index_t(3),
|
lb + index_t(3);
|
||||||
LoopBuilder(&j, lb, ub, 2)({
|
LoopBuilder(&j, lb, ub, 2)([&]{
|
||||||
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
|
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
|
||||||
index_t(32)),
|
index_t(32));
|
||||||
((f7 + f13) / f7) % f13 - f7 * f13,
|
((f7 + f13) / f7) % f13 - f7 * f13;
|
||||||
((i7 + i13) / i7) % i13 - i7 * i13,
|
((i7 + i13) / i7) % i13 - i7 * i13;
|
||||||
}),
|
});
|
||||||
});
|
});
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -86,7 +86,8 @@ def AddOp : Op<"x.add">,
|
||||||
auto ivs = IndexHandle::makeIndexHandles(view_A.rank());
|
auto ivs = IndexHandle::makeIndexHandles(view_A.rank());
|
||||||
auto pivs = IndexHandle::makePIndexHandles(ivs);
|
auto pivs = IndexHandle::makePIndexHandles(ivs);
|
||||||
IndexedValue A(arg_A), B(arg_B), C(arg_C);
|
IndexedValue A(arg_A), B(arg_B), C(arg_C);
|
||||||
LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())({
|
LoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())(
|
||||||
|
[&]{
|
||||||
C(ivs) = A(ivs) + B(ivs)
|
C(ivs) = A(ivs) + B(ivs)
|
||||||
});
|
});
|
||||||
}];
|
}];
|
||||||
|
|
|
@ -169,12 +169,9 @@ public:
|
||||||
LoopBuilder &operator=(LoopBuilder &&) = default;
|
LoopBuilder &operator=(LoopBuilder &&) = default;
|
||||||
|
|
||||||
/// The only purpose of this operator is to serve as a sequence point so that
|
/// The only purpose of this operator is to serve as a sequence point so that
|
||||||
/// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is
|
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||||
/// sequenced strictly after the constructor of LoopBuilder.
|
/// scoped within a LoopBuilder.
|
||||||
/// In order to be admissible in a nested ArrayRef<ValueHandle>, operator()
|
ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||||
/// returns a ValueHandle::null() that cannot be captured.
|
|
||||||
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
|
|
||||||
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
|
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
|
||||||
|
@ -184,15 +181,16 @@ public:
|
||||||
/// Usage:
|
/// Usage:
|
||||||
///
|
///
|
||||||
/// ```c++
|
/// ```c++
|
||||||
/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})({
|
/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})(
|
||||||
|
/// [&](){
|
||||||
/// ...
|
/// ...
|
||||||
/// });
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// ```c++
|
/// ```c++
|
||||||
/// LoopNestBuilder({&i}, {lb}, {ub}, {1})({
|
/// LoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){
|
||||||
/// LoopNestBuilder({&j}, {lb}, {ub}, {1})({
|
/// LoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){
|
||||||
/// LoopNestBuilder({&k}, {lb}, {ub}, {1})({
|
/// LoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){
|
||||||
/// ...
|
/// ...
|
||||||
/// }),
|
/// }),
|
||||||
/// }),
|
/// }),
|
||||||
|
@ -203,8 +201,7 @@ public:
|
||||||
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
LoopNestBuilder(ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
|
||||||
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
|
||||||
|
|
||||||
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
|
ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||||
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallVector<LoopBuilder, 4> loops;
|
SmallVector<LoopBuilder, 4> loops;
|
||||||
|
@ -235,9 +232,9 @@ public:
|
||||||
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
|
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
|
||||||
|
|
||||||
/// The only purpose of this operator is to serve as a sequence point so that
|
/// The only purpose of this operator is to serve as a sequence point so that
|
||||||
/// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is
|
/// the evaluation of `fun` (which build IR snippets in a scoped fashion) is
|
||||||
/// sequenced strictly after the constructor of BlockBuilder.
|
/// scoped within a BlockBuilder.
|
||||||
void operator()(ArrayRef<CapturableHandle> stmts);
|
void operator()(std::function<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BlockBuilder(BlockBuilder &) = delete;
|
BlockBuilder(BlockBuilder &) = delete;
|
||||||
|
|
|
@ -34,7 +34,7 @@ public:
|
||||||
llvm::ArrayRef<edsc::ValueHandle> ranges);
|
llvm::ArrayRef<edsc::ValueHandle> ranges);
|
||||||
LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
|
||||||
llvm::ArrayRef<Value *> ranges);
|
llvm::ArrayRef<Value *> ranges);
|
||||||
edsc::ValueHandle operator()(llvm::ArrayRef<edsc::CapturableHandle> stmts);
|
edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::SmallVector<edsc::LoopBuilder, 4> loops;
|
llvm::SmallVector<edsc::LoopBuilder, 4> loops;
|
||||||
|
|
|
@ -184,8 +184,7 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
|
||||||
enter(body, /*prev=*/1);
|
enter(body, /*prev=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueHandle
|
ValueHandle mlir::edsc::LoopBuilder::operator()(std::function<void(void)> fun) {
|
||||||
mlir::edsc::LoopBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
|
||||||
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
||||||
// destructor) because of ordering wrt comma operator.
|
// destructor) because of ordering wrt comma operator.
|
||||||
/// The particular use case concerns nested blocks:
|
/// The particular use case concerns nested blocks:
|
||||||
|
@ -204,6 +203,8 @@ mlir::edsc::LoopBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
||||||
/// }),
|
/// }),
|
||||||
/// });
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
|
if (fun)
|
||||||
|
fun();
|
||||||
exit();
|
exit();
|
||||||
return ValueHandle::null();
|
return ValueHandle::null();
|
||||||
}
|
}
|
||||||
|
@ -222,14 +223,16 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueHandle
|
ValueHandle
|
||||||
mlir::edsc::LoopNestBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
mlir::edsc::LoopNestBuilder::operator()(std::function<void(void)> fun) {
|
||||||
|
if (fun)
|
||||||
|
fun();
|
||||||
// Iterate on the calling operator() on all the loops in the nest.
|
// Iterate on the calling operator() on all the loops in the nest.
|
||||||
// The iteration order is from innermost to outermost because enter/exit needs
|
// The iteration order is from innermost to outermost because enter/exit needs
|
||||||
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
|
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
|
||||||
// occurs on calling operator()). The asymmetry is required for properly
|
// occurs on calling operator()). The asymmetry is required for properly
|
||||||
// nesting imperfectly nested regions (see LoopBuilder::operator()).
|
// nesting imperfectly nested regions (see LoopBuilder::operator()).
|
||||||
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
|
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
|
||||||
(*lit)({});
|
(*lit)();
|
||||||
}
|
}
|
||||||
return ValueHandle::null();
|
return ValueHandle::null();
|
||||||
}
|
}
|
||||||
|
@ -258,9 +261,11 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
|
||||||
|
|
||||||
/// Only serves as an ordering point between entering nested block and creating
|
/// Only serves as an ordering point between entering nested block and creating
|
||||||
/// stmts.
|
/// stmts.
|
||||||
void mlir::edsc::BlockBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
|
void mlir::edsc::BlockBuilder::operator()(std::function<void(void)> fun) {
|
||||||
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
// Call to `exit` must be explicit and asymmetric (cannot happen in the
|
||||||
// destructor) because of ordering wrt comma operator.
|
// destructor) because of ordering wrt comma operator.
|
||||||
|
if (fun)
|
||||||
|
fun();
|
||||||
exit();
|
exit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ OperationHandle mlir::edsc::intrinsics::br(BlockHandle *bh,
|
||||||
ArrayRef<ValueHandle> operands) {
|
ArrayRef<ValueHandle> operands) {
|
||||||
assert(!*bh && "Unexpected already captured BlockHandle");
|
assert(!*bh && "Unexpected already captured BlockHandle");
|
||||||
enforceEmptyCapturesMatchOperands(captures, operands);
|
enforceEmptyCapturesMatchOperands(captures, operands);
|
||||||
BlockBuilder(bh, captures)({/* no body */});
|
BlockBuilder(bh, captures)(/* no body */);
|
||||||
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
|
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
|
||||||
return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
|
return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
|
||||||
}
|
}
|
||||||
|
@ -77,8 +77,8 @@ OperationHandle mlir::edsc::intrinsics::cond_br(
|
||||||
assert(!*falseBranch && "Unexpected already captured BlockHandle");
|
assert(!*falseBranch && "Unexpected already captured BlockHandle");
|
||||||
enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands);
|
enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands);
|
||||||
enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
|
enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
|
||||||
BlockBuilder(trueBranch, trueCaptures)({/* no body */});
|
BlockBuilder(trueBranch, trueCaptures)(/* no body */);
|
||||||
BlockBuilder(falseBranch, falseCaptures)({/* no body */});
|
BlockBuilder(falseBranch, falseCaptures)(/* no body */);
|
||||||
SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
|
SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
|
||||||
SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
|
SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
|
||||||
return OperationHandle::create<CondBranchOp>(
|
return OperationHandle::create<CondBranchOp>(
|
||||||
|
|
|
@ -109,6 +109,10 @@ ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
|
||||||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
|
parser->resolveOperands(sizeInfo, bufferType, result->operands));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// ForOp.
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// LoadOp.
|
// LoadOp.
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -632,6 +636,7 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
|
||||||
AffineMap::get(3, 0, {i, j}, {})};
|
AffineMap::get(3, 0, {i, j}, {})};
|
||||||
llvm_unreachable("Missing loopToOperandRangesMaps for op");
|
llvm_unreachable("Missing loopToOperandRangesMaps for op");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ideally this should all be Tablegen'd but there is no good story for op
|
// Ideally this should all be Tablegen'd but there is no good story for op
|
||||||
// expansion directly in MLIR for now.
|
// expansion directly in MLIR for now.
|
||||||
void mlir::linalg::emitScalarImplementation(
|
void mlir::linalg::emitScalarImplementation(
|
||||||
|
@ -641,7 +646,7 @@ void mlir::linalg::emitScalarImplementation(
|
||||||
using linalg_store = OperationBuilder<linalg::StoreOp>;
|
using linalg_store = OperationBuilder<linalg::StoreOp>;
|
||||||
using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
|
using IndexedValue = TemplatedIndexedValue<linalg_load, linalg_store>;
|
||||||
assert(reductionIvs.size() == 1);
|
assert(reductionIvs.size() == 1);
|
||||||
auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
|
auto innermostLoop = mlir::getForInductionVarOwner(reductionIvs.back());
|
||||||
auto *body = innermostLoop.getBody();
|
auto *body = innermostLoop.getBody();
|
||||||
using edsc::op::operator+;
|
using edsc::op::operator+;
|
||||||
using edsc::op::operator*;
|
using edsc::op::operator*;
|
||||||
|
|
|
@ -76,18 +76,15 @@ static void emitLinalgOpAsLoops(LinalgOp &linalgOp, FunctionConstants &state) {
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
ArrayRef<Value *> ranges(loopRanges);
|
ArrayRef<Value *> ranges(loopRanges);
|
||||||
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
|
LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))([&] {
|
||||||
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
|
LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))(
|
||||||
[&linalgOp, ¶llelIvs, &reductionIvs]() {
|
[&linalgOp, ¶llelIvs, &reductionIvs] {
|
||||||
SmallVector<mlir::Value *, 4> parallel(
|
SmallVector<mlir::Value *, 4> parallel(
|
||||||
parallelIvs.begin(), parallelIvs.end());
|
parallelIvs.begin(), parallelIvs.end());
|
||||||
SmallVector<mlir::Value *, 4> reduction(
|
SmallVector<mlir::Value *, 4> reduction(
|
||||||
reductionIvs.begin(), reductionIvs.end());
|
reductionIvs.begin(), reductionIvs.end());
|
||||||
emitScalarImplementation(parallel, reduction, linalgOp);
|
mlir::linalg::emitScalarImplementation(parallel, reduction, linalgOp);
|
||||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
});
|
||||||
return IndexHandle();
|
|
||||||
}()
|
|
||||||
})
|
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
@ -101,11 +98,9 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
|
||||||
void LowerLinalgToLoopsPass::runOnFunction() {
|
void LowerLinalgToLoopsPass::runOnFunction() {
|
||||||
auto &f = getFunction();
|
auto &f = getFunction();
|
||||||
FunctionConstants state(f);
|
FunctionConstants state(f);
|
||||||
f.walk([&state](Operation *op) {
|
f.walk<LinalgOp>([&state](LinalgOp linalgOp) {
|
||||||
if (auto linalgOp = dyn_cast<LinalgOp>(op)) {
|
|
||||||
emitLinalgOpAsLoops(linalgOp, state);
|
emitLinalgOpAsLoops(linalgOp, state);
|
||||||
op->erase();
|
linalgOp.getOperation()->erase();
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -202,7 +202,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
|
||||||
|
|
||||||
SmallVector<IndexHandle, 4> ivs(loopRanges.size());
|
SmallVector<IndexHandle, 4> ivs(loopRanges.size());
|
||||||
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
|
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
|
||||||
LoopNestRangeBuilder(pivs, loopRanges)({[&op, &tileSizes, &ivs, &state]() {
|
LoopNestRangeBuilder(pivs, loopRanges)([&op, &tileSizes, &ivs, &state] {
|
||||||
auto *b = ScopedContext::getBuilder();
|
auto *b = ScopedContext::getBuilder();
|
||||||
auto loc = ScopedContext::getLocation();
|
auto loc = ScopedContext::getLocation();
|
||||||
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
||||||
|
@ -214,7 +214,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
|
||||||
op.create(*b, loc, views);
|
op.create(*b, loc, views);
|
||||||
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
/// NestedBuilders expect handles, we thus return an IndexHandle.
|
||||||
return IndexHandle();
|
return IndexHandle();
|
||||||
}()});
|
});
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,9 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
|
||||||
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
|
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
|
||||||
|
|
||||||
ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
|
ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
|
||||||
ArrayRef<CapturableHandle> stmts) {
|
std::function<void(void)> fun) {
|
||||||
|
if (fun)
|
||||||
|
fun();
|
||||||
for (auto &lit : reverse(loops)) {
|
for (auto &lit : reverse(loops)) {
|
||||||
lit({});
|
lit({});
|
||||||
}
|
}
|
||||||
|
|
|
@ -285,9 +285,9 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
|
||||||
ValueHandle tmp = alloc(tmpMemRefType(transfer));
|
ValueHandle tmp = alloc(tmpMemRefType(transfer));
|
||||||
IndexedValue local(tmp);
|
IndexedValue local(tmp);
|
||||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
||||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
|
||||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||||
local(ivs) = remote(clip(transfer, view, ivs)),
|
local(ivs) = remote(clip(transfer, view, ivs));
|
||||||
});
|
});
|
||||||
ValueHandle vectorValue = load(vec, {constant_index(0)});
|
ValueHandle vectorValue = load(vec, {constant_index(0)});
|
||||||
(dealloc(tmp)); // vexing parse
|
(dealloc(tmp)); // vexing parse
|
||||||
|
@ -346,9 +346,9 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
|
||||||
IndexedValue local(tmp);
|
IndexedValue local(tmp);
|
||||||
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
|
||||||
store(vectorValue, vec, {constant_index(0)});
|
store(vectorValue, vec, {constant_index(0)});
|
||||||
LoopNestBuilder(pivs, lbs, ubs, steps)({
|
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
|
||||||
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
|
||||||
remote(clip(transfer, view, ivs)) = local(ivs),
|
remote(clip(transfer, view, ivs)) = local(ivs);
|
||||||
});
|
});
|
||||||
(dealloc(tmp)); // vexing parse...
|
(dealloc(tmp)); // vexing parse...
|
||||||
|
|
||||||
|
|
|
@ -70,15 +70,15 @@ TEST_FUNC(builder_dynamic_for_func_args) {
|
||||||
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
|
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
|
||||||
ValueHandle i7(constant_int(7, 32));
|
ValueHandle i7(constant_int(7, 32));
|
||||||
ValueHandle i13(constant_int(13, 32));
|
ValueHandle i13(constant_int(13, 32));
|
||||||
LoopBuilder(&i, lb, ub, 3)({
|
LoopBuilder(&i, lb, ub, 3)([&] {
|
||||||
lb * index_t(3) + ub,
|
lb *index_t(3) + ub;
|
||||||
lb + index_t(3),
|
lb + index_t(3);
|
||||||
LoopBuilder(&j, lb, ub, 2)({
|
LoopBuilder(&j, lb, ub, 2)([&] {
|
||||||
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
|
ceilDiv(index_t(31) * floorDiv(i + j * index_t(3), index_t(32)),
|
||||||
index_t(32)),
|
index_t(32));
|
||||||
((f7 + f13) / f7) % f13 - f7 * f13,
|
((f7 + f13) / f7) % f13 - f7 *f13;
|
||||||
((i7 + i13) / i7) % i13 - i7 * i13,
|
((i7 + i13) / i7) % i13 - i7 *i13;
|
||||||
}),
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
@ -117,7 +117,7 @@ TEST_FUNC(builder_dynamic_for) {
|
||||||
ScopedContext scope(builder, f->getLoc());
|
ScopedContext scope(builder, f->getLoc());
|
||||||
ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
|
ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
|
||||||
c(f->getArgument(2)), d(f->getArgument(3));
|
c(f->getArgument(2)), d(f->getArgument(3));
|
||||||
LoopBuilder(&i, a - b, c + d, 2)({});
|
LoopBuilder(&i, a - b, c + d, 2)();
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
// CHECK-LABEL: func @builder_dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
// CHECK-LABEL: func @builder_dynamic_for(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||||
|
@ -140,7 +140,7 @@ TEST_FUNC(builder_max_min_for) {
|
||||||
ScopedContext scope(builder, f->getLoc());
|
ScopedContext scope(builder, f->getLoc());
|
||||||
ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
|
ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
|
||||||
ub1(f->getArgument(2)), ub2(f->getArgument(3));
|
ub1(f->getArgument(2)), ub2(f->getArgument(3));
|
||||||
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)({});
|
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
|
||||||
ret();
|
ret();
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
@ -165,24 +165,20 @@ TEST_FUNC(builder_blocks) {
|
||||||
arg4(c1.getType()), r(c1.getType());
|
arg4(c1.getType()), r(c1.getType());
|
||||||
|
|
||||||
BlockHandle b1, b2, functionBlock(&f->front());
|
BlockHandle b1, b2, functionBlock(&f->front());
|
||||||
BlockBuilder(&b1, {&arg1, &arg2})({
|
BlockBuilder(&b1, {&arg1, &arg2})(
|
||||||
// b2 has not yet been constructed, need to come back later.
|
// b2 has not yet been constructed, need to come back later.
|
||||||
// This is a byproduct of non-structured control-flow.
|
// This is a byproduct of non-structured control-flow.
|
||||||
});
|
);
|
||||||
BlockBuilder(&b2, {&arg3, &arg4})({
|
BlockBuilder(&b2, {&arg3, &arg4})([&] { br(b1, {arg3, arg4}); });
|
||||||
br(b1, {arg3, arg4}),
|
|
||||||
});
|
|
||||||
// The insertion point within the toplevel function is now past b2, we will
|
// The insertion point within the toplevel function is now past b2, we will
|
||||||
// need to get back the entry block.
|
// need to get back the entry block.
|
||||||
// This is what happens with unstructured control-flow..
|
// This is what happens with unstructured control-flow..
|
||||||
BlockBuilder(b1, Append())({
|
BlockBuilder(b1, Append())([&] {
|
||||||
r = arg1 + arg2,
|
r = arg1 + arg2;
|
||||||
br(b2, {arg1, r}),
|
br(b2, {arg1, r});
|
||||||
});
|
});
|
||||||
// Get back to entry block and add a branch into b1
|
// Get back to entry block and add a branch into b1
|
||||||
BlockBuilder(functionBlock, Append())({
|
BlockBuilder(functionBlock, Append())([&] { br(b1, {c1, c2}); });
|
||||||
br(b1, {c1, c2}),
|
|
||||||
});
|
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
// CHECK-LABEL: @builder_blocks
|
// CHECK-LABEL: @builder_blocks
|
||||||
|
@ -218,13 +214,13 @@ TEST_FUNC(builder_blocks_eager) {
|
||||||
// Build a new block for b1 eagerly.
|
// Build a new block for b1 eagerly.
|
||||||
br(&b1, {&arg1, &arg2}, {c1, c2});
|
br(&b1, {&arg1, &arg2}, {c1, c2});
|
||||||
// Construct a new block b2 explicitly with a branch into b1.
|
// Construct a new block b2 explicitly with a branch into b1.
|
||||||
BlockBuilder(&b2, {&arg3, &arg4})({
|
BlockBuilder(&b2, {&arg3, &arg4})([&]{
|
||||||
br(b1, {arg3, arg4}),
|
br(b1, {arg3, arg4});
|
||||||
});
|
});
|
||||||
/// And come back to append into b1 once b2 exists.
|
/// And come back to append into b1 once b2 exists.
|
||||||
BlockBuilder(b1, Append())({
|
BlockBuilder(b1, Append())([&]{
|
||||||
r = arg1 + arg2,
|
r = arg1 + arg2;
|
||||||
br(b2, {arg1, r}),
|
br(b2, {arg1, r});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,15 +253,11 @@ TEST_FUNC(builder_cond_branch) {
|
||||||
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
|
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
|
||||||
|
|
||||||
BlockHandle b1, b2, functionBlock(&f->front());
|
BlockHandle b1, b2, functionBlock(&f->front());
|
||||||
BlockBuilder(&b1, {&arg1})({
|
BlockBuilder(&b1, {&arg1})([&] { ret(); });
|
||||||
ret(),
|
BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); });
|
||||||
});
|
|
||||||
BlockBuilder(&b2, {&arg2, &arg3})({
|
|
||||||
ret(),
|
|
||||||
});
|
|
||||||
// Get back to entry block and add a conditional branch
|
// Get back to entry block and add a conditional branch
|
||||||
BlockBuilder(functionBlock, Append())({
|
BlockBuilder(functionBlock, Append())([&] {
|
||||||
cond_br(funcArg, b1, {c32}, b2, {c64, c42}),
|
cond_br(funcArg, b1, {c32}, b2, {c64, c42});
|
||||||
});
|
});
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
@ -300,11 +292,11 @@ TEST_FUNC(builder_cond_branch_eager) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
BlockHandle b1, b2;
|
BlockHandle b1, b2;
|
||||||
cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42});
|
cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42});
|
||||||
BlockBuilder(b1, Append())({
|
BlockBuilder(b1, Append())([]{
|
||||||
ret(),
|
ret();
|
||||||
});
|
});
|
||||||
BlockBuilder(b2, Append())({
|
BlockBuilder(b2, Append())([]{
|
||||||
ret(),
|
ret();
|
||||||
});
|
});
|
||||||
|
|
||||||
// CHECK-LABEL: @builder_cond_branch_eager
|
// CHECK-LABEL: @builder_cond_branch_eager
|
||||||
|
@ -344,13 +336,13 @@ TEST_FUNC(builder_helpers) {
|
||||||
lb2 = vA.lb(2);
|
lb2 = vA.lb(2);
|
||||||
ub2 = vA.ub(2);
|
ub2 = vA.ub(2);
|
||||||
step2 = vA.step(2);
|
step2 = vA.step(2);
|
||||||
LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})({
|
LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{
|
||||||
LoopBuilder(&k1, lb2, ub2, step2)({
|
LoopBuilder(&k1, lb2, ub2, step2)([&]{
|
||||||
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1),
|
C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1);
|
||||||
}),
|
});
|
||||||
LoopBuilder(&k2, lb2, ub2, step2)({
|
LoopBuilder(&k2, lb2, ub2, step2)([&]{
|
||||||
C(i, j, k2) += A(i, j, k2) + B(i, j, k2),
|
C(i, j, k2) += A(i, j, k2) + B(i, j, k2);
|
||||||
}),
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// CHECK-LABEL: @builder_helpers
|
// CHECK-LABEL: @builder_helpers
|
||||||
|
@ -392,14 +384,14 @@ TEST_FUNC(custom_ops) {
|
||||||
OperationHandle ih0, ih2;
|
OperationHandle ih0, ih2;
|
||||||
IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
|
IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
|
||||||
IndexHandle ten(index_t(10)), twenty(index_t(20));
|
IndexHandle ten(index_t(10)), twenty(index_t(20));
|
||||||
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})({
|
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
|
||||||
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}),
|
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
|
||||||
ih0 = MY_CUSTOM_OP_0({m, m + n}, {}),
|
ih0 = MY_CUSTOM_OP_0({m, m + n}, {});
|
||||||
ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType}),
|
ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType});
|
||||||
// These captures are verbose for now, can improve when used in practice.
|
// These captures are verbose for now, can improve when used in practice.
|
||||||
vh20 = ValueHandle(ih2.getOperation()->getResult(0)),
|
vh20 = ValueHandle(ih2.getOperation()->getResult(0));
|
||||||
vh21 = ValueHandle(ih2.getOperation()->getResult(1)),
|
vh21 = ValueHandle(ih2.getOperation()->getResult(1));
|
||||||
MY_CUSTOM_OP({vh20, vh21}, {indexType}, {}),
|
MY_CUSTOM_OP({vh20, vh21}, {indexType}, {});
|
||||||
});
|
});
|
||||||
|
|
||||||
// CHECK-LABEL: @custom_ops
|
// CHECK-LABEL: @custom_ops
|
||||||
|
@ -425,8 +417,8 @@ TEST_FUNC(insertion_in_block) {
|
||||||
BlockHandle b1;
|
BlockHandle b1;
|
||||||
// clang-format off
|
// clang-format off
|
||||||
ValueHandle::create<ConstantIntOp>(0, 32);
|
ValueHandle::create<ConstantIntOp>(0, 32);
|
||||||
BlockBuilder(&b1, {})({
|
BlockBuilder(&b1, {})([]{
|
||||||
ValueHandle::create<ConstantIntOp>(1, 32)
|
ValueHandle::create<ConstantIntOp>(1, 32);
|
||||||
});
|
});
|
||||||
ValueHandle::create<ConstantIntOp>(2, 32);
|
ValueHandle::create<ConstantIntOp>(2, 32);
|
||||||
// CHECK-LABEL: @insertion_in_block
|
// CHECK-LABEL: @insertion_in_block
|
||||||
|
@ -453,12 +445,12 @@ TEST_FUNC(select_op) {
|
||||||
MemRefView vA(f->getArgument(0));
|
MemRefView vA(f->getArgument(0));
|
||||||
IndexedValue A(f->getArgument(0));
|
IndexedValue A(f->getArgument(0));
|
||||||
IndexHandle i, j;
|
IndexHandle i, j;
|
||||||
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})({
|
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
|
||||||
// This test exercises IndexedValue::operator Value*.
|
// This test exercises IndexedValue::operator Value*.
|
||||||
// Without it, one must force conversion to ValueHandle as such:
|
// Without it, one must force conversion to ValueHandle as such:
|
||||||
// edsc::intrinsics::select(
|
// edsc::intrinsics::select(
|
||||||
// i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j)))
|
// i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j)))
|
||||||
edsc::intrinsics::select(i == zero, *A(zero, zero), *A(i, j))
|
edsc::intrinsics::select(i == zero, *A(zero, zero), *A(i, j));
|
||||||
});
|
});
|
||||||
|
|
||||||
// CHECK-LABEL: @select_op
|
// CHECK-LABEL: @select_op
|
||||||
|
@ -491,13 +483,13 @@ TEST_FUNC(tile_2d) {
|
||||||
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
|
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({
|
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
|
||||||
LoopNestBuilder(&k1, zero, O, 1)({
|
LoopNestBuilder(&k1, zero, O, 1)([&]{
|
||||||
C(i, j, k1) = A(i, j, k1) + B(i, j, k1)
|
C(i, j, k1) = A(i, j, k1) + B(i, j, k1);
|
||||||
}),
|
});
|
||||||
LoopNestBuilder(&k2, zero, O, 1)({
|
LoopNestBuilder(&k2, zero, O, 1)([&]{
|
||||||
C(i, j, k2) = A(i, j, k2) + B(i, j, k2)
|
C(i, j, k2) = A(i, j, k2) + B(i, j, k2);
|
||||||
}),
|
});
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
@ -566,8 +558,8 @@ TEST_FUNC(vectorize_2d) {
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
IndexHandle i, j, k;
|
IndexHandle i, j, k;
|
||||||
LoopNestBuilder({&i, &j, &k}, {zero, zero, zero}, {M, N, P}, {1, 1, 1})({
|
LoopNestBuilder({&i, &j, &k}, {zero, zero, zero}, {M, N, P}, {1, 1, 1})([&]{
|
||||||
C(i, j, k) = A(i, j, k) + B(i, j, k)
|
C(i, j, k) = A(i, j, k) + B(i, j, k);
|
||||||
});
|
});
|
||||||
ret();
|
ret();
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue