forked from OSchip/llvm-project
Allow linalg.view to change the underlying elemental type.
This CL adds the ability for linalg.view to act as a bitcast operation. This will be used when promoting views into faster memory and casting to vector types. In the process, linalg.view is moved to ODS. PiperOrigin-RevId: 262556246
This commit is contained in:
parent
d2aba89f2e
commit
20f2d3b598
|
@ -186,47 +186,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
|
||||
/// range abstraction on top of an underlying linalg.buffer. This gives an
|
||||
/// indexing structure to an otherwise non-indexable linalg.buffer.
|
||||
///
|
||||
/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
|
||||
/// a `view` of the same elemental type as the buffer and of rank the number of
|
||||
/// ranges:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
/// ```
|
||||
class ViewOp : public Op<ViewOp, OpTrait::VariadicOperands, OpTrait::OneResult,
|
||||
OpTrait::HasNoSideEffect> {
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
// Hooks to customize the behavior of this op.
|
||||
static llvm::StringRef getOperationName() { return "linalg.view"; }
|
||||
static void build(Builder *b, OperationState *result, Value *buffer,
|
||||
llvm::ArrayRef<Value *> indexings);
|
||||
LogicalResult verify();
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
|
||||
// Op-specific functionality.
|
||||
unsigned getRank() { return getViewType().getRank(); }
|
||||
Type getElementType() { return getViewType().getElementType(); }
|
||||
ViewType getViewType() { return getType().cast<ViewType>(); }
|
||||
Value *getSupportingBuffer() { return getOperand(0); }
|
||||
// Get the underlying indexing at a given rank.
|
||||
Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
|
||||
// Get all the indexings in this view.
|
||||
Operation::operand_range getIndexings() {
|
||||
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Linalg/IR/LinalgOps.h.inc"
|
||||
|
||||
|
|
|
@ -215,6 +215,51 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
|
||||
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
|
||||
Results<(outs View)> {
|
||||
let summary = "view operation";
|
||||
let description = [{
|
||||
The "linalg.view" op produces a linalg.view which is a multi-dimensional
|
||||
range abstraction on top of an underlying linalg.buffer. This gives an
|
||||
indexing structure to an otherwise non-indexable linalg.buffer.
|
||||
|
||||
A "linalg.view" takes a buffer and a variadic number of ranges and produces
|
||||
a `view` of rank the number of ranges. The elemental type may not match the
|
||||
buffer element type:
|
||||
|
||||
Examples:
|
||||
```
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xvector<4xf32>>
|
||||
```
|
||||
}];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState *result, Value *buffer, "
|
||||
"ArrayRef<Value *> ranges, Type resultType = Type(), "
|
||||
"ArrayRef<NamedAttribute> attrs = {}">];
|
||||
|
||||
let verifier = [{
|
||||
if (getViewType().getRank() != llvm::size(ranges()))
|
||||
return emitOpError("the view rank must be the number of its ranges");
|
||||
return success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
unsigned getRank() { return getViewType().getRank(); }
|
||||
Type getElementType() { return getViewType().getElementType(); }
|
||||
ViewType getViewType() { return getType().cast<ViewType>(); }
|
||||
/// Get the underlying indexing at a given rank.
|
||||
Value *getRange(unsigned rank) {
|
||||
assert(rank < getRank() && "rank overflow");
|
||||
return *(ranges().begin() + rank);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
|
||||
Arguments<(ins Variadic<AnyType>:$values)> {
|
||||
let summary = "Linalg yield operation";
|
||||
|
|
|
@ -53,7 +53,7 @@ Value *Aliases::find(Value *v) {
|
|||
return it.first->second;
|
||||
}
|
||||
if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
|
||||
auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
|
||||
auto it = aliases.insert(std::make_pair(v, view.buffer()));
|
||||
return it.first->second;
|
||||
}
|
||||
if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
|
||||
|
|
|
@ -67,10 +67,10 @@ SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
|
|||
Value *min, *max, *step;
|
||||
if (view) {
|
||||
// Cannot traverse block arguments, fail.
|
||||
if (isa<BlockArgument>(view.getIndexing(dim)))
|
||||
if (isa<BlockArgument>(view.getRange(dim)))
|
||||
return matchFailure();
|
||||
// Record min, max, step for further processing.
|
||||
auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
|
||||
auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp());
|
||||
std::tie(min, max, step) =
|
||||
std::make_tuple(range.min(), range.max(), range.step());
|
||||
} else if (subView) {
|
||||
|
@ -414,97 +414,15 @@ LogicalResult mlir::linalg::StoreOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// ViewOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
|
||||
Value *buffer, ArrayRef<Value *> indexings) {
|
||||
BufferType bufferType = buffer->getType().cast<BufferType>();
|
||||
result->addOperands({buffer});
|
||||
result->addOperands(indexings);
|
||||
assert(
|
||||
std::none_of(indexings.begin(), indexings.end(),
|
||||
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
|
||||
"linalg.view takes only arguments of type linalg.range");
|
||||
|
||||
Type elementType = bufferType.getElementType();
|
||||
result->addTypes(
|
||||
{ViewType::get(b->getContext(), elementType, indexings.size())});
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::ViewOp::verify() {
|
||||
if (llvm::empty(getOperands()))
|
||||
return emitOpError(
|
||||
"requires at least a buffer operand followed by indexings");
|
||||
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
|
||||
if (!bufferType)
|
||||
return emitOpError("first operand must be of BufferType");
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
if (!indexing->getType().isa<RangeType>()) {
|
||||
return emitOpError() << index << "^th index must be of range type";
|
||||
}
|
||||
++index;
|
||||
}
|
||||
if (getViewType().getRank() != index)
|
||||
return emitOpError()
|
||||
<< "the rank of the view must be the number of its indexings";
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
|
||||
Type bType, type;
|
||||
if (parser->parseOperand(bufferInfo) ||
|
||||
parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColon() || parser->parseType(bType) ||
|
||||
parser->parseArrow() || parser->parseType(type)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
BufferType bufferType = bType.dyn_cast<BufferType>();
|
||||
if (!bufferType) {
|
||||
return parser->emitError(parser->getNameLoc(), "buffer type expected");
|
||||
}
|
||||
|
||||
ViewType viewType = type.dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser->emitError(parser->getNameLoc(), "view type expected");
|
||||
if (viewType.getRank() != indexingsInfo.size())
|
||||
return parser->emitError(parser->getNameLoc(), "expected")
|
||||
<< viewType.getRank() << " range indexings";
|
||||
return failure(
|
||||
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
|
||||
(!indexingsInfo.empty() &&
|
||||
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
|
||||
result->operands)) ||
|
||||
parser->addTypeToList(viewType, result->types));
|
||||
}
|
||||
|
||||
// A ViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
|
||||
// holding a range.
|
||||
void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
|
||||
interleave(
|
||||
getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
|
||||
[&]() { *p << ", "; });
|
||||
*p << "] : " << getSupportingBuffer()->getType() << " -> " << getType();
|
||||
}
|
||||
|
||||
///////////////////// Operations defined with Tablegen /////////////////////////
|
||||
// For such operations that do not correspond to library calls (i.e. defined in
|
||||
// LinalgOps.td), we define an overloaded `print` function and a
|
||||
// parse`className` function.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferAllocOp op) {
|
||||
*p << op.getOperationName() << " ";
|
||||
if (!llvm::empty(op.size()))
|
||||
|
@ -544,6 +462,10 @@ static LogicalResult verify(BufferAllocOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferDeallocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
|
||||
*p << op.getOperationName() << " " << *op.buffer();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
|
@ -565,6 +487,10 @@ static void print(OpAsmPrinter *p, BufferSizeOp op) {
|
|||
*p << " : " << op.getOperand()->getType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseBufferSizeOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType op;
|
||||
|
@ -747,6 +673,66 @@ static LogicalResult verify(GenericOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
|
||||
Value *buffer, ArrayRef<Value *> ranges,
|
||||
Type resultType,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
if (!resultType) {
|
||||
Type elementType = buffer->getType().cast<BufferType>().getElementType();
|
||||
resultType = ViewType::get(b->getContext(), elementType, ranges.size());
|
||||
}
|
||||
build(b, result, resultType, buffer, ranges);
|
||||
result->addAttributes(attrs);
|
||||
}
|
||||
|
||||
static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
|
||||
Type bType, vType;
|
||||
if (parser->parseOperand(bufferInfo) ||
|
||||
parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColon() || parser->parseType(bType) ||
|
||||
parser->parseArrow() || parser->parseType(vType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
BufferType bufferType = bType.dyn_cast<BufferType>();
|
||||
if (!bufferType) {
|
||||
return parser->emitError(parser->getNameLoc(), "buffer type expected");
|
||||
}
|
||||
|
||||
ViewType viewType = vType.dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser->emitError(parser->getNameLoc(), "view type expected");
|
||||
if (viewType.getRank() != rangesInfo.size())
|
||||
return parser->emitError(parser->getNameLoc(), "expected")
|
||||
<< viewType.getRank() << " range ranges";
|
||||
return failure(
|
||||
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
|
||||
(!rangesInfo.empty() &&
|
||||
parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
|
||||
result->operands)) ||
|
||||
parser->addTypeToList(viewType, result->types));
|
||||
}
|
||||
|
||||
// A ViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
|
||||
// holding a range.
|
||||
static void print(OpAsmPrinter *p, ViewOp op) {
|
||||
*p << op.getOperationName() << " " << *op.buffer() << "[";
|
||||
interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
|
||||
*p << "] : " << op.buffer()->getType() << " -> " << op.getType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -808,6 +794,10 @@ static void print(OpAsmPrinter *p, SubViewOp op) {
|
|||
*p << " : " << op.getViewType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SubViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType inputView, resultView;
|
||||
Type viewType;
|
||||
|
|
|
@ -35,7 +35,7 @@ using namespace mlir::linalg;
|
|||
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addTypes<BufferType, RangeType, ViewType>();
|
||||
addOperations<LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
|
||||
addOperations<LoadOp, RangeOp, StoreOp, SliceOp>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
||||
|
|
|
@ -512,9 +512,9 @@ public:
|
|||
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
|
||||
|
||||
// Compute and insert view sizes (max - min along the range).
|
||||
int numIndexings = llvm::size(viewOp.getIndexings());
|
||||
int numRanges = llvm::size(viewOp.ranges());
|
||||
Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
|
||||
for (int i = numIndexings - 1; i >= 0; --i) {
|
||||
for (int i = numRanges - 1; i >= 0; --i) {
|
||||
// Update stride.
|
||||
Value *rangeDescriptor = operands[1 + i];
|
||||
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
|
||||
|
|
|
@ -39,6 +39,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
|
|||
%5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
%6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
|
||||
%7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
|
||||
%8 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
|
||||
return
|
||||
}
|
||||
|
@ -51,6 +52,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
|
|||
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
|
||||
// CHECK-NEXT: %{{.*}} = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
|
||||
|
||||
func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<?xf32>, %arg3: !linalg.view<f32>) {
|
||||
|
|
Loading…
Reference in New Issue