Update Linalg to use std.view

Now that a view op has graduated to the std dialect, we can update Linalg to use it and remove ops that have become obsolete. As a byproduct, the linalg buffer and associated ops can also disappear.

PiperOrigin-RevId: 279073591
This commit is contained in:
Nicolas Vasilache 2019-11-07 06:32:39 -08:00 committed by A. Unique TensorFlower
parent eee9cbdeb7
commit 72040bf7c8
17 changed files with 149 additions and 853 deletions

View File

@ -44,108 +44,6 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def BufferAllocOp :
Linalg_Op<"buffer_alloc">,
Arguments<(ins Variadic<Index>:$size, OptionalAttr<I64Attr>:$alignment)>,
Results<(outs Buffer)> {
let summary = "buffer allocation operation";
let description = [{
The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
upon which a base view can be laid out to give it indexing semantics.
"buffer_alloc" takes a single argument, the size of the buffer to allocate
(in number of elements).
An optional alignment attribute may be specified in which case the actual
underlying allocation size may be increased. The base pointer is guaranteed
to be a multiple of `alignment`. Such an alignment must be a positive power
of 2.
Examples:
%0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
%1 = linalg.buffer_alloc(%arg0) { alignment = 16 } :
!linalg.buffer<?xf32>
The size argument may be omitted if it is statically known, in which case it
must be reflected in the type.
Example:
%0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
}];
let builders = [
OpBuilder<
"Builder *b, OperationState &result, BufferType bufferType", [{
result.addTypes(bufferType);
}]>,
OpBuilder<
"Builder *b, OperationState &result, BufferType bufferType, "
"unsigned alignment", [{
build(b, result, bufferType);
if (alignment != 0)
result.addAttribute(BufferAllocOp::getAlignmentAttrName(),
b->getI64IntegerAttr(alignment));
}]>,
OpBuilder<
"Builder *b, OperationState &result, BufferType bufferType, "
"Value *size, unsigned alignment", [{
if (alignment == 0)
return build(b, result, bufferType, size);
build(b, result, bufferType, size, b->getI64IntegerAttr(alignment));
}]>,
OpBuilder<
"Builder *b, OperationState &result, BufferType bufferType, Value *size",
[{
result.addOperands(size);
result.addTypes(bufferType);
}]>
];
let extraClassDeclaration = [{
static StringRef getAlignmentAttrName() { return "alignment"; }
BufferType getBufferType() { return getType().cast<BufferType>(); }
Type getElementType() { return getBufferType().getElementType(); }
}];
}
def BufferDeallocOp :
Linalg_Op<"buffer_dealloc">,
Arguments<(ins Buffer:$buffer)>,
Results<(outs)> {
let summary = "buffer allocation operation";
let description = [{
The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
Example:
linalg.buffer_dealloc %0 : !linalg.buffer<f32>
}];
let extraClassDeclaration = [{
BufferType getBufferType() {
return buffer()->getType().cast<BufferType>();
}
}];
// Fully specified by traits.
let verifier = ?;
}
def BufferSizeOp :
Linalg_Op<"buffer_size", [NoSideEffect]>,
Arguments<(ins Buffer:$buffer)>,
Results<(outs Index)> {
let summary = "buffer size operation";
let description = [{
The "linalg.buffer_size" operation takes a linalg.buffer and returns an
"index".
Example:
%0 = linalg.buffer_size %arg0 : !linalg.buffer<f32>
}];
// Fully specified by traits.
let verifier = ?;
}
def RangeOp :
Linalg_Op<"range", [NoSideEffect]>,
Arguments<(ins Index:$min, Index:$max, Index:$step)>,
@ -329,51 +227,6 @@ def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
}];
}
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
Results<(outs AnyStridedMemRef)> {
let summary = "view operation";
let description = [{
The "linalg.view" op produces a strided memref which is a multi-dimensional
range abstraction on top of an underlying linalg.buffer. This gives an
indexing structure to an otherwise non-indexable linalg.buffer.
A "linalg.view" takes a buffer and a variadic number of ranges and produces
a `view` of rank the number of ranges. The elemental type may not match the
buffer element type:
Example:
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.view %1[%2, %2] :
memref<?x?xvector<4xf32>, stride_specification>
}];
let builders = [OpBuilder<
"Builder *b, OperationState &result, Value *buffer, "
"ArrayRef<Value *> ranges, Type resultType = Type(), "
"ArrayRef<NamedAttribute> attrs = {}">];
let verifier = [{
if (getViewType().getRank() != llvm::size(ranges()))
return emitOpError("the view rank must be the number of its ranges");
return success();
}];
let extraClassDeclaration = [{
enum { FirstIndexingOperand = 1 };
unsigned getRank() { return getViewType().getRank(); }
Type getElementType() { return getViewType().getElementType(); }
MemRefType getViewType() { return getType().cast<MemRefType>(); }
/// Get the underlying indexing at a given rank.
Value *getRange(unsigned rank) {
assert(rank < getRank() && "rank overflow");
return *(ranges().begin() + rank);
}
}];
}
def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";

View File

@ -36,7 +36,8 @@ std::unique_ptr<OpPassBase<FuncOp>> createLinalgFusionPass();
std::unique_ptr<OpPassBase<FuncOp>>
createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {});
std::unique_ptr<OpPassBase<FuncOp>> createLinalgPromotionPass();
std::unique_ptr<OpPassBase<FuncOp>>
createLinalgPromotionPass(bool dynamicBuffers);
std::unique_ptr<OpPassBase<FuncOp>> createLowerLinalgToLoopsPass();

View File

@ -22,22 +22,15 @@
namespace mlir {
namespace linalg {
class BufferAllocOp;
class BufferDeallocOp;
class CopyOp;
class FillOp;
class RangeOp;
class SliceOp;
class ViewOp;
namespace intrinsics {
using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder<BufferAllocOp>;
using buffer_dealloc =
mlir::edsc::intrinsics::OperationBuilder<BufferDeallocOp>;
using copy = mlir::edsc::intrinsics::OperationBuilder<CopyOp>;
using fill = mlir::edsc::intrinsics::OperationBuilder<FillOp>;
using range = mlir::edsc::intrinsics::ValueBuilder<RangeOp>;
using slice = mlir::edsc::intrinsics::ValueBuilder<SliceOp>;
using view = mlir::edsc::intrinsics::ValueBuilder<ViewOp>;
} // namespace intrinsics
} // namespace linalg
} // namespace mlir

View File

@ -173,9 +173,10 @@ struct PromotionInfo {
///
/// Returns a list of PromotionInfo which hold the promoted buffer and the
/// full and partial views indexing into the buffer.
llvm::SmallVector<PromotionInfo, 8> promoteSubViews(OpBuilder &b, Location loc,
ArrayRef<Value *> subViews,
OperationFolder *folder);
llvm::SmallVector<PromotionInfo, 8>
promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value *> subViews,
bool promoteSubViews = false,
OperationFolder *folder = nullptr);
/// Returns all the operands of `linalgOp` that are not views.
/// Asserts that these operands are value types to allow transformations like

View File

@ -216,6 +216,7 @@ using std_load = ValueBuilder<LoadOp>;
using std_store = OperationBuilder<StoreOp>;
using subi = ValueBuilder<SubIOp>;
using vector_type_cast = ValueBuilder<vector::VectorTypeCastOp>;
using view = ValueBuilder<ViewOp>;
/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`.
///

View File

@ -1430,10 +1430,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
return op->emitWarning("Cannot cast to non-strided shape"),
matchFailure();
if (strides.back() != 1)
return op->emitWarning("Cannot cast to non-contiguous shape"),
return op->emitWarning("cannot cast to non-strided shape"),
matchFailure();
// Create the descriptor.
@ -1466,7 +1463,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
rewriter.getI64ArrayAttr(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
// Early exit for 0-D corner case.
if (viewMemRefType.getRank() == 0)
return rewriter.replaceOp(op, desc), matchSuccess();
// Update sizes and strides.
if (strides.back() != 1)
return op->emitWarning("cannot cast to non-contiguous shape"),
matchFailure();
Value *stride = nullptr, *nextSize = nullptr;
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size.

View File

@ -74,7 +74,7 @@ Value *Aliases::find(Value *v) {
return it.first->second;
}
if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
auto it = aliases.insert(std::make_pair(v, view.buffer()));
auto it = aliases.insert(std::make_pair(v, view.source()));
return it.first->second;
}
if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {

View File

@ -50,102 +50,6 @@ using namespace mlir::linalg;
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.
//===----------------------------------------------------------------------===//
// BufferAllocOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BufferAllocOp op) {
p << op.getOperationName() << " ";
if (!llvm::empty(op.size()))
p << *op.getOperand(0);
if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
p.printOptionalAttrDict(op.getAttrs());
else
p.printOptionalAttrDict(op.getAttrs(),
BufferAllocOp::getAlignmentAttrName());
p << " : " << op.getBufferType();
}
static ParseResult parseBufferAllocOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
BufferType bufferType;
auto indexTy = parser.getBuilder().getIndexType();
if (parser.parseOperandList(sizeInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(bufferType))
return failure();
if (sizeInfo.empty())
return parser.addTypeToList(bufferType, result.types);
return failure(parser.resolveOperands(sizeInfo, indexTy, result.operands) ||
parser.addTypeToList(bufferType, result.types));
}
static LogicalResult verify(BufferAllocOp op) {
if (!op.getBufferType().hasConstantSize()) {
if (llvm::size(op.size()) != 1)
return op.emitOpError("expected one index operand");
} else { // op.getBufferType().hasConstantSize()
if (!llvm::empty(op.size()))
return op.emitOpError("expected zero operand");
if (op.getBufferType().getBufferSize().getValue() <= 0)
return op.emitOpError("expected nonnegative static buffer size");
}
if (op.alignment().hasValue()) {
auto align = op.alignment().getValue();
if (align.getSExtValue() < 0)
return op.emitOpError("expected positive alignment");
if (!llvm::isPowerOf2_64(align.getZExtValue()))
return op.emitOpError("expected power of 2 alignment");
}
if (!TensorType::isValidElementType(op.getElementType()))
return op.emitOpError("expected valid buffer element type");
return success();
}
//===----------------------------------------------------------------------===//
// BufferDeallocOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BufferDeallocOp op) {
p << op.getOperationName() << " " << *op.buffer();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getBufferType();
}
static ParseResult parseBufferDeallocOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType bufferInfo;
BufferType bufferType;
if (parser.parseOperand(bufferInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(bufferType))
return failure();
return parser.resolveOperands(bufferInfo, bufferType, result.operands);
}
//===----------------------------------------------------------------------===//
// BufferSizeOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, BufferSizeOp op) {
p << op.getOperationName() << " " << *op.buffer();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.buffer()->getType();
}
static ParseResult parseBufferSizeOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType op;
Type type;
return failure(
parser.parseOperand(op) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(op, type, result.operands) ||
parser.addTypeToList(parser.getBuilder().getIndexType(), result.types));
}
//===----------------------------------------------------------------------===//
// GenericOps
//===----------------------------------------------------------------------===//
@ -426,7 +330,7 @@ static LogicalResult verify(SliceOp op) {
unsigned rank = op.getBaseViewRank();
if (rank != llvm::size(op.indexings()))
return op.emitOpError("expected ")
<< op.getRank() << " indexings, got " << llvm::size(op.indexings());
<< rank << " indexings, got " << llvm::size(op.indexings());
unsigned index = 0;
for (auto indexing : op.indexings()) {
if (indexing->getType().isa<IndexType>())
@ -562,59 +466,6 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
parser.addTypeToList(type, result.types));
}
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
void mlir::linalg::ViewOp::build(Builder *b, OperationState &result,
Value *buffer, ArrayRef<Value *> ranges,
Type resultType,
ArrayRef<NamedAttribute> attrs) {
// If the result type is not specified, assume sizes are fully dynamic.
// Strides are set to match an empty layout map which means "contiguous view".
if (!resultType) {
auto rank = ranges.size();
SmallVector<int64_t, 4> sizes(rank, -1);
Type elementType = buffer->getType().cast<BufferType>().getElementType();
resultType = MemRefType::get(sizes, elementType, {}, 0);
}
build(b, result, resultType, buffer, ranges);
result.addAttributes(attrs);
}
static void print(OpAsmPrinter &p, mlir::linalg::ViewOp op) {
p << op.getOperationName() << " " << *op.buffer() << "[";
interleaveComma(op.ranges(), p, [&](Value *v) { p << *v; });
p << "] ";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.buffer()->getType() << " -> " << op.getType();
}
static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
Type bType, vType;
if (parser.parseOperand(bufferInfo) ||
parser.parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(bType) || parser.parseArrow() ||
parser.parseType(vType)) {
return failure();
}
MemRefType memRefType = vType.dyn_cast<MemRefType>();
if (!memRefType)
return parser.emitError(parser.getNameLoc(), "expected memref type");
if (static_cast<unsigned>(memRefType.getRank()) != rangesInfo.size())
return parser.emitError(parser.getNameLoc(), "expected ")
<< memRefType.getRank() << " ranges";
return failure(
parser.resolveOperand(bufferInfo, bType, result.operands) ||
(!rangesInfo.empty() &&
parser.resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
result.operands)) ||
parser.addTypeToList(memRefType, result.types));
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//

View File

@ -166,144 +166,6 @@ public:
};
} // namespace
// BufferAllocOp creates a new `!linalg.buffer` value.
class BufferAllocOpConversion : public LLVMOpLowering {
public:
explicit BufferAllocOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
// Insert the `malloc` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
auto mallocFunc = module.lookupSymbol<LLVMFuncOp>("malloc");
if (!mallocFunc) {
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
mallocFunc = moduleBuilder.create<LLVMFuncOp>(
rewriter.getUnknownLoc(), "malloc",
LLVM::LLVMType::getFunctionTy(voidPtrTy, int64Ty,
/*isVarArg=*/false));
}
// Get MLIR types for injecting element pointer.
auto allocOp = cast<BufferAllocOp>(op);
auto elementType = allocOp.getElementType();
uint64_t elementSize = 0;
if (auto vectorType = elementType.dyn_cast<VectorType>())
elementSize = vectorType.getNumElements() *
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
auto bufferType = allocOp.getBufferType();
auto elementPtrType = getPtrToElementType(bufferType, lowering);
auto bufferDescriptorTy = convertLinalgType(bufferType, lowering);
// Emit IR for creating a new buffer descriptor with an underlying malloc.
edsc::ScopedContext context(rewriter, op->getLoc());
auto constantSize = bufferType.getBufferSize();
Value *size =
constantSize
? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
.getValue()
: operands[0];
Value *allocSize =
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
Value *one = nullptr, *align = nullptr;
if (allocOp.alignment().hasValue()) {
one = constant(int64Ty, IntegerAttr::get(indexType, 1));
align =
constant(int64Ty, rewriter.getIntegerAttr(
rewriter.getIndexType(),
allocOp.alignment().getValue().getSExtValue()));
allocSize = sub(add(allocSize, align), one);
}
Value *allocated =
llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
.getOperation()
->getResult(0);
Value *data = allocated;
if (allocOp.alignment().hasValue()) {
// offset = (align - (ptr % align))% align
Value *offset =
urem(sub(align, urem(ptrtoint(int64Ty, allocated), align)), align);
data = gep(voidPtrTy, allocated, offset);
}
data = bitcast(elementPtrType, data);
Value *desc = llvm_undef(bufferDescriptorTy);
desc = insertvalue(bufferDescriptorTy, desc, allocated,
rewriter.getI64ArrayAttr(kBasePtrPosInBuffer));
desc = insertvalue(bufferDescriptorTy, desc, data,
rewriter.getI64ArrayAttr(kPtrPosInBuffer));
desc = insertvalue(bufferDescriptorTy, desc, size,
rewriter.getI64ArrayAttr(kSizePosInBuffer));
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
// BufferDeallocOp creates no value.
class BufferDeallocOpConversion : public LLVMOpLowering {
public:
explicit BufferDeallocOpConversion(MLIRContext *context,
LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto voidTy = LLVM::LLVMType::getVoidTy(lowering.getDialect());
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
auto module = op->getParentOfType<ModuleOp>();
auto freeFunc = module.lookupSymbol<LLVMFuncOp>("free");
if (!freeFunc) {
OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
freeFunc = moduleBuilder.create<LLVMFuncOp>(
rewriter.getUnknownLoc(), "free",
LLVM::LLVMType::getFunctionTy(voidTy, voidPtrTy,
/*isVarArg=*/false));
}
// Emit MLIR for buffer_dealloc.
BufferDeallocOpOperandAdaptor adaptor(operands);
edsc::ScopedContext context(rewriter, op->getLoc());
Value *base = extractvalue(voidPtrTy, adaptor.buffer(),
rewriter.getI64ArrayAttr(kBasePtrPosInBuffer));
llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), base);
rewriter.eraseOp(op);
return matchSuccess();
}
};
// BufferSizeOp creates a new `index` value.
class BufferSizeOpConversion : public LLVMOpLowering {
public:
BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
edsc::ScopedContext context(rewriter, op->getLoc());
BufferSizeOpOperandAdaptor adaptor(operands);
rewriter.replaceOp(
op, {extractvalue(int64Ty, adaptor.buffer(),
rewriter.getI64ArrayAttr(kSizePosInBuffer))});
return matchSuccess();
}
};
// RangeOp creates a new range descriptor.
class RangeOpConversion : public LLVMOpLowering {
public:
@ -480,78 +342,6 @@ public:
}
};
/// Conversion pattern that transforms a linalg.view op into:
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
/// and stride.
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
/// The linalg.view op is replaced by the alloca'ed pointer.
class ViewOpConversion : public LLVMOpLowering {
public:
explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(mlir::linalg::ViewOp::getOperationName(), context,
lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
mlir::linalg::ViewOpOperandAdaptor adaptor(operands);
auto viewOp = cast<mlir::linalg::ViewOp>(op);
BaseViewConversionHelper helper(op->getLoc(), viewOp.getViewType(),
rewriter, lowering);
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
Value *desc = helper.desc;
Value *bufferDescriptor = adaptor.buffer();
auto bufferTy = getPtrToElementType(
viewOp.buffer()->getType().cast<BufferType>(), lowering);
edsc::ScopedContext context(rewriter, op->getLoc());
// Copy the buffer pointer from the old descriptor to the new one.
Value *bufferAsViewElementType =
bitcast(elementTy, extractvalue(bufferTy, bufferDescriptor,
helper.pos(kPtrPosInBuffer)));
desc =
insertvalue(desc, bufferAsViewElementType, helper.pos(kPtrPosInView));
// Zero base offset.
auto indexTy = rewriter.getIndexType();
Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView));
// Corner case, no sizes or stride: early return the descriptor.
if (helper.zeroDMemRef) {
rewriter.replaceOp(op, desc);
return matchSuccess();
}
// Compute and insert view sizes (max - min along the range).
int numRanges = llvm::size(viewOp.ranges());
Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
for (int i = numRanges - 1; i >= 0; --i) {
// Update stride.
Value *rangeDescriptor = operands[1 + i];
Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
Value *stride = mul(runningStride, step);
desc = insertvalue(desc, stride, helper.pos({kStridePosInView, i}));
// Update size.
Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
Value *size = sub(max, min);
desc = insertvalue(desc, size, helper.pos({kSizePosInView, i}));
// Update stride for the next dimension.
if (i > 0)
runningStride = mul(runningStride, max);
}
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
// YieldOp produces and LLVM::ReturnOp.
class YieldOpConversion : public LLVMOpLowering {
public:
@ -731,10 +521,8 @@ static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, RangeOpConversion, SliceOpConversion,
TransposeOpConversion, ViewOpConversion, YieldOpConversion>(
ctx, converter);
patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion,
YieldOpConversion>(ctx, converter);
}
namespace {

View File

@ -49,17 +49,26 @@ using llvm::SetVector;
#define DEBUG_TYPE "linalg-promotion"
static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
static llvm::cl::opt<bool> clPromoteDynamic(
"test-linalg-promote-dynamic",
llvm::cl::desc("Test generation of dynamic promoted buffers"),
llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
static AffineMap getAffineDifferenceMap(MLIRContext *context) {
AffineExpr d0(getAffineDimExpr(0, context)), d1(getAffineDimExpr(1, context));
return AffineMap::get(2, 0, {d0 - d1});
}
// TODO(ntv): replace this with 1-D memref alloc once there is an std.view op.
static Value *allocBuffer(Type elementType, Value *size) {
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
return buffer_alloc(
BufferType::get(size->getContext(), elementType, cst.getValue()));
return buffer_alloc(BufferType::get(size->getContext(), elementType), size);
static Value *allocBuffer(Type elementType, Value *size, bool dynamicBuffers) {
auto *ctx = size->getContext();
auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
if (!dynamicBuffers)
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
return alloc(
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
Value *mul = muli(constant_index(width), size);
return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
}
// Performs promotion of a `subView` into a local buffer of the size of the
@ -81,6 +90,7 @@ static Value *allocBuffer(Type elementType, Value *size) {
// by a partial `copy` op.
static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
SubViewOp subView,
bool dynamicBuffers,
OperationFolder *folder) {
auto zero = constant_index(folder, 0);
auto one = constant_index(folder, 1);
@ -101,18 +111,21 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
{rangeValue.max, rangeValue.min}, folder)
.front();
allocSize = muli(folder, allocSize, d).getValue();
fullRanges.push_back(range(folder, zero, d, one));
fullRanges.push_back(d);
partialRanges.push_back(range(folder, zero, dim(subView, rank), one));
}
auto *buffer = allocBuffer(viewType.getElementType(), allocSize);
auto fullLocalView = view(buffer, fullRanges);
SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
auto *buffer =
allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
auto fullLocalView = view(
MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
auto partialLocalView = slice(fullLocalView, partialRanges);
return PromotionInfo{buffer, fullLocalView, partialLocalView};
}
SmallVector<PromotionInfo, 8>
mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
ArrayRef<Value *> subViews,
ArrayRef<Value *> subViews, bool dynamicBuffers,
OperationFolder *folder) {
if (subViews.empty())
return {};
@ -127,7 +140,8 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
// TODO(ntv): support more cases than just float.
if (!viewType.getElementType().isa<FloatType>())
continue;
auto promotionInfo = promoteFullTileBuffer(b, loc, subView, folder);
auto promotionInfo =
promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder);
promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
res.push_back(promotionInfo);
}
@ -157,12 +171,13 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
}
static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
bool dynamicBuffers,
OperationFolder *folder) {
// 1. Promote the specified views and use them in the new op.
OpBuilder b(op);
ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews =
promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), folder);
auto promotedBufferAndViews = promoteSubViews(
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
SmallVector<Value *, 8> opViews;
opViews.reserve(op.getNumInputsAndOutputs());
SmallVector<std::pair<Value *, Value *>, 8> writebackViews;
@ -197,13 +212,13 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
// 4. Dealloc local buffers.
for (const auto &pi : promotedBufferAndViews)
buffer_dealloc(pi.buffer);
dealloc(pi.buffer);
}
static void promoteSubViews(FuncOp f) {
static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
SmallVector<LinalgOp, 8> toErase;
OperationFolder folder(f.getContext());
f.walk([&folder, &toErase](LinalgOp op) {
f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing.
SetVector<Value *> subViews;
@ -211,7 +226,7 @@ static void promoteSubViews(FuncOp f) {
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
promoteSubViewOperands(op, subViews, &folder);
promoteSubViewOperands(op, subViews, dynamicBuffers, &folder);
toErase.push_back(op);
}
});
@ -221,13 +236,23 @@ static void promoteSubViews(FuncOp f) {
namespace {
struct LinalgPromotionPass : public FunctionPass<LinalgPromotionPass> {
void runOnFunction() override { promoteSubViews(getFunction()); }
LinalgPromotionPass() = default;
LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {}
void runOnFunction() override {
promoteSubViews(getFunction(), dynamicBuffers);
}
bool dynamicBuffers;
};
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgPromotionPass() {
return std::make_unique<LinalgPromotionPass>();
std::unique_ptr<OpPassBase<FuncOp>>
mlir::linalg::createLinalgPromotionPass(bool dynamicBuffers) {
return std::make_unique<LinalgPromotionPass>(dynamicBuffers);
}
static PassRegistration<LinalgPromotionPass>
pass("linalg-promote-subviews", "promote subview ops to local buffers");
pass("linalg-promote-subviews", "promote subview ops to local buffers", [] {
return std::make_unique<LinalgPromotionPass>(clPromoteDynamic);
});

View File

@ -34,6 +34,7 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
//===----------------------------------------------------------------------===//

View File

@ -1,49 +1,5 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// -----
func @buffer_alloc_single_index() {
// expected-error @+1 {{expected one index operand}}
%0 = linalg.buffer_alloc : !linalg.buffer<?xf32>
}
// -----
func @buffer_alloc_unexpected_index(%s : index) {
// expected-error @+1 {{expected zero operand}}
%0 = linalg.buffer_alloc %s : !linalg.buffer<32xf32>
}
// -----
func @buffer_alloc_nonegative_size() {
// expected-error @+1 {{expected nonnegative static buffer size}}
%0 = linalg.buffer_alloc : !linalg.buffer<0xf32>
}
// -----
func @buffer_alloc_nonegative_alignment(%arg0: index) {
// expected-error @+1 {{expected positive alignment}}
%0 = linalg.buffer_alloc %arg0 {alignment = -123}: !linalg.buffer<?xf32>
}
// -----
func @buffer_alloc_powerof2_alignment(%arg0: index) {
// expected-error @+1 {{expected power of 2 alignment}}
%0 = linalg.buffer_alloc %arg0 {alignment = 123}: !linalg.buffer<?xf32>
}
// -----
func @buffer_valid_element_type() {
// expected-error @+1 {{expected valid buffer element type}}
%0 = linalg.buffer_alloc : !linalg.buffer<4xindex>
}
// -----
func @load_number_of_indices(%v : memref<f32>) {
// expected-error @+2 {{incorrect number of indices for load}}
%c0 = constant 0 : index
@ -99,22 +55,6 @@ func @transpose_bad_rank(%v : memref<?x?xf32, (i, j)[off, M]->(off + M * i + j)>
// -----
func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
// expected-error @+2 {{expected memref type}}
%r = linalg.range %min:%max:%step : !linalg.range
%0 = linalg.view %buf[%r]: !linalg.buffer<?xf32> -> index
}
// -----
func @view_num_ranges(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
// expected-error @+2 {{expected 2 ranges}}
%r = linalg.range %min:%max:%step : !linalg.range
%0 = linalg.view %buf[%r]: !linalg.buffer<?xf32> -> memref<?x?xf32, (i, j)[off, M]->(off + M * i + j)>
}
// -----
func @yield_parent(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
// expected-error @+1 {{op expected 'linalg.generic' or 'linalg.indexed_generic' parent op}}
linalg.yield %arg0: memref<?xf32, (i)[off]->(off + i)>
@ -396,15 +336,5 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
// -----
// expected-error @+1 {{expected single element in size list}}
!invalid_type = type !linalg.buffer<1x1xf32>
// -----
// expected-error @+1 {{expected '>'}}
!invalid_type = type !linalg<"buffer<1xf32">
// -----
// expected-error @+1 {{expected valid keyword}}
!invalid_type = type !linalg<"?">

View File

@ -1,35 +1,6 @@
// RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s
// RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS
func @buffer_size(%arg0: !linalg.buffer<?xf32>) {
%c1 = constant 1 : index
%s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
%t = addi %s, %c1 : index
return
}
// CHECK-LABEL: func @buffer_size
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i8*, float*, i64 }">
// CHECK-NEXT: llvm.add {{.*}}, {{.*}} : !llvm.i64
func @buffer_alloc_aligned(%arg0: index) {
%s = linalg.buffer_alloc %arg0 {alignment=16} : !linalg.buffer<?xf32>
return
}
// CHECK-LABEL: func @buffer_alloc_aligned
// CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: %[[m:.*]] = llvm.mul %arg0, %[[c4]] : !llvm.i64
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[c16:.*]] = llvm.mlir.constant(16 : index) : !llvm.i64
// CHECK: %[[a:.*]] = llvm.add %[[m]], %[[c16]] : !llvm.i64
// CHECK: %[[s:.*]] = llvm.sub %[[a]], %[[c1]] : !llvm.i64
// CHECK: %[[alloc:.*]] = llvm.call @malloc(%[[s]]) : (!llvm.i64) -> !llvm<"i8*">
// aligning `ptr` on `align` is done computing the address `ptr + (align - ptr % align) % align`.
// CHECK: %[[cast:.*]] = llvm.ptrtoint %[[alloc]] : !llvm<"i8*"> to !llvm.i64
// CHECK: %[[rem:.*]] = llvm.urem %[[cast]], %[[c16]] : !llvm.i64
// CHECK: %[[drem:.*]] = llvm.sub %[[c16]], %[[rem]] : !llvm.i64
// CHECK: %[[off:.*]] = llvm.urem %[[drem]], %[[c16]] : !llvm.i64
// CHECK: llvm.getelementptr %{{.*}}[%[[off]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
func @range(%arg0: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
@ -44,48 +15,11 @@ func @range(%arg0: index) {
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
return
}
// CHECK-LABEL: func @view
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }">
// CHECK-NEXT: llvm.bitcast {{.*}} : !llvm<"float*"> to !llvm<"float*">
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.return
func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
%0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.buffer<?xf32> -> memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
return
}
// CHECK-LABEL: func @view3d
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }">
// CHECK: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
%0 = linalg.view %arg0[%arg1] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
%1 = linalg.slice %0[%arg1] : memref<?xf32, offset: ?, strides: [1]>, !linalg.range, memref<?xf32, offset: ?, strides: [1]>
func @slice(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: !linalg.range) {
%1 = linalg.slice %arg0[%arg1] : memref<?xf32, offset: ?, strides: [1]>, !linalg.range, memref<?xf32, offset: ?, strides: [1]>
return
}
// CHECK-LABEL: func @slice
// insert ptr for view op
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// insert data ptr for slice op
// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
@ -122,13 +56,6 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
// CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store
// CHECK-NEXT: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64 }*">) -> ()
func @dim(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
%0 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) {
// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
func @subview(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
%c0 = constant 0 : index
%0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : memref<?x?xf32, offset: ?, strides: [?, 1]>

View File

@ -12,22 +12,20 @@
// CHECK-DAG: #[[Stride2Dilation4:.*]] = (d0, d1) -> (d0 * 2 + d1 * 4)
// CHECK-DAG: #[[Stride3Dilation5:.*]] = (d0, d1) -> (d0 * 3 + d1 * 5)
func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
%K = linalg.range %c0:%arg3:%c1 : !linalg.range
%A = linalg.view %arg0[%I, %K] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
%B = linalg.view %arg0[%K, %J] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
%C = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
%A = view %arg0[%M, %K][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%B = view %arg0[%K, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%C = view %arg0[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL: func @matmul(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[N:.*]] = dim %[[B]], 1 : memref<?x?xf32, #[[strided2D]]>
@ -41,21 +39,19 @@ func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
// CHECK: store %[[res]], %[[C]][%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>
func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
%J = linalg.range %c0:%arg2:%c1 : !linalg.range
%2 = linalg.view %arg0[%I, %J] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
%3 = linalg.view %arg0[%J] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
%4 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
%2 = view %arg0[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%3 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%4 = view %arg0[%N][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
linalg.matvec(%2, %3, %4) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
return
}
// CHECK-LABEL: func @matvec(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index) {
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
@ -67,20 +63,19 @@ func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: in
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
// CHECK: store %[[res]], %[[C]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
func @dot(%arg0: memref<?xi8>, %M: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
%1 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
%2 = linalg.view %arg0[%I] : !linalg.buffer<?xf32> -> memref<?xf32, offset: ?, strides: [1]>
%3 = linalg.view %arg0[] : !linalg.buffer<?xf32> -> memref<f32>
%1 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%2 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%3 = view %arg0[][] : memref<?xi8> to memref<f32>
linalg.dot(%1, %2, %3) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
return
}
// CHECK-LABEL: func @dot(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer<?xf32> -> memref<?xf32, #[[strided1D]]>
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer<?xf32> -> memref<f32>
// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>, %{{.*}}: index) {
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[C:.*]] = std.view %{{.*}}[][] : memref<?xi8> to memref<f32>
// CHECK: %[[K:.*]] = dim %[[A]], 0 : memref<?xf32, #[[strided1D]]>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}] : memref<?xf32, #[[strided1D]]>

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s
// RUN: mlir-opt %s -linalg-promote-subviews -test-linalg-promote-dynamic | FileCheck %s --check-prefix=DYNAMIC
#map0 = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
#map1 = (d0) -> (d0 + 2)
@ -9,18 +10,15 @@
// CHECK-DAG: #[[strided2DnoOffset:.*]] = (d0, d1)[s0] -> (d0 * s0 + d1)
module {
func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
func @matmul(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.range %c0:%arg1:%c1 : !linalg.range
%1 = linalg.range %c0:%arg2:%c1 : !linalg.range
%2 = linalg.range %c0:%arg3:%c1 : !linalg.range
%3 = linalg.view %arg0[%0, %2] : !linalg.buffer<?xf32> -> memref<?x?xf32, #map0>
%4 = linalg.view %arg0[%2, %1] : !linalg.buffer<?xf32> -> memref<?x?xf32, #map0>
%5 = linalg.view %arg0[%0, %1] : !linalg.buffer<?xf32> -> memref<?x?xf32, #map0>
%3 = view %A[%M, %K][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
%4 = view %A[%K, %N][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
%5 = view %A[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
%6 = dim %3, 0 : memref<?x?xf32, #map0>
%7 = dim %3, 1 : memref<?x?xf32, #map0>
%8 = dim %4, 1 : memref<?x?xf32, #map0>
@ -44,7 +42,7 @@ module {
}
}
// CHECK-LABEL: func @matmul(%{{.*}}: !linalg.buffer<?xf32>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@ -52,16 +50,19 @@ module {
// CHECK: %[[vB:.*]] = linalg.subview {{.*}} : memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[vC:.*]] = linalg.subview {{.*}} : memref<?x?xf32, #[[strided2D]]>
///
// CHECK: %[[tmpA:.*]] = linalg.buffer_alloc : !linalg.buffer<8xf32>
// CHECK: %[[fullA:.*]] = linalg.view %[[tmpA]][{{.*}}] : !linalg.buffer<8xf32> -> memref<?x?xf32>
// CHECK: %[[tmpA:.*]] = alloc() : memref<32xi8>
// CHECK: %[[fullA:.*]] = std.view %[[tmpA]][{{.*}}][] : memref<32xi8> to memref<?x?xf32>
// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref<?xi8> to memref<?x?xf32>
// CHECK: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
///
// CHECK: %[[tmpB:.*]] = linalg.buffer_alloc : !linalg.buffer<12xf32>
// CHECK: %[[fullB:.*]] = linalg.view %[[tmpB]][{{.*}}] : !linalg.buffer<12xf32> -> memref<?x?xf32>
// CHECK: %[[tmpB:.*]] = alloc() : memref<48xi8>
// CHECK: %[[fullB:.*]] = std.view %[[tmpB]][{{.*}}][] : memref<48xi8> to memref<?x?xf32>
// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref<?xi8> to memref<?x?xf32>
// CHECK: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
///
// CHECK: %[[tmpC:.*]] = linalg.buffer_alloc : !linalg.buffer<6xf32>
// CHECK: %[[fullC:.*]] = linalg.view %[[tmpC]][{{.*}}] : !linalg.buffer<6xf32> -> memref<?x?xf32>
// CHECK: %[[tmpC:.*]] = alloc() : memref<24xi8>
// CHECK: %[[fullC:.*]] = std.view %[[tmpC]][{{.*}}][] : memref<24xi8> to memref<?x?xf32>
// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref<?xi8> to memref<?x?xf32>
// CHECK: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
// CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref<?x?xf32>, f32
@ -75,6 +76,6 @@ module {
//
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2DnoOffset]]>, memref<?x?xf32, #[[strided2D]]>
//
// CHECK: linalg.buffer_dealloc %[[tmpA]] : !linalg.buffer<8xf32>
// CHECK: linalg.buffer_dealloc %[[tmpB]] : !linalg.buffer<12xf32>
// CHECK: linalg.buffer_dealloc %[[tmpC]] : !linalg.buffer<6xf32>
// CHECK: dealloc %[[tmpA]] : memref<32xi8>
// CHECK: dealloc %[[tmpB]] : memref<48xi8>
// CHECK: dealloc %[[tmpC]] : memref<24xi8>

View File

@ -7,7 +7,6 @@
// CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0)
// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
// CHECK-DAG: #[[strided2D42by1SymbolicOffset:.*]] = (d0, d1)[s0] -> (d0 * 42 + s0 + d1)
// CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)
// CHECK-DAG: #[[strided6D:.*]] = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)
@ -21,60 +20,31 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
func @buffer_size(%arg0: !linalg.buffer<?xf32>) -> index {
%0 = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
return %0 : index
}
// CHECK-LABEL: func @buffer_size
// CHECK: linalg.buffer_size {{.*}} : !linalg.buffer<?xf32>
func @buffer(%arg0: index, %arg1: index) {
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
%2 = linalg.buffer_alloc %0 {alignment = 16} : !linalg.buffer<?xvector<4xi8>>
%3 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
%4 = linalg.buffer_alloc {alignment = 32} : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %4 : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %3 : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %2 : !linalg.buffer<?xvector<4xi8>>
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
return
}
// CHECK-LABEL: func @buffer(%{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} {alignment = 16 : i64} : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_alloc {alignment = 32 : i64} : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%c0 = constant 0 : index
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> memref<?x?xf32, offset: ?, strides: [?, 1]>
%1 = alloc (%0) : memref<?xi8>
%2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
%3 = view %1[%arg0, %arg0][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%4 = linalg.slice %3[%2, %2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, !linalg.range, !linalg.range, memref<?x?xf32, offset: ?, strides: [?, 1]>
%5 = linalg.slice %3[%2, %arg2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, !linalg.range, index, memref<?xf32, offset: ?, strides: [1]>
%6 = linalg.slice %3[%arg2, %2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, index, !linalg.range, memref<?xf32, offset: ?, strides: [1]>
%7 = linalg.slice %3[%arg2, %arg3] : memref<?x?xf32, offset: ?, strides: [?, 1]>, index, index, memref<f32>
%8 = linalg.view %1[%2, %2] : !linalg.buffer<?xf32> -> memref<?x?xvector<4x4xf32>, offset: ?, strides: [?, 1]>
linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
%8 = view %1[%arg0, %arg0][%c0] : memref<?xi8> to memref<?x?xvector<4x4xf32>, offset: ?, strides: [?, 1]>
dealloc %1 : memref<?xi8>
return
}
// CHECK-LABEL: func @views(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xf32, #[[strided2D]]>
// CHECK: muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: alloc(%{{.*}}) : memref<?xi8>
// CHECK-NEXT: range
// CHECK-NEXT: std.view %{{.*}}[%{{.*}}][%{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2D]]>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, !linalg.range, index, memref<?xf32, #[[strided1D]]>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, index, !linalg.range, memref<?xf32, #[[strided1D]]>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>, index, index, memref<f32>
// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> memref<?x?xvector<4x4xf32>, #[[strided2D]]>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
// CHECK-NEXT: view %{{.*}}[%{{.*}}][%{{.*}}] : memref<?xi8> to memref<?x?xvector<4x4xf32>, #[[strided2D]]>
// CHECK-NEXT: dealloc %{{.*}} : memref<?xi8>
func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<?xf32, offset: ?, strides: [1]>, %arg3: memref<f32>) {
linalg.matmul(%arg0, %arg0, %arg0) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
@ -88,41 +58,6 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?xf3
// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[strided2D]]>, memref<?xf32, #[[strided1D]]>, memref<?xf32, #[[strided1D]]>
// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?xf32, #[[strided1D]]>, memref<?xf32, #[[strided1D]]>, memref<f32>
func @dim(%arg0: memref<?x?xf32, offset: ?, strides: [42, 1]>) {
%0 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [42, 1]>
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
return
}
// CHECK-LABEL: func @dim(
// CHECK: %{{.*}}: memref<?x?xf32, #[[strided2D42by1SymbolicOffset]]>) {
// CHECK-NEXT: dim %{{.*}}, 1 : memref<?x?xf32, #[[strided2D42by1SymbolicOffset]]>
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) {
loop.for %i0 = %arg0 to %arg1 step %arg2 {
loop.for %i1 = %arg0 to %arg1 step %arg2 {
%min_cmp = cmpi "slt", %i0, %i1 : index
%min = select %min_cmp, %i0, %i1 : index
%max_cmp = cmpi "sge", %i0, %i1 : index
%max = select %max_cmp, %i0, %i1 : index
loop.for %i2 = %min to %max step %i1 {
}
}
}
return
}
// CHECK-LABEL: func @linalg_for(
// CHECK: %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: cmpi "slt", %{{.*}}, %{{.*}} : index
// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cmpi "sge", %{{.*}}, %{{.*}} : index
// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : memref<?xf32, offset: ?, strides: [1]>, f32
return
@ -190,13 +125,6 @@ func @subview(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>) {
// CHECK: constant 0 : index
// CHECK: linalg.subview %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?xvector<3x4xi4>, #[[strided2D]]>
func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) {
%c0 = linalg.buffer_alloc : !linalg.buffer<17xf32>
%c1 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
%c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> memref<?xf32, offset: ?, strides: [1]>
return
}
#accesses = [
(i, j, k) -> (j, i),
(i, j, k) -> (i, k, i + j)

View File

@ -5,18 +5,19 @@
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
#strided1D = (d0)[s0] -> (d0 + s0)
#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
#strided1D = (d0) -> (d0)
#strided2D = (d0, d1)[s0] -> (d0 * s0 + d1)
// Creates and returns a 1-D buffer of size %s filled with the value %f
func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
func @alloc_filled_f32(%s : index, %f : f32) -> memref<?xi8> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%buf = linalg.buffer_alloc %s {alignment = 256} : !linalg.buffer<?xf32>
%R = linalg.range %c0:%s:%c1 : !linalg.range
%V = linalg.view %buf[%R] : !linalg.buffer<?xf32> -> memref<?xf32, #strided1D>
%c4 = constant 4 : index
%s4 = muli %s, %c4: index
%buf = alloc(%s4) {alignment = 256} : memref<?xi8>
%V = view %buf[%s][] : memref<?xi8> to memref<?xf32, #strided1D>
linalg.fill(%V, %f) : memref<?xf32, #strided1D>, f32
return %buf : !linalg.buffer<?xf32>
return %buf : memref<?xi8>
}
// Test for linalg.dot.
@ -28,21 +29,20 @@ func @dot() -> f32 {
%f1 = constant 1.00000e+00 : f32
%f2 = constant 2.00000e+00 : f32
%bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<?xf32>)
%bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<?xf32>)
%bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
%bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (memref<?xi8>)
%bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (memref<?xi8>)
%bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (memref<?xi8>)
%R = linalg.range %c0:%c16:%c1 : !linalg.range
%A = linalg.view %bA[%R] : !linalg.buffer<?xf32> -> memref<?xf32, #strided1D>
%B = linalg.view %bB[%R] : !linalg.buffer<?xf32> -> memref<?xf32, #strided1D>
%C = linalg.view %bC[] : !linalg.buffer<?xf32> -> memref<f32>
%A = view %bA[%c16][] : memref<?xi8> to memref<?xf32, #strided1D>
%B = view %bB[%c16][] : memref<?xi8> to memref<?xf32, #strided1D>
%C = view %bC[][] : memref<?xi8> to memref<f32>
linalg.dot(%A, %B, %C) : memref<?xf32, #strided1D>, memref<?xf32, #strided1D>, memref<f32>
%res = load %C[] : memref<f32>
linalg.buffer_dealloc %bC : !linalg.buffer<?xf32>
linalg.buffer_dealloc %bB : !linalg.buffer<?xf32>
linalg.buffer_dealloc %bA : !linalg.buffer<?xf32>
dealloc %bC : memref<?xi8>
dealloc %bB : memref<?xi8>
dealloc %bA : memref<?xi8>
return %res : f32
}
@ -61,23 +61,20 @@ func @matmul() -> f32 {
%f2 = constant 2.00000e+00 : f32
%f10 = constant 10.00000e+00 : f32
%bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (!linalg.buffer<?xf32>)
%bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (!linalg.buffer<?xf32>)
%bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
%bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (memref<?xi8>)
%bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (memref<?xi8>)
%bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (memref<?xi8>)
%M = linalg.range %c0:%c10:%c1 : !linalg.range
%N = linalg.range %c0:%c10:%c1 : !linalg.range
%K = linalg.range %c0:%c16:%c1 : !linalg.range
%A = linalg.view %bA[%M, %K] : !linalg.buffer<?xf32> -> memref<?x?xf32, #strided2D>
%B = linalg.view %bB[%K, %N] : !linalg.buffer<?xf32> -> memref<?x?xf32, #strided2D>
%C = linalg.view %bC[%M, %N] : !linalg.buffer<?xf32> -> memref<?x?xf32, #strided2D>
%A = view %bA[%c10, %c16][] : memref<?xi8> to memref<?x?xf32, #strided2D>
%B = view %bB[%c16, %c10][] : memref<?xi8> to memref<?x?xf32, #strided2D>
%C = view %bC[%c10, %c10][] : memref<?xi8> to memref<?x?xf32, #strided2D>
linalg.matmul(%A, %B, %C) : memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>
%res = load %C[%c6, %c7] : memref<?x?xf32, #strided2D>
linalg.buffer_dealloc %bC : !linalg.buffer<?xf32>
linalg.buffer_dealloc %bB : !linalg.buffer<?xf32>
linalg.buffer_dealloc %bA : !linalg.buffer<?xf32>
dealloc %bC : memref<?xi8>
dealloc %bB : memref<?xi8>
dealloc %bA : memref<?xi8>
return %res : f32
}