forked from OSchip/llvm-project
Add higher-level linalg.view_slice operation.
This will be useful to simplify the IR emitted during transformations as well as lowering to affine. PiperOrigin-RevId: 254757641
This commit is contained in:
parent
3df510bf42
commit
2ff1c01063
|
@ -134,4 +134,51 @@ def TerminatorOp :
|
|||
let verifier = ?;
|
||||
}
|
||||
|
||||
def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
|
||||
Arguments<(ins View:$view, Variadic<Index>:$ranges)>,
|
||||
Results<(outs View)> {
|
||||
let summary = "subview operation";
|
||||
let description = [{
|
||||
The "linalg.subview" operation takes a linalg.view, a list of indices and
|
||||
returns a new linalg.view of the same type that is contained within the
|
||||
operand view.
|
||||
This operation is equivalent to a non-rank-reducing slice operation. The
|
||||
main difference is the operands are all of type `index` and no intermediate
|
||||
linalg.range operations are required. A "linalg.subview" is thus a
|
||||
specialized linalg.slice with a higher level of abstraction.
|
||||
|
||||
%1 = linalg.subview %0[%1, %2, %3, %4, %5, %6] : view<?x?xf32>
|
||||
|
||||
}];
|
||||
// TODO(ntv) evolve syntax towards:
|
||||
// linalg.subview %0[%1:%2:%3][%4:%5:%6] : view<?x?xf32>
|
||||
|
||||
let verifier = [{
|
||||
auto numRanges = (getNumOperands() - 1) / 3;
|
||||
if (getNumOperands() != 3 * numRanges + 1 ||
|
||||
numRanges != getViewType().getRank())
|
||||
return emitOpError("expected a view followed by 3 indices specifying ") <<
|
||||
"a range for each dimension";
|
||||
return success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Value *getView() { return getOperand(0); }
|
||||
ViewType getViewType() { return getView()->getType().cast<ViewType>(); }
|
||||
struct Range { Value *min; Value *max; Value *step; };
|
||||
Range getRange(unsigned i) {
|
||||
return Range{
|
||||
getOperand(1 + 3*i), getOperand(1 + 3*i + 1), getOperand(1 + 3*i + 2)};
|
||||
}
|
||||
SmallVector<Range, 8> getRanges() {
|
||||
SmallVector<Range, 8> res;
|
||||
unsigned rank = getViewType().getRank();
|
||||
res.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
res.push_back(getRange(i));
|
||||
return res;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // LINALG_OPS
|
||||
|
|
|
@ -662,6 +662,40 @@ static ParseResult parseRangeIntersectOp(OpAsmParser *parser,
|
|||
parser->addTypeToList(type, result->types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, SubViewOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getOperand(0) << "[";
|
||||
auto ranges = op.getRanges();
|
||||
interleaveComma(ranges, *p, [&p](const SubViewOp::Range &i) {
|
||||
*p << *i.min << ", " << *i.max << ", " << *i.step;
|
||||
});
|
||||
*p << "]";
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getViewType();
|
||||
}
|
||||
|
||||
static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType inputView, resultView;
|
||||
Type viewType;
|
||||
if (parser->parseOperand(inputView))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpAsmParser::OperandType, 12> ops;
|
||||
// TODO(ntv) evolve parsing from
|
||||
// linalg.subview %0[%1, %2, %3, %4, %5, %6]
|
||||
// to something resembling
|
||||
// linalg.subview %0[%1:%2:%3][%4:%5:%6]
|
||||
if (parser->parseOperandList(ops, OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(viewType))
|
||||
return failure();
|
||||
|
||||
auto indexTy = parser->getBuilder().getIndexType();
|
||||
return failure(
|
||||
parser->resolveOperand(inputView, viewType, result->operands) ||
|
||||
parser->resolveOperands(ops, indexTy, result->operands) ||
|
||||
parser->addTypeToList(viewType, result->types));
|
||||
}
|
||||
|
||||
/////// Operations corresponding to library calls defined with Tablegen ////////
|
||||
// For such operations correspond to library calls (i.e. defined in
|
||||
// LinalgLibraryOps.td), we define an overloaded `print` function and a
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "mlir/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Linalg/Passes.h"
|
||||
#include "mlir/Linalg/Utils/Intrinsics.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
@ -48,6 +49,7 @@ using namespace mlir::edsc;
|
|||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::LLVM;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::linalg::intrinsics;
|
||||
|
||||
using add = ValueBuilder<mlir::LLVM::AddOp>;
|
||||
using addi = ValueBuilder<mlir::AddIOp>;
|
||||
|
@ -716,6 +718,30 @@ struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// This is currently written as a standalone function because the lowering to
|
||||
// affine will look different than lowering to LLVM and it is still unclear how
|
||||
// everything will be eventually structured.
|
||||
static void lowerLinalgSubViewOps(Function &f) {
|
||||
f.walk<SubViewOp>([&](SubViewOp op) {
|
||||
OpBuilder b(op);
|
||||
ScopedContext scope(b, op.getLoc());
|
||||
auto *view = op.getView();
|
||||
SmallVector<Value *, 8> ranges;
|
||||
for (auto en : llvm::enumerate(op.getRanges())) {
|
||||
using edsc::op::operator<;
|
||||
using linalg::intrinsics::dim;
|
||||
unsigned rank = en.index();
|
||||
auto sliceRange = en.value();
|
||||
auto size = dim(view, rank);
|
||||
ValueHandle ub(sliceRange.max);
|
||||
auto max = edsc::intrinsics::select(size < ub, size, ub);
|
||||
ranges.push_back(range(sliceRange.min, max, sliceRange.step));
|
||||
}
|
||||
op.replaceAllUsesWith(slice(view, ranges));
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
|
||||
// Converts a `linalg.for` op to CFG form before actual conversion to the LLVM
|
||||
// dialect starts.
|
||||
static void lowerLinalgForToCFG(Function &f) {
|
||||
|
@ -773,6 +799,7 @@ void LowerLinalgToLLVMPass::runOnModule() {
|
|||
auto &module = getModule();
|
||||
|
||||
for (auto &f : module.getFunctions()) {
|
||||
lowerLinalgSubViewOps(f);
|
||||
lowerLinalgForToCFG(f);
|
||||
if (failed(lowerAffineConstructs(f)))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -171,3 +171,25 @@ func @linalg_for_2(%arg0 : index, %arg1 : index, %arg2 : index) {
|
|||
// CHECK: %9 = llvm.mul %0, %7 : !llvm.i64
|
||||
// CHECK: %10 = llvm.add %7, %arg2 : !llvm.i64
|
||||
// CHECK: llvm.br ^bb8(%10 : !llvm.i64)
|
||||
|
||||
func @subview(%arg0: !linalg.view<?x?xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @subview(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
|
||||
// CHECK: %0 = llvm.constant(0 : index) : !llvm.i64
|
||||
// CHECK: %1 = llvm.extractvalue %arg0[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %2 = llvm.icmp "slt" %1, %0 : !llvm.i64
|
||||
// CHECK: %3 = llvm.select %2, %1, %0 : !llvm.i1, !llvm.i64
|
||||
// CHECK: %4 = llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %5 = llvm.insertvalue %0, %4[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %6 = llvm.insertvalue %3, %5[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %7 = llvm.insertvalue %0, %6[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %8 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: %9 = llvm.icmp "slt" %8, %0 : !llvm.i64
|
||||
// CHECK: %10 = llvm.select %9, %8, %0 : !llvm.i1, !llvm.i64
|
||||
// CHECK: %11 = llvm.undef : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %12 = llvm.insertvalue %0, %11[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %13 = llvm.insertvalue %10, %12[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: %14 = llvm.insertvalue %0, %13[2] : !llvm<"{ i64, i64, i64 }">
|
||||
|
|
|
@ -144,3 +144,12 @@ func @conv_view6(%arg0: !linalg.view<?x?x?x?x?x?xf32>, %arg1: !linalg.view<?x?x?
|
|||
}
|
||||
// CHECK-LABEL: func @conv_view6(%arg0: !linalg.view<?x?x?x?x?x?xf32>, %arg1: !linalg.view<?x?x?x?x?x?xf32>, %arg2: !linalg.view<?x?x?x?x?x?xf32>) {
|
||||
// CHECK: linalg.conv(%arg0, %arg1, %arg2) {dilations: [4, 4, 5, 5], strides: [2, 2, 3, 3]} : !linalg.view<?x?x?x?x?x?xf32>, !linalg.view<?x?x?x?x?x?xf32>, !linalg.view<?x?x?x?x?x?xf32>
|
||||
|
||||
func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xvector<3x4xi4>>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
|
||||
// CHECK: %c0 = constant 0 : index
|
||||
// CHECK: %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view<?x?xvector<3x4xi4>>
|
||||
|
|
Loading…
Reference in New Issue