forked from OSchip/llvm-project
Move BufferAllocOp and BufferDeallocOp to ODS
This CL also fixes a parsing issue in the BufferType, adds LLVM lowering support for handling the static constant buffer size and a roundtrip test. PiperOrigin-RevId: 255834356
This commit is contained in:
parent
ef76343488
commit
6a7a1ca25d
|
@ -29,56 +29,6 @@ class OperationFolder;
|
||||||
|
|
||||||
namespace linalg {
|
namespace linalg {
|
||||||
|
|
||||||
/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
|
|
||||||
/// upon which a base view can be laid out to give it indexing semantics.
|
|
||||||
/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
|
|
||||||
/// (in number of elements).
|
|
||||||
///
|
|
||||||
/// ```{.mlir}
|
|
||||||
/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
|
|
||||||
/// ```
|
|
||||||
class BufferAllocOp
|
|
||||||
: public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
|
|
||||||
public:
|
|
||||||
using Op::Op;
|
|
||||||
|
|
||||||
// Hooks to customize the behavior of this op.
|
|
||||||
static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
|
|
||||||
static void build(Builder *b, OperationState *result, Type type, Value *size);
|
|
||||||
LogicalResult verify();
|
|
||||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
|
||||||
void print(OpAsmPrinter *p);
|
|
||||||
|
|
||||||
// Op-specific functionality.
|
|
||||||
Value *size() { return getOperand(); }
|
|
||||||
BufferType getBufferType() { return getType().cast<BufferType>(); }
|
|
||||||
Type getElementType() { return getBufferType().getElementType(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
|
|
||||||
///
|
|
||||||
/// ```{.mlir}
|
|
||||||
/// linalg.buffer_dealloc %0 : !linalg.buffer<f32>
|
|
||||||
/// ```
|
|
||||||
class BufferDeallocOp
|
|
||||||
: public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
|
|
||||||
public:
|
|
||||||
using Op::Op;
|
|
||||||
|
|
||||||
// Hooks to customize the behavior of this op.
|
|
||||||
static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
|
|
||||||
static void build(Builder *b, OperationState *result, Value *buffer);
|
|
||||||
LogicalResult verify();
|
|
||||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
|
||||||
void print(OpAsmPrinter *p);
|
|
||||||
|
|
||||||
// Op-specific functionality.
|
|
||||||
Value *getBuffer() { return getOperand(); }
|
|
||||||
BufferType getBufferType() {
|
|
||||||
return getOperand()->getType().cast<BufferType>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// The "linalg.for" operation represents a loop nest taking 3 SSA value as
|
/// The "linalg.for" operation represents a loop nest taking 3 SSA value as
|
||||||
/// operands that represent the lower bound, upper bound and step respectively.
|
/// operands that represent the lower bound, upper bound and step respectively.
|
||||||
/// The operation defines an SSA value for its induction variable. It has one
|
/// The operation defines an SSA value for its induction variable. It has one
|
||||||
|
|
|
@ -39,6 +39,65 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def BufferAllocOp :
|
||||||
|
Linalg_Op<"buffer_alloc">,
|
||||||
|
Arguments<(ins Variadic<Index>:$size)>,
|
||||||
|
Results<(outs Buffer)> {
|
||||||
|
let summary = "buffer allocation operation";
|
||||||
|
let description = [{
|
||||||
|
The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
|
||||||
|
upon which a base view can be laid out to give it indexing semantics.
|
||||||
|
"buffer_alloc" takes a single argument, the size of the buffer to allocate
|
||||||
|
(in number of elements).
|
||||||
|
|
||||||
|
```{.mlir}
|
||||||
|
%0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
|
||||||
|
```
|
||||||
|
|
||||||
|
The size argument may be omitted if it is statically known, in which case it
|
||||||
|
must be reflected in the type.
|
||||||
|
|
||||||
|
```{.mlir}
|
||||||
|
%0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *builder, OperationState *result, BufferType bufferType", [{
|
||||||
|
result->types.push_back(bufferType);
|
||||||
|
}]
|
||||||
|
>];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
BufferType getBufferType() { return getType().cast<BufferType>(); }
|
||||||
|
Type getElementType() { return getBufferType().getElementType(); }
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def BufferDeallocOp :
|
||||||
|
Linalg_Op<"buffer_dealloc">,
|
||||||
|
Arguments<(ins Buffer:$buffer)>,
|
||||||
|
Results<(outs)> {
|
||||||
|
let summary = "buffer allocation operation";
|
||||||
|
let description = [{
|
||||||
|
The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
|
||||||
|
|
||||||
|
```{.mlir}
|
||||||
|
linalg.buffer_dealloc %0 : !linalg.buffer<f32>
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *builder, OperationState *result, BufferType bufferType", [{
|
||||||
|
result->types.push_back(bufferType);
|
||||||
|
}]
|
||||||
|
>];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
BufferType getBufferType() {
|
||||||
|
return getOperand()->getType().cast<BufferType>();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
// Fully specified by traits.
|
||||||
|
let verifier = ?;
|
||||||
|
}
|
||||||
|
|
||||||
def BufferSizeOp :
|
def BufferSizeOp :
|
||||||
Linalg_Op<"buffer_size", [NoSideEffect]>,
|
Linalg_Op<"buffer_size", [NoSideEffect]>,
|
||||||
Arguments<(ins Buffer)>,
|
Arguments<(ins Buffer)>,
|
||||||
|
|
|
@ -37,76 +37,6 @@ using namespace mlir::edsc;
|
||||||
using namespace mlir::edsc::intrinsics;
|
using namespace mlir::edsc::intrinsics;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
// BufferAllocOp
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result,
|
|
||||||
Type type, Value *size) {
|
|
||||||
result->addOperands({size});
|
|
||||||
result->addTypes(type);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult mlir::linalg::BufferAllocOp::verify() {
|
|
||||||
if (!size() || !size()->getType().isa<IndexType>())
|
|
||||||
return emitOpError("first operand should be of type index");
|
|
||||||
if (!VectorType::isValidElementType(getElementType()) &&
|
|
||||||
!getElementType().isa<VectorType>())
|
|
||||||
return emitOpError("unsupported buffer element type");
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// A BufferAllocOp prints as:
|
|
||||||
//
|
|
||||||
// ```{.mlir}
|
|
||||||
// linalg.alloc %0 : !linalg.buffer<f32>
|
|
||||||
// ```
|
|
||||||
void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) {
|
|
||||||
*p << getOperationName() << " " << *size() << " : " << getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
|
|
||||||
OperationState *result) {
|
|
||||||
OpAsmParser::OperandType sizeInfo;
|
|
||||||
BufferType bufferType;
|
|
||||||
auto indexTy = parser->getBuilder().getIndexType();
|
|
||||||
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
|
|
||||||
return failure();
|
|
||||||
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
|
|
||||||
parser->addTypeToList(bufferType, result->types));
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
// BufferDeallocOp
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
|
||||||
void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result,
|
|
||||||
Value *buffer) {
|
|
||||||
result->addOperands({buffer});
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult mlir::linalg::BufferDeallocOp::verify() {
|
|
||||||
if (!getBuffer()->getType())
|
|
||||||
return emitOpError("first operand should be of type buffer");
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// A BufferDeallocOp prints as:
|
|
||||||
//
|
|
||||||
// ```{.mlir}
|
|
||||||
// linalg.dealloc %0 : !linalg.buffer<f32>
|
|
||||||
// ```
|
|
||||||
void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) {
|
|
||||||
*p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser,
|
|
||||||
OperationState *result) {
|
|
||||||
OpAsmParser::OperandType sizeInfo;
|
|
||||||
BufferType bufferType;
|
|
||||||
return failure(
|
|
||||||
parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
|
|
||||||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// ForOp.
|
// ForOp.
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -605,6 +535,60 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
|
||||||
// LinalgOps.td), we define an overloaded `print` function and a
|
// LinalgOps.td), we define an overloaded `print` function and a
|
||||||
// parse`className` function.
|
// parse`className` function.
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter *p, BufferAllocOp op) {
|
||||||
|
*p << op.getOperationName() << " ";
|
||||||
|
if (!llvm::empty(op.size()))
|
||||||
|
*p << *op.getOperand(0);
|
||||||
|
p->printOptionalAttrDict(op.getAttrs());
|
||||||
|
*p << " : " << op.getBufferType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseBufferAllocOp(OpAsmParser *parser,
|
||||||
|
OperationState *result) {
|
||||||
|
SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
|
||||||
|
BufferType bufferType;
|
||||||
|
auto indexTy = parser->getBuilder().getIndexType();
|
||||||
|
if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType))
|
||||||
|
return failure();
|
||||||
|
if (sizeInfo.empty())
|
||||||
|
return parser->addTypeToList(bufferType, result->types);
|
||||||
|
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
|
||||||
|
parser->addTypeToList(bufferType, result->types));
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(BufferAllocOp op) {
|
||||||
|
if (!op.getBufferType().hasConstantSize()) {
|
||||||
|
if (llvm::size(op.size()) != 1 ||
|
||||||
|
!op.getOperand(0)->getType().isa<IndexType>())
|
||||||
|
return op.emitOpError(
|
||||||
|
"one operand of type index expected for dynamic buffer");
|
||||||
|
} else { // op.getBufferType().hasConstantSize()
|
||||||
|
if (!llvm::empty(op.size()))
|
||||||
|
return op.emitOpError("unexpected static buffer operand");
|
||||||
|
if (op.getBufferType().getBufferSize().getValue() <= 0)
|
||||||
|
return op.emitOpError("expected nonnegative static buffer size");
|
||||||
|
}
|
||||||
|
if (!VectorType::isValidElementType(op.getElementType()) &&
|
||||||
|
!op.getElementType().isa<VectorType>())
|
||||||
|
return op.emitOpError("unsupported buffer element type");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter *p, BufferDeallocOp op) {
|
||||||
|
*p << op.getOperationName() << " " << *op.buffer();
|
||||||
|
p->printOptionalAttrDict(op.getAttrs());
|
||||||
|
*p << " : " << op.getBufferType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
|
||||||
|
OperationState *result) {
|
||||||
|
OpAsmParser::OperandType bufferInfo;
|
||||||
|
BufferType bufferType;
|
||||||
|
if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType))
|
||||||
|
return failure();
|
||||||
|
return parser->resolveOperands(bufferInfo, bufferType, result->operands);
|
||||||
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter *p, BufferSizeOp op) {
|
static void print(OpAsmPrinter *p, BufferSizeOp op) {
|
||||||
*p << op.getOperationName() << " " << *op.getOperand();
|
*p << op.getOperationName() << " " << *op.getOperand();
|
||||||
p->printOptionalAttrDict(op.getAttrs());
|
p->printOptionalAttrDict(op.getAttrs());
|
||||||
|
|
|
@ -34,8 +34,7 @@ using namespace mlir::linalg;
|
||||||
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
|
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||||
: Dialect(getDialectNamespace(), context) {
|
: Dialect(getDialectNamespace(), context) {
|
||||||
addTypes<BufferType, RangeType, ViewType>();
|
addTypes<BufferType, RangeType, ViewType>();
|
||||||
addOperations<BufferAllocOp, BufferDeallocOp, ForOp, LoadOp, RangeOp, StoreOp,
|
addOperations<ForOp, LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
|
||||||
SliceOp, ViewOp>();
|
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
||||||
|
@ -119,8 +118,8 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
|
||||||
// Check for '?'
|
// Check for '?'
|
||||||
int64_t bufferSize = -1;
|
int64_t bufferSize = -1;
|
||||||
if (!spec.consume_front("?")) {
|
if (!spec.consume_front("?")) {
|
||||||
unsigned parsedBufferSize;
|
unsigned long long parsedBufferSize = 0;
|
||||||
if (!spec.consumeInteger(10, parsedBufferSize)) {
|
if (spec.consumeInteger(10, parsedBufferSize)) {
|
||||||
emitError(loc, "expected buffer size to be an unsigned integer");
|
emitError(loc, "expected buffer size to be an unsigned integer");
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
|
|
@ -168,7 +168,7 @@ public:
|
||||||
auto indexType = IndexType::get(op->getContext());
|
auto indexType = IndexType::get(op->getContext());
|
||||||
auto voidPtrTy =
|
auto voidPtrTy =
|
||||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||||
auto int64Ty = lowering.convertType(operands[0]->getType());
|
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||||
// Insert the `malloc` declaration if it is not already present.
|
// Insert the `malloc` declaration if it is not already present.
|
||||||
auto *module = op->getFunction()->getModule();
|
auto *module = op->getFunction()->getModule();
|
||||||
Function *mallocFunc = module->getNamedFunction("malloc");
|
Function *mallocFunc = module->getNamedFunction("malloc");
|
||||||
|
@ -187,14 +187,19 @@ public:
|
||||||
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
|
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
|
||||||
else
|
else
|
||||||
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||||
auto elementPtrType = getPtrToElementType(
|
auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
|
||||||
allocOp.getResult()->getType().cast<BufferType>(), lowering);
|
auto elementPtrType = getPtrToElementType(bufferType, lowering);
|
||||||
auto bufferDescriptorType =
|
auto bufferDescriptorType =
|
||||||
convertLinalgType(allocOp.getResult()->getType(), lowering);
|
convertLinalgType(allocOp.getResult()->getType(), lowering);
|
||||||
|
|
||||||
// Emit IR for creating a new buffer descriptor with an underlying malloc.
|
// Emit IR for creating a new buffer descriptor with an underlying malloc.
|
||||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||||
Value *size = operands[0];
|
auto constantSize = bufferType.getBufferSize();
|
||||||
|
Value *size =
|
||||||
|
constantSize
|
||||||
|
? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
|
||||||
|
.getValue()
|
||||||
|
: operands[0];
|
||||||
Value *allocSize =
|
Value *allocSize =
|
||||||
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
|
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
|
||||||
Value *allocated =
|
Value *allocated =
|
||||||
|
|
|
@ -13,12 +13,16 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
|
||||||
func @buffer(%arg0: index, %arg1: index) {
|
func @buffer(%arg0: index, %arg1: index) {
|
||||||
%0 = muli %arg0, %arg0 : index
|
%0 = muli %arg0, %arg0 : index
|
||||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
|
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
|
||||||
|
%2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
|
||||||
|
linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
|
||||||
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
|
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
|
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
|
||||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
|
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
|
||||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
|
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
|
||||||
|
// CHECK-NEXT: %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
|
||||||
|
// CHECK-NEXT: linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
|
||||||
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
|
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
|
||||||
|
|
||||||
func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
|
func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
|
||||||
|
|
Loading…
Reference in New Issue