Add higher-level linalg.view_slice operation.

This will be useful to simplify the IR emitted during transformations as well as lowering to affine.

PiperOrigin-RevId: 254757641
This commit is contained in:
Nicolas Vasilache 2019-06-24 07:59:59 -07:00 committed by jpienaar
parent 3df510bf42
commit 2ff1c01063
5 changed files with 139 additions and 0 deletions

View File

@ -134,4 +134,51 @@ def TerminatorOp :
let verifier = ?;
}
def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
Arguments<(ins View:$view, Variadic<Index>:$ranges)>,
Results<(outs View)> {
let summary = "subview operation";
let description = [{
The "linalg.subview" operation takes a linalg.view, a list of indices and
returns a new linalg.view of the same type that is contained within the
operand view.
This operation is equivalent to a non-rank-reducing slice operation. The
main difference is the operands are all of type `index` and no intermediate
linalg.range operations are required. A "linalg.subview" is thus a
specialized linalg.slice with a higher level of abstraction.
%1 = linalg.subview %0[%1, %2, %3, %4, %5, %6] : view<?x?xf32>
}];
// TODO(ntv) evolve syntax towards:
// linalg.subview %0[%1:%2:%3][%4:%5:%6] : view<?x?xf32>
let verifier = [{
auto numRanges = (getNumOperands() - 1) / 3;
if (getNumOperands() != 3 * numRanges + 1 ||
numRanges != getViewType().getRank())
return emitOpError("expected a view followed by 3 indices specifying ") <<
"a range for each dimension";
return success();
}];
let extraClassDeclaration = [{
Value *getView() { return getOperand(0); }
ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
struct Range { Value *min; Value *max; Value *step; };
Range getRange(unsigned i) {
return Range{
getOperand(1 + 3*i), getOperand(1 + 3*i + 1), getOperand(1 + 3*i + 2)};
}
SmallVector<Range, 8> getRanges() {
SmallVector<Range, 8> res;
unsigned rank = getViewType().getRank();
res.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
res.push_back(getRange(i));
return res;
}
}];
}
#endif // LINALG_OPS

View File

@ -662,6 +662,40 @@ static ParseResult parseRangeIntersectOp(OpAsmParser *parser,
parser->addTypeToList(type, result->types));
}
static void print(OpAsmPrinter *p, SubViewOp op) {
*p << op.getOperationName() << " " << *op.getOperand(0) << "[";
auto ranges = op.getRanges();
interleaveComma(ranges, *p, [&p](const SubViewOp::Range &i) {
*p << *i.min << ", " << *i.max << ", " << *i.step;
});
*p << "]";
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getViewType();
}
static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType inputView, resultView;
Type viewType;
if (parser->parseOperand(inputView))
return failure();
SmallVector<OpAsmParser::OperandType, 12> ops;
// TODO(ntv) evolve parsing from
// linalg.subview %0[%1, %2, %3, %4, %5, %6]
// to something resembling
// linalg.subview %0[%1:%2:%3][%4:%5:%6]
if (parser->parseOperandList(ops, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(viewType))
return failure();
auto indexTy = parser->getBuilder().getIndexType();
return failure(
parser->resolveOperand(inputView, viewType, result->operands) ||
parser->resolveOperands(ops, indexTy, result->operands) ||
parser->addTypeToList(viewType, result->types));
}
/////// Operations corresponding to library calls defined with Tablegen ////////
// For such operations correspond to library calls (i.e. defined in
// LinalgLibraryOps.td), we define an overloaded `print` function and a

View File

@ -31,6 +31,7 @@
#include "mlir/Linalg/IR/LinalgOps.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Linalg/Passes.h"
#include "mlir/Linalg/Utils/Intrinsics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
@ -48,6 +49,7 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::LLVM;
using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using add = ValueBuilder<mlir::LLVM::AddOp>;
using addi = ValueBuilder<mlir::AddIOp>;
@ -716,6 +718,30 @@ struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
};
} // namespace
// This is currently written as a standalone function because the lowering to
// affine will look different than lowering to LLVM and it is still unclear how
// everything will be eventually structured.
static void lowerLinalgSubViewOps(Function &f) {
f.walk<SubViewOp>([&](SubViewOp op) {
OpBuilder b(op);
ScopedContext scope(b, op.getLoc());
auto *view = op.getView();
SmallVector<Value *, 8> ranges;
for (auto en : llvm::enumerate(op.getRanges())) {
using edsc::op::operator<;
using linalg::intrinsics::dim;
unsigned rank = en.index();
auto sliceRange = en.value();
auto size = dim(view, rank);
ValueHandle ub(sliceRange.max);
auto max = edsc::intrinsics::select(size < ub, size, ub);
ranges.push_back(range(sliceRange.min, max, sliceRange.step));
}
op.replaceAllUsesWith(slice(view, ranges));
op.erase();
});
}
// Converts a `linalg.for` op to CFG form before actual conversion to the LLVM
// dialect starts.
static void lowerLinalgForToCFG(Function &f) {
@ -773,6 +799,7 @@ void LowerLinalgToLLVMPass::runOnModule() {
auto &module = getModule();
for (auto &f : module.getFunctions()) {
lowerLinalgSubViewOps(f);
lowerLinalgForToCFG(f);
if (failed(lowerAffineConstructs(f)))
signalPassFailure();

View File

@ -171,3 +171,25 @@ func @linalg_for_2(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: %9 = llvm.mul %0, %7 : !llvm.i64
// CHECK: %10 = llvm.add %7, %arg2 : !llvm.i64
// CHECK: llvm.br ^bb8(%10 : !llvm.i64)
func @subview(%arg0: !linalg.view<?x?xf32>) {
%c0 = constant 0 : index
%0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xf32>
return
}
// CHECK-LABEL: func @subview(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
// CHECK: %0 = llvm.constant(0 : index) : !llvm.i64
// CHECK: %1 = llvm.extractvalue %arg0[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %2 = llvm.icmp "slt" %1, %0 : !llvm.i64
// CHECK: %3 = llvm.select %2, %1, %0 : !llvm.i1, !llvm.i64
// CHECK: %4 = llvm.undef : !llvm<"{ i64, i64, i64 }">
// CHECK: %5 = llvm.insertvalue %0, %4[0] : !llvm<"{ i64, i64, i64 }">
// CHECK: %6 = llvm.insertvalue %3, %5[1] : !llvm<"{ i64, i64, i64 }">
// CHECK: %7 = llvm.insertvalue %0, %6[2] : !llvm<"{ i64, i64, i64 }">
// CHECK: %8 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %9 = llvm.icmp "slt" %8, %0 : !llvm.i64
// CHECK: %10 = llvm.select %9, %8, %0 : !llvm.i1, !llvm.i64
// CHECK: %11 = llvm.undef : !llvm<"{ i64, i64, i64 }">
// CHECK: %12 = llvm.insertvalue %0, %11[0] : !llvm<"{ i64, i64, i64 }">
// CHECK: %13 = llvm.insertvalue %10, %12[1] : !llvm<"{ i64, i64, i64 }">
// CHECK: %14 = llvm.insertvalue %0, %13[2] : !llvm<"{ i64, i64, i64 }">

View File

@ -144,3 +144,12 @@ func @conv_view6(%arg0: !linalg.view<?x?x?x?x?x?xf32>, %arg1: !linalg.view<?x?x?
}
// CHECK-LABEL: func @conv_view6(%arg0: !linalg.view<?x?x?x?x?x?xf32>, %arg1: !linalg.view<?x?x?x?x?x?xf32>, %arg2: !linalg.view<?x?x?x?x?x?xf32>) {
// CHECK: linalg.conv(%arg0, %arg1, %arg2) {dilations: [4, 4, 5, 5], strides: [2, 2, 3, 3]} : !linalg.view<?x?x?x?x?x?xf32>, !linalg.view<?x?x?x?x?x?xf32>, !linalg.view<?x?x?x?x?x?xf32>
func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
%c0 = constant 0 : index
%0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xvector<3x4xi4>>
return
}
// CHECK-LABEL: func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
// CHECK: %c0 = constant 0 : index
// CHECK: %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xvector<3x4xi4>>