[Linalg] Add a view type with base_view op

This CL adds a linalg.view<?x?xf32> type and base_view op with the proper roundtripping test. The parser will be improved in a subsequent CL once portions of the mlir::Parser are exposed.

    For now this only supports dynamic views, static views will be introduced at a later time when they are needed.

--

PiperOrigin-RevId: 244374180
This commit is contained in:
Nicolas Vasilache 2019-04-19 09:56:11 -07:00 committed by Mehdi Amini
parent e8d551e2bd
commit 1d5dc840e7
5 changed files with 296 additions and 37 deletions

View File

@ -24,6 +24,50 @@
namespace mlir {
/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range
/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a
/// buffer an indexing structure.
///
/// A new value of ViewType is constructed from a buffer with a base_view op and
/// ranges:
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
class BaseViewOp : public mlir::Op<BaseViewOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.base_view"; }
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *buffer,
llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
mlir::Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
mlir::Value *getSupportingBuffer() { return getOperand(0); }
// Get the underlying indexing at a given rank.
mlir::Value *getIndexing(unsigned rank) {
return *(getIndexings().begin() + rank);
}
// Get all the indexings in this view.
mlir::Operation::operand_range getIndexings() {
return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()};
}
};
/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
/// view can be laid out. The size argument is an `i64` (and not an index), so
/// that we can

View File

@ -27,7 +27,8 @@ class MLIRContext;
enum LinalgTypes {
Buffer = Type::FIRST_LINALG_TYPE,
Range,
LAST_USED_LINALG_TYPE = Range,
View,
LAST_USED_LINALG_TYPE = View,
};
class LinalgDialect : public Dialect {
@ -51,9 +52,8 @@ public:
static BufferType get(MLIRContext *context, Type elementType);
/// Used to implement llvm-style cast.
static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; }
//////////////////////////////////////////////////////////////////////////////
// Type-specific functionality.
//////////////////////////////////////////////////////////////////////////////
Type getElementType();
};
@ -71,6 +71,37 @@ public:
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
};
/// A ViewType represents a multi-dimensional range abstraction on top of an
/// underlying storage type. It is parameterizable by the underlying element
/// type and the rank of the view.
/// A new value of ViewType is constructed from a buffer with a base_view op and
/// passing it ranges:
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
class ViewTypeStorage;
class ViewType
: public mlir::Type::TypeBase<ViewType, mlir::Type, ViewTypeStorage> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
/// Construction hook.
static ViewType get(mlir::MLIRContext *context, mlir::Type elementType,
unsigned rank);
// Used to implement llvm-style cast.
static bool kindof(unsigned kind) { return kind == LinalgTypes::View; }
// Type-specific functionality.
/// Return the underlying elemental type.
mlir::Type getElementType();
/// Return the rank of the view.
/// This is the number of indexings needed to reach an underlying element.
unsigned getRank();
};
} // namespace mlir
#endif // MLIR_LINALG_LINALGTYPES_H_

View File

@ -25,48 +25,91 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/LinalgTypes.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
using namespace mlir;
//////////////////////////////////////////////////////////////////////////////
// RangeOp
// BaseViewOp
//////////////////////////////////////////////////////////////////////////////
void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
Value *max, Value *step) {
result->addOperands({min, max, step});
result->addTypes({RangeType::get(b->getContext())});
void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer,
ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
assert(
std::none_of(indexings.begin(), indexings.end(),
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
"linalg.base_view takes only arguments of type linalg.range");
Type elementType = bufferType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
// Verification is simply that a RangeOp takes 3 index ssa-value.
mlir::LogicalResult mlir::RangeOp::verify() {
if (!min() || !min()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!max() || !max()->getType().isa<IndexType>())
return emitOpError("second operand should be of type index");
if (!step() || !step()->getType().isa<IndexType>())
return emitOpError("third operand should be of type index");
return mlir::success();
LogicalResult mlir::BaseViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
if (!bufferType)
return emitOpError("first operand must be of BufferType");
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>()) {
return emitOpError(Twine(index) + "^th index must be of range type");
}
++index;
}
if (getViewType().getRank() != index)
return emitOpError(
"the rank of the base view must be the number of its indexings");
return success();
}
// A RangeOp prints as:
bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
"expected" + Twine(viewType.getRank()) +
" range indexings");
return parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo,
RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types);
}
// A BaseViewOp prints as:
//
// ```{.mlir}
// linalg.range %0:%1:%2 : !linalg.range
// linalg.base_view %0[%1, %2] : !linalg.view<?x?xf32>
// ```
void mlir::RangeOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
<< " : " << getType();
}
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
void mlir::BaseViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
getIndexings().begin(), getIndexings().end(),
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
*p << "] : " << getType();
}
//////////////////////////////////////////////////////////////////////////////
@ -140,3 +183,44 @@ bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands);
}
//////////////////////////////////////////////////////////////////////////////
// RangeOp
//////////////////////////////////////////////////////////////////////////////
void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
Value *max, Value *step) {
result->addOperands({min, max, step});
result->addTypes({RangeType::get(b->getContext())});
}
// Verification is simply that a RangeOp takes 3 index ssa-value.
mlir::LogicalResult mlir::RangeOp::verify() {
if (!min() || !min()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!max() || !max()->getType().isa<IndexType>())
return emitOpError("second operand should be of type index");
if (!step() || !step()->getType().isa<IndexType>())
return emitOpError("third operand should be of type index");
return mlir::success();
}
// A RangeOp prints as:
//
// ```{.mlir}
// linalg.range %0:%1:%2 : !linalg.range
// ```
void mlir::RangeOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
<< " : " << getType();
}
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
}

View File

@ -29,8 +29,8 @@ using namespace mlir;
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect("linalg", context) {
addTypes<BufferType, RangeType>();
addOperations<BufferAllocOp, BufferDeallocOp, RangeOp>();
addTypes<BufferType, RangeType, ViewType>();
addOperations<BaseViewOp, BufferAllocOp, BufferDeallocOp, RangeOp>();
}
struct mlir::BufferTypeStorage : public mlir::TypeStorage {
@ -80,15 +80,97 @@ Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
// TODO(ntv): reuse mlir Parser once exposed.
if (spec == "buffer<f32>")
return BufferType::get(getContext(), FloatType::getF32(getContext()));
// TODO(ntv): reuse mlir Parser once exposed.
if (spec.startswith("view")) {
spec.consume_front("view");
// Just count the number of ? to get the rank, the type must be f32 for now.
unsigned rank = 0;
for (unsigned i = 0, e = spec.size(); i < e; ++i)
if (spec[i] == '?')
++rank;
return ViewType::get(context, FloatType::getF32(context), rank);
}
return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
}
/// RangeType prints as just "range".
struct mlir::ViewTypeStorage : public mlir::TypeStorage {
/// Underlying Key type to transport the payload needed to construct a custom
/// type in a generic way.
struct Key {
Key(Type elementType, unsigned rank)
: elementType(elementType), rank(rank) {}
Type elementType;
unsigned rank;
};
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
using KeyTy = Key;
/// Construction in the llvm::BumpPtrAllocator given a key.
static ViewTypeStorage *construct(TypeStorageAllocator &allocator,
const Key &key) {
return new (allocator.allocate<ViewTypeStorage>()) ViewTypeStorage(key);
}
/// Equality operator for hashing.
bool operator==(const Key &key) const {
return elementType == key.elementType && rank == key.rank;
}
/// Hashing for unique'ing.
static unsigned hashKey(const Key &key) {
return llvm::hash_combine(key.elementType, key.rank);
}
unsigned getRank() { return rank; };
Type getElementType() { return elementType; };
private:
ViewTypeStorage(const Key &key)
: elementType(key.elementType), rank(key.rank) {}
Type elementType;
unsigned rank;
};
ViewType mlir::ViewType::get(MLIRContext *context, Type elementType,
unsigned rank) {
return Base::get(context, LinalgTypes::View, elementType, rank);
}
Type mlir::ViewType::getElementType() { return getImpl()->getElementType(); }
unsigned mlir::ViewType::getRank() { return getImpl()->getRank(); }
/// BufferType prints as "buffer<element_type>".
static void print(BufferType bt, raw_ostream &os) {
os << "buffer<" << bt.getElementType() << ">";
}
/// RangeType prints as just "range".
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
/// ViewType prints as:
///
/// ```{.mlir}
/// view<?x?xf32>
/// ```
///
/// or
///
/// ```{.mlir}
/// view<?xf32>
/// ```
///
/// for 0-D views (a.k.a pointer to a scalar value).
static void print(mlir::ViewType rt, raw_ostream &os) {
os << "view<";
for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
os << "?x";
}
os << rt.getElementType();
os << ">";
}
void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
switch (type.getKind()) {
default:
@ -99,5 +181,8 @@ void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;
case LinalgTypes::View:
print(type.cast<ViewType>(), os);
break;
}
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s
// RUN: mlir-opt %s -verify | FileCheck %s
func @range(%arg0: index, %arg1: index, %arg2: index) {
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
@ -16,4 +16,19 @@ func @buffer(%arg0: i64, %arg1: i64) {
// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
%0 = muli %arg0, %arg0 : i64
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
return
}
// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
// CHECK-NEXT: %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>