forked from OSchip/llvm-project
Move BufferAllocOp and BufferDeallocOp to ODS
This CL also fixes a parsing issue in the BufferType, adds LLVM lowering support for handling the static constant buffer size and a roundtrip test. PiperOrigin-RevId: 255834356
This commit is contained in:
parent
ef76343488
commit
6a7a1ca25d
|
@ -29,56 +29,6 @@ class OperationFolder;
|
|||
|
||||
namespace linalg {
|
||||
|
||||
/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
|
||||
/// upon which a base view can be laid out to give it indexing semantics.
|
||||
/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
|
||||
/// (in number of elements).
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
|
||||
/// ```
|
||||
class BufferAllocOp
|
||||
: public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
// Hooks to customize the behavior of this op.
|
||||
static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
|
||||
static void build(Builder *b, OperationState *result, Type type, Value *size);
|
||||
LogicalResult verify();
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
|
||||
// Op-specific functionality.
|
||||
Value *size() { return getOperand(); }
|
||||
BufferType getBufferType() { return getType().cast<BufferType>(); }
|
||||
Type getElementType() { return getBufferType().getElementType(); }
|
||||
};
|
||||
|
||||
/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// linalg.buffer_dealloc %0 : !linalg.buffer<f32>
|
||||
/// ```
|
||||
class BufferDeallocOp
|
||||
: public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
// Hooks to customize the behavior of this op.
|
||||
static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
|
||||
static void build(Builder *b, OperationState *result, Value *buffer);
|
||||
LogicalResult verify();
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
|
||||
// Op-specific functionality.
|
||||
Value *getBuffer() { return getOperand(); }
|
||||
BufferType getBufferType() {
|
||||
return getOperand()->getType().cast<BufferType>();
|
||||
}
|
||||
};
|
||||
|
||||
/// The "linalg.for" operation represents a loop nest taking 3 SSA value as
|
||||
/// operands that represent the lower bound, upper bound and step respectively.
|
||||
/// The operation defines an SSA value for its induction variable. It has one
|
||||
|
|
|
@ -39,6 +39,65 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def BufferAllocOp :
|
||||
Linalg_Op<"buffer_alloc">,
|
||||
Arguments<(ins Variadic<Index>:$size)>,
|
||||
Results<(outs Buffer)> {
|
||||
let summary = "buffer allocation operation";
|
||||
let description = [{
|
||||
The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
|
||||
upon which a base view can be laid out to give it indexing semantics.
|
||||
"buffer_alloc" takes a single argument, the size of the buffer to allocate
|
||||
(in number of elements).
|
||||
|
||||
```{.mlir}
|
||||
%0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
|
||||
```
|
||||
|
||||
The size argument may be omitted if it is statically known, in which case it
|
||||
must be reflected in the type.
|
||||
|
||||
```{.mlir}
|
||||
%0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, BufferType bufferType", [{
|
||||
result->types.push_back(bufferType);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
BufferType getBufferType() { return getType().cast<BufferType>(); }
|
||||
Type getElementType() { return getBufferType().getElementType(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def BufferDeallocOp :
|
||||
Linalg_Op<"buffer_dealloc">,
|
||||
Arguments<(ins Buffer:$buffer)>,
|
||||
Results<(outs)> {
|
||||
let summary = "buffer allocation operation";
|
||||
let description = [{
|
||||
The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
|
||||
|
||||
```{.mlir}
|
||||
linalg.buffer_dealloc %0 : !linalg.buffer<f32>
|
||||
```
|
||||
}];
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, BufferType bufferType", [{
|
||||
result->types.push_back(bufferType);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
BufferType getBufferType() {
|
||||
return getOperand()->getType().cast<BufferType>();
|
||||
}
|
||||
}];
|
||||
// Fully specified by traits.
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def BufferSizeOp :
|
||||
Linalg_Op<"buffer_size", [NoSideEffect]>,
|
||||
Arguments<(ins Buffer)>,
|
||||
|
|
|
@ -37,76 +37,6 @@ using namespace mlir::edsc;
|
|||
using namespace mlir::edsc::intrinsics;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// BufferAllocOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
|
||||
Type type, Value *size) {
|
||||
result->addOperands({size});
|
||||
result->addTypes(type);
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::BufferAllocOp::verify() {
|
||||
if (!size() || !size()->getType().isa<IndexType>())
|
||||
return emitOpError("first operand should be of type index");
|
||||
if (!VectorType::isValidElementType(getElementType()) &&
|
||||
!getElementType().isa<VectorType>())
|
||||
return emitOpError("unsupported buffer element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
// A BufferAllocOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.alloc %0 : !linalg.buffer<f32>
|
||||
// ```
|
||||
void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *size() << " : " << getType();
|
||||
}
|
||||
|
||||
ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType sizeInfo;
|
||||
BufferType bufferType;
|
||||
auto indexTy = parser->getBuilder().getIndexType();
|
||||
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
|
||||
return failure();
|
||||
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
|
||||
parser->addTypeToList(bufferType, result->types));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// BufferDeallocOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
|
||||
Value *buffer) {
|
||||
result->addOperands({buffer});
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::BufferDeallocOp::verify() {
|
||||
if (!getBuffer()->getType())
|
||||
return emitOpError("first operand should be of type buffer");
|
||||
return success();
|
||||
}
|
||||
|
||||
// A BufferDeallocOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.dealloc %0 : !linalg.buffer<f32>
|
||||
// ```
|
||||
void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
|
||||
}
|
||||
|
||||
ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType sizeInfo;
|
||||
BufferType bufferType;
|
||||
return failure(
|
||||
parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
|
||||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// ForOp.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -605,6 +535,60 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
|
|||
// LinalgOps.td), we define an overloaded `print` function and a
|
||||
// parse`className` function.
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferAllocOp op) {
|
||||
*p << op.getOperationName() << " ";
|
||||
if (!llvm::empty(op.size()))
|
||||
*p << *op.getOperand(0);
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getBufferType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferAllocOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
|
||||
BufferType bufferType;
|
||||
auto indexTy = parser->getBuilder().getIndexType();
|
||||
if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType))
|
||||
return failure();
|
||||
if (sizeInfo.empty())
|
||||
return parser->addTypeToList(bufferType, result->types);
|
||||
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
|
||||
parser->addTypeToList(bufferType, result->types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(BufferAllocOp op) {
|
||||
if (!op.getBufferType().hasConstantSize()) {
|
||||
if (llvm::size(op.size()) != 1 ||
|
||||
!op.getOperand(0)->getType().isa<IndexType>())
|
||||
return op.emitOpError(
|
||||
"one operand of type index expected for dynamic buffer");
|
||||
} else { // op.getBufferType().hasConstantSize()
|
||||
if (!llvm::empty(op.size()))
|
||||
return op.emitOpError("unexpected static buffer operand");
|
||||
if (op.getBufferType().getBufferSize().getValue() <= 0)
|
||||
return op.emitOpError("expected nonnegative static buffer size");
|
||||
}
|
||||
if (!VectorType::isValidElementType(op.getElementType()) &&
|
||||
!op.getElementType().isa<VectorType>())
|
||||
return op.emitOpError("unsupported buffer element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
|
||||
*p << op.getOperationName() << " " << *op.buffer();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getBufferType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
BufferType bufferType;
|
||||
if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType))
|
||||
return failure();
|
||||
return parser->resolveOperands(bufferInfo, bufferType, result->operands);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, BufferSizeOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getOperand();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
|
|
|
@ -34,8 +34,7 @@ using namespace mlir::linalg;
|
|||
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addTypes<BufferType, RangeType, ViewType>();
|
||||
addOperations<BufferAllocOp, BufferDeallocOp, ForOp, LoadOp, RangeOp, StoreOp,
|
||||
SliceOp, ViewOp>();
|
||||
addOperations<ForOp, LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
||||
|
@ -119,8 +118,8 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
|
|||
// Check for '?'
|
||||
int64_t bufferSize = -1;
|
||||
if (!spec.consume_front("?")) {
|
||||
unsigned parsedBufferSize;
|
||||
if (!spec.consumeInteger(10, parsedBufferSize)) {
|
||||
unsigned long long parsedBufferSize = 0;
|
||||
if (spec.consumeInteger(10, parsedBufferSize)) {
|
||||
emitError(loc, "expected buffer size to be an unsigned integer");
|
||||
return Type();
|
||||
}
|
||||
|
|
|
@ -168,7 +168,7 @@ public:
|
|||
auto indexType = IndexType::get(op->getContext());
|
||||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
auto int64Ty = lowering.convertType(operands[0]->getType());
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto *module = op->getFunction()->getModule();
|
||||
Function *mallocFunc = module->getNamedFunction("malloc");
|
||||
|
@ -187,14 +187,19 @@ public:
|
|||
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
|
||||
else
|
||||
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||
auto elementPtrType = getPtrToElementType(
|
||||
allocOp.getResult()->getType().cast<BufferType>(), lowering);
|
||||
auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
|
||||
auto elementPtrType = getPtrToElementType(bufferType, lowering);
|
||||
auto bufferDescriptorType =
|
||||
convertLinalgType(allocOp.getResult()->getType(), lowering);
|
||||
|
||||
// Emit IR for creating a new buffer descriptor with an underlying malloc.
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
Value *size = operands[0];
|
||||
auto constantSize = bufferType.getBufferSize();
|
||||
Value *size =
|
||||
constantSize
|
||||
? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
|
||||
.getValue()
|
||||
: operands[0];
|
||||
Value *allocSize =
|
||||
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
|
||||
Value *allocated =
|
||||
|
|
|
@ -13,12 +13,16 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
|
|||
func @buffer(%arg0: index, %arg1: index) {
|
||||
%0 = muli %arg0, %arg0 : index
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
|
||||
%2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
|
||||
linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
|
||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
|
||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
|
||||
// CHECK-NEXT: %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
|
||||
|
||||
func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
|
||||
|
|
Loading…
Reference in New Issue