forked from OSchip/llvm-project
Add support for promoting Linalg views into new buffers.
This CL uses the generic CopyOp to promote a subview (constructed during tiling) into a new buffer + copy by: 1. Creating a new buffer for the subview. 2. Taking a view into the buffer and copying into it. 3. Adapting the linalg op to operating on the view from point 2. Tiling is extended with a boolean flag to enable promoting views (all or nothing for now). More specifically, the current implementation creates a buffer that is always of the full size of the ranges of the subview. This produces a buffer whose size may be bigger than the actual size of the `subView` at the boundaries and is related to the full/partial tile problem. In practice, we introduce a `buffer`, a `fullLocalView` and a `partialLocalView` such that: * `buffer` is always the size of the subview in the full tile case. * `fullLocalView` is a dense contiguous view into that buffer. * `partialLocalView` is a dense non-contiguous slice of `fullLocalView` that corresponds to the size of `subView` and accounting for boundary effects. The point of the full tile buffer is that constant static tile sizes are folded and result in a buffer type with statically known size and alignment properties. Padding is introduced on the boundary tiles with a `fill` op followed by a partial `copy` op. These behaviors will be refined later, on a per-need basis. PiperOrigin-RevId: 256237319
This commit is contained in:
parent
c73edeec13
commit
516188bf1c
|
@ -32,7 +32,8 @@ class ModulePassBase;
|
|||
namespace linalg {
|
||||
FunctionPassBase *createLinalgFusionPass(ArrayRef<int64_t> tileSizes = {});
|
||||
|
||||
FunctionPassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
|
||||
FunctionPassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {},
|
||||
bool promoteViews = false);
|
||||
|
||||
FunctionPassBase *createLowerLinalgToLoopsPass();
|
||||
|
||||
|
|
|
@ -22,12 +22,21 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
class BufferAllocOp;
|
||||
class BufferDeallocOp;
|
||||
class CopyOp;
|
||||
class DimOp;
|
||||
class FillOp;
|
||||
class RangeOp;
|
||||
class SliceOp;
|
||||
class ViewOp;
|
||||
namespace intrinsics {
|
||||
using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder<BufferAllocOp>;
|
||||
using buffer_dealloc =
|
||||
mlir::edsc::intrinsics::OperationBuilder<BufferDeallocOp>;
|
||||
using copy = mlir::edsc::intrinsics::OperationBuilder<CopyOp>;
|
||||
using dim = mlir::edsc::intrinsics::ValueBuilder<linalg::DimOp>;
|
||||
using fill = mlir::edsc::intrinsics::OperationBuilder<FillOp>;
|
||||
using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
|
||||
using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
|
||||
using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
|
||||
|
|
|
@ -86,11 +86,47 @@ struct TiledLinalgOp {
|
|||
SmallVector<ForOp, 8> loops;
|
||||
};
|
||||
|
||||
llvm::Optional<TiledLinalgOp>
|
||||
tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes, OperationFolder &state);
|
||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// Inserts scoped local buffers and copies tiled views into/from those buffers
|
||||
/// when the corresponding entry in `viewsToPromote` is true.
|
||||
/// Returns a struct containing the tiled loops and the cloned op if successful,
|
||||
/// llvm::None otherwise.
|
||||
// TODO(ntv) implement a heuristic for view promotion.
|
||||
llvm::Optional<TiledLinalgOp> tileLinalgOp(LinalgOp op,
|
||||
ArrayRef<Value *> tileSizes,
|
||||
OperationFolder &folder,
|
||||
ArrayRef<bool> viewsToPromote = {});
|
||||
|
||||
llvm::Optional<TiledLinalgOp>
|
||||
tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes, OperationFolder &state);
|
||||
/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
|
||||
/// Inserts scoped local buffers and copies tiled views into/from those buffers
|
||||
/// when the corresponding entry in `viewsToPromote` is true.
|
||||
/// Returns a struct containing the tiled loops and the cloned op if successful,
|
||||
/// llvm::None otherwise.
|
||||
// TODO(ntv) implement a heuristic for view promotion.
|
||||
llvm::Optional<TiledLinalgOp> tileLinalgOp(LinalgOp op,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
OperationFolder &folder,
|
||||
ArrayRef<bool> viewsToPromote = {});
|
||||
|
||||
struct PromotionInfo {
|
||||
Value *buffer;
|
||||
Value *fullLocalView;
|
||||
Value *partialLocalView;
|
||||
};
|
||||
|
||||
/// Promotes the `views` into a new buffer allocated at the insertion point `b`.
|
||||
/// For now, promotion occurs in 3 steps:
|
||||
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
|
||||
/// 2. Take a full view on the buffer and `linalg.fill` it with zeros (use
|
||||
/// float zero for now).
|
||||
/// 3. Take a partial slice of the full view in step 2. and copy into it.
|
||||
///
|
||||
/// Returns a list of PromotionInfo which hold the promoted buffer and the
|
||||
/// full and partial views indexing into the buffer.
|
||||
llvm::SmallVector<PromotionInfo, 8> promoteLinalgViews(OpBuilder &b,
|
||||
Location loc,
|
||||
ArrayRef<Value *> views,
|
||||
OperationFolder &folder);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -49,6 +49,10 @@ static llvm::cl::list<unsigned>
|
|||
llvm::cl::desc("Tile sizes by which to tile linalg operations"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
|
||||
llvm::cl::cat(clOptionsCategory));
|
||||
static llvm::cl::opt<bool> clPromoteFullTileViews(
|
||||
"linalg-tile-promote-full-tile-views",
|
||||
llvm::cl::desc("Create scoped local buffers for tiled views "),
|
||||
llvm::cl::init(false), llvm::cl::cat(clOptionsCategory));
|
||||
|
||||
static bool isZero(Value *v) {
|
||||
return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
|
||||
|
@ -65,10 +69,10 @@ static bool isZero(Value *v) {
|
|||
static SmallVector<Value *, 4>
|
||||
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
||||
ArrayRef<Value *> allViewSizes,
|
||||
ArrayRef<Value *> allTileSizes, OperationFolder &state) {
|
||||
ArrayRef<Value *> allTileSizes, OperationFolder &folder) {
|
||||
assert(allTileSizes.size() == map.getNumResults());
|
||||
// Apply `map` to get view sizes in loop order.
|
||||
auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, state);
|
||||
auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder);
|
||||
SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
|
||||
|
||||
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
|
||||
|
@ -83,7 +87,7 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
|
|||
SmallVector<Value *, 4> res;
|
||||
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
|
||||
res.push_back(b.create<RangeOp>(loc,
|
||||
state.create<ConstantIndexOp>(b, loc, 0),
|
||||
folder.create<ConstantIndexOp>(b, loc, 0),
|
||||
viewSizes[idx], tileSizes[idx]));
|
||||
}
|
||||
return res;
|
||||
|
@ -120,7 +124,7 @@ static SmallVector<Value *, 4> makeTiledViews(OpBuilder &b, Location loc,
|
|||
LinalgOp linalgOp,
|
||||
ArrayRef<Value *> ivs,
|
||||
ArrayRef<Value *> tileSizes,
|
||||
OperationFolder &state) {
|
||||
OperationFolder &folder) {
|
||||
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
|
||||
llvm::make_range(tileSizes.begin(), tileSizes.end()),
|
||||
[](Value *v) { return !isZero(v); })) &&
|
||||
|
@ -154,14 +158,14 @@ static SmallVector<Value *, 4> makeTiledViews(OpBuilder &b, Location loc,
|
|||
auto tileSize = tileSizes[pos];
|
||||
if (isZero(tileSize)) {
|
||||
subViewOperands.push_back(
|
||||
SubViewOp::Range{state.create<ConstantIndexOp>(b, loc, 0),
|
||||
SubViewOp::Range{folder.create<ConstantIndexOp>(b, loc, 0),
|
||||
linalg::intrinsics::dim(view, r),
|
||||
state.create<ConstantIndexOp>(b, loc, 1)});
|
||||
folder.create<ConstantIndexOp>(b, loc, 1)});
|
||||
continue;
|
||||
}
|
||||
|
||||
// `tileSizes` of `0` don't have an induction variable counterpart. So
|
||||
// we count the number of zeros ot align the index in `ivs` to pos.
|
||||
// we count the number of zeros to align the index in `ivs` to pos.
|
||||
auto count = llvm::count_if(
|
||||
llvm::make_range(tileSizes.begin(), tileSizes.begin() + pos),
|
||||
[](Value *v) { return isZero(v); });
|
||||
|
@ -177,17 +181,155 @@ static SmallVector<Value *, 4> makeTiledViews(OpBuilder &b, Location loc,
|
|||
// Tiling creates a new slice at the proper index, the slice step is 1
|
||||
// (i.e. the slice view does not subsample, stepping occurs in the loop).
|
||||
subViewOperands.push_back(SubViewOp::Range{
|
||||
iv, steppedLb, state.create<ConstantIndexOp>(b, loc, 1)});
|
||||
iv, steppedLb, folder.create<ConstantIndexOp>(b, loc, 1)});
|
||||
}
|
||||
res.push_back(b.create<SubViewOp>(loc, view, subViewOperands));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static AffineMap getAffineDifferenceMap(MLIRContext *context) {
|
||||
AffineExpr d0(getAffineDimExpr(0, context)), d1(getAffineDimExpr(1, context));
|
||||
return AffineMap::get(2, 0, {d0 - d1});
|
||||
}
|
||||
|
||||
static Value *allocBuffer(Type elementType, Value *size) {
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
|
||||
return buffer_alloc(
|
||||
BufferType::get(size->getContext(), elementType, cst.getValue()));
|
||||
return buffer_alloc(BufferType::get(size->getContext(), elementType), size);
|
||||
}
|
||||
|
||||
// Performs promotion of a `subView` into a local buffer of the size of the
|
||||
// *ranges* of the `subView`. This produces a buffer whose size may be bigger
|
||||
// than the actual size of the `subView` at the boundaries.
|
||||
// This is related to the full/partial tile problem.
|
||||
// Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
|
||||
// `partialLocalView` such that:
|
||||
// * `buffer` is always the size of the full tile.
|
||||
// * `fullLocalView` is a dense contiguous view into that buffer.
|
||||
// * `partialLocalView` is a dense non-contiguous slice of `fullLocalView`
|
||||
// that corresponds to the size of `subView` and accounting for boundary
|
||||
// effects.
|
||||
// The point of the full tile buffer is that constant static tile sizes are
|
||||
// folded and result in a buffer type with statically known size and alignment
|
||||
// properties.
|
||||
// To account for general boundary effects, padding must be performed on the
|
||||
// boundary tiles. For now this is done with an unconditional `fill` op followed
|
||||
// by a partial `copy` op.
|
||||
static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
|
||||
SubViewOp subView,
|
||||
OperationFolder &folder) {
|
||||
auto zero = constant_index(folder, 0);
|
||||
auto one = constant_index(folder, 1);
|
||||
|
||||
auto viewType = subView.getViewType();
|
||||
auto rank = viewType.getRank();
|
||||
Value *allocSize = one;
|
||||
SmallVector<Value *, 8> fullRanges, partialRanges;
|
||||
fullRanges.reserve(rank);
|
||||
partialRanges.reserve(rank);
|
||||
for (auto en : llvm::enumerate(subView.getRanges())) {
|
||||
auto rank = en.index();
|
||||
auto rangeValue = en.value();
|
||||
Value *d =
|
||||
isa<linalg::DimOp>(rangeValue.max->getDefiningOp())
|
||||
? rangeValue.max
|
||||
: applyMapToValues(b, loc, getAffineDifferenceMap(b.getContext()),
|
||||
{rangeValue.max, rangeValue.min}, folder)
|
||||
.front();
|
||||
allocSize = muli(folder, allocSize, d).getValue();
|
||||
fullRanges.push_back(range(folder, zero, d, one));
|
||||
partialRanges.push_back(
|
||||
range(folder, zero, linalg::intrinsics::dim(subView, rank), one));
|
||||
}
|
||||
auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
|
||||
auto fullLocalView = view(buffer, fullRanges);
|
||||
auto partialLocalView = slice(fullLocalView, partialRanges);
|
||||
return PromotionInfo{buffer, fullLocalView, partialLocalView};
|
||||
}
|
||||
|
||||
// Performs promotion of a view `v` into a local buffer of the size of the
|
||||
// view. This produces a buffer whose size is exactky the size of `v`.
|
||||
// Returns a PromotionInfo containing a `buffer`, `fullLocalView` and
|
||||
// `partialLocalView` such that:
|
||||
// * `buffer` is always the size of the view.
|
||||
// * `partialLocalView` is a dense contiguous view into that buffer.
|
||||
// * `fullLocalView` is equal to `partialLocalView`.
|
||||
// The point of the full tile buffer is that constant static tile sizes are
|
||||
// folded and result in a buffer type with statically known size and alignment
|
||||
// properties.
|
||||
static PromotionInfo promotePartialTileBuffer(OpBuilder &b, Location loc,
|
||||
Value *v,
|
||||
OperationFolder &folder) {
|
||||
auto zero = constant_index(folder, 0);
|
||||
auto one = constant_index(folder, 1);
|
||||
|
||||
auto viewType = v->getType().cast<ViewType>();
|
||||
auto rank = viewType.getRank();
|
||||
Value *allocSize = one;
|
||||
SmallVector<Value *, 8> partialRanges;
|
||||
partialRanges.reserve(rank);
|
||||
for (unsigned r = 0; r < rank; ++r) {
|
||||
Value *d = linalg::intrinsics::dim(v, r);
|
||||
allocSize = muli(folder, allocSize, d).getValue();
|
||||
partialRanges.push_back(range(folder, zero, d, one));
|
||||
}
|
||||
auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
|
||||
auto partialLocalView = view(folder, buffer, partialRanges);
|
||||
return PromotionInfo{buffer, partialLocalView, partialLocalView};
|
||||
}
|
||||
|
||||
SmallVector<PromotionInfo, 8>
|
||||
mlir::linalg::promoteLinalgViews(OpBuilder &b, Location loc,
|
||||
ArrayRef<Value *> views,
|
||||
OperationFolder &folder) {
|
||||
if (views.empty())
|
||||
return {};
|
||||
|
||||
ScopedContext scope(b, loc);
|
||||
SmallVector<PromotionInfo, 8> res;
|
||||
res.reserve(views.size());
|
||||
DenseMap<Value *, PromotionInfo> promotionInfo;
|
||||
for (auto *v : views) {
|
||||
PromotionInfo pi;
|
||||
if (auto subView = dyn_cast<SubViewOp>(v->getDefiningOp()))
|
||||
pi = promoteFullTileBuffer(b, loc, subView, folder);
|
||||
else
|
||||
pi = promotePartialTileBuffer(b, loc, v, folder);
|
||||
promotionInfo.insert(std::make_pair(v, pi));
|
||||
res.push_back(pi);
|
||||
}
|
||||
|
||||
for (auto *v : views) {
|
||||
auto info = promotionInfo.find(v);
|
||||
if (info == promotionInfo.end())
|
||||
continue;
|
||||
auto viewType = v->getType().cast<ViewType>();
|
||||
// TODO(ntv): value to fill with should be related to the operation.
|
||||
// For now, just use APFloat(0.0f).
|
||||
auto t = viewType.getElementType().cast<FloatType>();
|
||||
Value *fillVal = constant_float(folder, APFloat(0.0f), t);
|
||||
// TODO(ntv): fill is only necessary if `promotionInfo` has a full local
|
||||
// view that is different from the partial local view and we are on the
|
||||
// boundary.
|
||||
fill(info->second.fullLocalView, fillVal);
|
||||
}
|
||||
|
||||
for (auto *v : views) {
|
||||
auto info = promotionInfo.find(v);
|
||||
if (info == promotionInfo.end())
|
||||
continue;
|
||||
copy(v, info->second.partialLocalView);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
llvm::Optional<TiledLinalgOp>
|
||||
mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
|
||||
OperationFolder &state) {
|
||||
// Enforce the convention that "tiling by zero" skips tiling a particular
|
||||
OperationFolder &folder,
|
||||
ArrayRef<bool> viewsToPromote) {
|
||||
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
|
||||
// dimension. This convention is significantly simpler to handle instead of
|
||||
// adjusting affine maps to account for missing dimensions.
|
||||
assert(op.getNumParallelLoops() + op.getNumReductionLoops() +
|
||||
|
@ -197,37 +339,92 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
|
|||
|
||||
OpBuilder builder(op.getOperation());
|
||||
ScopedContext scope(builder, op.getLoc());
|
||||
// 2. Build the tiled loop ranges.
|
||||
auto loopRanges = makeTiledLoopRanges(
|
||||
scope.getBuilder(), scope.getLocation(),
|
||||
// The flattened loopToOperandRangesMaps is expected to be an invertible
|
||||
// permutation map (which is asserted in the inverse calculation).
|
||||
inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op))),
|
||||
getViewSizes(op), tileSizes, state);
|
||||
getViewSizes(op), tileSizes, folder);
|
||||
|
||||
// 3. Create the tiled loops.
|
||||
LinalgOp res = op;
|
||||
SmallVector<IndexHandle, 4> ivs(loopRanges.size());
|
||||
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
|
||||
LoopNestRangeBuilder(pivs, loopRanges)([&op, &tileSizes, &ivs, &res, &state] {
|
||||
LoopNestRangeBuilder(pivs, loopRanges)([&] {
|
||||
auto b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
|
||||
// If/when the assertion below becomes false, we will have to templatize
|
||||
// `makeTiledViews`.
|
||||
assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
|
||||
auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, state);
|
||||
res = op.create(b, loc, views);
|
||||
auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, folder);
|
||||
|
||||
// If no promotion, we are done.
|
||||
auto promote = !viewsToPromote.empty() &&
|
||||
llvm::any_of(llvm::make_range(viewsToPromote.begin(),
|
||||
viewsToPromote.end()),
|
||||
[](bool b) { return b; });
|
||||
if (!promote) {
|
||||
res = op.create(b, loc, views);
|
||||
return;
|
||||
}
|
||||
|
||||
// 4. Filter the subset of views that need to be promoted.
|
||||
SmallVector<Value *, 8> filteredViews;
|
||||
filteredViews.reserve(views.size());
|
||||
assert(
|
||||
viewsToPromote.empty() ||
|
||||
views.size() == viewsToPromote.size() &&
|
||||
"expected viewsToPromote to be empty or of the same size as view");
|
||||
for (auto it : llvm::zip(views, viewsToPromote)) {
|
||||
if (!std::get<1>(it))
|
||||
continue;
|
||||
filteredViews.push_back(std::get<0>(it));
|
||||
}
|
||||
|
||||
// 5. Promote the specified views and use them in the new op.
|
||||
auto promotedBufferAndViews =
|
||||
promoteLinalgViews(b, loc, filteredViews, folder);
|
||||
SmallVector<Value *, 8> opViews(views.size(), nullptr);
|
||||
SmallVector<Value *, 8> writebackViews(views.size(), nullptr);
|
||||
for (unsigned i = 0, promotedIdx = 0, e = opViews.size(); i < e; ++i) {
|
||||
if (viewsToPromote[i]) {
|
||||
opViews[i] = promotedBufferAndViews[promotedIdx].fullLocalView;
|
||||
writebackViews[i] =
|
||||
promotedBufferAndViews[promotedIdx].partialLocalView;
|
||||
promotedIdx++;
|
||||
} else {
|
||||
opViews[i] = views[i];
|
||||
}
|
||||
}
|
||||
res = op.create(b, loc, opViews);
|
||||
|
||||
// 6. Emit write-back for the promoted output views: copy the partial view.
|
||||
for (unsigned i = 0, e = writebackViews.size(); i < e; ++i) {
|
||||
bool isOutput = res.getIndexOfOutput(opViews[i]).hasValue();
|
||||
if (writebackViews[i] && isOutput)
|
||||
copy(writebackViews[i], views[i]);
|
||||
}
|
||||
|
||||
// 7. Dealloc local buffers.
|
||||
for (const auto &pi : promotedBufferAndViews)
|
||||
buffer_dealloc(pi.buffer);
|
||||
});
|
||||
|
||||
// 7. Gather the newly created loops and return them with the new op.
|
||||
SmallVector<ForOp, 8> loops;
|
||||
loops.reserve(ivs.size());
|
||||
for (auto iv : ivs)
|
||||
loops.push_back(linalg::getForInductionVarOwner(iv));
|
||||
|
||||
return TiledLinalgOp{res, loops};
|
||||
}
|
||||
|
||||
llvm::Optional<TiledLinalgOp>
|
||||
mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
|
||||
OperationFolder &state) {
|
||||
OperationFolder &folder,
|
||||
ArrayRef<bool> viewsToPromote) {
|
||||
if (tileSizes.empty())
|
||||
return llvm::None;
|
||||
|
||||
|
@ -243,26 +440,31 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes,
|
|||
|
||||
// Create a builder for tile size constants.
|
||||
OpBuilder builder(op);
|
||||
auto loc = op.getLoc();
|
||||
ScopedContext scope(builder, op.getLoc());
|
||||
|
||||
// Materialize concrete tile size values to pass the generic tiling function.
|
||||
SmallVector<Value *, 8> tileSizeValues;
|
||||
tileSizeValues.reserve(tileSizes.size());
|
||||
for (auto ts : tileSizes)
|
||||
tileSizeValues.push_back(state.create<ConstantIndexOp>(builder, loc, ts));
|
||||
tileSizeValues.push_back(constant_index(folder, ts));
|
||||
// Pad tile sizes with zero values to enforce our convention.
|
||||
if (tileSizeValues.size() < nLoops) {
|
||||
for (unsigned i = tileSizeValues.size(); i < nLoops; ++i)
|
||||
tileSizeValues.push_back(state.create<ConstantIndexOp>(builder, loc, 0));
|
||||
tileSizeValues.push_back(constant_index(folder, 0));
|
||||
}
|
||||
|
||||
return tileLinalgOp(op, tileSizeValues, state);
|
||||
return tileLinalgOp(op, tileSizeValues, folder, viewsToPromote);
|
||||
}
|
||||
|
||||
static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
|
||||
OperationFolder state;
|
||||
f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) {
|
||||
auto opLoopsPair = tileLinalgOp(op, tileSizes, state);
|
||||
static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes,
|
||||
bool promoteViews) {
|
||||
OperationFolder folder;
|
||||
f.walk<LinalgOp>([promoteViews, tileSizes, &folder](LinalgOp op) {
|
||||
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
|
||||
// nothing.
|
||||
SmallVector<bool, 8> viewsToPromote(op.getNumInputsAndOutputs(),
|
||||
promoteViews);
|
||||
auto opLoopsPair = tileLinalgOp(op, tileSizes, folder, viewsToPromote);
|
||||
// If tiling occurred successfully, erase old op.
|
||||
if (opLoopsPair)
|
||||
op.erase();
|
||||
|
@ -271,28 +473,41 @@ static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes) {
|
|||
|
||||
namespace {
|
||||
struct LinalgTilingPass : public FunctionPass<LinalgTilingPass> {
|
||||
LinalgTilingPass();
|
||||
LinalgTilingPass(ArrayRef<int64_t> sizes);
|
||||
LinalgTilingPass(ArrayRef<int64_t> sizes, bool promoteViews);
|
||||
|
||||
void runOnFunction() { tileLinalgOps(getFunction(), tileSizes); }
|
||||
void runOnFunction() {
|
||||
tileLinalgOps(getFunction(), tileSizes, promoteViews);
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8> tileSizes;
|
||||
bool promoteViews;
|
||||
|
||||
protected:
|
||||
LinalgTilingPass() {}
|
||||
};
|
||||
|
||||
struct LinalgTilingPassCLI : public LinalgTilingPass {
|
||||
LinalgTilingPassCLI();
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LinalgTilingPass::LinalgTilingPass()
|
||||
: tileSizes(clTileSizes.begin(), clTileSizes.end()) {}
|
||||
LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes, bool promoteViews) {
|
||||
this->tileSizes.assign(sizes.begin(), sizes.end());
|
||||
this->promoteViews = promoteViews;
|
||||
}
|
||||
|
||||
LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes)
|
||||
: LinalgTilingPass() {
|
||||
if (!sizes.empty())
|
||||
this->tileSizes.assign(sizes.begin(), sizes.end());
|
||||
LinalgTilingPassCLI::LinalgTilingPassCLI() : LinalgTilingPass() {
|
||||
this->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
|
||||
this->promoteViews = clPromoteFullTileViews;
|
||||
llvm::errs() << "\nAAAA: " << this->promoteViews << " "
|
||||
<< clPromoteFullTileViews;
|
||||
}
|
||||
|
||||
FunctionPassBase *
|
||||
mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
|
||||
return new LinalgTilingPass(tileSizes);
|
||||
mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes,
|
||||
bool promoteViews) {
|
||||
return new LinalgTilingPass(tileSizes, promoteViews);
|
||||
}
|
||||
|
||||
static PassRegistration<LinalgTilingPass>
|
||||
static PassRegistration<LinalgTilingPassCLI>
|
||||
pass("linalg-tile", "Tile operations in the linalg dialect");
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-tile-promote-full-tile-views=true | FileCheck %s -check-prefix=TILE-1D
|
||||
|
||||
func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
|
||||
%K = linalg.range %c0:%arg3:%c1 : !linalg.range
|
||||
%A = linalg.view %arg0[%I, %K] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%B = linalg.view %arg0[%K, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
%C = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
return
|
||||
}
|
||||
// TILE-1D-LABEL: func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
// TILE-1D: linalg.for %i0 = %c0 to %6 step %c2 {
|
||||
// TILE-1D: linalg.for %i1 = %c0 to %9 step %c3 {
|
||||
// TILE-1D: linalg.for %i2 = %c0 to %7 step %c4 {
|
||||
// TILE-1D: %[[vA:.*]] = linalg.subview {{.*}} : !linalg.view<?x?xf32>
|
||||
// TILE-1D: %[[vB:.*]] = linalg.subview {{.*}} : !linalg.view<?x?xf32>
|
||||
// TILE-1D: %[[vC:.*]] = linalg.subview {{.*}} : !linalg.view<?x?xf32>
|
||||
///
|
||||
// TILE-1D: %[[tmpA:.*]] = linalg.buffer_alloc : !linalg.buffer<8xf32>
|
||||
// TILE-1D: %[[fullA:.*]] = linalg.view %[[tmpA]][{{.*}}] : !linalg.buffer<8xf32> -> !linalg.view<?x?xf32>
|
||||
// TILE-1D: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
///
|
||||
// TILE-1D: %[[tmpB:.*]] = linalg.buffer_alloc : !linalg.buffer<12xf32>
|
||||
// TILE-1D: %[[fullB:.*]] = linalg.view %[[tmpB]][{{.*}}] : !linalg.buffer<12xf32> -> !linalg.view<?x?xf32>
|
||||
// TILE-1D: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
///
|
||||
// TILE-1D: %[[tmpC:.*]] = linalg.buffer_alloc : !linalg.buffer<6xf32>
|
||||
// TILE-1D: %[[fullC:.*]] = linalg.view %[[tmpC]][{{.*}}] : !linalg.buffer<6xf32> -> !linalg.view<?x?xf32>
|
||||
// TILE-1D: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
|
||||
// TILE-1D: linalg.fill(%[[fullA]], {{.*}}) : !linalg.view<?x?xf32>, f32
|
||||
// TILE-1D: linalg.fill(%[[fullB]], {{.*}}) : !linalg.view<?x?xf32>, f32
|
||||
// TILE-1D: linalg.fill(%[[fullC]], {{.*}}) : !linalg.view<?x?xf32>, f32
|
||||
// TILE-1D: linalg.copy(%[[vA]], %[[partialA]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
// TILE-1D: linalg.copy(%[[vB]], %[[partialB]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
// TILE-1D: linalg.copy(%[[vC]], %[[partialC]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
//
|
||||
// TILE-1D: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
//
|
||||
// TILE-1D: linalg.copy(%[[partialC]], %[[vC]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
|
||||
//
|
||||
// TILE-1D: linalg.buffer_dealloc %[[tmpA]] : !linalg.buffer<8xf32>
|
||||
// TILE-1D: linalg.buffer_dealloc %[[tmpB]] : !linalg.buffer<12xf32>
|
||||
// TILE-1D: linalg.buffer_dealloc %[[tmpC]] : !linalg.buffer<6xf32>
|
Loading…
Reference in New Issue