forked from OSchip/llvm-project
[mlir][Linalg] NFC: Clean up for 0-D abstraction.
Summary: After D75831 has been landed, both the generic op and indexed_generic op can handle 0-D edge case. In the previous patch, only generic op has been updated. This patch updates the lowering to loops for indexed_generic op. Since they are almost the sanme, the patch also refactors the common part. Differential Revision: https://reviews.llvm.org/D76413
This commit is contained in:
parent
2b52e4e629
commit
be4e9db579
|
@ -33,6 +33,7 @@ using namespace mlir::linalg;
|
|||
|
||||
using edsc::op::operator+;
|
||||
using edsc::op::operator==;
|
||||
using mlir::edsc::intrinsics::detail::ValueHandleArray;
|
||||
|
||||
static SmallVector<ValueHandle, 8>
|
||||
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
|
||||
|
@ -81,6 +82,30 @@ SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
|||
return res;
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
static void inlineRegionAndEmitStdStore(OpType op,
|
||||
ArrayRef<Value> indexedValues,
|
||||
ArrayRef<ValueHandleArray> indexing,
|
||||
ArrayRef<Value> outputBuffers) {
|
||||
auto &b = ScopedContext::getBuilder();
|
||||
auto &block = op.region().front();
|
||||
BlockAndValueMapping map;
|
||||
map.map(block.getArguments(), indexedValues);
|
||||
for (auto &op : block.without_terminator()) {
|
||||
assert(op.getNumRegions() == 0 && "expected a non-nested region");
|
||||
auto *newOp = b.clone(op, map);
|
||||
map.map(op.getResults(), newOp->getResults());
|
||||
}
|
||||
|
||||
Operation &terminator = block.back();
|
||||
assert(isa<YieldOp>(terminator) &&
|
||||
"expected an yield op in the end of the region");
|
||||
for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
|
||||
std_store(map.lookup(terminator.getOperand(i)), outputBuffers[i],
|
||||
indexing[i]);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename IndexedValueType, typename LinalgOpType>
|
||||
class LinalgScopedEmitter {};
|
||||
|
@ -300,6 +325,8 @@ public:
|
|||
}
|
||||
|
||||
// 1.b. Emit std_load from output views.
|
||||
// TODO(mravishankar): Avoid the loads if the corresponding argument of the
|
||||
// region has no uses.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = genericOp.getOutputBuffer(i);
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
|
@ -324,24 +351,16 @@ public:
|
|||
}
|
||||
// TODO(ntv): When a region inliner exists, use it.
|
||||
// 2. Inline region, currently only works for a single basic block.
|
||||
BlockAndValueMapping map;
|
||||
auto &block = genericOp.region().front();
|
||||
map.map(block.getArguments(), indexedValues);
|
||||
for (auto &op : block.without_terminator()) {
|
||||
assert(op.getNumRegions() == 0);
|
||||
auto *newOp = b.clone(op, map);
|
||||
map.map(op.getResults(), newOp->getResults());
|
||||
}
|
||||
|
||||
// 3. Emit std_store.
|
||||
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
|
||||
assert(yieldOp->getNumOperands() == nOutputs);
|
||||
SmallVector<ValueHandleArray, 8> indexing;
|
||||
SmallVector<Value, 8> outputBuffers;
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
indexing.emplace_back(makeCanonicalAffineApplies(
|
||||
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(map.lookup(yieldOp->getOperand(i)),
|
||||
genericOp.getOutputBuffer(i), indexing);
|
||||
outputBuffers.push_back(genericOp.getOutputBuffer(i));
|
||||
}
|
||||
inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing,
|
||||
outputBuffers);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -397,25 +416,17 @@ public:
|
|||
// 1.a. Emit std_load from input views.
|
||||
for (unsigned i = 0; i < nInputs; ++i) {
|
||||
Value input = indexedGenericOp.getInput(i);
|
||||
if (input.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
|
||||
indexedValues[nLoops + i] = std_load(input, indexing);
|
||||
} else {
|
||||
indexedValues[nLoops + i] = std_load(input);
|
||||
}
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
|
||||
indexedValues[nLoops + i] = std_load(input, indexing);
|
||||
}
|
||||
|
||||
// 1.b. Emit std_load from output views.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
|
||||
} else {
|
||||
indexedValues[nLoops + nInputs + i] = std_load(output);
|
||||
}
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
|
||||
}
|
||||
|
||||
if (auto funcOp = indexedGenericOp.getFunction()) {
|
||||
|
@ -426,40 +437,24 @@ public:
|
|||
// 3. Emit std_store.
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), output, indexing);
|
||||
} else {
|
||||
std_store(callOp->getResult(i), output);
|
||||
}
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(callOp->getResult(i), output, indexing);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// TODO(ntv): When a region inliner exists, use it.
|
||||
// 2. Inline region, currently only works for a single basic block.
|
||||
BlockAndValueMapping map;
|
||||
auto &block = indexedGenericOp.region().front();
|
||||
map.map(block.getArguments(), indexedValues);
|
||||
for (auto &op : block.without_terminator()) {
|
||||
assert(op.getNumRegions() == 0);
|
||||
auto *newOp = b.clone(op, map);
|
||||
map.map(op.getResults(), newOp->getResults());
|
||||
}
|
||||
|
||||
// 3. Emit std_store.
|
||||
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
|
||||
assert(yieldOp->getNumOperands() == nOutputs);
|
||||
SmallVector<ValueHandleArray, 8> indexing;
|
||||
SmallVector<Value, 8> outputBuffers;
|
||||
for (unsigned i = 0; i < nOutputs; ++i) {
|
||||
Value output = indexedGenericOp.getOutputBuffer(i);
|
||||
if (output.getType().cast<ShapedType>().getRank()) {
|
||||
ValueHandleArray indexing(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
|
||||
} else {
|
||||
std_store(map.lookup(yieldOp->getOperand(i)), output);
|
||||
}
|
||||
indexing.emplace_back(makeCanonicalAffineApplies(
|
||||
b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
|
||||
outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
|
||||
}
|
||||
inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing,
|
||||
outputBuffers);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue