Add a linalg.dim

A linalg.dim operation is used to extract size information from !linalg.view objects passed
    through function call boundaries.

--

PiperOrigin-RevId: 248017488
This commit is contained in:
Nicolas Vasilache 2019-05-13 14:59:55 -07:00 committed by Mehdi Amini
parent cad382406f
commit cf3959f49d
7 changed files with 136 additions and 42 deletions

View File

@ -82,10 +82,8 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
class LinalgOp<string mnemonic, list<OpTrait> props> :
Op<Linalg_Dialect, mnemonic, props> {
let arguments = (ins Variadic<View>); // default variadic builder
let parser = [{ return impl::parseLinalgLibraryOp(parser, result); }];
let printer = [{ impl::printLinalgLibraryOp(p, *this); }];
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
let printer = [{ printLinalgLibraryOp(p, *this); }];
}
def BufferSizeOp :
@ -93,12 +91,39 @@ def BufferSizeOp :
Arguments<(ins Buffer)>,
Results<(outs Index)>
{
let parser = [{
return impl::parseBufferSizeOp(parser, result);
let parser = [{ return parseBufferSizeOp(parser, result); }];
let printer = [{ return printBufferSizeOp(p, *this); }];
}
def DimOp : Op<Linalg_Dialect, "dim", [NoSideEffect]>,
Arguments<(ins View:$view, APIntAttr:$index)>,
Results<(outs Index)> {
let summary = "dimension index operation";
let description = [{
The "linalg.dim" operation takes a linalg.view and returns an
"index". It requires a single integer attribute named "index". It
returns the size of the specified dimension. For example:
%1 = linalg.dim %0, 2 : view<?x?x?xf32>
}];
let printer = [{
return impl::printBufferSizeOp(p, this->getOperation());
let parser = [{ return parseDimOp(parser, result); }];
let printer = [{ return printDimOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let builders = [OpBuilder<
"Builder *builder, OperationState *result, Value *view," "unsigned index",
[{
result->addOperands(view);
result->addAttribute(
"index", builder->getIntegerAttr(builder->getIndexType(), index));
result->types.push_back(builder->getIndexType());
}]>];
let extraClassDeclaration = [{
unsigned getIndex() {
return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
}
}];
}

View File

@ -28,10 +28,12 @@
namespace mlir {
class ModulePassBase;
mlir::ModulePassBase *
createLinalgTilingPass(llvm::ArrayRef<int64_t> tileSizes = {});
namespace linalg {
ModulePassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
mlir::ModulePassBase *createLowerLinalgToLLVMPass();
ModulePassBase *createLowerLinalgToLLVMPass();
} // namespace linalg
} // namespace mlir
#endif // MLIR_LINALG_PASSES_H_

View File

@ -488,30 +488,19 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
*p << "] : " << getType();
}
namespace mlir {
namespace linalg {
namespace impl {
void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op);
ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
void printBufferSizeOp(OpAsmPrinter *p, Operation *op);
ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
} // namespace impl
} // namespace linalg
/// Buffer size prints as:
///
/// ``` {.mlir}
/// %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
/// ```
void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << cast<BufferSizeOp>(op).getOperationName() << " " << *op->getOperand(0);
p->printOptionalAttrDict(op->getAttrs());
*p << " : " << op->getOperand(0)->getType();
static void printBufferSizeOp(OpAsmPrinter *p, BufferSizeOp op) {
*p << op.getOperationName() << " " << *op.getOperand();
p->printOptionalAttrDict(op.getAttrs());
*p << " : " << op.getOperand()->getType();
}
ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
static ParseResult parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType op;
Type type;
return failure(parser->parseOperand(op) ||
@ -522,10 +511,44 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
result->types));
}
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
static void printDimOp(OpAsmPrinter *p, DimOp op) {
*p << op.getOperationName() << " " << *op.getOperand() << ", "
<< op.getIndex();
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
*p << " : " << op.getOperand()->getType();
}
} // namespace mlir
static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr;
Type type;
Type indexType = parser->getBuilder().getIndexType();
return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, indexType, "index",
result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, result->operands) ||
parser->addTypeToList(indexType, result->types));
}
static LogicalResult verify(linalg::DimOp op) {
// Check that we have an integer index operand.
auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
return op.emitOpError("requires an integer attribute named 'index'");
uint64_t index = indexAttr.getValue().getZExtValue();
auto type = op.getOperand()->getType();
if (auto viewType = type.dyn_cast<ViewType>()) {
if (index >= viewType.getRank())
return op.emitOpError("index is out of range");
} else {
return op.emitOpError("requires an operand with view type");
}
return success();
}
// A LinalgLibraryOp prints as:
//
@ -541,7 +564,7 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
// ```
//
// Where %0, %1 and %2 are ssa-values of type ViewType.
void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
*p << op->getName().getStringRef() << "(";
interleave(
@ -553,8 +576,8 @@ void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
[&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
}
ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
OperationState *result) {
static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
@ -565,6 +588,13 @@ ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
result->operands));
}
namespace mlir {
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
} // namespace mlir
// Ideally this should all be Tablegen'd but there is no good story for
// AffineMap for now.
SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {

View File

@ -154,7 +154,7 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder,
return builder.getArrayAttr(attrs);
}
// BufferAllocOp creates a new `index` value.
// BufferAllocOp creates a new `!linalg.buffer` value.
class BufferAllocOpConversion : public LLVMOpLowering {
public:
explicit BufferAllocOpConversion(MLIRContext *context,
@ -213,7 +213,7 @@ public:
}
};
// BufferDeallocOp creates a new `index` value.
// BufferDeallocOp creates no value.
class BufferDeallocOpConversion : public LLVMOpLowering {
public:
explicit BufferDeallocOpConversion(MLIRContext *context,
@ -268,6 +268,23 @@ public:
}
};
// DimOp creates a new `index` value.
class DimOpConversion : public LLVMOpLowering {
public:
explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
: LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
auto dimOp = cast<linalg::DimOp>(op);
auto indexTy = lowering.convertType(rewriter.getIndexType());
edsc::ScopedContext context(rewriter, op->getLoc());
return {extractvalue(
indexTy, operands[0],
makePositionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
}
};
namespace {
// Common functionality for Linalg LoadOp and StoreOp conversion to the
// LLVM IR Dialect.
@ -533,10 +550,11 @@ protected:
llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
return ConversionListBuilder<
BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DotOpConversion, LoadOpConversion,
RangeOpConversion, SliceOpConversion, StoreOpConversion,
ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
*this);
BufferSizeOpConversion, DimOpConversion, DotOpConversion,
LoadOpConversion, RangeOpConversion, SliceOpConversion,
StoreOpConversion, ViewOpConversion>::build(&converterStorage,
llvmDialect->getContext(),
*this);
}
Type convertAdditionalType(Type t) override {
@ -564,7 +582,7 @@ void LowerLinalgToLLVMPass::runOnModule() {
signalPassFailure();
}
ModulePassBase *mlir::createLowerLinalgToLLVMPass() {
ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
return new LowerLinalgToLLVMPass();
}

View File

@ -360,7 +360,8 @@ LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes)
this->tileSizes.assign(sizes.begin(), sizes.end());
}
ModulePassBase *mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
ModulePassBase *
mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
return new LinalgTilingPass(tileSizes);
}

View File

@ -73,3 +73,10 @@ func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg
}
// CHECK-LABEL: func @dot(%arg0: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg1: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg2: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
// CHECK: llvm.call @linalg_dot(%arg0, %arg1, %arg2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
func @dim(%arg0: !linalg.view<?x?xf32>) {
%0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
return
}
// CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
// CHECK: %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">

View File

@ -52,3 +52,14 @@ func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !lina
// CHECK-NEXT: linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
// CHECK-NEXT: linalg.dot(%arg1, %arg2, %arg3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
func @dim(%arg0: !linalg.view<?x?xf32>) {
%0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
return
}
// CHECK-LABEL: func @dim(%arg0: !linalg.view<?x?xf32>) {
// CHECK-NEXT: %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>