forked from OSchip/llvm-project
[Linalg] Add a view type with base_view op
This CL adds a linalg.view<?x?xf32> type and base_view op with the proper roundtripping test. The parser will be improved in a subsequent CL once portions of the mlir::Parser are exposed. For now this only supports dynamic views, static views will be introduced at a later time when they are needed. -- PiperOrigin-RevId: 244374180
This commit is contained in:
parent
e8d551e2bd
commit
1d5dc840e7
|
@ -24,6 +24,50 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range
|
||||
/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a
|
||||
/// buffer an indexing structure.
|
||||
///
|
||||
/// A new value of ViewType is constructed from a buffer with a base_view op and
|
||||
/// ranges:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
/// ```
|
||||
class BaseViewOp : public mlir::Op<BaseViewOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
// Hooks to customize the behavior of this op.
|
||||
static llvm::StringRef getOperationName() { return "linalg.base_view"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *buffer,
|
||||
llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::LogicalResult verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
// Op-specific functionality.
|
||||
unsigned getRank() { return getViewType().getRank(); }
|
||||
mlir::Type getElementType() { return getViewType().getElementType(); }
|
||||
ViewType getViewType() { return getType().cast<ViewType>(); }
|
||||
mlir::Value *getSupportingBuffer() { return getOperand(0); }
|
||||
// Get the underlying indexing at a given rank.
|
||||
mlir::Value *getIndexing(unsigned rank) {
|
||||
return *(getIndexings().begin() + rank);
|
||||
}
|
||||
// Get all the indexings in this view.
|
||||
mlir::Operation::operand_range getIndexings() {
|
||||
return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
|
||||
/// view can be laid out. The size argument is an `i64` (and not an index), so
|
||||
/// that we can
|
||||
|
|
|
@ -27,7 +27,8 @@ class MLIRContext;
|
|||
enum LinalgTypes {
|
||||
Buffer = Type::FIRST_LINALG_TYPE,
|
||||
Range,
|
||||
LAST_USED_LINALG_TYPE = Range,
|
||||
View,
|
||||
LAST_USED_LINALG_TYPE = View,
|
||||
};
|
||||
|
||||
class LinalgDialect : public Dialect {
|
||||
|
@ -51,9 +52,8 @@ public:
|
|||
static BufferType get(MLIRContext *context, Type elementType);
|
||||
/// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; }
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Type-specific functionality.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
Type getElementType();
|
||||
};
|
||||
|
||||
|
@ -71,6 +71,37 @@ public:
|
|||
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
|
||||
};
|
||||
|
||||
/// A ViewType represents a multi-dimensional range abstraction on top of an
|
||||
/// underlying storage type. It is parameterizable by the underlying element
|
||||
/// type and the rank of the view.
|
||||
/// A new value of ViewType is constructed from a buffer with a base_view op and
|
||||
/// passing it ranges:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
/// ```
|
||||
class ViewTypeStorage;
|
||||
class ViewType
|
||||
: public mlir::Type::TypeBase<ViewType, mlir::Type, ViewTypeStorage> {
|
||||
public:
|
||||
// Used for generic hooks in TypeBase.
|
||||
using Base::Base;
|
||||
/// Construction hook.
|
||||
static ViewType get(mlir::MLIRContext *context, mlir::Type elementType,
|
||||
unsigned rank);
|
||||
// Used to implement llvm-style cast.
|
||||
static bool kindof(unsigned kind) { return kind == LinalgTypes::View; }
|
||||
|
||||
// Type-specific functionality.
|
||||
/// Return the underlying elemental type.
|
||||
mlir::Type getElementType();
|
||||
/// Return the rank of the view.
|
||||
/// This is the number of indexings needed to reach an underlying element.
|
||||
unsigned getRank();
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_LINALG_LINALGTYPES_H_
|
||||
|
|
|
@ -25,48 +25,91 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Linalg/LinalgTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// RangeOp
|
||||
// BaseViewOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
|
||||
Value *max, Value *step) {
|
||||
result->addOperands({min, max, step});
|
||||
result->addTypes({RangeType::get(b->getContext())});
|
||||
void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer,
|
||||
ArrayRef<Value *> indexings) {
|
||||
BufferType bufferType = buffer->getType().cast<BufferType>();
|
||||
result->addOperands({buffer});
|
||||
result->addOperands(indexings);
|
||||
assert(
|
||||
std::none_of(indexings.begin(), indexings.end(),
|
||||
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
|
||||
"linalg.base_view takes only arguments of type linalg.range");
|
||||
|
||||
Type elementType = bufferType.getElementType();
|
||||
result->addTypes(
|
||||
{ViewType::get(b->getContext(), elementType, indexings.size())});
|
||||
}
|
||||
|
||||
// Verification is simply that a RangeOp takes 3 index ssa-value.
|
||||
mlir::LogicalResult mlir::RangeOp::verify() {
|
||||
if (!min() || !min()->getType().isa<IndexType>())
|
||||
return emitOpError("first operand should be of type index");
|
||||
if (!max() || !max()->getType().isa<IndexType>())
|
||||
return emitOpError("second operand should be of type index");
|
||||
if (!step() || !step()->getType().isa<IndexType>())
|
||||
return emitOpError("third operand should be of type index");
|
||||
return mlir::success();
|
||||
LogicalResult mlir::BaseViewOp::verify() {
|
||||
if (llvm::empty(getOperands()))
|
||||
return emitOpError(
|
||||
"requires at least a buffer operand followed by indexings");
|
||||
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
|
||||
if (!bufferType)
|
||||
return emitOpError("first operand must be of BufferType");
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
if (!indexing->getType().isa<RangeType>()) {
|
||||
return emitOpError(Twine(index) + "^th index must be of range type");
|
||||
}
|
||||
++index;
|
||||
}
|
||||
if (getViewType().getRank() != index)
|
||||
return emitOpError(
|
||||
"the rank of the base view must be the number of its indexings");
|
||||
return success();
|
||||
}
|
||||
|
||||
// A RangeOp prints as:
|
||||
bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
|
||||
Type type;
|
||||
if (parser->parseOperand(bufferInfo) ||
|
||||
parser->parseOperandList(indexingsInfo, -1,
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type))
|
||||
return true;
|
||||
|
||||
ViewType viewType = type.dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser->emitError(parser->getNameLoc(), "view type expected");
|
||||
if (viewType.getRank() != indexingsInfo.size())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"expected" + Twine(viewType.getRank()) +
|
||||
" range indexings");
|
||||
return parser->resolveOperand(
|
||||
bufferInfo,
|
||||
BufferType::get(type.getContext(), viewType.getElementType()),
|
||||
result->operands) ||
|
||||
(!indexingsInfo.empty() &&
|
||||
parser->resolveOperands(indexingsInfo,
|
||||
RangeType::get(type.getContext()),
|
||||
result->operands)) ||
|
||||
parser->addTypeToList(viewType, result->types);
|
||||
}
|
||||
|
||||
// A BaseViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.range %0:%1:%2 : !linalg.range
|
||||
// linalg.base_view %0[%1, %2] : !linalg.view<?x?xf32>
|
||||
// ```
|
||||
void mlir::RangeOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
|
||||
<< " : " << getType();
|
||||
}
|
||||
|
||||
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
|
||||
RangeType type;
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
|
||||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
|
||||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
|
||||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
|
||||
parser->addTypeToList(type, result->types);
|
||||
//
|
||||
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
|
||||
// holding a range.
|
||||
void mlir::BaseViewOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
|
||||
interleave(
|
||||
getIndexings().begin(), getIndexings().end(),
|
||||
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
|
||||
*p << "] : " << getType();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -140,3 +183,44 @@ bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
|
||||
parser->resolveOperands(sizeInfo, bufferType, result->operands);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// RangeOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
|
||||
Value *max, Value *step) {
|
||||
result->addOperands({min, max, step});
|
||||
result->addTypes({RangeType::get(b->getContext())});
|
||||
}
|
||||
|
||||
// Verification is simply that a RangeOp takes 3 index ssa-value.
|
||||
mlir::LogicalResult mlir::RangeOp::verify() {
|
||||
if (!min() || !min()->getType().isa<IndexType>())
|
||||
return emitOpError("first operand should be of type index");
|
||||
if (!max() || !max()->getType().isa<IndexType>())
|
||||
return emitOpError("second operand should be of type index");
|
||||
if (!step() || !step()->getType().isa<IndexType>())
|
||||
return emitOpError("third operand should be of type index");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// A RangeOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.range %0:%1:%2 : !linalg.range
|
||||
// ```
|
||||
void mlir::RangeOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
|
||||
<< " : " << getType();
|
||||
}
|
||||
|
||||
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
|
||||
RangeType type;
|
||||
auto affineIntTy = parser->getBuilder().getIndexType();
|
||||
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
|
||||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
|
||||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
|
||||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
|
||||
parser->addTypeToList(type, result->types);
|
||||
}
|
||||
|
|
|
@ -29,8 +29,8 @@ using namespace mlir;
|
|||
|
||||
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||
: Dialect("linalg", context) {
|
||||
addTypes<BufferType, RangeType>();
|
||||
addOperations<BufferAllocOp, BufferDeallocOp, RangeOp>();
|
||||
addTypes<BufferType, RangeType, ViewType>();
|
||||
addOperations<BaseViewOp, BufferAllocOp, BufferDeallocOp, RangeOp>();
|
||||
}
|
||||
|
||||
struct mlir::BufferTypeStorage : public mlir::TypeStorage {
|
||||
|
@ -80,15 +80,97 @@ Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
|
|||
// TODO(ntv): reuse mlir Parser once exposed.
|
||||
if (spec == "buffer<f32>")
|
||||
return BufferType::get(getContext(), FloatType::getF32(getContext()));
|
||||
// TODO(ntv): reuse mlir Parser once exposed.
|
||||
if (spec.startswith("view")) {
|
||||
spec.consume_front("view");
|
||||
// Just count the number of ? to get the rank, the type must be f32 for now.
|
||||
unsigned rank = 0;
|
||||
for (unsigned i = 0, e = spec.size(); i < e; ++i)
|
||||
if (spec[i] == '?')
|
||||
++rank;
|
||||
return ViewType::get(context, FloatType::getF32(context), rank);
|
||||
}
|
||||
return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
|
||||
}
|
||||
|
||||
/// RangeType prints as just "range".
|
||||
struct mlir::ViewTypeStorage : public mlir::TypeStorage {
|
||||
/// Underlying Key type to transport the payload needed to construct a custom
|
||||
/// type in a generic way.
|
||||
struct Key {
|
||||
Key(Type elementType, unsigned rank)
|
||||
: elementType(elementType), rank(rank) {}
|
||||
Type elementType;
|
||||
unsigned rank;
|
||||
};
|
||||
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
|
||||
using KeyTy = Key;
|
||||
|
||||
/// Construction in the llvm::BumpPtrAllocator given a key.
|
||||
static ViewTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const Key &key) {
|
||||
return new (allocator.allocate<ViewTypeStorage>()) ViewTypeStorage(key);
|
||||
}
|
||||
|
||||
/// Equality operator for hashing.
|
||||
bool operator==(const Key &key) const {
|
||||
return elementType == key.elementType && rank == key.rank;
|
||||
}
|
||||
|
||||
/// Hashing for unique'ing.
|
||||
static unsigned hashKey(const Key &key) {
|
||||
return llvm::hash_combine(key.elementType, key.rank);
|
||||
}
|
||||
|
||||
unsigned getRank() { return rank; };
|
||||
Type getElementType() { return elementType; };
|
||||
|
||||
private:
|
||||
ViewTypeStorage(const Key &key)
|
||||
: elementType(key.elementType), rank(key.rank) {}
|
||||
|
||||
Type elementType;
|
||||
unsigned rank;
|
||||
};
|
||||
|
||||
ViewType mlir::ViewType::get(MLIRContext *context, Type elementType,
|
||||
unsigned rank) {
|
||||
return Base::get(context, LinalgTypes::View, elementType, rank);
|
||||
}
|
||||
|
||||
Type mlir::ViewType::getElementType() { return getImpl()->getElementType(); }
|
||||
|
||||
unsigned mlir::ViewType::getRank() { return getImpl()->getRank(); }
|
||||
|
||||
/// BufferType prints as "buffer<element_type>".
|
||||
static void print(BufferType bt, raw_ostream &os) {
|
||||
os << "buffer<" << bt.getElementType() << ">";
|
||||
}
|
||||
|
||||
/// RangeType prints as just "range".
|
||||
static void print(RangeType rt, raw_ostream &os) { os << "range"; }
|
||||
|
||||
/// ViewType prints as:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// view<?x?xf32>
|
||||
/// ```
|
||||
///
|
||||
/// or
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// view<?xf32>
|
||||
/// ```
|
||||
///
|
||||
/// for 0-D views (a.k.a pointer to a scalar value).
|
||||
static void print(mlir::ViewType rt, raw_ostream &os) {
|
||||
os << "view<";
|
||||
for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
|
||||
os << "?x";
|
||||
}
|
||||
os << rt.getElementType();
|
||||
os << ">";
|
||||
}
|
||||
|
||||
void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
|
||||
switch (type.getKind()) {
|
||||
default:
|
||||
|
@ -99,5 +181,8 @@ void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
|
|||
case LinalgTypes::Range:
|
||||
print(type.cast<RangeType>(), os);
|
||||
break;
|
||||
case LinalgTypes::View:
|
||||
print(type.cast<ViewType>(), os);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s
|
||||
// RUN: mlir-opt %s -verify | FileCheck %s
|
||||
|
||||
func @range(%arg0: index, %arg1: index, %arg2: index) {
|
||||
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
|
@ -16,4 +16,19 @@ func @buffer(%arg0: i64, %arg1: i64) {
|
|||
// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
|
||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
|
||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
||||
|
||||
func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
|
||||
%0 = muli %arg0, %arg0 : i64
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
%3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
|
||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
|
||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
// CHECK-NEXT: %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
Loading…
Reference in New Issue