Extend linalg transformations to allow value operands that are not views

This CL extends the linalg ops that can be tiled and fused to operations that take either views, scalar or vector operands.

PiperOrigin-RevId: 258149291
This commit is contained in:
Nicolas Vasilache 2019-07-15 06:27:21 -07:00 committed by Mehdi Amini
parent ec82e1c907
commit 4de019901b
6 changed files with 57 additions and 2 deletions

View File

@ -66,6 +66,9 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
{}
// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on linalg::View as their
// first operands. These may be optionally followed by non-view operands
// depending on the specific Linalg op.
class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic, props> {
let parser = [{ return parseLinalgLibraryOp(parser, result); }];

View File

@ -133,6 +133,11 @@ llvm::SmallVector<PromotionInfo, 8> promoteLinalgViews(OpBuilder &b,
ArrayRef<Value *> views,
OperationFolder &folder);
// Returns all the operands of `linalgOp` that are not views.
// Asserts that these operands are value types to allow transformations like
// tiling to just use the values when cloning `linalgOp`.
llvm::SmallVector<Value *, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
} // namespace linalg
} // namespace mlir

View File

@ -104,6 +104,8 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
// TODO(ntv) opportunities for folding/CSE here rather than build new IR.
clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges));
}
auto operands = getAssumedNonViewOperands(op);
clonedViews.append(operands.begin(), operands.end());
return op.create(b, loc, clonedViews, op.getAttrs());
}

View File

@ -383,8 +383,6 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
// If/when the assertion below becomes false, templatize `makeTiledViews`.
assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
auto views =
makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder);
@ -394,6 +392,8 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
viewsToPromote.end()),
[](bool b) { return b; });
if (!promote) {
auto operands = getAssumedNonViewOperands(op);
views.append(operands.begin(), operands.end());
res = op.create(b, loc, views, op.getAttrs());
return;
}
@ -424,6 +424,8 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
opViews[i] = views[i];
}
}
auto operands = getAssumedNonViewOperands(op);
opViews.append(operands.begin(), operands.end());
res = op.create(b, loc, opViews, op.getAttrs());
// 6. Emit write-back for the promoted output views: copy the partial view.

View File

@ -140,3 +140,22 @@ SmallVector<Value *, 4> mlir::linalg::applyMapToValues(OpBuilder &b,
}
return res;
}
// Returns all the operands of `linalgOp` that are not views.
// Asserts that these operands are value types to allow transformations like
// tiling to just use the values when cloning `linalgOp`.
SmallVector<Value *, 4>
mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
auto *op = linalgOp.getOperation();
unsigned numViews = linalgOp.getNumInputsAndOutputs();
unsigned nOperands = op->getNumOperands() - numViews;
SmallVector<Value *, 4> res;
res.reserve(nOperands);
for (unsigned i = 0; i < nOperands; ++i) {
res.push_back(op->getOperand(numViews + i));
auto t = res.back()->getType();
assert((t.isIntOrIndexOrFloat() || t.isa<VectorType>()) &&
"expected scalar or vector type");
}
return res;
}

View File

@ -135,3 +135,27 @@ func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg
// TILE-234: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}})
// TILE-234: %[[sBi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : !linalg.view<?xf32>
// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
func @fill(%arg0: !linalg.view<?x?xf32>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : !linalg.view<?x?xf32>, f32
return
}
// TILE-2-LABEL: func @fill
// TILE-2: for
// TILE-2-NOT: for
// TILE-2: fill{{.*}} f32
// TILE-02-LABEL: func @fill
// TILE-02: for
// TILE-02-NOT: for
// TILE-02: fill{{.*}} f32
// TILE-002-LABEL: func @fill
// TILE-002-NOT: for
// TILE-002: fill{{.*}} f32
// TILE-234-LABEL: func @fill
// TILE-234: for
// TILE-234: for
// TILE-234-NOT: for
// TILE-234: fill{{.*}} f32