forked from OSchip/llvm-project
Update Linalg to use std.view
Now that a view op has graduated to the std dialect, we can update Linalg to use it and remove ops that have become obsolete. As a byproduct, the linalg buffer and associated ops can also disappear. PiperOrigin-RevId: 279073591
This commit is contained in:
parent
eee9cbdeb7
commit
72040bf7c8
|
@ -44,108 +44,6 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def BufferAllocOp :
|
||||
Linalg_Op<"buffer_alloc">,
|
||||
Arguments<(ins Variadic<Index>:$size, OptionalAttr<I64Attr>:$alignment)>,
|
||||
Results<(outs Buffer)> {
|
||||
let summary = "buffer allocation operation";
|
||||
let description = [{
|
||||
The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
|
||||
upon which a base view can be laid out to give it indexing semantics.
|
||||
"buffer_alloc" takes a single argument, the size of the buffer to allocate
|
||||
(in number of elements).
|
||||
An optional alignment attribute may be specified in which case the actual
|
||||
underlying allocation size may be increased. The base pointer is guaranteed
|
||||
to be a multiple of `alignment`. Such an alignment must be a positive power
|
||||
of 2.
|
||||
|
||||
Examples:
|
||||
|
||||
%0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
|
||||
|
||||
%1 = linalg.buffer_alloc(%arg0) { alignment = 16 } :
|
||||
!linalg.buffer<?xf32>
|
||||
|
||||
The size argument may be omitted if it is statically known, in which case it
|
||||
must be reflected in the type.
|
||||
|
||||
Example:
|
||||
|
||||
%0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
"Builder *b, OperationState &result, BufferType bufferType", [{
|
||||
result.addTypes(bufferType);
|
||||
}]>,
|
||||
OpBuilder<
|
||||
"Builder *b, OperationState &result, BufferType bufferType, "
|
||||
"unsigned alignment", [{
|
||||
build(b, result, bufferType);
|
||||
if (alignment != 0)
|
||||
result.addAttribute(BufferAllocOp::getAlignmentAttrName(),
|
||||
b->getI64IntegerAttr(alignment));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
"Builder *b, OperationState &result, BufferType bufferType, "
|
||||
"Value *size, unsigned alignment", [{
|
||||
if (alignment == 0)
|
||||
return build(b, result, bufferType, size);
|
||||
build(b, result, bufferType, size, b->getI64IntegerAttr(alignment));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
"Builder *b, OperationState &result, BufferType bufferType, Value *size",
|
||||
[{
|
||||
result.addOperands(size);
|
||||
result.addTypes(bufferType);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getAlignmentAttrName() { return "alignment"; }
|
||||
BufferType getBufferType() { return getType().cast<BufferType>(); }
|
||||
Type getElementType() { return getBufferType().getElementType(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def BufferDeallocOp :
|
||||
Linalg_Op<"buffer_dealloc">,
|
||||
Arguments<(ins Buffer:$buffer)>,
|
||||
Results<(outs)> {
|
||||
let summary = "buffer allocation operation";
|
||||
let description = [{
|
||||
The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
|
||||
|
||||
Example:
|
||||
|
||||
linalg.buffer_dealloc %0 : !linalg.buffer<f32>
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
BufferType getBufferType() {
|
||||
return buffer()->getType().cast<BufferType>();
|
||||
}
|
||||
}];
|
||||
|
||||
// Fully specified by traits.
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def BufferSizeOp :
|
||||
Linalg_Op<"buffer_size", [NoSideEffect]>,
|
||||
Arguments<(ins Buffer:$buffer)>,
|
||||
Results<(outs Index)> {
|
||||
let summary = "buffer size operation";
|
||||
let description = [{
|
||||
The "linalg.buffer_size" operation takes a linalg.buffer and returns an
|
||||
"index".
|
||||
|
||||
Example:
|
||||
|
||||
%0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
|
||||
}];
|
||||
// Fully specified by traits.
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def RangeOp :
|
||||
Linalg_Op<"range", [NoSideEffect]>,
|
||||
Arguments<(ins Index:$min, Index:$max, Index:$step)>,
|
||||
|
@ -329,51 +227,6 @@ def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
|
||||
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
|
||||
Results<(outs AnyStridedMemRef)> {
|
||||
let summary = "view operation";
|
||||
let description = [{
|
||||
The "linalg.view" op produces a strided memref which is a multi-dimensional
|
||||
range abstraction on top of an underlying linalg.buffer. This gives an
|
||||
indexing structure to an otherwise non-indexable linalg.buffer.
|
||||
|
||||
A "linalg.view" takes a buffer and a variadic number of ranges and produces
|
||||
a `view` of rank the number of ranges. The elemental type may not match the
|
||||
buffer element type:
|
||||
|
||||
Example:
|
||||
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
%3 = linalg.view %1[%2, %2] :
|
||||
memref<?x?xvector<4xf32>, stride_specification>
|
||||
}];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result, Value *buffer, "
|
||||
"ArrayRef<Value *> ranges, Type resultType = Type(), "
|
||||
"ArrayRef<NamedAttribute> attrs = {}">];
|
||||
|
||||
let verifier = [{
|
||||
if (getViewType().getRank() != llvm::size(ranges()))
|
||||
return emitOpError("the view rank must be the number of its ranges");
|
||||
return success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
unsigned getRank() { return getViewType().getRank(); }
|
||||
Type getElementType() { return getViewType().getElementType(); }
|
||||
MemRefType getViewType() { return getType().cast<MemRefType>(); }
|
||||
/// Get the underlying indexing at a given rank.
|
||||
Value *getRange(unsigned rank) {
|
||||
assert(rank < getRank() && "rank overflow");
|
||||
return *(ranges().begin() + rank);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
|
||||
Arguments<(ins Variadic<AnyType>:$values)> {
|
||||
let summary = "Linalg yield operation";
|
||||
|
|
|
@ -36,7 +36,8 @@ std::unique_ptr<OpPassBase<FuncOp>> createLinalgFusionPass();
|
|||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLinalgPromotionPass();
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
createLinalgPromotionPass(bool dynamicBuffers);
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLowerLinalgToLoopsPass();
|
||||
|
||||
|
|
|
@ -22,22 +22,15 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
class BufferAllocOp;
|
||||
class BufferDeallocOp;
|
||||
class CopyOp;
|
||||
class FillOp;
|
||||
class RangeOp;
|
||||
class SliceOp;
|
||||
class ViewOp;
|
||||
namespace intrinsics {
|
||||
using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder<BufferAllocOp>;
|
||||
using buffer_dealloc =
|
||||
mlir::edsc::intrinsics::OperationBuilder<BufferDeallocOp>;
|
||||
using copy = mlir::edsc::intrinsics::OperationBuilder<CopyOp>;
|
||||
using fill = mlir::edsc::intrinsics::OperationBuilder<FillOp>;
|
||||
using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
|
||||
using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
|
||||
using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
|
||||
} // namespace intrinsics
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -173,9 +173,10 @@ struct PromotionInfo {
|
|||
///
|
||||
/// Returns a list of PromotionInfo which hold the promoted buffer and the
|
||||
/// full and partial views indexing into the buffer.
|
||||
llvm::SmallVector<PromotionInfo, 8> promoteSubViews(OpBuilder &b, Location loc,
|
||||
ArrayRef<Value *> subViews,
|
||||
OperationFolder *folder);
|
||||
llvm::SmallVector<PromotionInfo, 8>
|
||||
promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value *> subViews,
|
||||
bool promoteSubViews = false,
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
/// Returns all the operands of `linalgOp` that are not views.
|
||||
/// Asserts that these operands are value types to allow transformations like
|
||||
|
|
|
@ -216,6 +216,7 @@ using std_load = ValueBuilder<LoadOp>;
|
|||
using std_store = OperationBuilder<StoreOp>;
|
||||
using subi = ValueBuilder<SubIOp>;
|
||||
using vector_type_cast = ValueBuilder<vector::VectorTypeCastOp>;
|
||||
using view = ValueBuilder<ViewOp>;
|
||||
|
||||
/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
|
||||
///
|
||||
|
|
|
@ -1430,10 +1430,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
|||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
|
||||
if (failed(successStrides))
|
||||
return op->emitWarning("Cannot cast to non-strided shape"),
|
||||
matchFailure();
|
||||
if (strides.back() != 1)
|
||||
return op->emitWarning("Cannot cast to non-contiguous shape"),
|
||||
return op->emitWarning("cannot cast to non-strided shape"),
|
||||
matchFailure();
|
||||
|
||||
// Create the descriptor.
|
||||
|
@ -1466,7 +1463,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
|||
rewriter.getI64ArrayAttr(
|
||||
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
|
||||
|
||||
// Early exit for 0-D corner case.
|
||||
if (viewMemRefType.getRank() == 0)
|
||||
return rewriter.replaceOp(op, desc), matchSuccess();
|
||||
|
||||
// Update sizes and strides.
|
||||
if (strides.back() != 1)
|
||||
return op->emitWarning("cannot cast to non-contiguous shape"),
|
||||
matchFailure();
|
||||
Value *stride = nullptr, *nextSize = nullptr;
|
||||
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
||||
// Update size.
|
||||
|
|
|
@ -74,7 +74,7 @@ Value *Aliases::find(Value *v) {
|
|||
return it.first->second;
|
||||
}
|
||||
if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
|
||||
auto it = aliases.insert(std::make_pair(v, view.buffer()));
|
||||
auto it = aliases.insert(std::make_pair(v, view.source()));
|
||||
return it.first->second;
|
||||
}
|
||||
if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
|
||||
|
|
|
@ -50,102 +50,6 @@ using namespace mlir::linalg;
|
|||
// LinalgOps.td), we define an overloaded `print` function and a
|
||||
// parse`className` function.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAllocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, BufferAllocOp op) {
|
||||
p << op.getOperationName() << " ";
|
||||
if (!llvm::empty(op.size()))
|
||||
p << *op.getOperand(0);
|
||||
if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
else
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
BufferAllocOp::getAlignmentAttrName());
|
||||
p << " : " << op.getBufferType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferAllocOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
|
||||
BufferType bufferType;
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
if (parser.parseOperandList(sizeInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(bufferType))
|
||||
return failure();
|
||||
if (sizeInfo.empty())
|
||||
return parser.addTypeToList(bufferType, result.types);
|
||||
return failure(parser.resolveOperands(sizeInfo, indexTy, result.operands) ||
|
||||
parser.addTypeToList(bufferType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(BufferAllocOp op) {
|
||||
if (!op.getBufferType().hasConstantSize()) {
|
||||
if (llvm::size(op.size()) != 1)
|
||||
return op.emitOpError("expected one index operand");
|
||||
} else { // op.getBufferType().hasConstantSize()
|
||||
if (!llvm::empty(op.size()))
|
||||
return op.emitOpError("expected zero operand");
|
||||
if (op.getBufferType().getBufferSize().getValue() <= 0)
|
||||
return op.emitOpError("expected nonnegative static buffer size");
|
||||
}
|
||||
if (op.alignment().hasValue()) {
|
||||
auto align = op.alignment().getValue();
|
||||
if (align.getSExtValue() < 0)
|
||||
return op.emitOpError("expected positive alignment");
|
||||
if (!llvm::isPowerOf2_64(align.getZExtValue()))
|
||||
return op.emitOpError("expected power of 2 alignment");
|
||||
}
|
||||
if (!TensorType::isValidElementType(op.getElementType()))
|
||||
return op.emitOpError("expected valid buffer element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferDeallocOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, BufferDeallocOp op) {
|
||||
p << op.getOperationName() << " " << *op.buffer();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getBufferType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferDeallocOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
BufferType bufferType;
|
||||
if (parser.parseOperand(bufferInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(bufferType))
|
||||
return failure();
|
||||
return parser.resolveOperands(bufferInfo, bufferType, result.operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, BufferSizeOp op) {
|
||||
p << op.getOperationName() << " " << *op.buffer();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.buffer()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseBufferSizeOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType op;
|
||||
Type type;
|
||||
return failure(
|
||||
parser.parseOperand(op) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(op, type, result.operands) ||
|
||||
parser.addTypeToList(parser.getBuilder().getIndexType(), result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -426,7 +330,7 @@ static LogicalResult verify(SliceOp op) {
|
|||
unsigned rank = op.getBaseViewRank();
|
||||
if (rank != llvm::size(op.indexings()))
|
||||
return op.emitOpError("expected ")
|
||||
<< op.getRank() << " indexings, got " << llvm::size(op.indexings());
|
||||
<< rank << " indexings, got " << llvm::size(op.indexings());
|
||||
unsigned index = 0;
|
||||
for (auto indexing : op.indexings()) {
|
||||
if (indexing->getType().isa<IndexType>())
|
||||
|
@ -562,59 +466,6 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
|
|||
parser.addTypeToList(type, result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void mlir::linalg::ViewOp::build(Builder *b, OperationState &result,
|
||||
Value *buffer, ArrayRef<Value *> ranges,
|
||||
Type resultType,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
// If the result type is not specified, assume sizes are fully dynamic.
|
||||
// Strides are set to match an empty layout map which means "contiguous view".
|
||||
if (!resultType) {
|
||||
auto rank = ranges.size();
|
||||
SmallVector<int64_t, 4> sizes(rank, -1);
|
||||
Type elementType = buffer->getType().cast<BufferType>().getElementType();
|
||||
resultType = MemRefType::get(sizes, elementType, {}, 0);
|
||||
}
|
||||
build(b, result, resultType, buffer, ranges);
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, mlir::linalg::ViewOp op) {
|
||||
p << op.getOperationName() << " " << *op.buffer() << "[";
|
||||
interleaveComma(op.ranges(), p, [&](Value *v) { p << *v; });
|
||||
p << "] ";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.buffer()->getType() << " -> " << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
|
||||
Type bType, vType;
|
||||
if (parser.parseOperand(bufferInfo) ||
|
||||
parser.parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseType(bType) || parser.parseArrow() ||
|
||||
parser.parseType(vType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
MemRefType memRefType = vType.dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return parser.emitError(parser.getNameLoc(), "expected memref type");
|
||||
if (static_cast<unsigned>(memRefType.getRank()) != rangesInfo.size())
|
||||
return parser.emitError(parser.getNameLoc(), "expected ")
|
||||
<< memRefType.getRank() << " ranges";
|
||||
return failure(
|
||||
parser.resolveOperand(bufferInfo, bType, result.operands) ||
|
||||
(!rangesInfo.empty() &&
|
||||
parser.resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
|
||||
result.operands)) ||
|
||||
parser.addTypeToList(memRefType, result.types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -166,144 +166,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// BufferAllocOp creates a new `!linalg.buffer` value.
|
||||
class BufferAllocOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit BufferAllocOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &lowering_)
|
||||
: LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto indexType = IndexType::get(op->getContext());
|
||||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
|
||||
.cast<LLVM::LLVMType>();
|
||||
// Insert the `malloc` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto mallocFunc = module.lookupSymbol<LLVMFuncOp>("malloc");
|
||||
if (!mallocFunc) {
|
||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
mallocFunc = moduleBuilder.create<LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "malloc",
|
||||
LLVM::LLVMType::getFunctionTy(voidPtrTy, int64Ty,
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
|
||||
// Get MLIR types for injecting element pointer.
|
||||
auto allocOp = cast<BufferAllocOp>(op);
|
||||
auto elementType = allocOp.getElementType();
|
||||
uint64_t elementSize = 0;
|
||||
if (auto vectorType = elementType.dyn_cast<VectorType>())
|
||||
elementSize = vectorType.getNumElements() *
|
||||
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
|
||||
else
|
||||
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||
auto bufferType = allocOp.getBufferType();
|
||||
auto elementPtrType = getPtrToElementType(bufferType, lowering);
|
||||
auto bufferDescriptorTy = convertLinalgType(bufferType, lowering);
|
||||
|
||||
// Emit IR for creating a new buffer descriptor with an underlying malloc.
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
auto constantSize = bufferType.getBufferSize();
|
||||
Value *size =
|
||||
constantSize
|
||||
? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
|
||||
.getValue()
|
||||
: operands[0];
|
||||
Value *allocSize =
|
||||
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
|
||||
Value *one = nullptr, *align = nullptr;
|
||||
if (allocOp.alignment().hasValue()) {
|
||||
one = constant(int64Ty, IntegerAttr::get(indexType, 1));
|
||||
align =
|
||||
constant(int64Ty, rewriter.getIntegerAttr(
|
||||
rewriter.getIndexType(),
|
||||
allocOp.alignment().getValue().getSExtValue()));
|
||||
allocSize = sub(add(allocSize, align), one);
|
||||
}
|
||||
|
||||
Value *allocated =
|
||||
llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
|
||||
.getOperation()
|
||||
->getResult(0);
|
||||
Value *data = allocated;
|
||||
if (allocOp.alignment().hasValue()) {
|
||||
// offset = (align - (ptr % align))% align
|
||||
Value *offset =
|
||||
urem(sub(align, urem(ptrtoint(int64Ty, allocated), align)), align);
|
||||
data = gep(voidPtrTy, allocated, offset);
|
||||
}
|
||||
data = bitcast(elementPtrType, data);
|
||||
Value *desc = llvm_undef(bufferDescriptorTy);
|
||||
desc = insertvalue(bufferDescriptorTy, desc, allocated,
|
||||
rewriter.getI64ArrayAttr(kBasePtrPosInBuffer));
|
||||
desc = insertvalue(bufferDescriptorTy, desc, data,
|
||||
rewriter.getI64ArrayAttr(kPtrPosInBuffer));
|
||||
desc = insertvalue(bufferDescriptorTy, desc, size,
|
||||
rewriter.getI64ArrayAttr(kSizePosInBuffer));
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// BufferDeallocOp creates no value.
|
||||
class BufferDeallocOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit BufferDeallocOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &lowering_)
|
||||
: LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto voidTy = LLVM::LLVMType::getVoidTy(lowering.getDialect());
|
||||
auto voidPtrTy =
|
||||
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto freeFunc = module.lookupSymbol<LLVMFuncOp>("free");
|
||||
if (!freeFunc) {
|
||||
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
|
||||
freeFunc = moduleBuilder.create<LLVMFuncOp>(
|
||||
rewriter.getUnknownLoc(), "free",
|
||||
LLVM::LLVMType::getFunctionTy(voidTy, voidPtrTy,
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
|
||||
// Emit MLIR for buffer_dealloc.
|
||||
BufferDeallocOpOperandAdaptor adaptor(operands);
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
Value *base = extractvalue(voidPtrTy, adaptor.buffer(),
|
||||
rewriter.getI64ArrayAttr(kBasePtrPosInBuffer));
|
||||
llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), base);
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// BufferSizeOp creates a new `index` value.
|
||||
class BufferSizeOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
BufferSizeOpOperandAdaptor adaptor(operands);
|
||||
rewriter.replaceOp(
|
||||
op, {extractvalue(int64Ty, adaptor.buffer(),
|
||||
rewriter.getI64ArrayAttr(kSizePosInBuffer))});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// RangeOp creates a new range descriptor.
|
||||
class RangeOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
|
@ -480,78 +342,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that transforms a linalg.view op into:
|
||||
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
||||
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
||||
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
||||
/// and stride.
|
||||
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
||||
/// The linalg.view op is replaced by the alloca'ed pointer.
|
||||
class ViewOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: LLVMOpLowering(mlir::linalg::ViewOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
mlir::linalg::ViewOpOperandAdaptor adaptor(operands);
|
||||
|
||||
auto viewOp = cast<mlir::linalg::ViewOp>(op);
|
||||
BaseViewConversionHelper helper(op->getLoc(), viewOp.getViewType(),
|
||||
rewriter, lowering);
|
||||
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
|
||||
Value *desc = helper.desc;
|
||||
|
||||
Value *bufferDescriptor = adaptor.buffer();
|
||||
auto bufferTy = getPtrToElementType(
|
||||
viewOp.buffer()->getType().cast<BufferType>(), lowering);
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
||||
// Copy the buffer pointer from the old descriptor to the new one.
|
||||
Value *bufferAsViewElementType =
|
||||
bitcast(elementTy, extractvalue(bufferTy, bufferDescriptor,
|
||||
helper.pos(kPtrPosInBuffer)));
|
||||
desc =
|
||||
insertvalue(desc, bufferAsViewElementType, helper.pos(kPtrPosInView));
|
||||
|
||||
// Zero base offset.
|
||||
auto indexTy = rewriter.getIndexType();
|
||||
Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
|
||||
desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView));
|
||||
|
||||
// Corner case, no sizes or stride: early return the descriptor.
|
||||
if (helper.zeroDMemRef) {
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// Compute and insert view sizes (max - min along the range).
|
||||
int numRanges = llvm::size(viewOp.ranges());
|
||||
Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
|
||||
for (int i = numRanges - 1; i >= 0; --i) {
|
||||
// Update stride.
|
||||
Value *rangeDescriptor = operands[1 + i];
|
||||
Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
|
||||
Value *stride = mul(runningStride, step);
|
||||
desc = insertvalue(desc, stride, helper.pos({kStridePosInView, i}));
|
||||
// Update size.
|
||||
Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
|
||||
Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
|
||||
Value *size = sub(max, min);
|
||||
desc = insertvalue(desc, size, helper.pos({kSizePosInView, i}));
|
||||
// Update stride for the next dimension.
|
||||
if (i > 0)
|
||||
runningStride = mul(runningStride, max);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// YieldOp produces and LLVM::ReturnOp.
|
||||
class YieldOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
|
@ -731,10 +521,8 @@ static void
|
|||
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx) {
|
||||
patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
|
||||
BufferSizeOpConversion, RangeOpConversion, SliceOpConversion,
|
||||
TransposeOpConversion, ViewOpConversion, YieldOpConversion>(
|
||||
ctx, converter);
|
||||
patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion,
|
||||
YieldOpConversion>(ctx, converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -49,17 +49,26 @@ using llvm::SetVector;
|
|||
|
||||
#define DEBUG_TYPE "linalg-promotion"
|
||||
|
||||
static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
|
||||
static llvm::cl::opt<bool> clPromoteDynamic(
|
||||
"test-linalg-promote-dynamic",
|
||||
llvm::cl::desc("Test generation of dynamic promoted buffers"),
|
||||
llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
|
||||
|
||||
static AffineMap getAffineDifferenceMap(MLIRContext *context) {
|
||||
AffineExpr d0(getAffineDimExpr(0, context)), d1(getAffineDimExpr(1, context));
|
||||
return AffineMap::get(2, 0, {d0 - d1});
|
||||
}
|
||||
|
||||
// TODO(ntv): replace this with 1-D memref alloc once there is an std.view op.
|
||||
static Value *allocBuffer(Type elementType, Value *size) {
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
|
||||
return buffer_alloc(
|
||||
BufferType::get(size->getContext(), elementType, cst.getValue()));
|
||||
return buffer_alloc(BufferType::get(size->getContext(), elementType), size);
|
||||
static Value *allocBuffer(Type elementType, Value *size, bool dynamicBuffers) {
|
||||
auto *ctx = size->getContext();
|
||||
auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||
if (!dynamicBuffers)
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
|
||||
return alloc(
|
||||
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
|
||||
Value *mul = muli(constant_index(width), size);
|
||||
return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
|
||||
}
|
||||
|
||||
// Performs promotion of a `subView` into a local buffer of the size of the
|
||||
|
@ -81,6 +90,7 @@ static Value *allocBuffer(Type elementType, Value *size) {
|
|||
// by a partial `copy` op.
|
||||
static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
|
||||
SubViewOp subView,
|
||||
bool dynamicBuffers,
|
||||
OperationFolder *folder) {
|
||||
auto zero = constant_index(folder, 0);
|
||||
auto one = constant_index(folder, 1);
|
||||
|
@ -101,18 +111,21 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
|
|||
{rangeValue.max, rangeValue.min}, folder)
|
||||
.front();
|
||||
allocSize = muli(folder, allocSize, d).getValue();
|
||||
fullRanges.push_back(range(folder, zero, d, one));
|
||||
fullRanges.push_back(d);
|
||||
partialRanges.push_back(range(folder, zero, dim(subView, rank), one));
|
||||
}
|
||||
auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
|
||||
auto fullLocalView = view(buffer, fullRanges);
|
||||
SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
|
||||
auto *buffer =
|
||||
allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
|
||||
auto fullLocalView = view(
|
||||
MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
|
||||
auto partialLocalView = slice(fullLocalView, partialRanges);
|
||||
return PromotionInfo{buffer, fullLocalView, partialLocalView};
|
||||
}
|
||||
|
||||
SmallVector<PromotionInfo, 8>
|
||||
mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
|
||||
ArrayRef<Value *> subViews,
|
||||
ArrayRef<Value *> subViews, bool dynamicBuffers,
|
||||
OperationFolder *folder) {
|
||||
if (subViews.empty())
|
||||
return {};
|
||||
|
@ -127,7 +140,8 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
|
|||
// TODO(ntv): support more cases than just float.
|
||||
if (!viewType.getElementType().isa<FloatType>())
|
||||
continue;
|
||||
auto promotionInfo = promoteFullTileBuffer(b, loc, subView, folder);
|
||||
auto promotionInfo =
|
||||
promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder);
|
||||
promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
|
||||
res.push_back(promotionInfo);
|
||||
}
|
||||
|
@ -157,12 +171,13 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
|
||||
bool dynamicBuffers,
|
||||
OperationFolder *folder) {
|
||||
// 1. Promote the specified views and use them in the new op.
|
||||
OpBuilder b(op);
|
||||
ScopedContext scope(b, op.getLoc());
|
||||
auto promotedBufferAndViews =
|
||||
promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), folder);
|
||||
auto promotedBufferAndViews = promoteSubViews(
|
||||
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
|
||||
SmallVector<Value *, 8> opViews;
|
||||
opViews.reserve(op.getNumInputsAndOutputs());
|
||||
SmallVector<std::pair<Value *, Value *>, 8> writebackViews;
|
||||
|
@ -197,13 +212,13 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
|
|||
|
||||
// 4. Dealloc local buffers.
|
||||
for (const auto &pi : promotedBufferAndViews)
|
||||
buffer_dealloc(pi.buffer);
|
||||
dealloc(pi.buffer);
|
||||
}
|
||||
|
||||
static void promoteSubViews(FuncOp f) {
|
||||
static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
|
||||
SmallVector<LinalgOp, 8> toErase;
|
||||
OperationFolder folder(f.getContext());
|
||||
f.walk([&folder, &toErase](LinalgOp op) {
|
||||
f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
|
||||
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
|
||||
// nothing.
|
||||
SetVector<Value *> subViews;
|
||||
|
@ -211,7 +226,7 @@ static void promoteSubViews(FuncOp f) {
|
|||
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
|
||||
subViews.insert(sv);
|
||||
if (!subViews.empty()) {
|
||||
promoteSubViewOperands(op, subViews, &folder);
|
||||
promoteSubViewOperands(op, subViews, dynamicBuffers, &folder);
|
||||
toErase.push_back(op);
|
||||
}
|
||||
});
|
||||
|
@ -221,13 +236,23 @@ static void promoteSubViews(FuncOp f) {
|
|||
|
||||
namespace {
|
||||
struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> {
|
||||
void runOnFunction() override { promoteSubViews(getFunction()); }
|
||||
LinalgPromotionPass() = default;
|
||||
LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {}
|
||||
|
||||
void runOnFunction() override {
|
||||
promoteSubViews(getFunction(), dynamicBuffers);
|
||||
}
|
||||
|
||||
bool dynamicBuffers;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgPromotionPass() {
|
||||
return std::make_unique<LinalgPromotionPass>();
|
||||
std::unique_ptr<OpPassBase<FuncOp>>
|
||||
mlir::linalg::createLinalgPromotionPass(bool dynamicBuffers) {
|
||||
return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
|
||||
}
|
||||
|
||||
static PassRegistration<LinalgPromotionPass>
|
||||
pass("linalg-promote-subviews", "promote subview ops to local buffers");
|
||||
pass("linalg-promote-subviews", "promote subview ops to local buffers", [] {
|
||||
return std::make_unique<LinalgPromotionPass>(clPromoteDynamic);
|
||||
});
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,49 +1,5 @@
|
|||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
// -----
|
||||
|
||||
func @buffer_alloc_single_index() {
|
||||
// expected-error @+1 {{expected one index operand}}
|
||||
%0 = linalg.buffer_alloc : !linalg.buffer<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @buffer_alloc_unexpected_index(%s : index) {
|
||||
// expected-error @+1 {{expected zero operand}}
|
||||
%0 = linalg.buffer_alloc %s : !linalg.buffer<32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @buffer_alloc_nonegative_size() {
|
||||
// expected-error @+1 {{expected nonnegative static buffer size}}
|
||||
%0 = linalg.buffer_alloc : !linalg.buffer<0xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @buffer_alloc_nonegative_alignment(%arg0: index) {
|
||||
// expected-error @+1 {{expected positive alignment}}
|
||||
%0 = linalg.buffer_alloc %arg0 {alignment = -123}: !linalg.buffer<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @buffer_alloc_powerof2_alignment(%arg0: index) {
|
||||
// expected-error @+1 {{expected power of 2 alignment}}
|
||||
%0 = linalg.buffer_alloc %arg0 {alignment = 123}: !linalg.buffer<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @buffer_valid_element_type() {
|
||||
// expected-error @+1 {{expected valid buffer element type}}
|
||||
%0 = linalg.buffer_alloc : !linalg.buffer<4xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @load_number_of_indices(%v : memref<f32>) {
|
||||
// expected-error @+2 {{incorrect number of indices for load}}
|
||||
%c0 = constant 0 : index
|
||||
|
@ -99,22 +55,6 @@ func @transpose_bad_rank(%v : memref<?x?xf32, (i, j)[off, M]->(off + M * i + j)>
|
|||
|
||||
// -----
|
||||
|
||||
func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
|
||||
// expected-error @+2 {{expected memref type}}
|
||||
%r = linalg.range %min:%max:%step : !linalg.range
|
||||
%0 = linalg.view %buf[%r]: !linalg.buffer<?xf32> -> index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @view_num_ranges(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
|
||||
// expected-error @+2 {{expected 2 ranges}}
|
||||
%r = linalg.range %min:%max:%step : !linalg.range
|
||||
%0 = linalg.view %buf[%r]: !linalg.buffer<?xf32> -> memref<?x?xf32, (i, j)[off, M]->(off + M * i + j)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @yield_parent(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
|
||||
// expected-error @+1 {{op expected 'linalg.generic' or 'linalg.indexed_generic' parent op}}
|
||||
linalg.yield %arg0: memref<?xf32, (i)[off]->(off + i)>
|
||||
|
@ -396,15 +336,5 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected single element in size list}}
|
||||
!invalid_type = type !linalg.buffer<1x1xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected '>'}}
|
||||
!invalid_type = type !linalg<"buffer<1xf32">
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected valid keyword}}
|
||||
!invalid_type = type !linalg<"?">
|
||||
|
|
|
@ -1,35 +1,6 @@
|
|||
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
|
||||
// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS
|
||||
|
||||
func @buffer_size(%arg0: !linalg.buffer<?xf32>) {
|
||||
%c1 = constant 1 : index
|
||||
%s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
|
||||
%t = addi %s, %c1 : index
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @buffer_size
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i8*, float*, i64 }">
|
||||
// CHECK-NEXT: llvm.add {{.*}}, {{.*}} : !llvm.i64
|
||||
|
||||
func @buffer_alloc_aligned(%arg0: index) {
|
||||
%s = linalg.buffer_alloc %arg0 {alignment=16} : !linalg.buffer<?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @buffer_alloc_aligned
|
||||
// CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK: %[[m:.*]] = llvm.mul %arg0, %[[c4]] : !llvm.i64
|
||||
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: %[[c16:.*]] = llvm.mlir.constant(16 : index) : !llvm.i64
|
||||
// CHECK: %[[a:.*]] = llvm.add %[[m]], %[[c16]] : !llvm.i64
|
||||
// CHECK: %[[s:.*]] = llvm.sub %[[a]], %[[c1]] : !llvm.i64
|
||||
// CHECK: %[[alloc:.*]] = llvm.call @malloc(%[[s]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||
// aligning `ptr` on `align` is done computing the address `ptr + (align - ptr % align) % align`.
|
||||
// CHECK: %[[cast:.*]] = llvm.ptrtoint %[[alloc]] : !llvm<"i8*"> to !llvm.i64
|
||||
// CHECK: %[[rem:.*]] = llvm.urem %[[cast]], %[[c16]] : !llvm.i64
|
||||
// CHECK: %[[drem:.*]] = llvm.sub %[[c16]], %[[rem]] : !llvm.i64
|
||||
// CHECK: %[[off:.*]] = llvm.urem %[[drem]], %[[c16]] : !llvm.i64
|
||||
// CHECK: llvm.getelementptr %{{.*}}[%[[off]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
|
||||
|
||||
func @range(%arg0: index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
|
@ -44,48 +15,11 @@ func @range(%arg0: index) {
|
|||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
|
||||
func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @view
|
||||
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }">
|
||||
// CHECK-NEXT: llvm.bitcast {{.*}} : !llvm<"float*"> to !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.return
|
||||
|
||||
func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.buffer<?xf32> -> memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @view3d
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
|
||||
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
|
||||
|
||||
func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
|
||||
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
|
||||
%1 = linalg.slice %0[%arg1] : memref<?xf32, offset: ?, strides: [1]>, !linalg.range, memref<?xf32, offset: ?, strides: [1]>
|
||||
func @slice(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: !linalg.range) {
|
||||
%1 = linalg.slice %arg0[%arg1] : memref<?xf32, offset: ?, strides: [1]>, !linalg.range, memref<?xf32, offset: ?, strides: [1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @slice
|
||||
// insert ptr for view op
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// insert data ptr for slice op
|
||||
// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
|
||||
|
@ -122,13 +56,6 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
|
|||
// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
|
||||
// CHECK-NEXT: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64 }*">) -> ()
|
||||
|
||||
func @dim(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
%0 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) {
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
||||
func @subview(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
|
|
@ -12,22 +12,20 @@
|
|||
// CHECK-DAG: #[[Stride2Dilation4:.*]] = (d0, d1) -> (d0 * 2 + d1 * 4)
|
||||
// CHECK-DAG: #[[Stride3Dilation5:.*]] = (d0, d1) -> (d0 * 3 + d1 * 5)
|
||||
|
||||
func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
|
||||
func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
|
||||
%K = linalg.range %c0:%arg3:%c1 : !linalg.range
|
||||
%A = linalg.view %arg0[%I, %K] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%B = linalg.view %arg0[%K, %J] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%C = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%A = view %arg0[%M, %K][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%B = view %arg0[%K, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%C = view %arg0[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[N:.*]] = dim %[[B]], 1 : memref<?x?xf32, #[[strided2D]]>
|
||||
|
@ -41,21 +39,19 @@ func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
|
|||
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK: store %[[res]], %[[C]][%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>
|
||||
|
||||
func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
|
||||
%2 = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%3 = linalg.view %arg0[%J] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
|
||||
%4 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
|
||||
%2 = view %arg0[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%3 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
|
||||
%4 = view %arg0[%N][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
|
||||
linalg.matvec(%2, %3, %4) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matvec(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
|
||||
// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
|
||||
|
@ -67,20 +63,19 @@ func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
|
|||
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
|
||||
// CHECK: store %[[res]], %[[C]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
|
||||
|
||||
func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
func @dot(%arg0: memref<?xi8>, %M: index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%1 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
|
||||
%2 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
|
||||
%3 = linalg.view %arg0[] : !linalg.buffer<?xf32> -> memref<f32>
|
||||
%1 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
|
||||
%2 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
|
||||
%3 = view %arg0[][] : memref<?xi8> to memref<f32>
|
||||
linalg.dot(%1, %2, %3) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer<?xf32> -> memref<f32>
|
||||
// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>, %{{.*}}: index) {
|
||||
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[C:.*]] = std.view %{{.*}}[][] : memref<?xi8> to memref<f32>
|
||||
// CHECK: %[[K:.*]] = dim %[[A]], 0 : memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
|
||||
// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s
|
||||
// RUN: mlir-opt %s -linalg-promote-subviews -test-linalg-promote-dynamic | FileCheck %s --check-prefix=DYNAMIC
|
||||
|
||||
#map0 = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
|
||||
#map1 = (d0) -> (d0 + 2)
|
||||
|
@ -9,18 +10,15 @@
|
|||
// CHECK-DAG: #[[strided2DnoOffset:.*]] = (d0, d1)[s0] -> (d0 * s0 + d1)
|
||||
|
||||
module {
|
||||
func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
|
||||
func @matmul(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
|
||||
%c4 = constant 4 : index
|
||||
%c3 = constant 3 : index
|
||||
%c2 = constant 2 : index
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.range %c0:%arg1:%c1 : !linalg.range
|
||||
%1 = linalg.range %c0:%arg2:%c1 : !linalg.range
|
||||
%2 = linalg.range %c0:%arg3:%c1 : !linalg.range
|
||||
%3 = linalg.view %arg0[%0, %2] : !linalg.buffer<?xf32> -> memref<?x?xf32, #map0>
|
||||
%4 = linalg.view %arg0[%2, %1] : !linalg.buffer<?xf32> -> memref<?x?xf32, #map0>
|
||||
%5 = linalg.view %arg0[%0, %1] : !linalg.buffer<?xf32> -> memref<?x?xf32, #map0>
|
||||
%3 = view %A[%M, %K][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
|
||||
%4 = view %A[%K, %N][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
|
||||
%5 = view %A[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
|
||||
%6 = dim %3, 0 : memref<?x?xf32, #map0>
|
||||
%7 = dim %3, 1 : memref<?x?xf32, #map0>
|
||||
%8 = dim %4, 1 : memref<?x?xf32, #map0>
|
||||
|
@ -44,7 +42,7 @@ module {
|
|||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @matmul(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
|
@ -52,16 +50,19 @@ module {
|
|||
// CHECK: %[[vB:.*]] = linalg.subview {{.*}} : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[vC:.*]] = linalg.subview {{.*}} : memref<?x?xf32, #[[strided2D]]>
|
||||
///
|
||||
// CHECK: %[[tmpA:.*]] = linalg.buffer_alloc : !linalg.buffer<8xf32>
|
||||
// CHECK: %[[fullA:.*]] = linalg.view %[[tmpA]][{{.*}}] : !linalg.buffer<8xf32> -> memref<?x?xf32>
|
||||
// CHECK: %[[tmpA:.*]] = alloc() : memref<32xi8>
|
||||
// CHECK: %[[fullA:.*]] = std.view %[[tmpA]][{{.*}}][] : memref<32xi8> to memref<?x?xf32>
|
||||
// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref<?xi8> to memref<?x?xf32>
|
||||
// CHECK: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
///
|
||||
// CHECK: %[[tmpB:.*]] = linalg.buffer_alloc : !linalg.buffer<12xf32>
|
||||
// CHECK: %[[fullB:.*]] = linalg.view %[[tmpB]][{{.*}}] : !linalg.buffer<12xf32> -> memref<?x?xf32>
|
||||
// CHECK: %[[tmpB:.*]] = alloc() : memref<48xi8>
|
||||
// CHECK: %[[fullB:.*]] = std.view %[[tmpB]][{{.*}}][] : memref<48xi8> to memref<?x?xf32>
|
||||
// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref<?xi8> to memref<?x?xf32>
|
||||
// CHECK: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
///
|
||||
// CHECK: %[[tmpC:.*]] = linalg.buffer_alloc : !linalg.buffer<6xf32>
|
||||
// CHECK: %[[fullC:.*]] = linalg.view %[[tmpC]][{{.*}}] : !linalg.buffer<6xf32> -> memref<?x?xf32>
|
||||
// CHECK: %[[tmpC:.*]] = alloc() : memref<24xi8>
|
||||
// CHECK: %[[fullC:.*]] = std.view %[[tmpC]][{{.*}}][] : memref<24xi8> to memref<?x?xf32>
|
||||
// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref<?xi8> to memref<?x?xf32>
|
||||
// CHECK: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
|
||||
// CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref<?x?xf32>, f32
|
||||
|
@ -75,6 +76,6 @@ module {
|
|||
//
|
||||
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2DnoOffset]]>, memref<?x?xf32, #[[strided2D]]>
|
||||
//
|
||||
// CHECK: linalg.buffer_dealloc %[[tmpA]] : !linalg.buffer<8xf32>
|
||||
// CHECK: linalg.buffer_dealloc %[[tmpB]] : !linalg.buffer<12xf32>
|
||||
// CHECK: linalg.buffer_dealloc %[[tmpC]] : !linalg.buffer<6xf32>
|
||||
// CHECK: dealloc %[[tmpA]] : memref<32xi8>
|
||||
// CHECK: dealloc %[[tmpB]] : memref<48xi8>
|
||||
// CHECK: dealloc %[[tmpC]] : memref<24xi8>
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
// CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0)
|
||||
// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
|
||||
// CHECK-DAG: #[[strided2D42by1SymbolicOffset:.*]] = (d0, d1)[s0] -> (d0 * 42 + s0 + d1)
|
||||
// CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)
|
||||
// CHECK-DAG: #[[strided6D:.*]] = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)
|
||||
|
||||
|
@ -21,60 +20,31 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
|
|||
// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
|
||||
|
||||
func @buffer_size(%arg0: !linalg.buffer<?xf32>) -> index {
|
||||
%0 = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
|
||||
return %0 : index
|
||||
}
|
||||
// CHECK-LABEL: func @buffer_size
|
||||
// CHECK: linalg.buffer_size {{.*}} : !linalg.buffer<?xf32>
|
||||
|
||||
func @buffer(%arg0: index, %arg1: index) {
|
||||
%0 = muli %arg0, %arg0 : index
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
|
||||
%2 = linalg.buffer_alloc %0 {alignment = 16} : !linalg.buffer<?xvector<4xi8>>
|
||||
%3 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
|
||||
%4 = linalg.buffer_alloc {alignment = 32} : !linalg.buffer<17xvector<4xi8>>
|
||||
linalg.buffer_dealloc %4 : !linalg.buffer<17xvector<4xi8>>
|
||||
linalg.buffer_dealloc %3 : !linalg.buffer<17xvector<4xi8>>
|
||||
linalg.buffer_dealloc %2 : !linalg.buffer<?xvector<4xi8>>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @buffer(%{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} {alignment = 16 : i64} : !linalg.buffer<?xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_alloc {alignment = 32 : i64} : !linalg.buffer<17xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
|
||||
|
||||
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = muli %arg0, %arg0 : index
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
|
||||
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
%3 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%1 = alloc (%0) : memref<?xi8>
|
||||
%2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
%3 = view %1[%arg0, %arg0][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%4 = linalg.slice %3[%2, %2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, !linalg.range, !linalg.range, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%5 = linalg.slice %3[%2, %arg2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, !linalg.range, index, memref<?xf32, offset: ?, strides: [1]>
|
||||
%6 = linalg.slice %3[%arg2, %2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, index, !linalg.range, memref<?xf32, offset: ?, strides: [1]>
|
||||
%7 = linalg.slice %3[%arg2, %arg3] : memref<?x?xf32, offset: ?, strides: [?, 1]>, index, index, memref<f32>
|
||||
%8 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> memref<?x?xvector<4x4xf32>, offset: ?, strides: [?, 1]>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
|
||||
%8 = view %1[%arg0, %arg0][%c0] : memref<?xi8> to memref<?x?xvector<4x4xf32>, offset: ?, strides: [?, 1]>
|
||||
dealloc %1 : memref<?xi8>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @views(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
|
||||
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
|
||||
// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: muli %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: alloc(%{{.*}}) : memref<?xi8>
|
||||
// CHECK-NEXT: range
|
||||
// CHECK-NEXT: std.view %{{.*}}[%{{.*}}][%{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, !linalg.range, index, memref<?xf32, #[[strided1D]]>
|
||||
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, index, !linalg.range, memref<?xf32, #[[strided1D]]>
|
||||
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, index, index, memref<f32>
|
||||
// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xvector<4x4xf32>, #[[strided2D]]>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
|
||||
// CHECK-NEXT: view %{{.*}}[%{{.*}}][%{{.*}}] : memref<?xi8> to memref<?x?xvector<4x4xf32>, #[[strided2D]]>
|
||||
// CHECK-NEXT: dealloc %{{.*}} : memref<?xi8>
|
||||
|
||||
func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<?xf32, offset: ?, strides: [1]>, %arg3: memref<f32>) {
|
||||
linalg.matmul(%arg0, %arg0, %arg0) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
@ -88,41 +58,6 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf3
|
|||
// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[strided2D]]>, memref<?xf32, #[[strided1D]]>, memref<?xf32, #[[strided1D]]>
|
||||
// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?xf32, #[[strided1D]]>, memref<?xf32, #[[strided1D]]>, memref<f32>
|
||||
|
||||
func @dim(%arg0: memref<?x?xf32, offset: ?, strides: [42, 1]>) {
|
||||
%0 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [42, 1]>
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dim(
|
||||
// CHECK: %{{.*}}: memref<?x?xf32, #[[strided2D42by1SymbolicOffset]]>) {
|
||||
// CHECK-NEXT: dim %{{.*}}, 1 : memref<?x?xf32, #[[strided2D42by1SymbolicOffset]]>
|
||||
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
|
||||
|
||||
func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
loop.for %i0 = %arg0 to %arg1 step %arg2 {
|
||||
loop.for %i1 = %arg0 to %arg1 step %arg2 {
|
||||
%min_cmp = cmpi "slt", %i0, %i1 : index
|
||||
%min = select %min_cmp, %i0, %i1 : index
|
||||
%max_cmp = cmpi "sge", %i0, %i1 : index
|
||||
%max = select %max_cmp, %i0, %i1 : index
|
||||
loop.for %i2 = %min to %max step %i1 {
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @linalg_for(
|
||||
// CHECK: %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
// CHECK-NEXT: cmpi "slt", %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: cmpi "sge", %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index
|
||||
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
|
||||
|
||||
func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
|
||||
linalg.fill(%arg0, %arg1) : memref<?xf32, offset: ?, strides: [1]>, f32
|
||||
return
|
||||
|
@ -190,13 +125,6 @@ func @subview(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>) {
|
|||
// CHECK: constant 0 : index
|
||||
// CHECK: linalg.subview %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?xvector<3x4xi4>, #[[strided2D]]>
|
||||
|
||||
func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) {
|
||||
%c0 = linalg.buffer_alloc : !linalg.buffer<17xf32>
|
||||
%c1 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
%c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> memref<?xf32, offset: ?, strides: [1]>
|
||||
return
|
||||
}
|
||||
|
||||
#accesses = [
|
||||
(i, j, k) -> (j, i),
|
||||
(i, j, k) -> (i, k, i + j)
|
||||
|
|
|
@ -5,18 +5,19 @@
|
|||
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
|
||||
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
|
||||
|
||||
#strided1D = (d0)[s0] -> (d0 + s0)
|
||||
#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
|
||||
#strided1D = (d0) -> (d0)
|
||||
#strided2D = (d0, d1)[s0] -> (d0 * s0 + d1)
|
||||
|
||||
// Creates and returns a 1-D buffer of size %s filled with the value %f
|
||||
func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
|
||||
func @alloc_filled_f32(%s : index, %f : f32) -> memref<?xi8> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%buf = linalg.buffer_alloc %s {alignment = 256} : !linalg.buffer<?xf32>
|
||||
%R = linalg.range %c0:%s:%c1 : !linalg.range
|
||||
%V = linalg.view %buf[%R] : !linalg.buffer<?xf32> -> memref<?xf32, #strided1D>
|
||||
%c4 = constant 4 : index
|
||||
%s4 = muli %s, %c4: index
|
||||
%buf = alloc(%s4) {alignment = 256} : memref<?xi8>
|
||||
%V = view %buf[%s][] : memref<?xi8> to memref<?xf32, #strided1D>
|
||||
linalg.fill(%V, %f) : memref<?xf32, #strided1D>, f32
|
||||
return %buf : !linalg.buffer<?xf32>
|
||||
return %buf : memref<?xi8>
|
||||
}
|
||||
|
||||
// Test for linalg.dot.
|
||||
|
@ -28,21 +29,20 @@ func @dot() -> f32 {
|
|||
%f1 = constant 1.00000e+00 : f32
|
||||
%f2 = constant 2.00000e+00 : f32
|
||||
|
||||
%bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<?xf32>)
|
||||
%bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<?xf32>)
|
||||
%bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
|
||||
%bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (memref<?xi8>)
|
||||
%bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (memref<?xi8>)
|
||||
%bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (memref<?xi8>)
|
||||
|
||||
%R = linalg.range %c0:%c16:%c1 : !linalg.range
|
||||
%A = linalg.view %bA[%R] : !linalg.buffer<?xf32> -> memref<?xf32, #strided1D>
|
||||
%B = linalg.view %bB[%R] : !linalg.buffer<?xf32> -> memref<?xf32, #strided1D>
|
||||
%C = linalg.view %bC[] : !linalg.buffer<?xf32> -> memref<f32>
|
||||
%A = view %bA[%c16][] : memref<?xi8> to memref<?xf32, #strided1D>
|
||||
%B = view %bB[%c16][] : memref<?xi8> to memref<?xf32, #strided1D>
|
||||
%C = view %bC[][] : memref<?xi8> to memref<f32>
|
||||
|
||||
linalg.dot(%A, %B, %C) : memref<?xf32, #strided1D>, memref<?xf32, #strided1D>, memref<f32>
|
||||
%res = load %C[] : memref<f32>
|
||||
|
||||
linalg.buffer_dealloc %bC : !linalg.buffer<?xf32>
|
||||
linalg.buffer_dealloc %bB : !linalg.buffer<?xf32>
|
||||
linalg.buffer_dealloc %bA : !linalg.buffer<?xf32>
|
||||
dealloc %bC : memref<?xi8>
|
||||
dealloc %bB : memref<?xi8>
|
||||
dealloc %bA : memref<?xi8>
|
||||
|
||||
return %res : f32
|
||||
}
|
||||
|
@ -61,23 +61,20 @@ func @matmul() -> f32 {
|
|||
%f2 = constant 2.00000e+00 : f32
|
||||
%f10 = constant 10.00000e+00 : f32
|
||||
|
||||
%bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (!linalg.buffer<?xf32>)
|
||||
%bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (!linalg.buffer<?xf32>)
|
||||
%bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
|
||||
%bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (memref<?xi8>)
|
||||
%bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (memref<?xi8>)
|
||||
%bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (memref<?xi8>)
|
||||
|
||||
%M = linalg.range %c0:%c10:%c1 : !linalg.range
|
||||
%N = linalg.range %c0:%c10:%c1 : !linalg.range
|
||||
%K = linalg.range %c0:%c16:%c1 : !linalg.range
|
||||
%A = linalg.view %bA[%M, %K] : !linalg.buffer<?xf32> -> memref<?x?xf32, #strided2D>
|
||||
%B = linalg.view %bB[%K, %N] : !linalg.buffer<?xf32> -> memref<?x?xf32, #strided2D>
|
||||
%C = linalg.view %bC[%M, %N] : !linalg.buffer<?xf32> -> memref<?x?xf32, #strided2D>
|
||||
%A = view %bA[%c10, %c16][] : memref<?xi8> to memref<?x?xf32, #strided2D>
|
||||
%B = view %bB[%c16, %c10][] : memref<?xi8> to memref<?x?xf32, #strided2D>
|
||||
%C = view %bC[%c10, %c10][] : memref<?xi8> to memref<?x?xf32, #strided2D>
|
||||
|
||||
linalg.matmul(%A, %B, %C) : memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>
|
||||
%res = load %C[%c6, %c7] : memref<?x?xf32, #strided2D>
|
||||
|
||||
linalg.buffer_dealloc %bC : !linalg.buffer<?xf32>
|
||||
linalg.buffer_dealloc %bB : !linalg.buffer<?xf32>
|
||||
linalg.buffer_dealloc %bA : !linalg.buffer<?xf32>
|
||||
dealloc %bC : memref<?xi8>
|
||||
dealloc %bB : memref<?xi8>
|
||||
dealloc %bA : memref<?xi8>
|
||||
|
||||
return %res : f32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue