forked from OSchip/llvm-project
Add a linalg.dim
A linalg.dim operation is used to extract size information from !linalg.view objects passed through function call boundaries. -- PiperOrigin-RevId: 248017488
This commit is contained in:
parent
cad382406f
commit
cf3959f49d
|
@ -82,10 +82,8 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
|
|||
class LinalgOp<string mnemonic, list<OpTrait> props> :
|
||||
Op<Linalg_Dialect, mnemonic, props> {
|
||||
let arguments = (ins Variadic<View>); // default variadic builder
|
||||
|
||||
let parser = [{ return impl::parseLinalgLibraryOp(parser, result); }];
|
||||
|
||||
let printer = [{ impl::printLinalgLibraryOp(p, *this); }];
|
||||
let parser = [{ return parseLinalgLibraryOp(parser, result); }];
|
||||
let printer = [{ printLinalgLibraryOp(p, *this); }];
|
||||
}
|
||||
|
||||
def BufferSizeOp :
|
||||
|
@ -93,12 +91,39 @@ def BufferSizeOp :
|
|||
Arguments<(ins Buffer)>,
|
||||
Results<(outs Index)>
|
||||
{
|
||||
let parser = [{
|
||||
return impl::parseBufferSizeOp(parser, result);
|
||||
let parser = [{ return parseBufferSizeOp(parser, result); }];
|
||||
let printer = [{ return printBufferSizeOp(p, *this); }];
|
||||
}
|
||||
|
||||
def DimOp : Op<Linalg_Dialect, "dim", [NoSideEffect]>,
|
||||
Arguments<(ins View:$view, APIntAttr:$index)>,
|
||||
Results<(outs Index)> {
|
||||
let summary = "dimension index operation";
|
||||
let description = [{
|
||||
The "linalg.dim" operation takes a linalg.view and returns an
|
||||
"index". It requires a single integer attribute named "index". It
|
||||
returns the size of the specified dimension. For example:
|
||||
|
||||
%1 = linalg.dim %0, 2 : view<?x?x?xf32>
|
||||
}];
|
||||
|
||||
let printer = [{
|
||||
return impl::printBufferSizeOp(p, this->getOperation());
|
||||
let parser = [{ return parseDimOp(parser, result); }];
|
||||
let printer = [{ return printDimOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, Value *view," "unsigned index",
|
||||
[{
|
||||
result->addOperands(view);
|
||||
result->addAttribute(
|
||||
"index", builder->getIntegerAttr(builder->getIndexType(), index));
|
||||
result->types.push_back(builder->getIndexType());
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
unsigned getIndex() {
|
||||
return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -28,10 +28,12 @@
|
|||
namespace mlir {
|
||||
class ModulePassBase;
|
||||
|
||||
mlir::ModulePassBase *
|
||||
createLinalgTilingPass(llvm::ArrayRef<int64_t> tileSizes = {});
|
||||
namespace linalg {
|
||||
ModulePassBase *createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
|
||||
|
||||
mlir::ModulePassBase *createLowerLinalgToLLVMPass();
|
||||
ModulePassBase *createLowerLinalgToLLVMPass();
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_LINALG_PASSES_H_
|
||||
|
|
|
@ -488,30 +488,19 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
|
|||
*p << "] : " << getType();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace impl {
|
||||
void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op);
|
||||
ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
|
||||
void printBufferSizeOp(OpAsmPrinter *p, Operation *op);
|
||||
ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
|
||||
} // namespace impl
|
||||
} // namespace linalg
|
||||
|
||||
/// Buffer size prints as:
|
||||
///
|
||||
/// ``` {.mlir}
|
||||
/// %0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
|
||||
/// ```
|
||||
void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
|
||||
assert(op->getAbstractOperation() && "unregistered operation");
|
||||
*p << cast<BufferSizeOp>(op).getOperationName() << " " << *op->getOperand(0);
|
||||
p->printOptionalAttrDict(op->getAttrs());
|
||||
*p << " : " << op->getOperand(0)->getType();
|
||||
static void printBufferSizeOp(OpAsmPrinter *p, BufferSizeOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getOperand();
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
*p << " : " << op.getOperand()->getType();
|
||||
}
|
||||
|
||||
ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
static ParseResult parseBufferSizeOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType op;
|
||||
Type type;
|
||||
return failure(parser->parseOperand(op) ||
|
||||
|
@ -522,10 +511,44 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
|
|||
result->types));
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
||||
static void printDimOp(OpAsmPrinter *p, DimOp op) {
|
||||
*p << op.getOperationName() << " " << *op.getOperand() << ", "
|
||||
<< op.getIndex();
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
|
||||
*p << " : " << op.getOperand()->getType();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
IntegerAttr indexAttr;
|
||||
Type type;
|
||||
Type indexType = parser->getBuilder().getIndexType();
|
||||
return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
|
||||
parser->parseAttribute(indexAttr, indexType, "index",
|
||||
result->attributes) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type) ||
|
||||
parser->resolveOperand(operandInfo, type, result->operands) ||
|
||||
parser->addTypeToList(indexType, result->types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(linalg::DimOp op) {
|
||||
// Check that we have an integer index operand.
|
||||
auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
|
||||
if (!indexAttr)
|
||||
return op.emitOpError("requires an integer attribute named 'index'");
|
||||
|
||||
uint64_t index = indexAttr.getValue().getZExtValue();
|
||||
auto type = op.getOperand()->getType();
|
||||
if (auto viewType = type.dyn_cast<ViewType>()) {
|
||||
if (index >= viewType.getRank())
|
||||
return op.emitOpError("index is out of range");
|
||||
} else {
|
||||
return op.emitOpError("requires an operand with view type");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// A LinalgLibraryOp prints as:
|
||||
//
|
||||
|
@ -541,7 +564,7 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser,
|
|||
// ```
|
||||
//
|
||||
// Where %0, %1 and %2 are ssa-values of type ViewType.
|
||||
void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
|
||||
static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
|
||||
assert(op->getAbstractOperation() && "unregistered operation");
|
||||
*p << op->getName().getStringRef() << "(";
|
||||
interleave(
|
||||
|
@ -553,8 +576,8 @@ void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
|
|||
[&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
|
||||
}
|
||||
|
||||
ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 3> ops;
|
||||
SmallVector<Type, 3> types;
|
||||
return failure(
|
||||
|
@ -565,6 +588,13 @@ ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser,
|
|||
result->operands));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
// Ideally this should all be Tablegen'd but there is no good story for
|
||||
// AffineMap for now.
|
||||
SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
|
||||
|
|
|
@ -154,7 +154,7 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder,
|
|||
return builder.getArrayAttr(attrs);
|
||||
}
|
||||
|
||||
// BufferAllocOp creates a new `index` value.
|
||||
// BufferAllocOp creates a new `!linalg.buffer` value.
|
||||
class BufferAllocOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit BufferAllocOpConversion(MLIRContext *context,
|
||||
|
@ -213,7 +213,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// BufferDeallocOp creates a new `index` value.
|
||||
// BufferDeallocOp creates no value.
|
||||
class BufferDeallocOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit BufferDeallocOpConversion(MLIRContext *context,
|
||||
|
@ -268,6 +268,23 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// DimOp creates a new `index` value.
|
||||
class DimOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
||||
: LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
FuncBuilder &rewriter) const override {
|
||||
auto dimOp = cast<linalg::DimOp>(op);
|
||||
auto indexTy = lowering.convertType(rewriter.getIndexType());
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
return {extractvalue(
|
||||
indexTy, operands[0],
|
||||
makePositionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))};
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
// Common functionality for Linalg LoadOp and StoreOp conversion to the
|
||||
// LLVM IR Dialect.
|
||||
|
@ -533,10 +550,11 @@ protected:
|
|||
llvm::DenseSet<DialectOpConversion *> initAdditionalConverters() override {
|
||||
return ConversionListBuilder<
|
||||
BufferAllocOpConversion, BufferDeallocOpConversion,
|
||||
BufferSizeOpConversion, DotOpConversion, LoadOpConversion,
|
||||
RangeOpConversion, SliceOpConversion, StoreOpConversion,
|
||||
ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(),
|
||||
*this);
|
||||
BufferSizeOpConversion, DimOpConversion, DotOpConversion,
|
||||
LoadOpConversion, RangeOpConversion, SliceOpConversion,
|
||||
StoreOpConversion, ViewOpConversion>::build(&converterStorage,
|
||||
llvmDialect->getContext(),
|
||||
*this);
|
||||
}
|
||||
|
||||
Type convertAdditionalType(Type t) override {
|
||||
|
@ -564,7 +582,7 @@ void LowerLinalgToLLVMPass::runOnModule() {
|
|||
signalPassFailure();
|
||||
}
|
||||
|
||||
ModulePassBase *mlir::createLowerLinalgToLLVMPass() {
|
||||
ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
|
||||
return new LowerLinalgToLLVMPass();
|
||||
}
|
||||
|
||||
|
|
|
@ -360,7 +360,8 @@ LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes)
|
|||
this->tileSizes.assign(sizes.begin(), sizes.end());
|
||||
}
|
||||
|
||||
ModulePassBase *mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
|
||||
ModulePassBase *
|
||||
mlir::linalg::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
|
||||
return new LinalgTilingPass(tileSizes);
|
||||
}
|
||||
|
||||
|
|
|
@ -73,3 +73,10 @@ func @dot(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg
|
|||
}
|
||||
// CHECK-LABEL: func @dot(%arg0: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg1: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg2: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) {
|
||||
// CHECK: llvm.call @linalg_dot(%arg0, %arg1, %arg2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> ()
|
||||
|
||||
func @dim(%arg0: !linalg.view<?x?xf32>) {
|
||||
%0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) {
|
||||
// CHECK: %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
|
|
@ -52,3 +52,14 @@ func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !lina
|
|||
// CHECK-NEXT: linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: linalg.dot(%arg1, %arg2, %arg3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
|
||||
|
||||
func @dim(%arg0: !linalg.view<?x?xf32>) {
|
||||
%0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dim(%arg0: !linalg.view<?x?xf32>) {
|
||||
// CHECK-NEXT: %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
||||
|
||||
|
|
Loading…
Reference in New Issue