Move BufferAllocOp and BufferDeallocOp to ODS

This CL also fixes a parsing issue in the BufferType, adds LLVM lowering support for handling the static constant buffer size and a roundtrip test.

PiperOrigin-RevId: 255834356
This commit is contained in:
Nicolas Vasilache 2019-06-30 07:07:50 -07:00 committed by jpienaar
parent ef76343488
commit 6a7a1ca25d
6 changed files with 129 additions and 128 deletions

View File

@ -29,56 +29,6 @@ class OperationFolder;
namespace linalg {
/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
/// upon which a base view can be laid out to give it indexing semantics.
/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
/// (in number of elements).
///
/// ```{.mlir}
/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
/// ```
class BufferAllocOp
: public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
static void build(Builder *b, OperationState *result, Type type, Value *size);
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
Value *size() { return getOperand(); }
BufferType getBufferType() { return getType().cast<BufferType>(); }
Type getElementType() { return getBufferType().getElementType(); }
};
/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
///
/// ```{.mlir}
/// linalg.buffer_dealloc %0 : !linalg.buffer<f32>
/// ```
class BufferDeallocOp
: public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
static void build(Builder *b, OperationState *result, Value *buffer);
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
Value *getBuffer() { return getOperand(); }
BufferType getBufferType() {
return getOperand()->getType().cast<BufferType>();
}
};
/// The "linalg.for" operation represents a loop nest taking 3 SSA value as
/// operands that represent the lower bound, upper bound and step respectively.
/// The operation defines an SSA value for its induction variable. It has one

View File

@ -39,6 +39,65 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def BufferAllocOp :
Linalg_Op<"buffer_alloc">,
Arguments<(ins Variadic<Index>:$size)>,
Results<(outs Buffer)> {
let summary = "buffer allocation operation";
let description = [{
The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
upon which a base view can be laid out to give it indexing semantics.
"buffer_alloc" takes a single argument, the size of the buffer to allocate
(in number of elements).
```{.mlir}
%0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
```
The size argument may be omitted if it is statically known, in which case it
must be reflected in the type.
```{.mlir}
%0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
```
}];
let builders = [OpBuilder<
"Builder *builder, OperationState *result, BufferType bufferType", [{
result->types.push_back(bufferType);
}]
>];
let extraClassDeclaration = [{
BufferType getBufferType() { return getType().cast<BufferType>(); }
Type getElementType() { return getBufferType().getElementType(); }
}];
}
def BufferDeallocOp :
Linalg_Op<"buffer_dealloc">,
Arguments<(ins Buffer:$buffer)>,
Results<(outs)> {
let summary = "buffer allocation operation";
let description = [{
The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
```{.mlir}
linalg.buffer_dealloc %0 : !linalg.buffer<f32>
```
}];
let builders = [OpBuilder<
"Builder *builder, OperationState *result, BufferType bufferType", [{
result->types.push_back(bufferType);
}]
>];
let extraClassDeclaration = [{
BufferType getBufferType() {
return getOperand()->getType().cast<BufferType>();
}
}];
// Fully specified by traits.
let verifier = ?;
}
def BufferSizeOp :
Linalg_Op<"buffer_size", [NoSideEffect]>,
Arguments<(ins Buffer)>,

View File

@ -37,76 +37,6 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
//////////////////////////////////////////////////////////////////////////////
// BufferAllocOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
Type type, Value *size) {
result->addOperands({size});
result->addTypes(type);
}
LogicalResult mlir::linalg::BufferAllocOp::verify() {
if (!size() || !size()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!VectorType::isValidElementType(getElementType()) &&
!getElementType().isa<VectorType>())
return emitOpError("unsupported buffer element type");
return success();
}
// A BufferAllocOp prints as:
//
// ```{.mlir}
// linalg.alloc %0 : !linalg.buffer<f32>
// ```
void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *size() << " : " << getType();
}
ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return failure();
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types));
}
//////////////////////////////////////////////////////////////////////////////
// BufferDeallocOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
Value *buffer) {
result->addOperands({buffer});
}
LogicalResult mlir::linalg::BufferDeallocOp::verify() {
if (!getBuffer()->getType())
return emitOpError("first operand should be of type buffer");
return success();
}
// A BufferDeallocOp prints as:
//
// ```{.mlir}
// linalg.dealloc %0 : !linalg.buffer<f32>
// ```
void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
}
ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
return failure(
parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
}
////////////////////////////////////////////////////////////////////////////////
// ForOp.
////////////////////////////////////////////////////////////////////////////////
@ -605,6 +535,60 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.
static void print(OpAsmPrinter *p, BufferAllocOp op) {
*p << op.getOperationName() << " ";
if (!llvm::empty(op.size()))
*p << *op.getOperand(0);
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getBufferType();
}
static ParseResult parseBufferAllocOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
BufferType bufferType;
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType))
return failure();
if (sizeInfo.empty())
return parser->addTypeToList(bufferType, result->types);
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types));
}
static LogicalResult verify(BufferAllocOp op) {
if (!op.getBufferType().hasConstantSize()) {
if (llvm::size(op.size()) != 1 ||
!op.getOperand(0)->getType().isa<IndexType>())
return op.emitOpError(
"one operand of type index expected for dynamic buffer");
} else { // op.getBufferType().hasConstantSize()
if (!llvm::empty(op.size()))
return op.emitOpError("unexpected static buffer operand");
if (op.getBufferType().getBufferSize().getValue() <= 0)
return op.emitOpError("expected nonnegative static buffer size");
}
if (!VectorType::isValidElementType(op.getElementType()) &&
!op.getElementType().isa<VectorType>())
return op.emitOpError("unsupported buffer element type");
return success();
}
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
*p << op.getOperationName() << " " << *op.buffer();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getBufferType();
}
static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType bufferInfo;
BufferType bufferType;
if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType))
return failure();
return parser->resolveOperands(bufferInfo, bufferType, result->operands);
}
static void print(OpAsmPrinter *p, BufferSizeOp op) {
*p << op.getOperationName() << " " << *op.getOperand();
p->printOptionalAttrDict(op.getAttrs());

View File

@ -34,8 +34,7 @@ using namespace mlir::linalg;
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<BufferType, RangeType, ViewType>();
addOperations<BufferAllocOp, BufferDeallocOp, ForOp, LoadOp, RangeOp, StoreOp,
SliceOp, ViewOp>();
addOperations<ForOp, LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
@ -119,8 +118,8 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
// Check for '?'
int64_t bufferSize = -1;
if (!spec.consume_front("?")) {
unsigned parsedBufferSize;
if (!spec.consumeInteger(10, parsedBufferSize)) {
unsigned long long parsedBufferSize = 0;
if (spec.consumeInteger(10, parsedBufferSize)) {
emitError(loc, "expected buffer size to be an unsigned integer");
return Type();
}

View File

@ -168,7 +168,7 @@ public:
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(operands[0]->getType());
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
auto *module = op->getFunction()->getModule();
Function *mallocFunc = module->getNamedFunction("malloc");
@ -187,14 +187,19 @@ public:
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
auto elementPtrType = getPtrToElementType(
allocOp.getResult()->getType().cast<BufferType>(), lowering);
auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
auto elementPtrType = getPtrToElementType(bufferType, lowering);
auto bufferDescriptorType =
convertLinalgType(allocOp.getResult()->getType(), lowering);
// Emit IR for creating a new buffer descriptor with an underlying malloc.
edsc::ScopedContext context(rewriter, op->getLoc());
Value *size = operands[0];
auto constantSize = bufferType.getBufferSize();
Value *size =
constantSize
? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
.getValue()
: operands[0];
Value *allocSize =
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
Value *allocated =

View File

@ -13,12 +13,16 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
func @buffer(%arg0: index, %arg1: index) {
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
%2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
return
}
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {