[mlir][Linalg] NFC: Clean up for 0-D abstraction.

Summary:
After D75831 has been landed, both the generic op and indexed_generic op can
handle 0-D edge case. In the previous patch, only generic op has been updated.
This patch updates the lowering to loops for indexed_generic op. Since they are
almost the sanme, the patch also refactors the common part.

Differential Revision: https://reviews.llvm.org/D76413
This commit is contained in:
Hanhan Wang 2020-03-20 13:06:21 -07:00
parent 2b52e4e629
commit be4e9db579
1 changed files with 49 additions and 54 deletions

View File

@ -33,6 +33,7 @@ using namespace mlir::linalg;
using edsc::op::operator+;
using edsc::op::operator==;
using mlir::edsc::intrinsics::detail::ValueHandleArray;
static SmallVector<ValueHandle, 8>
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
@ -81,6 +82,30 @@ SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
return res;
}
template <typename OpType>
static void inlineRegionAndEmitStdStore(OpType op,
ArrayRef<Value> indexedValues,
ArrayRef<ValueHandleArray> indexing,
ArrayRef<Value> outputBuffers) {
auto &b = ScopedContext::getBuilder();
auto &block = op.region().front();
BlockAndValueMapping map;
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
assert(op.getNumRegions() == 0 && "expected a non-nested region");
auto *newOp = b.clone(op, map);
map.map(op.getResults(), newOp->getResults());
}
Operation &terminator = block.back();
assert(isa<YieldOp>(terminator) &&
"expected an yield op in the end of the region");
for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
std_store(map.lookup(terminator.getOperand(i)), outputBuffers[i],
indexing[i]);
}
}
namespace {
template <typename IndexedValueType, typename LinalgOpType>
class LinalgScopedEmitter {};
@ -300,6 +325,8 @@ public:
}
// 1.b. Emit std_load from output views.
// TODO(mravishankar): Avoid the loads if the corresponding argument of the
// region has no uses.
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = genericOp.getOutputBuffer(i);
ValueHandleArray indexing(makeCanonicalAffineApplies(
@ -324,24 +351,16 @@ public:
}
// TODO(ntv): When a region inliner exists, use it.
// 2. Inline region, currently only works for a single basic block.
BlockAndValueMapping map;
auto &block = genericOp.region().front();
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
assert(op.getNumRegions() == 0);
auto *newOp = b.clone(op, map);
map.map(op.getResults(), newOp->getResults());
}
// 3. Emit std_store.
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
SmallVector<ValueHandleArray, 8> indexing;
SmallVector<Value, 8> outputBuffers;
for (unsigned i = 0; i < nOutputs; ++i) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
indexing.emplace_back(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
genericOp.getOutputBuffer(i), indexing);
outputBuffers.push_back(genericOp.getOutputBuffer(i));
}
inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing,
outputBuffers);
}
};
@ -397,25 +416,17 @@ public:
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
Value input = indexedGenericOp.getInput(i);
if (input.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
indexedValues[nLoops + i] = std_load(input, indexing);
} else {
indexedValues[nLoops + i] = std_load(input);
}
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
indexedValues[nLoops + i] = std_load(input, indexing);
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = indexedGenericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
} else {
indexedValues[nLoops + nInputs + i] = std_load(output);
}
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
}
if (auto funcOp = indexedGenericOp.getFunction()) {
@ -426,40 +437,24 @@ public:
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = indexedGenericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), output, indexing);
} else {
std_store(callOp->getResult(i), output);
}
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), output, indexing);
}
return;
}
// TODO(ntv): When a region inliner exists, use it.
// 2. Inline region, currently only works for a single basic block.
BlockAndValueMapping map;
auto &block = indexedGenericOp.region().front();
map.map(block.getArguments(), indexedValues);
for (auto &op : block.without_terminator()) {
assert(op.getNumRegions() == 0);
auto *newOp = b.clone(op, map);
map.map(op.getResults(), newOp->getResults());
}
// 3. Emit std_store.
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
SmallVector<ValueHandleArray, 8> indexing;
SmallVector<Value, 8> outputBuffers;
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = indexedGenericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
} else {
std_store(map.lookup(yieldOp->getOperand(i)), output);
}
indexing.emplace_back(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
}
inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing,
outputBuffers);
}
};