Allow linalg.view to change the underlying elemental type.

This CL adds the ability for linalg.view to act as a bitcast operation.
This will be used when promoting views into faster memory and casting to vector types.

In the process, linalg.view is moved to ODS.

PiperOrigin-RevId: 262556246
This commit is contained in:
Nicolas Vasilache 2019-08-09 07:28:51 -07:00 committed by A. Unique TensorFlower
parent d2aba89f2e
commit 20f2d3b598
7 changed files with 129 additions and 133 deletions

View File

@ -186,47 +186,6 @@ public:
}
};
/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
/// range abstraction on top of an underlying linalg.buffer. This gives an
/// indexing structure to an otherwise non-indexable linalg.buffer.
///
/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
/// a `view` of the same elemental type as the buffer and of rank the number of
/// ranges:
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
class ViewOp : public Op<ViewOp, OpTrait::VariadicOperands, OpTrait::OneResult,
OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.view"; }
static void build(Builder *b, OperationState *result, Value *buffer,
llvm::ArrayRef<Value *> indexings);
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
Value *getSupportingBuffer() { return getOperand(0); }
// Get the underlying indexing at a given rank.
Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
// Get all the indexings in this view.
Operation::operand_range getIndexings() {
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
}
};
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"

View File

@ -215,6 +215,51 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
}];
}
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
Results<(outs View)> {
let summary = "view operation";
let description = [{
The "linalg.view" op produces a linalg.view which is a multi-dimensional
range abstraction on top of an underlying linalg.buffer. This gives an
indexing structure to an otherwise non-indexable linalg.buffer.
A "linalg.view" takes a buffer and a variadic number of ranges and produces
a `view` of rank the number of ranges. The elemental type may not match the
buffer element type:
Examples:
```
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xvector<4xf32>>
```
}];
let builders = [OpBuilder<
"Builder *b, OperationState *result, Value *buffer, "
"ArrayRef<Value *> ranges, Type resultType = Type(), "
"ArrayRef<NamedAttribute> attrs = {}">];
let verifier = [{
if (getViewType().getRank() != llvm::size(ranges()))
return emitOpError("the view rank must be the number of its ranges");
return success();
}];
let extraClassDeclaration = [{
enum { FirstIndexingOperand = 1 };
unsigned getRank() { return getViewType().getRank(); }
Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
/// Get the underlying indexing at a given rank.
Value *getRange(unsigned rank) {
assert(rank < getRank() && "rank overflow");
return *(ranges().begin() + rank);
}
}];
}
def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";

View File

@ -53,7 +53,7 @@ Value *Aliases::find(Value *v) {
return it.first->second;
}
if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
auto it = aliases.insert(std::make_pair(v, view.buffer()));
return it.first->second;
}
if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {

View File

@ -67,10 +67,10 @@ SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
Value *min, *max, *step;
if (view) {
// Cannot traverse block arguments, fail.
if (isa<BlockArgument>(view.getIndexing(dim)))
if (isa<BlockArgument>(view.getRange(dim)))
return matchFailure();
// Record min, max, step for further processing.
auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp());
std::tie(min, max, step) =
std::make_tuple(range.min(), range.max(), range.step());
} else if (subView) {
@ -414,97 +414,15 @@ LogicalResult mlir::linalg::StoreOp::verify() {
return success();
}
//////////////////////////////////////////////////////////////////////////////
// ViewOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
Value *buffer, ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
assert(
std::none_of(indexings.begin(), indexings.end(),
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
"linalg.view takes only arguments of type linalg.range");
Type elementType = bufferType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
LogicalResult mlir::linalg::ViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
if (!bufferType)
return emitOpError("first operand must be of BufferType");
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>()) {
return emitOpError() << index << "^th index must be of range type";
}
++index;
}
if (getViewType().getRank() != index)
return emitOpError()
<< "the rank of the view must be the number of its indexings";
return success();
}
ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type bType, type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColon() || parser->parseType(bType) ||
parser->parseArrow() || parser->parseType(type)) {
return failure();
}
BufferType bufferType = bType.dyn_cast<BufferType>();
if (!bufferType) {
return parser->emitError(parser->getNameLoc(), "buffer type expected");
}
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != indexingsInfo.size())
return parser->emitError(parser->getNameLoc(), "expected")
<< viewType.getRank() << " range indexings";
return failure(
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
[&]() { *p << ", "; });
*p << "] : " << getSupportingBuffer()->getType() << " -> " << getType();
}
///////////////////// Operations defined with Tablegen /////////////////////////
// For such operations that do not correspond to library calls (i.e. defined in
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.
//===----------------------------------------------------------------------===//
// BufferAllocOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter *p, BufferAllocOp op) {
*p << op.getOperationName() << " ";
if (!llvm::empty(op.size()))
@ -544,6 +462,10 @@ static LogicalResult verify(BufferAllocOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// BufferDeallocOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
*p << op.getOperationName() << " " << *op.buffer();
p->printOptionalAttrDict(op.getAttrs());
@ -565,6 +487,10 @@ static void print(OpAsmPrinter *p, BufferSizeOp op) {
*p << " : " << op.getOperand()->getType();
}
//===----------------------------------------------------------------------===//
// BufferSizeOp
//===----------------------------------------------------------------------===//
static ParseResult parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType op;
@ -747,6 +673,66 @@ static LogicalResult verify(GenericOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
Value *buffer, ArrayRef<Value *> ranges,
Type resultType,
ArrayRef<NamedAttribute> attrs) {
if (!resultType) {
Type elementType = buffer->getType().cast<BufferType>().getElementType();
resultType = ViewType::get(b->getContext(), elementType, ranges.size());
}
build(b, result, resultType, buffer, ranges);
result->addAttributes(attrs);
}
static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
Type bType, vType;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColon() || parser->parseType(bType) ||
parser->parseArrow() || parser->parseType(vType)) {
return failure();
}
BufferType bufferType = bType.dyn_cast<BufferType>();
if (!bufferType) {
return parser->emitError(parser->getNameLoc(), "buffer type expected");
}
ViewType viewType = vType.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != rangesInfo.size())
return parser->emitError(parser->getNameLoc(), "expected")
<< viewType.getRank() << " range ranges";
return failure(
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
(!rangesInfo.empty() &&
parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
static void print(OpAsmPrinter *p, ViewOp op) {
*p << op.getOperationName() << " " << *op.buffer() << "[";
interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
*p << "] : " << op.buffer()->getType() << " -> " << op.getType();
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
@ -808,6 +794,10 @@ static void print(OpAsmPrinter *p, SubViewOp op) {
*p << " : " << op.getViewType();
}
//===----------------------------------------------------------------------===//
// SubViewOp
//===----------------------------------------------------------------------===//
static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType inputView, resultView;
Type viewType;

View File

@ -35,7 +35,7 @@ using namespace mlir::linalg;
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<BufferType, RangeType, ViewType>();
addOperations<LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
addOperations<LoadOp, RangeOp, StoreOp, SliceOp>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"

View File

@ -512,9 +512,9 @@ public:
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
// Compute and insert view sizes (max - min along the range).
int numIndexings = llvm::size(viewOp.getIndexings());
int numRanges = llvm::size(viewOp.ranges());
Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
for (int i = numIndexings - 1; i >= 0; --i) {
for (int i = numRanges - 1; i >= 0; --i) {
// Update stride.
Value *rangeDescriptor = operands[1 + i];
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));

View File

@ -39,6 +39,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
%5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
%6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
%7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
%8 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
return
}
@ -51,6 +52,7 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
// CHECK-NEXT: %{{.*}} = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<?xf32>, %arg3: !linalg.view<f32>) {