[MLIR] Create memref dialect and move dialect-specific ops from std.

Create the memref dialect and move dialect-specific ops
from std dialect to this dialect.

Moved ops:
AllocOp -> MemRef_AllocOp
AllocaOp -> MemRef_AllocaOp
AssumeAlignmentOp -> MemRef_AssumeAlignmentOp
DeallocOp -> MemRef_DeallocOp
DimOp -> MemRef_DimOp
MemRefCastOp -> MemRef_CastOp
MemRefReinterpretCastOp -> MemRef_ReinterpretCastOp
GetGlobalMemRefOp -> MemRef_GetGlobalOp
GlobalMemRefOp -> MemRef_GlobalOp
LoadOp -> MemRef_LoadOp
PrefetchOp -> MemRef_PrefetchOp
ReshapeOp -> MemRef_ReshapeOp
StoreOp -> MemRef_StoreOp
SubViewOp -> MemRef_SubViewOp
TransposeOp -> MemRef_TransposeOp
TensorLoadOp -> MemRef_TensorLoadOp
TensorStoreOp -> MemRef_TensorStoreOp
TensorToMemRefOp -> MemRef_BufferCastOp
ViewOp -> MemRef_ViewOp

The roadmap to split the memref dialect from std is discussed here:
https://llvm.discourse.group/t/rfc-split-the-memref-dialect-from-std/2667

Differential Revision: https://reviews.llvm.org/D98041
This commit is contained in:
Julian Gross 2021-02-10 13:53:11 +01:00
parent a88371490d
commit e2310704d8
367 changed files with 10070 additions and 9539 deletions

View File

@ -779,8 +779,8 @@ the deallocation of the source value.
## Known Limitations
BufferDeallocation introduces additional copies using allocations from the
std” dialect (“std.alloc”). Analogous, all deallocations use the “std”
dialect-free operation “std.dealloc”. The actual copy process is realized using
“linalg.copy”. Furthermore, buffers are essentially immutable after their
creation in a block. Another limitations are known in the case using
unstructered control flow.
memref” dialect (“memref.alloc”). Analogous, all deallocations use the
“memref” dialect-free operation “memref.dealloc”. The actual copy process is
realized using “linalg.copy”. Furthermore, buffers are essentially immutable
after their creation in a block. Another limitations are known in the case
using unstructered control flow.

View File

@ -190,8 +190,8 @@ One convenient utility provided by the MLIR bufferization infrastructure is the
`BufferizeTypeConverter`, which comes pre-loaded with the necessary conversions
and materializations between `tensor` and `memref`.
In this case, the `StandardOpsDialect` is marked as legal, so the `tensor_load`
and `tensor_to_memref` ops, which are inserted automatically by the dialect
In this case, the `MemRefOpsDialect` is marked as legal, so the `tensor_load`
and `buffer_cast` ops, which are inserted automatically by the dialect
conversion framework as materializations, are legal. There is a helper
`populateBufferizeMaterializationLegality`
([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L53))
@ -247,7 +247,7 @@ from the program.
The easiest way to write a finalizing bufferize pass is to not write one at all!
MLIR provides a pass `finalizing-bufferize` which eliminates the `tensor_load` /
`tensor_to_memref` materialization ops inserted by partial bufferization passes
`buffer_cast` materialization ops inserted by partial bufferization passes
and emits an error if that is not sufficient to remove all tensors from the
program.
@ -268,7 +268,7 @@ recommended in new code. A helper,
`populateEliminateBufferizeMaterializationsPatterns`
([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L58))
is available for such passes to provide patterns that eliminate `tensor_load`
and `tensor_to_memref`.
and `buffer_cast`.
## Changes since [the talk](#the-talk)

View File

@ -406,9 +406,9 @@ into a form that will resemble:
#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
func @example(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
%0 = memref_cast %arg0 : memref<?x?xf32> to memref<?x?xf32, #map0>
%1 = memref_cast %arg1 : memref<?x?xf32> to memref<?x?xf32, #map0>
%2 = memref_cast %arg2 : memref<?x?xf32> to memref<?x?xf32, #map0>
%0 = memref.cast %arg0 : memref<?x?xf32> to memref<?x?xf32, #map0>
%1 = memref.cast %arg1 : memref<?x?xf32> to memref<?x?xf32, #map0>
%2 = memref.cast %arg2 : memref<?x?xf32> to memref<?x?xf32, #map0>
call @pointwise_add(%0, %1, %2) : (memref<?x?xf32, #map0>, memref<?x?xf32, #map0>, memref<?x?xf32, #map0>) -> ()
return
}
@ -518,9 +518,9 @@ A set of ops that manipulate metadata but do not move memory. These ops take
generally alias the operand `view`. At the moment the existing ops are:
```
* `std.view`,
* `memref.view`,
* `std.subview`,
* `std.transpose`.
* `memref.transpose`.
* `linalg.range`,
* `linalg.slice`,
* `linalg.reshape`,

View File

@ -0,0 +1,76 @@
# 'memref' Dialect
This dialect provides documentation for operations within the MemRef dialect.
**Please post an RFC on the [forum](https://llvm.discourse.group/c/mlir/31)
before adding or changing any operation in this dialect.**
[TOC]
## Operations
[include "Dialects/MemRefOps.md"]
### 'dma_start' operation
Syntax:
```
operation ::= `dma_start` ssa-use`[`ssa-use-list`]` `,`
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)?
`:` memref-type `,` memref-type `,` memref-type
```
Starts a non-blocking DMA operation that transfers data from a source memref to
a destination memref. The operands include the source and destination memref's
each followed by its indices, size of the data transfer in terms of the number
of elements (of the elemental type of the memref), a tag memref with its
indices, and optionally two additional arguments corresponding to the stride (in
terms of number of elements) and the number of elements to transfer per stride.
The tag location is used by a dma_wait operation to check for completion. The
indices of the source memref, destination memref, and the tag memref have the
same restrictions as any load/store operation in an affine context (whenever DMA
operations appear in an affine context). See
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
in affine contexts. This allows powerful static analysis and transformations in
the presence of such DMAs including rescheduling, pipelining / overlap with
computation, and checking for matching start/end operations. The source and
destination memref need not be of the same dimensionality, but need to have the
same elemental type.
For example, a `dma_start` operation that transfers 32 vector elements from a
memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would be
specified as shown below.
Example:
```mlir
%size = constant 32 : index
%tag = alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
%idx = constant 0 : index
dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] :
memref<40 x 8 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 0>,
memref<2 x 4 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 2>,
memref<1 x i32>, affine_map<(d0) -> (d0)>, 4>
```
### 'dma_wait' operation
Syntax:
```
operation ::= `dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type
```
Blocks until the completion of a DMA operation associated with the tag element
specified with a tag memref and its indices. The operands include the tag memref
followed by its indices and the number of elements associated with the DMA being
waited on. The indices of the tag memref have the same restrictions as
load/store indices.
Example:
```mlir
dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
```

View File

@ -13,67 +13,3 @@ before adding or changing any operation in this dialect.**
## Operations
[include "Dialects/StandardOps.md"]
### 'dma_start' operation
Syntax:
```
operation ::= `dma_start` ssa-use`[`ssa-use-list`]` `,`
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)?
`:` memref-type `,` memref-type `,` memref-type
```
Starts a non-blocking DMA operation that transfers data from a source memref to
a destination memref. The operands include the source and destination memref's
each followed by its indices, size of the data transfer in terms of the number
of elements (of the elemental type of the memref), a tag memref with its
indices, and optionally two additional arguments corresponding to the stride (in
terms of number of elements) and the number of elements to transfer per stride.
The tag location is used by a dma_wait operation to check for completion. The
indices of the source memref, destination memref, and the tag memref have the
same restrictions as any load/store operation in an affine context (whenever DMA
operations appear in an affine context). See
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
in affine contexts. This allows powerful static analysis and transformations in
the presence of such DMAs including rescheduling, pipelining / overlap with
computation, and checking for matching start/end operations. The source and
destination memref need not be of the same dimensionality, but need to have the
same elemental type.
For example, a `dma_start` operation that transfers 32 vector elements from a
memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would be
specified as shown below.
Example:
```mlir
%size = constant 32 : index
%tag = alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
%idx = constant 0 : index
dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] :
memref<40 x 8 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 0>,
memref<2 x 4 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 2>,
memref<1 x i32>, affine_map<(d0) -> (d0)>, 4>
```
### 'dma_wait' operation
Syntax:
```
operation ::= `dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type
```
Blocks until the completion of a DMA operation associated with the tag element
specified with a tag memref and its indices. The operands include the tag memref
followed by its indices and the number of elements associated with the DMA being
waited on. The indices of the tag memref have the same restrictions as
load/store indices.
Example:
```mlir
dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
```

View File

@ -200,7 +200,7 @@ for.
### The `OpPointer` and `ConstOpPointer` Classes
The "typed operation" classes for registered operations (e.g. like `DimOp` for
the "std.dim" operation in standard ops) contain a pointer to an operation and
the "memref.dim" operation in memref ops) contain a pointer to an operation and
provide typed APIs for processing it.
However, this is a problem for our current `const` design - `const DimOp` means

View File

@ -211,7 +211,7 @@ are nested inside of other operations that themselves have this trait.
This trait is carried by region holding operations that define a new scope for
automatic allocation. Such allocations are automatically freed when control is
transferred back from the regions of such operations. As an example, allocations
performed by [`std.alloca`](Dialects/Standard.md#stdalloca-allocaop) are
performed by [`memref.alloca`](Dialects/MemRef.md#memrefalloca-allocaop) are
automatically freed when control leaves the region of its closest surrounding op
that has the trait AutomaticAllocationScope.

View File

@ -50,8 +50,9 @@ framework, we need to provide two things (and an optional third):
## Conversion Target
For our purposes, we want to convert the compute-intensive `Toy` operations into
a combination of operations from the `Affine` `Standard` dialects for further
optimization. To start off the lowering, we first define our conversion target:
a combination of operations from the `Affine`, `MemRef` and `Standard` dialects
for further optimization. To start off the lowering, we first define our
conversion target:
```c++
void ToyToAffineLoweringPass::runOnFunction() {
@ -61,8 +62,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<mlir::AffineDialect, mlir::StandardOpsDialect>();
// `Affine`, `MemRef` and `Standard` dialects.
target.addLegalDialect<mlir::AffineDialect, mlir::memref::MemRefDialect,
mlir::StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
// affine loops, memref operations and standard operations. This lowering
// expects that all calls have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
@ -16,6 +16,7 @@
#include "toy/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -36,7 +37,7 @@ static MemRefType convertTensorToMemRef(TensorType type) {
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter) {
auto alloc = rewriter.create<AllocOp>(loc, type);
auto alloc = rewriter.create<memref::AllocOp>(loc, type);
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
@ -44,7 +45,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
dealloc->moveBefore(&parentBlock->back());
return alloc;
}
@ -152,8 +153,8 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
if (!valueShape.empty()) {
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
} else {
// This is the case of a tensor of rank 0.
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
@ -257,7 +258,7 @@ namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, StandardOpsDialect>();
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
}
void runOnFunction() final;
};
@ -283,8 +284,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
// `Affine`, `MemRef` and `Standard` dialects.
target.addLegalDialect<AffineDialect, memref::MemRefDialect,
StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
// affine loops, memref operations and standard operations. This lowering
// expects that all calls have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
@ -16,6 +16,7 @@
#include "toy/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -36,7 +37,7 @@ static MemRefType convertTensorToMemRef(TensorType type) {
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter) {
auto alloc = rewriter.create<AllocOp>(loc, type);
auto alloc = rewriter.create<memref::AllocOp>(loc, type);
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
@ -44,7 +45,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
dealloc->moveBefore(&parentBlock->back());
return alloc;
}
@ -152,8 +153,8 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
if (!valueShape.empty()) {
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
} else {
// This is the case of a tensor of rank 0.
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
@ -256,7 +257,7 @@ namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, StandardOpsDialect>();
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
}
void runOnFunction() final;
};
@ -282,8 +283,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
// `Affine`, `MemRef` and `Standard` dialects.
target.addLegalDialect<AffineDialect, memref::MemRefDialect,
StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want

View File

@ -30,6 +30,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
@ -91,7 +92,8 @@ public:
// Generate a call to printf for the current element of the loop.
auto printOp = cast<toy::PrintOp>(op);
auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs);
auto elementLoad =
rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
ArrayRef<Value>({formatSpecifierCst, elementLoad}));

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
// affine loops, memref operations and standard operations. This lowering
// expects that all calls have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
@ -16,6 +16,7 @@
#include "toy/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -36,7 +37,7 @@ static MemRefType convertTensorToMemRef(TensorType type) {
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter) {
auto alloc = rewriter.create<AllocOp>(loc, type);
auto alloc = rewriter.create<memref::AllocOp>(loc, type);
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
@ -44,7 +45,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
dealloc->moveBefore(&parentBlock->back());
return alloc;
}
@ -152,8 +153,8 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
if (!valueShape.empty()) {
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
} else {
// This is the case of a tensor of rank 0.
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
@ -257,7 +258,7 @@ namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, StandardOpsDialect>();
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
}
void runOnFunction() final;
};
@ -283,8 +284,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
// `Affine`, `MemRef` and `Standard` dialects.
target.addLegalDialect<AffineDialect, memref::MemRefDialect,
StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want

View File

@ -30,6 +30,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
@ -91,7 +92,8 @@ public:
// Generate a call to printf for the current element of the loop.
auto printOp = cast<toy::PrintOp>(op);
auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs);
auto elementLoad =
rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
ArrayRef<Value>({formatSpecifierCst, elementLoad}));

View File

@ -121,7 +121,7 @@ def LowerHostCodeToLLVM : Pass<"lower-host-to-llvm", "ModuleOp"> {
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
let summary = "Generate NVVM operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
let dependentDialects = ["NVVM::NVVMDialect"];
let dependentDialects = ["NVVM::NVVMDialect", "memref::MemRefDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@ -210,7 +210,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the "
"Standard dialect";
let constructor = "mlir::createConvertLinalgToStandardPass()";
let dependentDialects = ["StandardOpsDialect"];
let dependentDialects = ["memref::MemRefDialect", "StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@ -316,7 +316,11 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let summary = "Convert operations from the shape dialect into the standard "
"dialect";
let constructor = "mlir::createConvertShapeToStandardPass()";
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
let dependentDialects = [
"memref::MemRefDialect",
"StandardOpsDialect",
"scf::SCFDialect"
];
}
def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
@ -474,7 +478,11 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
let summary = "Lower the operations from the vector dialect into the SCF "
"dialect";
let constructor = "mlir::createConvertVectorToSCFPass()";
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
let dependentDialects = [
"AffineDialect",
"memref::MemRefDialect",
"scf::SCFDialect"
];
let options = [
Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
"Perform full unrolling when converting vector transfers to SCF">,

View File

@ -72,7 +72,8 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
/// stdlib malloc/free is used by default for allocating memrefs allocated with
/// std.alloc, while LLVM's alloca is used for those allocated with std.alloca.
/// memref.alloc, while LLVM's alloca is used for those allocated with
/// memref.alloca.
std::unique_ptr<OperationPass<ModuleOp>>
createLowerToLLVMPass(const LowerToLLVMOptions &options =
LowerToLLVMOptions::getDefaultOptions());

View File

@ -18,6 +18,7 @@ include "mlir/Pass/PassBase.td"
def AffineDataCopyGeneration : FunctionPass<"affine-data-copy-generate"> {
let summary = "Generate explicit copying for affine memory operations";
let constructor = "mlir::createAffineDataCopyGenerationPass()";
let dependentDialects = ["memref::MemRefDialect"];
let options = [
Option<"fastMemoryCapacity", "fast-mem-capacity", "uint64_t",
/*default=*/"std::numeric_limits<uint64_t>::max()",

View File

@ -9,6 +9,7 @@ add_subdirectory(GPU)
add_subdirectory(Math)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)

View File

@ -480,7 +480,7 @@ def GPU_LaunchOp : GPU_Op<"launch">,
%num_bx : index, %num_by : index, %num_bz : index,
%num_tx : index, %num_ty : index, %num_tz : index)
"some_op"(%bx, %tx) : (index, index) -> ()
%3 = "std.load"(%val1, %bx) : (memref<?xf32, 1>, index) -> f32
%3 = "memref.load"(%val1, %bx) : (memref<?xf32, 1>, index) -> f32
}
```
@ -812,7 +812,7 @@ def GPU_AllocOp : GPU_Op<"alloc", [
let summary = "GPU memory allocation operation.";
let description = [{
The `gpu.alloc` operation allocates a region of memory on the GPU. It is
similar to the `std.alloc` op, but supports asynchronous GPU execution.
similar to the `memref.alloc` op, but supports asynchronous GPU execution.
The op does not execute before all async dependencies have finished
executing.
@ -850,7 +850,7 @@ def GPU_DeallocOp : GPU_Op<"dealloc", [GPU_AsyncOpInterface]> {
let description = [{
The `gpu.dealloc` operation frees the region of memory referenced by a
memref which was originally created by the `gpu.alloc` operation. It is
similar to the `std.dealloc` op, but supports asynchronous GPU execution.
similar to the `memref.dealloc` op, but supports asynchronous GPU execution.
The op does not execute before all async dependencies have finished
executing.

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/FoldUtils.h"
@ -35,30 +36,25 @@ struct FoldedValueBuilder {
};
using folded_math_tanh = FoldedValueBuilder<math::TanhOp>;
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
using folded_std_dim = FoldedValueBuilder<DimOp>;
using folded_memref_alloc = FoldedValueBuilder<memref::AllocOp>;
using folded_memref_cast = FoldedValueBuilder<memref::CastOp>;
using folded_memref_dim = FoldedValueBuilder<memref::DimOp>;
using folded_memref_load = FoldedValueBuilder<memref::LoadOp>;
using folded_memref_sub_view = FoldedValueBuilder<memref::SubViewOp>;
using folded_memref_tensor_load = FoldedValueBuilder<memref::TensorLoadOp>;
using folded_memref_view = FoldedValueBuilder<memref::ViewOp>;
using folded_std_muli = FoldedValueBuilder<MulIOp>;
using folded_std_addi = FoldedValueBuilder<AddIOp>;
using folded_std_addf = FoldedValueBuilder<AddFOp>;
using folded_std_alloc = FoldedValueBuilder<AllocOp>;
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
using folded_std_dim = FoldedValueBuilder<DimOp>;
using folded_std_index_cast = FoldedValueBuilder<IndexCastOp>;
using folded_std_muli = FoldedValueBuilder<MulIOp>;
using folded_std_mulf = FoldedValueBuilder<MulFOp>;
using folded_std_memref_cast = FoldedValueBuilder<MemRefCastOp>;
using folded_std_select = FoldedValueBuilder<SelectOp>;
using folded_std_load = FoldedValueBuilder<LoadOp>;
using folded_std_subi = FoldedValueBuilder<SubIOp>;
using folded_std_sub_view = FoldedValueBuilder<SubViewOp>;
using folded_std_tensor_load = FoldedValueBuilder<TensorLoadOp>;
using folded_std_view = FoldedValueBuilder<ViewOp>;
using folded_std_zero_extendi = FoldedValueBuilder<ZeroExtendIOp>;
using folded_std_sign_extendi = FoldedValueBuilder<SignExtendIOp>;
using folded_tensor_extract = FoldedValueBuilder<tensor::ExtractOp>;

View File

@ -18,7 +18,7 @@
//
// The other operations form the bridge between the opaque pointer and
// the actual storage of pointers, indices, and values. These operations
// resemble 'tensor_to_memref' in the sense that they map tensors to
// resemble 'buffer_cast' in the sense that they map tensors to
// their bufferized memrefs, but they lower into actual calls since
// sparse storage does not bufferize into a single memrefs, as dense
// tensors do, but into a hierarchical storage scheme where pointers
@ -74,9 +74,9 @@ def Linalg_SparseTensorToPointersMemRefOp :
let description = [{
Returns the pointers array of the sparse storage scheme at the
given dimension for the given tensor. This is similar to the
`tensor_to_memref` operation in the sense that it provides a bridge
`buffer_cast` operation in the sense that it provides a bridge
between a tensor world view and a bufferized world view. Unlike the
`tensor_to_memref` operation, however, this sparse operation actually
`buffer_cast` operation, however, this sparse operation actually
lowers into a call into a support library to obtain access to the
pointers array.
@ -98,9 +98,9 @@ def Linalg_SparseTensorToIndicesMemRefOp :
let description = [{
Returns the indices array of the sparse storage scheme at the
given dimension for the given tensor. This is similar to the
`tensor_to_memref` operation in the sense that it provides a bridge
`buffer_cast` operation in the sense that it provides a bridge
between a tensor world view and a bufferized world view. Unlike the
`tensor_to_memref` operation, however, this sparse operation actually
`buffer_cast` operation, however, this sparse operation actually
lowers into a call into a support library to obtain access to the
indices array.
@ -122,9 +122,9 @@ def Linalg_SparseTensorToValuesMemRefOp :
let description = [{
Returns the values array of the sparse storage scheme for the given
tensor, independent of the actual dimension. This is similar to the
`tensor_to_memref` operation in the sense that it provides a bridge
`buffer_cast` operation in the sense that it provides a bridge
between a tensor world view and a bufferized world view. Unlike the
`tensor_to_memref` operation, however, this sparse operation actually
`buffer_cast` operation, however, this sparse operation actually
lowers into a call into a support library to obtain access to the
values array.

View File

@ -34,11 +34,11 @@ createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();
/// Create a pass to convert Linalg operations to scf.for loops and
/// std.load/std.store accesses.
/// memref.load/memref.store accesses.
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToLoopsPass();
/// Create a pass to convert Linalg operations to scf.parallel loops and
/// std.load/std.store accesses.
/// memref.load/memref.store accesses.
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToParallelLoopsPass();
/// Create a pass to convert Linalg operations to affine.for loops and

View File

@ -19,7 +19,7 @@ def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
This pass only converts ops that operate on ranked tensors.
}];
let constructor = "mlir::createConvertElementwiseToLinalgPass()";
let dependentDialects = ["linalg::LinalgDialect"];
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
}
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
@ -70,13 +70,21 @@ def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"];
let dependentDialects = [
"linalg::LinalgDialect",
"scf::SCFDialect",
"AffineDialect"
];
}
def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> {
let summary = "Bufferize the linalg dialect";
let constructor = "mlir::createLinalgBufferizePass()";
let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
let dependentDialects = [
"linalg::LinalgDialect",
"AffineDialect",
"memref::MemRefDialect"
];
}
def LinalgLowerToParallelLoops
@ -90,7 +98,12 @@ def LinalgLowerToParallelLoops
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
let dependentDialects = [
"AffineDialect",
"linalg::LinalgDialect",
"memref::MemRefDialect",
"scf::SCFDialect"
];
}
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
@ -109,7 +122,10 @@ def LinalgTiling : FunctionPass<"linalg-tile"> {
let summary = "Tile operations in the linalg dialect";
let constructor = "mlir::createLinalgTilingPass()";
let dependentDialects = [
"AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"
"AffineDialect",
"linalg::LinalgDialect",
"memref::MemRefDialect",
"scf::SCFDialect"
];
let options = [
ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
@ -127,7 +143,12 @@ def LinalgTilingToParallelLoops
"Test generation of dynamic promoted buffers",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
let dependentDialects = [
"AffineDialect",
"linalg::LinalgDialect",
"memref::MemRefDialect",
"scf::SCFDialect"
];
}
def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> {

View File

@ -147,8 +147,8 @@ LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
/// dimension. If that is not possible, contains the dynamic size of the
/// subview. The call back should return the buffer to use.
using AllocBufferCallbackFn = std::function<Optional<Value>(
OpBuilder &b, SubViewOp subView, ArrayRef<Value> boundingSubViewSize,
OperationFolder *folder)>;
OpBuilder &b, memref::SubViewOp subView,
ArrayRef<Value> boundingSubViewSize, OperationFolder *folder)>;
/// Callback function type used to deallocate the buffers used to hold the
/// promoted subview.
@ -244,7 +244,7 @@ struct PromotionInfo {
Value partialLocalView;
};
Optional<PromotionInfo>
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
AllocBufferCallbackFn allocationFn,
OperationFolder *folder = nullptr);
@ -818,7 +818,7 @@ struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
/// Match and rewrite for the pattern:
/// ```
/// %alloc = ...
/// [optional] %view = std.view %alloc ...
/// [optional] %view = memref.view %alloc ...
/// %subView = subview %allocOrView ...
/// [optional] linalg.fill(%allocOrView, %cst) ...
/// ...
@ -828,7 +828,7 @@ struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
/// into
/// ```
/// [unchanged] %alloc = ...
/// [unchanged] [optional] %view = std.view %alloc ...
/// [unchanged] [optional] %view = memref.view %alloc ...
/// [unchanged] [unchanged] %subView = subview %allocOrView ...
/// ...
/// vector.transfer_read %in[...], %cst ...
@ -849,7 +849,7 @@ struct LinalgCopyVTRForwardingPattern
/// Match and rewrite for the pattern:
/// ```
/// %alloc = ...
/// [optional] %view = std.view %alloc ...
/// [optional] %view = memref.view %alloc ...
/// %subView = subview %allocOrView...
/// ...
/// vector.transfer_write %..., %allocOrView[...]
@ -858,7 +858,7 @@ struct LinalgCopyVTRForwardingPattern
/// into
/// ```
/// [unchanged] %alloc = ...
/// [unchanged] [optional] %view = std.view %alloc ...
/// [unchanged] [optional] %view = memref.view %alloc ...
/// [unchanged] %subView = subview %allocOrView...
/// ...
/// vector.transfer_write %..., %out[...]

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -21,7 +22,7 @@
#include "llvm/ADT/SetVector.h"
using mlir::edsc::intrinsics::AffineIndexedValue;
using mlir::edsc::intrinsics::StdIndexedValue;
using mlir::edsc::intrinsics::MemRefIndexedValue;
namespace mlir {
class AffineExpr;
@ -213,7 +214,7 @@ template <typename LoopTy>
struct GenerateLoopNest {
using IndexedValueTy =
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
AffineIndexedValue, StdIndexedValue>::type;
AffineIndexedValue, MemRefIndexedValue>::type;
static void
doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,37 @@
//===- Intrinsics.h - MLIR EDSC Intrinsics for MemRefOps --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MEMREF_EDSC_INTRINSICS_H_
#define MLIR_DIALECT_MEMREF_EDSC_INTRINSICS_H_
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/EDSC/Builders.h"
namespace mlir {
namespace edsc {
namespace intrinsics {
using memref_alloc = ValueBuilder<memref::AllocOp>;
using memref_alloca = ValueBuilder<memref::AllocaOp>;
using memref_cast = ValueBuilder<memref::CastOp>;
using memref_dealloc = OperationBuilder<memref::DeallocOp>;
using memref_dim = ValueBuilder<memref::DimOp>;
using memref_load = ValueBuilder<memref::LoadOp>;
using memref_store = OperationBuilder<memref::StoreOp>;
using memref_sub_view = ValueBuilder<memref::SubViewOp>;
using memref_tensor_load = ValueBuilder<memref::TensorLoadOp>;
using memref_tensor_store = OperationBuilder<memref::TensorStoreOp>;
using memref_view = ValueBuilder<memref::ViewOp>;
/// Provide an index notation around memref_load and memref_store.
using MemRefIndexedValue =
TemplatedIndexedValue<intrinsics::memref_load, intrinsics::memref_store>;
} // namespace intrinsics
} // namespace edsc
} // namespace mlir
#endif // MLIR_DIALECT_MEMREF_EDSC_INTRINSICS_H_

View File

@ -0,0 +1,2 @@
add_mlir_dialect(MemRefOps memref)
add_mlir_doc(MemRefOps -gen-dialect-doc MemRefOps Dialects/)

View File

@ -0,0 +1,239 @@
//===- MemRef.h - MemRef dialect --------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_
#define MLIR_DIALECT_MEMREF_IR_MEMREF_H_
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
raw_ostream &operator<<(raw_ostream &os, Range &range);
/// Return the list of Range (i.e. offset, size, stride). Each Range
/// entry contains either the dynamic value or a ConstantIndexOp constructed
/// with `b` at location `loc`.
SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
OpBuilder &b, Location loc);
} // namespace mlir
//===----------------------------------------------------------------------===//
// MemRef Dialect
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/IR/MemRefOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// MemRef Dialect Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc"
namespace mlir {
namespace memref {
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
// source memref to a destination memref. The source and destination memref need
// not be of the same dimensionality, but need to have the same elemental type.
// The operands include the source and destination memref's each followed by its
// indices, size of the data transfer in terms of the number of elements (of the
// elemental type of the memref), a tag memref with its indices, and optionally
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
// location is used by a DmaWaitOp to check for completion. The indices of the
// source memref, destination memref, and the tag memref have the same
// restrictions as any load/store. The optional stride arguments should be of
// 'index' type, and specify a stride for the slower memory space (memory space
// with a lower memory space id), transferring chunks of
// number_of_elements_per_stride every stride until %num_elements are
// transferred. Either both or no stride arguments should be specified.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
// 1 at indices [%k, %l], would be specified as follows:
//
// %num_elements = constant 256
// %idx = constant 0 : index
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
// memref<1 x i32>, (d0) -> (d0), 2>
//
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
// transfer %num_elt_per_stride elements every %stride elements apart from
// memory space 0 until %num_elements are transferred.
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
// %num_elt_per_stride :
//
// TODO: add additional operands to allow source and destination striding, and
// multiple stride levels.
// TODO: Consider replacing src/dst memref indices with view memrefs.
class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
ValueRange srcIndices, Value destMemRef,
ValueRange destIndices, Value numElements, Value tagMemRef,
ValueRange tagIndices, Value stride = nullptr,
Value elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return getSrcMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the source memref indices for this DMA operation.
operand_range getSrcIndices() {
return {(*this)->operand_begin() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
}
// Returns the destination MemRefType for this DMA operations.
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef().getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() {
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
unsigned getDstMemorySpace() {
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
getDstMemRefRank()};
}
// Returns the number of elements being transferred by this DMA operation.
Value getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
unsigned tagIndexStartPos =
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
return {(*this)->operand_begin() + tagIndexStartPos,
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
}
static StringRef getOperationName() { return "memref.dma_start"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
1 + 1 + getTagMemRefRank();
}
Value getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
Value getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
};
// DmaWaitOp blocks until the completion of a DMA operation associated with the
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
// with the same restrictions as any load/store index. %num_elements is the
// number of elements associated with the DMA operation. For example:
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
// memref<2048 x f32>, (d0) -> (d0), 0>,
// memref<256 x f32>, (d0) -> (d0), 1>
// memref<1 x i32>, (d0) -> (d0), 2>
// ...
// ...
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
//
class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
ValueRange tagIndices, Value numElements);
static StringRef getOperationName() { return "memref.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
Value getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
return {(*this)->operand_begin() + 1,
(*this)->operand_begin() + 1 + getTagMemRefRank()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the number of elements transferred in the associated DMA operation.
Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
LogicalResult verify();
};
} // namespace memref
} // namespace mlir
#endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_

View File

@ -0,0 +1,25 @@
//===- MemRefBase.td - Base definitions for memref dialect -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MEMREF_BASE
#define MEMREF_BASE
include "mlir/IR/OpBase.td"
def MemRef_Dialect : Dialect {
let name = "memref";
let cppNamespace = "::mlir::memref";
let description = [{
The `memref` dialect is intended to hold core memref creation and
manipulation ops, which are not strongly associated with any particular
other dialect or domain abstraction.
}];
let hasConstantMaterializer = 1;
}
#endif // MEMREF_BASE

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ include "mlir/Pass/PassBase.td"
def SCFBufferize : FunctionPass<"scf-bufferize"> {
let summary = "Bufferize the scf dialect.";
let constructor = "mlir::createSCFBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"];
}
def SCFForLoopSpecialization

View File

@ -25,5 +25,6 @@ def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
def ShapeBufferize : FunctionPass<"shape-bufferize"> {
let summary = "Bufferize the shape dialect.";
let constructor = "mlir::createShapeBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"];
}
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES

View File

@ -17,35 +17,24 @@ namespace intrinsics {
using std_addi = ValueBuilder<AddIOp>;
using std_addf = ValueBuilder<AddFOp>;
using std_alloc = ValueBuilder<AllocOp>;
using std_alloca = ValueBuilder<AllocaOp>;
using std_call = OperationBuilder<CallOp>;
using std_constant = ValueBuilder<ConstantOp>;
using std_constant_float = ValueBuilder<ConstantFloatOp>;
using std_constant_index = ValueBuilder<ConstantIndexOp>;
using std_constant_int = ValueBuilder<ConstantIntOp>;
using std_dealloc = OperationBuilder<DeallocOp>;
using std_divis = ValueBuilder<SignedDivIOp>;
using std_diviu = ValueBuilder<UnsignedDivIOp>;
using std_dim = ValueBuilder<DimOp>;
using std_fpext = ValueBuilder<FPExtOp>;
using std_fptrunc = ValueBuilder<FPTruncOp>;
using std_index_cast = ValueBuilder<IndexCastOp>;
using std_muli = ValueBuilder<MulIOp>;
using std_mulf = ValueBuilder<MulFOp>;
using std_memref_cast = ValueBuilder<MemRefCastOp>;
using std_ret = OperationBuilder<ReturnOp>;
using std_select = ValueBuilder<SelectOp>;
using std_load = ValueBuilder<LoadOp>;
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
using std_splat = ValueBuilder<SplatOp>;
using std_store = OperationBuilder<StoreOp>;
using std_subf = ValueBuilder<SubFOp>;
using std_subi = ValueBuilder<SubIOp>;
using std_sub_view = ValueBuilder<SubViewOp>;
using std_tensor_load = ValueBuilder<TensorLoadOp>;
using std_tensor_store = OperationBuilder<TensorStoreOp>;
using std_view = ValueBuilder<ViewOp>;
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
@ -77,10 +66,6 @@ BranchOp std_br(Block *block, ValueRange operands);
/// or to `falseBranch` and `falseOperand` if `cond` evaluates to `false`.
CondBranchOp std_cond_br(Value cond, Block *trueBranch, ValueRange trueOperands,
Block *falseBranch, ValueRange falseOperands);
/// Provide an index notation around sdt_load and std_store.
using StdIndexedValue =
TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
} // namespace intrinsics
} // namespace edsc
} // namespace mlir

View File

@ -34,8 +34,6 @@ class Builder;
class FuncOp;
class OpBuilder;
raw_ostream &operator<<(raw_ostream &os, Range &range);
/// Return the list of Range (i.e. offset, size, stride). Each Range
/// entry contains either the dynamic value or a ConstantIndexOp constructed
/// with `b` at location `loc`.
@ -110,200 +108,6 @@ public:
static bool classof(Operation *op);
};
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
// source memref to a destination memref. The source and destination memref need
// not be of the same dimensionality, but need to have the same elemental type.
// The operands include the source and destination memref's each followed by its
// indices, size of the data transfer in terms of the number of elements (of the
// elemental type of the memref), a tag memref with its indices, and optionally
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
// location is used by a DmaWaitOp to check for completion. The indices of the
// source memref, destination memref, and the tag memref have the same
// restrictions as any load/store. The optional stride arguments should be of
// 'index' type, and specify a stride for the slower memory space (memory space
// with a lower memory space id), transferring chunks of
// number_of_elements_per_stride every stride until %num_elements are
// transferred. Either both or no stride arguments should be specified.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
// 1 at indices [%k, %l], would be specified as follows:
//
// %num_elements = constant 256
// %idx = constant 0 : index
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
// memref<1 x i32>, (d0) -> (d0), 2>
//
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
// transfer %num_elt_per_stride elements every %stride elements apart from
// memory space 0 until %num_elements are transferred.
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
// %num_elt_per_stride :
//
// TODO: add additional operands to allow source and destination striding, and
// multiple stride levels.
// TODO: Consider replacing src/dst memref indices with view memrefs.
class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
ValueRange srcIndices, Value destMemRef,
ValueRange destIndices, Value numElements, Value tagMemRef,
ValueRange tagIndices, Value stride = nullptr,
Value elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return getSrcMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the source memref indices for this DMA operation.
operand_range getSrcIndices() {
return {(*this)->operand_begin() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
}
// Returns the destination MemRefType for this DMA operations.
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef().getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() {
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
unsigned getDstMemorySpace() {
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
}
// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
getDstMemRefRank()};
}
// Returns the number of elements being transferred by this DMA operation.
Value getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
unsigned tagIndexStartPos =
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
return {(*this)->operand_begin() + tagIndexStartPos,
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
}
static StringRef getOperationName() { return "std.dma_start"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
1 + 1 + getTagMemRefRank();
}
Value getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
Value getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
};
// DmaWaitOp blocks until the completion of a DMA operation associated with the
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
// with the same restrictions as any load/store index. %num_elements is the
// number of elements associated with the DMA operation. For example:
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
// memref<2048 x f32>, (d0) -> (d0), 0>,
// memref<256 x f32>, (d0) -> (d0), 1>
// memref<1 x i32>, (d0) -> (d0), 2>
// ...
// ...
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
//
class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
ValueRange tagIndices, Value numElements);
static StringRef getOperationName() { return "std.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
Value getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
return {(*this)->operand_begin() + 1,
(*this)->operand_begin() + 1 + getTagMemRefRank()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the number of elements transferred in the associated DMA operation.
Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
LogicalResult verify();
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain
@ -316,45 +120,6 @@ llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape);
/// Determines whether MemRefCastOp casts to a more dynamic version of the
/// source memref. This is useful to to fold a memref_cast into a consuming op
/// and implement canonicalization patterns for ops in different dialects that
/// may consume the results of memref_cast operations. Such foldable memref_cast
/// operations are typically inserted as `view` and `subview` ops and are
/// canonicalized, to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked memrefs with strided semantics and same
/// element type and rank.
/// 2. each of the source's size, offset or stride has more static information
/// than the corresponding result's size, offset or stride.
///
/// Example 1:
/// ```mlir
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
/// %2 = consumer %1 ... : memref<?x?xf32> ...
/// ```
///
/// may fold into:
///
/// ```mlir
/// %2 = consumer %0 ... : memref<8x16xf32> ...
/// ```
///
/// Example 2:
/// ```
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// to memref<?x?xf32>
/// consumer %1 : memref<?x?xf32> ...
/// ```
///
/// may fold into:
///
/// ```
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// ```
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,

File diff suppressed because it is too large Load Diff

View File

@ -37,7 +37,7 @@ std::unique_ptr<Pass> createTensorConstantBufferizePass();
/// Creates an instance of the StdExpand pass that legalizes Std
/// dialect ops to be convertible to LLVM. For example,
/// `std.ceildivi_signed` gets transformed to a number of std operations,
/// which can be lowered to LLVM; `memref_reshape` gets converted to
/// which can be lowered to LLVM; `memref.reshape` gets converted to
/// `memref_reinterpret_cast`.
std::unique_ptr<Pass> createStdExpandOpsPass();

View File

@ -44,9 +44,10 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
implement the `ReturnLike` trait are not rewritten in general, as they
require that the corresponding parent operation is also rewritten.
Finally, this pass fails for unknown terminators, as we cannot decide
whether they need rewriting.
whether they need rewriting.
}];
let constructor = "mlir::createFuncBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"];
}
def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
@ -54,12 +55,13 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
let description = [{
This pass bufferizes tensor constants.
This pass needs to be a module pass because it inserts std.global_memref
This pass needs to be a module pass because it inserts memref.global
ops into the module, which cannot be done safely from a function pass due to
multi-threading. Most other bufferization passes can run in parallel at
function granularity.
}];
let constructor = "mlir::createTensorConstantBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"];
}
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

View File

@ -16,6 +16,9 @@
#ifndef MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H
#define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
namespace mlir {
@ -27,6 +30,51 @@ class OpBuilder;
/// constructing the necessary DimOp operators.
SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b);
/// Matches a ConstantIndexOp.
detail::op_matcher<ConstantIndexOp> matchConstantIndex();
/// Detects the `values` produced by a ConstantIndexOp and places the new
/// constant in place of the corresponding sentinel value.
void canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> &values,
function_ref<bool(int64_t)> isDynamic);
void getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
llvm::SmallDenseSet<unsigned> &dimsToProject);
/// Pattern to rewrite a subview op with constant arguments.
template <typename OpType, typename CastOpFunc>
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
: public OpRewritePattern<OpType> {
public:
using OpRewritePattern<OpType>::OpRewritePattern;
LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override {
// No constant operand, just return;
if (llvm::none_of(op.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
}))
return failure();
// At least one of offsets/sizes/strides is a new constant.
// Form the new list of operands and constant attributes from the existing.
SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
auto newOp = rewriter.create<OpType>(op.getLoc(), op.source(), mixedOffsets,
mixedSizes, mixedStrides);
CastOpFunc func;
func(rewriter, op, newOp);
return success();
}
};
} // end namespace mlir
#endif // MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H

View File

@ -180,11 +180,11 @@ private:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// // fastpath, direct cast
/// memref_cast %A: memref<A...> to compatibleMemRefType
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slowpath, masked vector.transfer or linalg.copy.
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}

View File

@ -1133,7 +1133,7 @@ public:
/// A trait of region holding operations that define a new scope for automatic
/// allocations, i.e., allocations that are freed when control is transferred
/// back from the operation's region. Any operations performing such allocations
/// (for eg. std.alloca) will have their allocations automatically freed at
/// (for eg. memref.alloca) will have their allocations automatically freed at
/// their closest enclosing operation with this trait.
template <typename ConcreteType>
class AutomaticAllocationScope

View File

@ -28,6 +28,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
@ -60,6 +61,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
LLVM::LLVMArmSVEDialect,
linalg::LinalgDialect,
math::MathDialect,
memref::MemRefDialect,
scf::SCFDialect,
omp::OpenMPDialect,
pdl::PDLDialect,

View File

@ -54,7 +54,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
/// Populate patterns to eliminate bufferize materializations.
///
/// In particular, these are the tensor_load/tensor_to_memref ops.
/// In particular, these are the tensor_load/buffer_cast ops.
void populateEliminateBufferizeMaterializationsPatterns(
MLIRContext *context, BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);

View File

@ -54,7 +54,7 @@ std::unique_ptr<Pass>
createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
/// Creates a pass that finalizes a partial bufferization by removing remaining
/// tensor_load and tensor_to_memref operations.
/// tensor_load and buffer_cast operations.
std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
/// Creates a pass that converts memref function results to out-params.

View File

@ -352,7 +352,7 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
works for static shaped memrefs.
}];
let constructor = "mlir::createBufferResultsToOutParamsPass()";
let dependentDialects = ["linalg::LinalgDialect"];
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
}
def Canonicalizer : Pass<"canonicalize"> {
@ -363,6 +363,7 @@ def Canonicalizer : Pass<"canonicalize"> {
details.
}];
let constructor = "mlir::createCanonicalizerPass()";
let dependentDialects = ["memref::MemRefDialect"];
}
def CopyRemoval : FunctionPass<"copy-removal"> {
@ -406,11 +407,11 @@ def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
let summary = "Finalize a partial bufferization";
let description = [{
A bufferize pass that finalizes a partial bufferization by removing
remaining `tensor_load` and `tensor_to_memref` operations.
remaining `memref.tensor_load` and `memref.buffer_cast` operations.
The removal of those operations is only possible if the operations only
exist in pairs, i.e., all uses of `tensor_load` operations are
`tensor_to_memref` operations.
exist in pairs, i.e., all uses of `memref.tensor_load` operations are
`memref.buffer_cast` operations.
This pass will fail if not all operations can be removed or if any operation
with tensor typed operands remains.
@ -535,7 +536,7 @@ def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
contained in the op. Operations marked with the [MemRefsNormalizable]
(https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
expected to be normalizable. Supported operations include affine
operations, std.alloc, std.dealloc, and std.return.
operations, memref.alloc, memref.dealloc, and std.return.
Given an appropriate layout map specified in the code, this transformation
can express tiled or linearized access to multi-dimensional data

View File

@ -28,6 +28,10 @@ class AffineForOp;
class Location;
class OpBuilder;
namespace memref {
class AllocOp;
} // end namespace memref
/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while
/// optionally remapping the old memref's indices using the supplied affine map,
/// `indexRemap`. The new memref could be of a different shape or rank.
@ -88,7 +92,7 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
/// Rewrites the memref defined by this alloc op to have an identity layout map
/// and updates all its indexing uses. Returns failure if any of its uses
/// escape (while leaving the IR in a valid state).
LogicalResult normalizeMemRef(AllocOp op);
LogicalResult normalizeMemRef(memref::AllocOp *op);
/// Uses the old memref type map layout and computes the new memref type to have
/// a new shape and a layout map, where the old layout map has been normalized

View File

@ -15,6 +15,7 @@
#include "../PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
@ -44,7 +45,8 @@ public:
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
loc(loc) {}
template <typename OpTy> Value buildBinaryExpr(AffineBinaryOpExpr expr) {
template <typename OpTy>
Value buildBinaryExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
@ -563,8 +565,8 @@ public:
};
/// Apply the affine map from an 'affine.load' operation to its operands, and
/// feed the results to a newly created 'std.load' operation (which replaces the
/// original 'affine.load').
/// feed the results to a newly created 'memref.load' operation (which replaces
/// the original 'affine.load').
class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
public:
using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
@ -579,14 +581,14 @@ public:
return failure();
// Build vector.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, op.getMemRef(),
*resultOperands);
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
*resultOperands);
return success();
}
};
/// Apply the affine map from an 'affine.prefetch' operation to its operands,
/// and feed the results to a newly created 'std.prefetch' operation (which
/// and feed the results to a newly created 'memref.prefetch' operation (which
/// replaces the original 'affine.prefetch').
class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
public:
@ -601,16 +603,16 @@ public:
if (!resultOperands)
return failure();
// Build std.prefetch memref[expandedMap.results].
rewriter.replaceOpWithNewOp<PrefetchOp>(op, op.memref(), *resultOperands,
op.isWrite(), op.localityHint(),
op.isDataCache());
// Build memref.prefetch memref[expandedMap.results].
rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(),
op.isDataCache());
return success();
}
};
/// Apply the affine map from an 'affine.store' operation to its operands, and
/// feed the results to a newly created 'std.store' operation (which replaces
/// feed the results to a newly created 'memref.store' operation (which replaces
/// the original 'affine.store').
class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
public:
@ -625,8 +627,8 @@ public:
if (!maybeExpandedMap)
return failure();
// Build std.store valueToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<mlir::StoreOp>(
// Build memref.store valueToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
return success();
}
@ -634,7 +636,8 @@ public:
/// Apply the affine maps from an 'affine.dma_start' operation to each of their
/// respective map operands, and feed the results to a newly created
/// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
/// 'memref.dma_start' operation (which replaces the original
/// 'affine.dma_start').
class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
public:
using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
@ -663,8 +666,8 @@ public:
if (!maybeExpandedTagMap)
return failure();
// Build std.dma_start operation with affine map results.
rewriter.replaceOpWithNewOp<DmaStartOp>(
// Build memref.dma_start operation with affine map results.
rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
*maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
*maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
@ -673,7 +676,7 @@ public:
};
/// Apply the affine map from an 'affine.dma_wait' operation tag memref,
/// and feed the results to a newly created 'std.dma_wait' operation (which
/// and feed the results to a newly created 'memref.dma_wait' operation (which
/// replaces the original 'affine.dma_wait').
class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
public:
@ -688,8 +691,8 @@ public:
if (!maybeExpandedTagMap)
return failure();
// Build std.dma_wait operation with affine map results.
rewriter.replaceOpWithNewOp<DmaWaitOp>(
// Build memref.dma_wait operation with affine map results.
rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
return success();
}
@ -777,8 +780,8 @@ class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
populateAffineToStdConversionPatterns(patterns, &getContext());
populateAffineToVectorConversionPatterns(patterns, &getContext());
ConversionTarget target(getContext());
target
.addLegalDialect<scf::SCFDialect, StandardOpsDialect, VectorDialect>();
target.addLegalDialect<memref::MemRefDialect, scf::SCFDialect,
StandardOpsDialect, VectorDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();

View File

@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRAffineToStandard
LINK_LIBS PUBLIC
MLIRAffine
MLIRMemRef
MLIRSCF
MLIRPass
MLIRStandard

View File

@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms
MLIRGPU
MLIRGPUToGPURuntimeTransforms
MLIRLLVMIR
MLIRMemRef
MLIRNVVMIR
MLIRPass
MLIRStandardToLLVM

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRLinalgToStandard
MLIREDSC
MLIRIR
MLIRLinalg
MLIRMemRef
MLIRPass
MLIRSCF
MLIRTransforms

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -93,7 +94,7 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
continue;
}
Value cast =
b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
res.push_back(cast);
}
return res;
@ -143,12 +144,12 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
// If either inputPerm or outputPerm are non-identities, insert transposes.
auto inputPerm = op.inputPermutation();
if (inputPerm.hasValue() && !inputPerm->isIdentity())
in = rewriter.create<TransposeOp>(op.getLoc(), in,
AffineMapAttr::get(*inputPerm));
in = rewriter.create<memref::TransposeOp>(op.getLoc(), in,
AffineMapAttr::get(*inputPerm));
auto outputPerm = op.outputPermutation();
if (outputPerm.hasValue() && !outputPerm->isIdentity())
out = rewriter.create<TransposeOp>(op.getLoc(), out,
AffineMapAttr::get(*outputPerm));
out = rewriter.create<memref::TransposeOp>(op.getLoc(), out,
AffineMapAttr::get(*outputPerm));
// If nothing was transposed, fail and let the conversion kick in.
if (in == op.input() && out == op.output())
@ -213,7 +214,8 @@ struct ConvertLinalgToStandardPass
void ConvertLinalgToStandardPass::runOnOperation() {
auto module = getOperation();
ConversionTarget target(getContext());
target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
StandardOpsDialect>();
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
OwningRewritePatternList patterns;

View File

@ -38,6 +38,10 @@ namespace NVVM {
class NVVMDialect;
} // end namespace NVVM
namespace memref {
class MemRefDialect;
} // end namespace memref
namespace omp {
class OpenMPDialect;
} // end namespace omp

View File

@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRSCFToGPU
MLIRGPU
MLIRIR
MLIRLinalg
MLIRMemRef
MLIRPass
MLIRStandard
MLIRSupport

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/ParallelLoopMapper.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExpr.h"
@ -647,6 +648,7 @@ void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
}
void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
target.addLegalDialect<memref::MemRefDialect>();
target.addDynamicallyLegalOp<scf::ParallelOp>([](scf::ParallelOp parallelOp) {
return !parallelOp->getAttr(gpu::getMappingAttrName());
});

View File

@ -19,6 +19,7 @@ add_mlir_conversion_library(MLIRShapeToStandard
LINK_LIBS PUBLIC
MLIREDSC
MLIRIR
MLIRMemRef
MLIRShape
MLIRTensor
MLIRPass

View File

@ -9,6 +9,7 @@
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "../PassDetail.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -139,7 +140,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
return lb.create<DimOp>(v, zero);
return lb.create<memref::DimOp>(v, zero);
}));
// Find the maximum rank
@ -252,7 +253,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
return lb.create<DimOp>(v, zero);
return lb.create<memref::DimOp>(v, zero);
}));
// Find the maximum rank
@ -344,8 +345,8 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
// circumvents the necessity to materialize the shape in memory.
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
if (shapeOfOp.arg().getType().isa<ShapedType>()) {
rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
transformed.dim());
rewriter.replaceOpWithNewOp<memref::DimOp>(op, shapeOfOp.arg(),
transformed.dim());
return success();
}
}
@ -375,7 +376,7 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
return failure();
shape::RankOp::Adaptor transformed(operands);
rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, transformed.shape(), 0);
return success();
}
@ -404,7 +405,8 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
Type indexTy = rewriter.getIndexType();
Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
Value rank =
rewriter.create<memref::DimOp>(loc, indexTy, transformed.shape(), zero);
auto loop = rewriter.create<scf::ForOp>(
loc, zero, rank, one, op.initVals(),
@ -490,11 +492,12 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
Type indexTy = rewriter.getIndexType();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value firstShape = transformed.shapes().front();
Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
Value firstRank =
rewriter.create<memref::DimOp>(loc, indexTy, firstShape, zero);
Value result = nullptr;
// Generate a linear sequence of compares, all with firstShape as lhs.
for (Value shape : transformed.shapes().drop_front(1)) {
Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
Value rank = rewriter.create<memref::DimOp>(loc, indexTy, shape, zero);
Value eqRank =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
auto same = rewriter.create<IfOp>(
@ -559,7 +562,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
int64_t rank = rankedTensorTy.getRank();
for (int64_t i = 0; i < rank; i++) {
if (rankedTensorTy.isDynamicDim(i)) {
Value extent = rewriter.create<DimOp>(loc, tensor, i);
Value extent = rewriter.create<memref::DimOp>(loc, tensor, i);
extentValues.push_back(extent);
} else {
Value extent =
@ -583,7 +586,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
op, getExtentTensorType(ctx), ValueRange{rank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim = args.front();
Value extent = b.create<DimOp>(loc, tensor, dim);
Value extent = b.create<memref::DimOp>(loc, tensor, dim);
b.create<tensor::YieldOp>(loc, extent);
});
@ -613,7 +616,7 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
SplitAtOp::Adaptor transformed(op);
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value zero = b.create<ConstantIndexOp>(0);
Value rank = b.create<DimOp>(transformed.operand(), zero);
Value rank = b.create<memref::DimOp>(transformed.operand(), zero);
// index < 0 ? index + rank : index
Value originalIndex = transformed.index();
@ -670,8 +673,8 @@ void ConvertShapeToStandardPass::runOnOperation() {
// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target
.addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect,
tensor::TensorDialect>();
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.

View File

@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRStandardToLLVM
LINK_LIBS PUBLIC
MLIRLLVMIR
MLIRMath
MLIRMemRef
MLIRTransforms
)

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -1864,13 +1865,13 @@ private:
struct AllocOpLowering : public AllocLikeOpLowering {
AllocOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
: AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes,
Operation *op) const override {
// Heap allocations.
AllocOp allocOp = cast<AllocOp>(op);
memref::AllocOp allocOp = cast<memref::AllocOp>(op);
MemRefType memRefType = allocOp.getType();
Value alignment;
@ -1917,7 +1918,7 @@ struct AllocOpLowering : public AllocLikeOpLowering {
struct AlignedAllocOpLowering : public AllocLikeOpLowering {
AlignedAllocOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
: AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
/// Returns the memref's element size in bytes.
// TODO: there are other places where this is used. Expose publicly?
@ -1950,7 +1951,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
/// Returns the alignment to be used for the allocation call itself.
/// aligned_alloc requires the allocation size to be a power of two, and the
/// allocation size to be a multiple of alignment,
int64_t getAllocationAlignment(AllocOp allocOp) const {
int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
if (Optional<uint64_t> alignment = allocOp.alignment())
return *alignment;
@ -1966,7 +1967,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
Location loc, Value sizeBytes,
Operation *op) const override {
// Heap allocations.
AllocOp allocOp = cast<AllocOp>(op);
memref::AllocOp allocOp = cast<memref::AllocOp>(op);
MemRefType memRefType = allocOp.getType();
int64_t alignment = getAllocationAlignment(allocOp);
Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
@ -1997,7 +1998,7 @@ constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
struct AllocaOpLowering : public AllocLikeOpLowering {
AllocaOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {}
: AllocLikeOpLowering(memref::AllocaOp::getOperationName(), converter) {}
/// Allocates the underlying buffer using the right call. `allocatedBytePtr`
/// is set to null for stack allocations. `accessAlignment` is set if
@ -2008,7 +2009,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering {
// With alloca, one gets a pointer to the element type right away.
// For stack allocations.
auto allocaOp = cast<AllocaOp>(op);
auto allocaOp = cast<memref::AllocaOp>(op);
auto elementPtrType = this->getElementPtrType(allocaOp.getType());
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
@ -2180,17 +2181,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
explicit DeallocOpLowering(LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern<DeallocOp>(converter) {}
: ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
LogicalResult
matchAndRewrite(DeallocOp op, ArrayRef<Value> operands,
matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1 && "dealloc takes one operand");
DeallocOp::Adaptor transformed(operands);
memref::DeallocOp::Adaptor transformed(operands);
// Insert the `free` declaration if it is not already present.
auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
@ -2209,7 +2210,7 @@ static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
LLVMTypeConverter &typeConverter) {
// LLVM type for a global memref will be a multi-dimension array. For
// declarations or uninitialized global memrefs, we can potentially flatten
// this to a 1D array. However, for global_memref's with an initial value,
// this to a 1D array. However, for memref.global's with an initial value,
// we do not intend to flatten the ElementsAttribute when going from std ->
// LLVM dialect, so the LLVM type needs to me a multi-dimension array.
Type elementType = unwrap(typeConverter.convertType(type.getElementType()));
@ -2221,11 +2222,12 @@ static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
}
/// GlobalMemrefOp is lowered to a LLVM Global Variable.
struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
struct GlobalMemrefOpLowering
: public ConvertOpToLLVMPattern<memref::GlobalOp> {
using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(GlobalMemrefOp global, ArrayRef<Value> operands,
matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefType type = global.type().cast<MemRefType>();
if (!isConvertibleAndHasIdentityMaps(type))
@ -2259,14 +2261,15 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
: AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {}
: AllocLikeOpLowering(memref::GetGlobalOp::getOperationName(),
converter) {}
/// Buffer "allocation" for get_global_memref op is getting the address of
/// Buffer "allocation" for memref.get_global op is getting the address of
/// the global variable referenced.
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes,
Operation *op) const override {
auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
unsigned memSpace = type.getMemorySpaceAsInt();
@ -2285,7 +2288,7 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
createIndexConstant(rewriter, loc, 0));
auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
// We do not expect the memref obtained using `get_global_memref` to be
// We do not expect the memref obtained using `memref.get_global` to be
// ever deallocated. Set the allocated pointer to be known bad value to
// help debug if that ever happens.
auto intPtrType = getIntPtrType(memSpace);
@ -2354,17 +2357,17 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
}
};
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
LogicalResult match(MemRefCastOp memRefCastOp) const override {
LogicalResult match(memref::CastOp memRefCastOp) const override {
Type srcType = memRefCastOp.getOperand().getType();
Type dstType = memRefCastOp.getType();
// MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used
// for type erasure. For now they must preserve underlying element type and
// require source and result type to have the same rank. Therefore, perform
// a sanity check that the underlying structs are the same. Once op
// memref::CastOp reduce to bitcast in the ranked MemRef case and can be
// used for type erasure. For now they must preserve underlying element type
// and require source and result type to have the same rank. Therefore,
// perform a sanity check that the underlying structs are the same. Once op
// semantics are relaxed we can revisit.
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
return success(typeConverter->convertType(srcType) ==
@ -2381,9 +2384,9 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
: failure();
}
void rewrite(MemRefCastOp memRefCastOp, ArrayRef<Value> operands,
void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefCastOp::Adaptor transformed(operands);
memref::CastOp::Adaptor transformed(operands);
auto srcType = memRefCastOp.getOperand().getType();
auto dstType = memRefCastOp.getType();
@ -2486,14 +2489,15 @@ static void extractPointersAndOffset(Location loc,
}
struct MemRefReinterpretCastOpLowering
: public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
: public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
using ConvertOpToLLVMPattern<
memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef<Value> operands,
matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefReinterpretCastOp::Adaptor adaptor(operands,
castOp->getAttrDictionary());
memref::ReinterpretCastOp::Adaptor adaptor(operands,
castOp->getAttrDictionary());
Type srcType = castOp.source().getType();
Value descriptor;
@ -2505,11 +2509,10 @@ struct MemRefReinterpretCastOpLowering
}
private:
LogicalResult
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
Type srcType, MemRefReinterpretCastOp castOp,
MemRefReinterpretCastOp::Adaptor adaptor,
Value *descriptor) const {
LogicalResult convertSourceMemRefToDescriptor(
ConversionPatternRewriter &rewriter, Type srcType,
memref::ReinterpretCastOp castOp,
memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
MemRefType targetMemRefType =
castOp.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
@ -2555,14 +2558,14 @@ private:
};
struct MemRefReshapeOpLowering
: public ConvertOpToLLVMPattern<MemRefReshapeOp> {
using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
: public ConvertOpToLLVMPattern<memref::ReshapeOp> {
using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef<Value> operands,
matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *op = reshapeOp.getOperation();
MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
Type srcType = reshapeOp.source().getType();
Value descriptor;
@ -2576,8 +2579,8 @@ struct MemRefReshapeOpLowering
private:
LogicalResult
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
Type srcType, MemRefReshapeOp reshapeOp,
MemRefReshapeOp::Adaptor adaptor,
Type srcType, memref::ReshapeOp reshapeOp,
memref::ReshapeOp::Adaptor adaptor,
Value *descriptor) const {
// Conversion for statically-known shape args is performed via
// `memref_reinterpret_cast`.
@ -2722,11 +2725,11 @@ struct DialectCastOpLowering
// A `dim` is converted to a constant for static sizes and to an access to the
// size stored in the memref descriptor for dynamic sizes.
struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(DimOp dimOp, ArrayRef<Value> operands,
matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type operandType = dimOp.memrefOrTensor().getType();
if (operandType.isa<UnrankedMemRefType>()) {
@ -2744,11 +2747,11 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
}
private:
Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp,
Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
DimOp::Adaptor transformed(operands);
memref::DimOp::Adaptor transformed(operands);
auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
auto scalarMemRefType =
@ -2785,11 +2788,11 @@ private:
return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
}
Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp,
Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
DimOp::Adaptor transformed(operands);
memref::DimOp::Adaptor transformed(operands);
// Take advantage if index is constant.
MemRefType memRefType = operandType.cast<MemRefType>();
if (Optional<int64_t> index = dimOp.getConstantIndex()) {
@ -2833,7 +2836,7 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
};
// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
@ -2849,13 +2852,13 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
// Load operation is lowered to obtaining a pointer to the indexed element
// and loading it.
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
using Base::Base;
LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
LoadOp::Adaptor transformed(operands);
memref::LoadOp::Adaptor transformed(operands);
auto type = loadOp.getMemRefType();
Value dataPtr =
@ -2868,14 +2871,14 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
// Store operation is lowered to obtaining a pointer to the indexed element,
// and storing the given value to it.
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
using Base::Base;
LogicalResult
matchAndRewrite(StoreOp op, ArrayRef<Value> operands,
matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
StoreOp::Adaptor transformed(operands);
memref::StoreOp::Adaptor transformed(operands);
Value dataPtr =
getStridedElementPtr(op.getLoc(), type, transformed.memref(),
@ -2888,13 +2891,13 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
// The prefetch operation is lowered in a way similar to the load operation
// except that the llvm.prefetch operation is used for replacement.
struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
using Base::Base;
LogicalResult
matchAndRewrite(PrefetchOp prefetchOp, ArrayRef<Value> operands,
matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
PrefetchOp::Adaptor transformed(operands);
memref::PrefetchOp::Adaptor transformed(operands);
auto type = prefetchOp.getMemRefType();
auto loc = prefetchOp.getLoc();
@ -3221,11 +3224,11 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
/// and stride.
/// The subview op is replaced by the descriptor.
struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SubViewOp subViewOp, ArrayRef<Value> operands,
matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = subViewOp.getLoc();
@ -3234,7 +3237,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
typeConverter->convertType(sourceMemRefType.getElementType());
auto viewMemRefType = subViewOp.getType();
auto inferredType = SubViewOp::inferResultType(
auto inferredType = memref::SubViewOp::inferResultType(
subViewOp.getSourceType(),
extractFromI64ArrayAttr(subViewOp.static_offsets()),
extractFromI64ArrayAttr(subViewOp.static_sizes()),
@ -3335,7 +3338,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
if (static_cast<unsigned>(i) >= mixedSizes.size()) {
size = rewriter.create<LLVM::DialectCastOp>(
loc, llvmIndexType,
rewriter.create<DimOp>(loc, subViewOp.source(), i));
rewriter.create<memref::DimOp>(loc, subViewOp.source(), i));
stride = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
} else {
@ -3376,15 +3379,15 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
/// and stride. Size and stride are permutations of the original values.
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
/// The transpose op is replaced by the alloca'ed pointer.
class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
public:
using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(TransposeOp transposeOp, ArrayRef<Value> operands,
matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = transposeOp.getLoc();
TransposeOpAdaptor adaptor(operands);
memref::TransposeOpAdaptor adaptor(operands);
MemRefDescriptor viewMemRef(adaptor.in());
// No permutation, early exit.
@ -3424,8 +3427,8 @@ public:
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
/// and stride.
/// The view op is replaced by the descriptor.
struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
using ConvertOpToLLVMPattern<ViewOp>::ConvertOpToLLVMPattern;
struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
// Build and return the value for the idx^th shape dimension, either by
// returning the constant shape dimension or counting the proper dynamic size.
@ -3461,10 +3464,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
}
LogicalResult
matchAndRewrite(ViewOp viewOp, ArrayRef<Value> operands,
matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = viewOp.getLoc();
ViewOpAdaptor adaptor(operands);
memref::ViewOpAdaptor adaptor(operands);
auto viewMemRefType = viewOp.getType();
auto targetElementTy =
@ -3540,13 +3543,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
};
struct AssumeAlignmentOpLowering
: public ConvertOpToLLVMPattern<AssumeAlignmentOp> {
using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
: public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
using ConvertOpToLLVMPattern<
memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(AssumeAlignmentOp op, ArrayRef<Value> operands,
matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
AssumeAlignmentOp::Adaptor transformed(operands);
memref::AssumeAlignmentOp::Adaptor transformed(operands);
Value memref = transformed.memref();
unsigned alignment = op.alignment();
auto loc = op.getLoc();

View File

@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRStandardToSPIRV
LINK_LIBS PUBLIC
MLIRIR
MLIRMath
MLIRMemRef
MLIRPass
MLIRSPIRV
MLIRSPIRVConversion

View File

@ -14,6 +14,7 @@
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
@ -23,11 +24,11 @@
using namespace mlir;
/// Helpers to access the memref operand for each op.
static Value getMemRefOperand(LoadOp op) { return op.memref(); }
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
static Value getMemRefOperand(StoreOp op) { return op.memref(); }
static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.source();
@ -44,7 +45,7 @@ public:
PatternRewriter &rewriter) const override;
private:
void replaceOp(OpTy loadOp, SubViewOp subViewOp,
void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const;
};
@ -59,23 +60,22 @@ public:
PatternRewriter &rewriter) const override;
private:
void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const;
};
template <>
void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
sourceIndices);
void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
memref::LoadOp loadOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
sourceIndices);
}
template <>
void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
vector::TransferReadOp loadOp, SubViewOp subViewOp,
vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
@ -83,16 +83,16 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
}
template <>
void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
subViewOp.source(), sourceIndices);
void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
memref::StoreOp storeOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
}
template <>
void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
vector::TransferWriteOp tranferWriteOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
@ -120,7 +120,7 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
/// memref<12x42xf32>
static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
SubViewOp subViewOp, ValueRange indices,
memref::SubViewOp subViewOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
// TODO: Aborting when the offsets are static. There might be a way to fold
// the subview op with load even if the offsets have been canonicalized
@ -152,7 +152,8 @@ template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
auto subViewOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp) {
return failure();
}
@ -174,7 +175,7 @@ LogicalResult
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp) {
return failure();
}
@ -193,9 +194,9 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
patterns.insert<LoadOpOfSubViewFolder<memref::LoadOp>,
LoadOpOfSubViewFolder<vector::TransferReadOp>,
StoreOpOfSubViewFolder<StoreOp>,
StoreOpOfSubViewFolder<memref::StoreOp>,
StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
}

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@ -237,12 +238,12 @@ namespace {
/// to Workgroup memory when the size is constant. Note that this pattern needs
/// to be applied in a pass that runs at least at spv.module scope since it wil
/// ladd global variables into the spv.module.
class AllocOpPattern final : public OpConversionPattern<AllocOp> {
class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
public:
using OpConversionPattern<AllocOp>::OpConversionPattern;
using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
matchAndRewrite(memref::AllocOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefType allocType = operation.getType();
if (!isAllocationSupported(allocType))
@ -278,12 +279,12 @@ public:
/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
class DeallocOpPattern final : public OpConversionPattern<DeallocOp> {
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
public:
using OpConversionPattern<DeallocOp>::OpConversionPattern;
using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
matchAndRewrite(memref::DeallocOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
if (!isAllocationSupported(deallocType))
@ -430,23 +431,23 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.load to spv.Load.
class IntLoadOpPattern final : public OpConversionPattern<LoadOp> {
/// Converts memref.load to spv.Load.
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<LoadOp>::OpConversionPattern;
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.load to spv.Load.
class LoadOpPattern final : public OpConversionPattern<LoadOp> {
/// Converts memref.load to spv.Load.
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
using OpConversionPattern<LoadOp>::OpConversionPattern;
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -469,23 +470,23 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.store to spv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<StoreOp> {
/// Converts memref.store to spv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
using OpConversionPattern<StoreOp>::OpConversionPattern;
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.store to spv.Store.
class StoreOpPattern final : public OpConversionPattern<StoreOp> {
/// Converts memref.store to spv.Store.
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
using OpConversionPattern<StoreOp>::OpConversionPattern;
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -975,9 +976,10 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
//===----------------------------------------------------------------------===//
LogicalResult
IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpAdaptor loadOperands(operands);
memref::LoadOpAdaptor loadOperands(operands);
auto loc = loadOp.getLoc();
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
@ -1051,9 +1053,9 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
}
LogicalResult
LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpAdaptor loadOperands(operands);
memref::LoadOpAdaptor loadOperands(operands);
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();
@ -1101,9 +1103,10 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
//===----------------------------------------------------------------------===//
LogicalResult
IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpAdaptor storeOperands(operands);
memref::StoreOpAdaptor storeOperands(operands);
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
if (!memrefType.getElementType().isSignlessInteger())
return failure();
@ -1180,9 +1183,10 @@ IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
}
LogicalResult
StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpAdaptor storeOperands(operands);
memref::StoreOpAdaptor storeOperands(operands);
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
if (memrefType.getElementType().isSignlessInteger())
return failure();

View File

@ -20,6 +20,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
MLIRArmSVEToLLVM
MLIRLLVMArmSVE
MLIRLLVMIR
MLIRMemRef
MLIRStandardToLLVM
MLIRTargetLLVMIRExport
MLIRTransforms

View File

@ -12,6 +12,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
@ -1262,7 +1263,7 @@ public:
unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@ -39,6 +40,7 @@ struct LowerVectorToLLVMPass
// Override explicitly to allow conditional dialect dependence.
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
registry.insert<memref::MemRefDialect>();
if (enableArmNeon)
registry.insert<arm_neon::ArmNeonDialect>();
if (enableArmSVE)
@ -72,6 +74,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
target.addLegalOp<LLVM::DialectCastOp>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
if (enableArmNeon) {

View File

@ -11,5 +11,6 @@ add_mlir_conversion_library(MLIRVectorToSCF
MLIREDSC
MLIRAffineEDSC
MLIRLLVMIR
MLIRMemRef
MLIRTransforms
)

View File

@ -16,6 +16,7 @@
#include "../PassDetail.h"
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
@ -252,7 +253,7 @@ static Value setAllocAtFunctionEntry(MemRefType memRefMinorVectorType,
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
b.setInsertionPointToStart(&scope->getRegion(0).front());
Value res = std_alloca(memRefMinorVectorType);
Value res = memref_alloca(memRefMinorVectorType);
return res;
}
@ -314,7 +315,7 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
return {vector};
}
// 3.b. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
memref_store(vector, alloc, majorIvs);
return {};
},
[&]() -> scf::ValueVector {
@ -326,7 +327,7 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
return {vector};
}
// 3.d. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
memref_store(vector, alloc, majorIvs);
return {};
});
@ -341,14 +342,15 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
result = vector_insert(loaded1D, result, majorIvs);
// 5.b. Otherwise, just go through the temporary `alloc`.
else
std_store(loaded1D, alloc, majorIvs);
memref_store(loaded1D, alloc, majorIvs);
}
});
assert((!options.unroll ^ (bool)result) &&
"Expected resulting Value iff unroll");
if (!result)
result = std_load(vector_type_cast(MemRefType::get({}, vectorType), alloc));
result =
memref_load(vector_type_cast(MemRefType::get({}, vectorType), alloc));
rewriter.replaceOp(op, result);
return success();
@ -359,8 +361,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
Value alloc;
if (!options.unroll) {
alloc = setAllocAtFunctionEntry(memRefMinorVectorType, op);
std_store(xferOp.vector(),
vector_type_cast(MemRefType::get({}, vectorType), alloc));
memref_store(xferOp.vector(),
vector_type_cast(MemRefType::get({}, vectorType), alloc));
}
emitLoops([&](ValueRange majorIvs, ValueRange leadingOffsets,
@ -379,7 +381,7 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
if (options.unroll)
result = vector_extract(xferOp.vector(), majorIvs);
else
result = std_load(alloc, majorIvs);
result = memref_load(alloc, majorIvs);
auto map =
getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
ArrayAttr masked;
@ -560,7 +562,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
// Conservative lowering to scalar load / stores.
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
StdIndexedValue remote(transfer.source());
MemRefIndexedValue remote(transfer.source());
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
int coalescedIdx = computeCoalescedIndex(transfer);
@ -579,7 +581,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
// 2. Emit alloc-copy-load-dealloc.
MLIRContext *ctx = op->getContext();
Value tmp = setAllocAtFunctionEntry(tmpMemRefType(transfer), transfer);
StdIndexedValue local(tmp);
MemRefIndexedValue local(tmp);
loopNestBuilder(lbs, ubs, steps, [&](ValueRange loopIvs) {
auto ivsStorage = llvm::to_vector<8>(loopIvs);
// Swap the ivs which will reorder memory accesses.
@ -601,7 +603,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
rewriter, cast<VectorTransferOpInterface>(transfer.getOperation()), ivs,
memRefBoundsCapture, loadValue, loadPadding);
});
Value vectorValue = std_load(vector_type_cast(tmp));
Value vectorValue = memref_load(vector_type_cast(tmp));
// 3. Propagate.
rewriter.replaceOp(op, vectorValue);
@ -646,7 +648,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
StdIndexedValue remote(transfer.source());
MemRefIndexedValue remote(transfer.source());
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
Value vectorValue(transfer.vector());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
@ -665,9 +667,9 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
// 2. Emit alloc-store-copy-dealloc.
Value tmp = setAllocAtFunctionEntry(tmpMemRefType(transfer), transfer);
StdIndexedValue local(tmp);
MemRefIndexedValue local(tmp);
Value vec = vector_type_cast(tmp);
std_store(vectorValue, vec);
memref_store(vectorValue, vec);
loopNestBuilder(lbs, ubs, steps, [&](ValueRange loopIvs) {
auto ivsStorage = llvm::to_vector<8>(loopIvs);
// Swap the ivsStorage which will reorder memory accesses.

View File

@ -8,6 +8,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
@ -64,7 +65,7 @@ remainsLegalAfterInline(Value value, Region *src, Region *dest,
// op won't be top-level anymore after inlining.
Attribute operandCst;
return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
value.getDefiningOp<DimOp>();
value.getDefiningOp<memref::DimOp>();
}
/// Checks if all values known to be legal affine dimensions or symbols in `src`
@ -295,7 +296,7 @@ bool mlir::isValidDim(Value value, Region *region) {
return applyOp.isValidDim(region);
// The dim op is okay if its operand memref/tensor is defined at the top
// level.
if (auto dimOp = dyn_cast<DimOp>(op))
if (auto dimOp = dyn_cast<memref::DimOp>(op))
return isTopLevelValue(dimOp.memrefOrTensor());
return false;
}
@ -317,9 +318,8 @@ static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
}
/// Returns true if the result of the dim op is a valid symbol for `region`.
static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
// The dim op is okay if its operand memref/tensor is defined at the top
// level.
static bool isDimOpValidSymbol(memref::DimOp dimOp, Region *region) {
// The dim op is okay if its operand memref is defined at the top level.
if (isTopLevelValue(dimOp.memrefOrTensor()))
return true;
@ -328,14 +328,14 @@ static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
if (dimOp.memrefOrTensor().isa<BlockArgument>())
return false;
// The dim op is also okay if its operand memref/tensor is a view/subview
// whose corresponding size is a valid symbol.
// The dim op is also okay if its operand memref is a view/subview whose
// corresponding size is a valid symbol.
Optional<int64_t> index = dimOp.getConstantIndex();
assert(index.hasValue() &&
"expect only `dim` operations with a constant index");
int64_t i = index.getValue();
return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
.Case<ViewOp, SubViewOp, AllocOp>(
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
[&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
.Default([](Operation *) { return false; });
}
@ -404,7 +404,7 @@ bool mlir::isValidSymbol(Value value, Region *region) {
return applyOp.isValidSymbol(region);
// Dim op results could be valid symbols at any level.
if (auto dimOp = dyn_cast<DimOp>(defOp))
if (auto dimOp = dyn_cast<memref::DimOp>(defOp))
return isDimOpValidSymbol(dimOp, region);
// Check for values dominating `region`'s parent op.
@ -915,12 +915,12 @@ void AffineApplyOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<MemRefCastOp>();
auto cast = operand.get().getDefiningOp<memref::CastOp>();
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
@ -2254,7 +2254,8 @@ LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
// AffineMinMaxOpBase
//===----------------------------------------------------------------------===//
template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
template <typename T>
static LogicalResult verifyAffineMinMaxOp(T op) {
// Verify that operand count matches affine map dimension and symbol count.
if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
return op.emitOpError(
@ -2262,7 +2263,8 @@ template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
return success();
}
template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
template <typename T>
static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName());
auto operands = op.getOperands();
unsigned numDims = op.map().getNumDims();

View File

@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAffine
MLIREDSC
MLIRIR
MLIRLoopLikeInterface
MLIRMemRef
MLIRSideEffectInterfaces
MLIRStandard
)

View File

@ -23,6 +23,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"

View File

@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
MLIRAffineUtils
MLIREDSC
MLIRIR
MLIRMemRef
MLIRPass
MLIRSideEffectInterfaces
MLIRStandard

View File

@ -19,6 +19,11 @@ void registerDialect(DialectRegistry &registry);
namespace linalg {
class LinalgDialect;
} // end namespace linalg
namespace memref {
class MemRefDialect;
} // end namespace memref
namespace vector {
class VectorDialect;
} // end namespace vector

View File

@ -9,6 +9,7 @@ add_subdirectory(GPU)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)

View File

@ -35,6 +35,7 @@ add_mlir_dialect_library(MLIRGPU
MLIRAsync
MLIREDSC
MLIRIR
MLIRMemRef
MLIRLLVMIR
MLIRLLVMToLLVMIRTranslation
MLIRSCF

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@ -107,7 +108,7 @@ struct GpuAllReduceRewriter {
createPredicatedBlock(isFirstLane, [&] {
Value subgroupId = getDivideBySubgroupSize(invocationIdx);
Value index = create<IndexCastOp>(indexType, subgroupId);
create<StoreOp>(subgroupReduce, buffer, index);
create<memref::StoreOp>(subgroupReduce, buffer, index);
});
create<gpu::BarrierOp>();
@ -124,27 +125,29 @@ struct GpuAllReduceRewriter {
Value zero = create<ConstantIndexOp>(0);
createPredicatedBlock(isValidSubgroup, [&] {
Value index = create<IndexCastOp>(indexType, invocationIdx);
Value value = create<LoadOp>(valueType, buffer, index);
Value value = create<memref::LoadOp>(valueType, buffer, index);
Value result =
createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
create<StoreOp>(result, buffer, zero);
create<memref::StoreOp>(result, buffer, zero);
});
// Synchronize workgroup and load result from workgroup memory.
create<gpu::BarrierOp>();
Value result = create<LoadOp>(valueType, buffer, zero);
Value result = create<memref::LoadOp>(valueType, buffer, zero);
rewriter.replaceOp(reduceOp, result);
}
private:
// Shortcut to create an op from rewriter using loc as the first argument.
template <typename T, typename... Args> T create(Args... args) {
template <typename T, typename... Args>
T create(Args... args) {
return rewriter.create<T>(loc, std::forward<Args>(args)...);
}
// Creates dimension op of type T, with the result casted to int32.
template <typename T> Value getDimOp(StringRef dimension) {
template <typename T>
Value getDimOp(StringRef dimension) {
Value dim = create<T>(indexType, rewriter.getStringAttr(dimension));
return create<IndexCastOp>(int32Type, dim);
}
@ -236,7 +239,8 @@ private:
}
/// Returns an accumulator factory that creates an op of type T.
template <typename T> AccumulatorFactory getFactory() {
template <typename T>
AccumulatorFactory getFactory() {
return [&](Value lhs, Value rhs) {
return create<T>(lhs.getType(), lhs, rhs);
};

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/GPU/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@ -58,7 +59,7 @@ static void injectGpuIndexOperations(Location loc, Region &launchFuncOpBody,
/// operations may not have side-effects, as otherwise sinking (and hence
/// duplicating them) is not legal.
static bool isSinkingBeneficiary(Operation *op) {
return isa<ConstantOp, DimOp, SelectOp, CmpIOp>(op);
return isa<ConstantOp, memref::DimOp, SelectOp, CmpIOp>(op);
}
/// For a given operation `op`, computes whether it is beneficial to sink the

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/GPU/MemoryPromotion.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Pass/Pass.h"
@ -82,7 +83,7 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
loopNestBuilder(lbs, ubs, steps, [&](ValueRange loopIvs) {
ivs.assign(loopIvs.begin(), loopIvs.end());
auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank);
StdIndexedValue fromHandle(from), toHandle(to);
MemRefIndexedValue fromHandle(from), toHandle(to);
toHandle(activeIvs) = fromHandle(activeIvs);
});

View File

@ -7,5 +7,6 @@ add_mlir_dialect_library(MLIRLinalgAnalysis
LINK_LIBS PUBLIC
MLIRIR
MLIRLinalg
MLIRMemRef
MLIRStandard
)

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
@ -48,7 +49,7 @@ Value Aliases::find(Value v) {
// the aliasing further.
if (isa<RegionBranchOpInterface>(defOp))
return v;
if (isa<TensorToMemrefOp>(defOp))
if (isa<memref::BufferCastOp>(defOp))
return v;
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {

View File

@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRLinalgEDSC
MLIRAffineEDSC
MLIRLinalg
MLIRMath
MLIRMemRef
MLIRSCF
MLIRStandard
)

View File

@ -19,5 +19,6 @@ add_mlir_dialect_library(MLIRLinalg
MLIRSideEffectInterfaces
MLIRViewLikeInterface
MLIRStandard
MLIRMemRef
MLIRTensor
)

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "llvm/ADT/SmallSet.h"
@ -187,7 +188,7 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
for (Value v : getShapedOperands()) {
ShapedType t = v.getType().template cast<ShapedType>();
for (unsigned i = 0, e = t.getRank(); i < e; ++i)
res.push_back(b.create<DimOp>(loc, v, i));
res.push_back(b.create<memref::DimOp>(loc, v, i));
}
return res;
}

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
@ -109,12 +110,12 @@ static void dispatchIndexOpFoldResult(OpFoldResult ofr,
/// ```
/// someop(memrefcast) -> someop
/// ```
/// It folds the source of the memref_cast into the root operation directly.
/// It folds the source of the memref.cast into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
if (castOp && canFoldIntoConsumerOp(castOp)) {
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
@ -776,10 +777,10 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
/// - A constant value if the size is static along the dimension.
/// - The dynamic value that defines the size of the result of
/// `linalg.init_tensor` op.
struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
struct ReplaceDimOfInitTensorOp : public OpRewritePattern<memref::DimOp> {
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
if (!initTensorOp)
@ -986,7 +987,7 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
assert(rankedTensorType.hasStaticShape());
int rank = rankedTensorType.getRank();
for (int i = 0; i < rank; ++i) {
auto dimOp = builder.createOrFold<DimOp>(loc, source, i);
auto dimOp = builder.createOrFold<memref::DimOp>(loc, source, i);
auto resultDimSize = builder.createOrFold<ConstantIndexOp>(
loc, rankedTensorType.getDimSize(i));
auto highValue = builder.createOrFold<SubIOp>(loc, resultDimSize, dimOp);
@ -1292,7 +1293,7 @@ getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
AffineExpr expr;
SmallVector<Value, 2> dynamicDims;
for (auto dim : llvm::seq(startPos, endPos + 1)) {
dynamicDims.push_back(builder.create<DimOp>(loc, src, dim));
dynamicDims.push_back(builder.create<memref::DimOp>(loc, src, dim));
AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
expr = (expr ? expr * currExpr : currExpr);
}
@ -1361,7 +1362,7 @@ static Value getExpandedOutputDimFromInputShape(
"dimensions");
linearizedStaticDim *= d.value();
}
Value sourceDim = builder.create<DimOp>(loc, src, sourceDimPos);
Value sourceDim = builder.create<memref::DimOp>(loc, src, sourceDimPos);
return applyMapToValues(
builder, loc,
AffineMap::get(
@ -1637,9 +1638,9 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
};
/// Canonicalize dim ops that use the output shape with dim of the input.
struct ReplaceDimOfReshapeOpResult : OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
Value dimValue = dimOp.memrefOrTensor();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
@ -2445,24 +2446,25 @@ struct FoldTensorCastOp : public RewritePattern {
}
};
/// Replaces std.dim operations that use the result of a LinalgOp (on tensors)
/// with std.dim operations that use one of the arguments. For example,
/// Replaces memref.dim operations that use the result of a LinalgOp (on
/// tensors) with memref.dim operations that use one of the arguments. For
/// example,
///
/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
/// %1 = dim %0, %c0
/// %1 = memref.dim %0, %c0
///
/// with
///
/// %1 = dim %arg0, %c0
/// %1 = memref.dim %arg0, %c0
///
/// where possible. With this the result of the `linalg.matmul` is not used in
/// dim operations. If the value produced is replaced with another value (say by
/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
/// used in a dim op that would prevent the DCE of this op.
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<memref::DimOp> {
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
Value dimValue = dimOp.memrefOrTensor();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
@ -2479,7 +2481,7 @@ struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
if (!operandDimValue) {
// Its always possible to replace using the corresponding `outs`
// parameter.
operandDimValue = rewriter.create<DimOp>(
operandDimValue = rewriter.create<memref::DimOp>(
dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
}
rewriter.replaceOp(dimOp, *operandDimValue);

View File

@ -25,8 +25,8 @@ using namespace ::mlir::linalg;
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
auto memrefType = memref.getType().cast<MemRefType>();
auto alloc =
b.create<AllocOp>(loc, memrefType, getDynOperands(loc, memref, b));
auto alloc = b.create<memref::AllocOp>(loc, memrefType,
getDynOperands(loc, memref, b));
b.create<linalg::CopyOp>(loc, memref, alloc);
return alloc;
}
@ -60,17 +60,17 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
continue;
}
if (auto alloc = resultTensor.getDefiningOp<AllocOp>()) {
if (auto alloc = resultTensor.getDefiningOp<memref::AllocOp>()) {
resultBuffers.push_back(resultTensor);
continue;
}
// Allocate buffers for statically-shaped results.
if (memrefType.hasStaticShape()) {
resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
continue;
}
resultBuffers.push_back(b.create<AllocOp>(
resultBuffers.push_back(b.create<memref::AllocOp>(
loc, memrefType, getDynOperands(loc, resultTensor, b)));
}
return success();
@ -148,7 +148,7 @@ public:
matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
rewriter.replaceOpWithNewOp<AllocOp>(
rewriter.replaceOpWithNewOp<memref::AllocOp>(
op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
adaptor.sizes());
return success();
@ -231,9 +231,9 @@ public:
// op.sizes() capture exactly the dynamic alloc operands matching the
// subviewMemRefType thanks to subview/subtensor canonicalization and
// verification.
Value alloc =
rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
Value subView = rewriter.create<SubViewOp>(
Value alloc = rewriter.create<memref::AllocOp>(
op.getLoc(), subviewMemRefType, op.sizes());
Value subView = rewriter.create<memref::SubViewOp>(
op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
op.getMixedStrides());
rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
@ -243,8 +243,8 @@ public:
};
/// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
/// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
/// tensor_to_memref and tensor_load are inserted automatically by the
/// %t` to an buffer_cast + subview + copy + tensor_load pattern.
/// buffer_cast and tensor_load are inserted automatically by the
/// conversion infra:
/// ```
/// %sv = subview %dest [offsets][sizes][strides]
@ -273,7 +273,7 @@ public:
assert(destMemRef.getType().isa<MemRefType>());
// Take a subview to copy the small memref.
Value subview = rewriter.create<SubViewOp>(
Value subview = rewriter.create<memref::SubViewOp>(
op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
op.getMixedStrides());
// Copy the small memref.
@ -295,7 +295,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
// Mark all Standard operations legal.
target.addLegalDialect<AffineDialect, math::MathDialect,
StandardOpsDialect>();
memref::MemRefDialect, StandardOpsDialect>();
target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
// Mark all Linalg operations illegal as long as they work on tensors.

View File

@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAnalysis
MLIREDSC
MLIRIR
MLIRMemRef
MLIRLinalgAnalysis
MLIRLinalgEDSC
MLIRLinalg

View File

@ -18,6 +18,8 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
@ -104,11 +106,12 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
inferShapeComponents(map, loopRanges, offsets, sizes, strides);
Value shape = en.value();
Value sub = shape.getType().isa<MemRefType>()
? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
.getResult()
: b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
.getResult();
Value sub =
shape.getType().isa<MemRefType>()
? b.create<memref::SubViewOp>(loc, shape, offsets, sizes, strides)
.getResult()
: b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
.getResult();
clonedShapes.push_back(sub);
}
// Append the other operands.
@ -177,8 +180,8 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
// `ViewInterface`. The interface needs a `getOrCreateRanges` method which
// currently returns a `linalg.range`. The fix here is to move this op to
// `std` dialect and add the method to `ViewInterface`.
if (fromSubViewOpOnly &&
!isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
en.value().getDefiningOp()))
continue;
unsigned idx = en.index();
@ -227,9 +230,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
<< "existing LoopRange: " << loopRanges[i] << "\n");
else {
auto shapeDim = getShapeDefiningLoopRange(producer, i);
loopRanges[i] = Range{std_constant_index(0),
std_dim(shapeDim.shape, shapeDim.dimension),
std_constant_index(1)};
Value dim = memref_dim(shapeDim.shape, shapeDim.dimension);
loopRanges[i] = Range{std_constant_index(0), dim, std_constant_index(1)};
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
}
}
@ -242,7 +244,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
Value shapedOperand, unsigned dim) {
Operation *shapeProducingOp = shapedOperand.getDefiningOp();
if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
return subViewOp.getOrCreateRanges(b, loc)[dim];
if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
return subTensorOp.getOrCreateRanges(b, loc)[dim];
@ -425,7 +427,7 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
// Must be a subview or a slice to guarantee there are loops we can fuse
// into.
auto subView = consumerOpOperand.get().getDefiningOp<SubViewOp>();
auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
if (!subView) {
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
return llvm::None;

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
@ -200,7 +201,7 @@ Value getPaddedInput(Value input, ArrayRef<Value> indices,
conds.push_back(leftOutOfBound);
else
conds.push_back(conds.back() || leftOutOfBound);
Value rightBound = std_dim(input, idx);
Value rightBound = memref_dim(input, idx);
conds.push_back(conds.back() || (sge(dim, rightBound)));
// When padding is involved, the indices will only be shifted to negative,
@ -307,12 +308,12 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
IndexedValueType F(convOp.filter()), O(convOp.output());
// Emit scalar form. Padded conv involves an affine.max in the memory access
// which is not allowed by affine.load. Override to use an StdIndexedValue
// which is not allowed by affine.load. Override to use an MemRefIndexedValue
// when there is non-zero padding.
if (hasPadding(convOp)) {
Type type = convOp.input().getType().cast<MemRefType>().getElementType();
Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type));
Value paddedInput = getPaddedInput<StdIndexedValue>(
Value paddedInput = getPaddedInput<MemRefIndexedValue>(
convOp.input(), imIdx,
/* Only need to pad the window dimensions */
{0, static_cast<int>(imIdx.size()) - 1}, padValue);
@ -338,9 +339,9 @@ static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) {
Type type =
op.input().getType().template cast<MemRefType>().getElementType();
Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type));
return getPaddedInput<StdIndexedValue>(op.input(), inputIndices,
/*Pad every dimension*/ {},
padValue);
return getPaddedInput<MemRefIndexedValue>(op.input(), inputIndices,
/*Pad every dimension*/ {},
padValue);
}
IndexedValueType input(op.input());
return input(inputIndices);
@ -546,7 +547,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp,
MLIRContext *context = funcOp.getContext();
OwningRewritePatternList patterns;
patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
DimOp::getCanonicalizationPatterns(patterns, context);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldAffineOp>(context);
// Just apply the patterns greedily.
@ -593,12 +594,18 @@ struct FoldAffineOp : public RewritePattern {
struct LowerToAffineLoops
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector);
}
};
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, scf::SCFDialect>();
}
void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector);
}

View File

@ -26,6 +26,10 @@ namespace scf {
class SCFDialect;
} // end namespace scf
namespace memref {
class MemRefDialect;
} // end namespace memref
namespace vector {
class VectorDialect;
} // end namespace vector

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
@ -38,9 +39,9 @@ using llvm::MapVector;
using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
using folded_linalg_range = FoldedValueBuilder<linalg::RangeOp>;
using folded_std_dim = FoldedValueBuilder<DimOp>;
using folded_std_subview = FoldedValueBuilder<SubViewOp>;
using folded_std_view = FoldedValueBuilder<ViewOp>;
using folded_memref_dim = FoldedValueBuilder<memref::DimOp>;
using folded_memref_subview = FoldedValueBuilder<memref::SubViewOp>;
using folded_memref_view = FoldedValueBuilder<memref::ViewOp>;
#define DEBUG_TYPE "linalg-promotion"
@ -59,22 +60,22 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
if (!dynamicBuffers)
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
return options.useAlloca
? std_alloca(MemRefType::get(width * cst.getValue(),
IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr)
? memref_alloca(MemRefType::get(width * cst.getValue(),
IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr)
.value
: std_alloc(MemRefType::get(width * cst.getValue(),
IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr)
: memref_alloc(MemRefType::get(width * cst.getValue(),
IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr)
.value;
Value mul =
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
return options.useAlloca
? std_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr)
? memref_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr)
.value
: std_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr)
: memref_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr)
.value;
}
@ -82,10 +83,12 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
/// no call back to do so is provided. The default is to allocate a
/// memref<..xi8> and return a view to get a memref type of shape
/// boundingSubViewSize.
static Optional<Value> defaultAllocBufferCallBack(
const LinalgPromotionOptions &options, OpBuilder &builder,
SubViewOp subView, ArrayRef<Value> boundingSubViewSize, bool dynamicBuffers,
Optional<unsigned> alignment, OperationFolder *folder) {
static Optional<Value>
defaultAllocBufferCallBack(const LinalgPromotionOptions &options,
OpBuilder &builder, memref::SubViewOp subView,
ArrayRef<Value> boundingSubViewSize,
bool dynamicBuffers, Optional<unsigned> alignment,
OperationFolder *folder) {
ShapedType viewType = subView.getType();
int64_t rank = viewType.getRank();
(void)rank;
@ -100,7 +103,7 @@ static Optional<Value> defaultAllocBufferCallBack(
dynamicBuffers, folder, alignment);
SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
ShapedType::kDynamicSize);
Value view = folded_std_view(
Value view = folded_memref_view(
folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
zero, boundingSubViewSize);
return view;
@ -112,10 +115,10 @@ static Optional<Value> defaultAllocBufferCallBack(
static LogicalResult
defaultDeallocBufferCallBack(const LinalgPromotionOptions &options,
OpBuilder &b, Value fullLocalView) {
auto viewOp = fullLocalView.getDefiningOp<ViewOp>();
auto viewOp = fullLocalView.getDefiningOp<memref::ViewOp>();
assert(viewOp && "expected full local view to be a ViewOp");
if (!options.useAlloca)
std_dealloc(viewOp.source());
memref_dealloc(viewOp.source());
return success();
}
@ -161,21 +164,21 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
if (options.operandsToPromote && !options.operandsToPromote->count(idx))
continue;
auto *op = linalgOp.getShapedOperand(idx).getDefiningOp();
if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
subViews[idx] = sv;
useFullTileBuffers[sv] = vUseFullTileBuffers[idx];
}
}
allocationFn =
(options.allocationFn ? *(options.allocationFn)
: [&](OpBuilder &builder, SubViewOp subViewOp,
ArrayRef<Value> boundingSubViewSize,
OperationFolder *folder) -> Optional<Value> {
return defaultAllocBufferCallBack(options, builder, subViewOp,
boundingSubViewSize, dynamicBuffers,
alignment, folder);
});
allocationFn = (options.allocationFn
? *(options.allocationFn)
: [&](OpBuilder &builder, memref::SubViewOp subViewOp,
ArrayRef<Value> boundingSubViewSize,
OperationFolder *folder) -> Optional<Value> {
return defaultAllocBufferCallBack(options, builder, subViewOp,
boundingSubViewSize, dynamicBuffers,
alignment, folder);
});
deallocationFn =
(options.deallocationFn
? *(options.deallocationFn)
@ -209,7 +212,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
// boundary tiles. For now this is done with an unconditional `fill` op followed
// by a partial `copy` op.
Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
OpBuilder &b, Location loc, SubViewOp subView,
OpBuilder &b, Location loc, memref::SubViewOp subView,
AllocBufferCallbackFn allocationFn, OperationFolder *folder) {
ScopedContext scopedContext(b, loc);
auto viewType = subView.getType();
@ -227,7 +230,8 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
(!sizeAttr) ? rangeValue.size : b.create<ConstantOp>(loc, sizeAttr);
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
fullSizes.push_back(size);
partialSizes.push_back(folded_std_dim(folder, subView, en.index()).value);
partialSizes.push_back(
folded_memref_dim(folder, subView, en.index()).value);
}
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
// If a callback is not specified, then use the default implementation for
@ -238,7 +242,7 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0));
SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1));
auto partialLocalView =
folded_std_subview(folder, *fullLocalView, zeros, partialSizes, ones);
folded_memref_subview(folder, *fullLocalView, zeros, partialSizes, ones);
return PromotionInfo{*fullLocalView, partialLocalView};
}
@ -253,7 +257,8 @@ promoteSubViews(OpBuilder &b, Location loc,
MapVector<unsigned, PromotionInfo> promotionInfoMap;
for (auto v : options.subViews) {
SubViewOp subView = cast<SubViewOp>(v.second.getDefiningOp());
memref::SubViewOp subView =
cast<memref::SubViewOp>(v.second.getDefiningOp());
Optional<PromotionInfo> promotionInfo = promoteSubviewAsNewBuffer(
b, loc, subView, options.allocationFn, folder);
if (!promotionInfo)
@ -277,8 +282,9 @@ promoteSubViews(OpBuilder &b, Location loc,
auto info = promotionInfoMap.find(v.first);
if (info == promotionInfoMap.end())
continue;
if (failed(options.copyInFn(b, cast<SubViewOp>(v.second.getDefiningOp()),
info->second.partialLocalView)))
if (failed(options.copyInFn(
b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
info->second.partialLocalView)))
return {};
}
return promotionInfoMap;
@ -353,7 +359,7 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
return failure();
// Check that at least one of the requested operands is indeed a subview.
for (auto en : llvm::enumerate(linOp.getShapedOperands())) {
auto sv = isa_and_nonnull<SubViewOp>(en.value().getDefiningOp());
auto sv = isa_and_nonnull<memref::SubViewOp>(en.value().getDefiningOp());
if (sv) {
if (!options.operandsToPromote.hasValue() ||
options.operandsToPromote->count(en.index()))

View File

@ -44,11 +44,11 @@ class TensorFromPointerConverter
};
/// Sparse conversion rule for dimension accesses.
class TensorToDimSizeConverter : public OpConversionPattern<DimOp> {
class TensorToDimSizeConverter : public OpConversionPattern<memref::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(DimOp op, ArrayRef<Value> operands,
matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
return failure();

View File

@ -533,13 +533,13 @@ static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
// positions for the output tensor. Currently this results in functional,
// but slightly imprecise IR, so it is put under an experimental option.
if (codegen.options.fastOutput)
return rewriter.create<TensorToMemrefOp>(loc, denseTp, tensor);
return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
// By default, a new buffer is allocated which is initialized to the
// tensor defined in the outs() clause. This is always correct but
// introduces a dense initialization component that may negatively
// impact the running complexity of the sparse kernel.
Value init = rewriter.create<TensorToMemrefOp>(loc, denseTp, tensor);
Value alloc = rewriter.create<AllocOp>(loc, denseTp, args);
Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
rewriter.create<linalg::CopyOp>(loc, init, alloc);
return alloc;
}
@ -585,8 +585,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
}
// Find lower and upper bound in current dimension.
Value up;
if (shape[d] == TensorType::kDynamicSize) {
up = rewriter.create<DimOp>(loc, tensor, d);
if (shape[d] == MemRefType::kDynamicSize) {
up = rewriter.create<memref::DimOp>(loc, tensor, d);
args.push_back(up);
} else {
up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
@ -600,7 +600,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
auto denseTp = MemRefType::get(shape, tensorType.getElementType());
if (t < numInputs)
codegen.buffers[t] =
rewriter.create<TensorToMemrefOp>(loc, denseTp, tensor);
rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
else
codegen.buffers[t] =
genOutputBuffer(codegen, rewriter, op, denseTp, args);
@ -716,7 +716,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
Value ptr = codegen.buffers[tensor];
if (codegen.curVecLength > 1)
return genVectorLoad(codegen, rewriter, ptr, args);
return rewriter.create<LoadOp>(loc, ptr, args);
return rewriter.create<memref::LoadOp>(loc, ptr, args);
}
/// Generates a store on a dense tensor.
@ -744,7 +744,7 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
if (codegen.curVecLength > 1)
genVectorStore(codegen, rewriter, rhs, ptr, args);
else
rewriter.create<StoreOp>(loc, rhs, ptr, args);
rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
}
/// Generates a pointer/index load from the sparse storage scheme.
@ -752,7 +752,7 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
Value ptr, Value s) {
if (codegen.curVecLength > 1)
return genVectorLoad(codegen, rewriter, ptr, {s});
Value load = rewriter.create<LoadOp>(loc, ptr, s);
Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
return load.getType().isa<IndexType>()
? load
: rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
@ -1345,8 +1345,8 @@ public:
CodeGen codegen(options, numTensors, numLoops);
genBuffers(merger, codegen, rewriter, op);
genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
Value result =
rewriter.create<TensorLoadOp>(op.getLoc(), codegen.buffers.back());
Value result = rewriter.create<memref::TensorLoadOp>(
op.getLoc(), codegen.buffers.back());
rewriter.replaceOp(op, result);
return success();
}

View File

@ -17,6 +17,8 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -34,7 +36,6 @@ using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
using namespace mlir::scf;
#define DEBUG_TYPE "linalg-tiling"
static bool isZero(Value v) {
@ -144,9 +145,9 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
// scf.for %k = %c0 to operand_dim_0 step %c10 {
// scf.for %l = %c0 to operand_dim_1 step %c25 {
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1]
// : memref<50x100xf32> to memref<?x?xf32, #strided>
// linalg.indexed_generic pointwise_2d_trait %4, %5 {
// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
@ -262,7 +263,7 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
for (unsigned r = 0; r < rank; ++r) {
if (!isTiled(map.getSubMap({r}), tileSizes)) {
offsets.push_back(b.getIndexAttr(0));
sizes.push_back(std_dim(shapedOp, r).value);
sizes.push_back(memref_dim(shapedOp, r).value);
strides.push_back(b.getIndexAttr(1));
continue;
}
@ -290,7 +291,7 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
getAffineDimExpr(/*position=*/1, b.getContext()) -
getAffineDimExpr(/*position=*/2, b.getContext())},
b.getContext());
auto d = std_dim(shapedOp, r);
Value d = memref_dim(shapedOp, r);
SmallVector<Value, 4> operands{size, d, offset};
fullyComposeAffineMapAndOperands(&minMap, &operands);
size = affine_min(b.getIndexType(), minMap, operands);
@ -302,7 +303,7 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
if (shapedType.isa<MemRefType>())
res.push_back(
b.create<SubViewOp>(loc, shapedOp, offsets, sizes, strides));
b.create<memref::SubViewOp>(loc, shapedOp, offsets, sizes, strides));
else
res.push_back(
b.create<SubTensorOp>(loc, shapedOp, offsets, sizes, strides));
@ -474,7 +475,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
if (!options.tileSizeComputationFunction)
return llvm::None;
// Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
@ -564,9 +565,9 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
SubViewOp::getCanonicalizationPatterns(patterns, ctx);
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
ViewOp::getCanonicalizationPatterns(patterns, ctx);
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
CanonicalizationPatternList<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"

View File

@ -212,7 +212,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
auto sizes = llvm::to_vector<4>(llvm::map_range(
llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d);
auto dimOp = rewriter.create<memref::DimOp>(loc, std::get<0>(it), d);
newUsersOfOpToPad.insert(dimOp);
return dimOp.getResult();
}));

View File

@ -85,7 +85,7 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
}
/// Build a vector.transfer_read from `source` at indices set to all `0`.
/// If source has rank zero, build an std.load.
/// If source has rank zero, build an memref.load.
/// Return the produced value.
static Value buildVectorRead(OpBuilder &builder, Value source) {
edsc::ScopedContext scope(builder);
@ -94,11 +94,11 @@ static Value buildVectorRead(OpBuilder &builder, Value source) {
SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
return vector_transfer_read(vectorType, source, indices);
}
return std_load(source);
return memref_load(source);
}
/// Build a vector.transfer_write of `value` into `dest` at indices set to all
/// `0`. If `dest` has null rank, build an std.store.
/// `0`. If `dest` has null rank, build an memref.store.
/// Return the produced value or null if no value is produced.
static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
edsc::ScopedContext scope(builder);
@ -110,7 +110,7 @@ static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
value = vector_broadcast(vectorType, value);
write = vector_transfer_write(value, dest, indices);
} else {
write = std_store(value, dest);
write = memref_store(value, dest);
}
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
if (!write->getResults().empty())
@ -544,7 +544,7 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
rewriter.getAffineMapArrayAttr(indexingMaps),
rewriter.getStrArrayAttr(iteratorTypes));
rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
rewriter.eraseOp(op);
return success();
}
@ -667,12 +667,12 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
}
/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
static SubViewOp getSubViewUseIfUnique(Value v) {
SubViewOp subViewOp;
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
memref::SubViewOp subViewOp;
for (auto &u : v.getUses()) {
if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
if (subViewOp)
return SubViewOp();
return memref::SubViewOp();
subViewOp = newSubViewOp;
}
}
@ -686,14 +686,14 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
// Transfer into `view`.
Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return failure();
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return failure();
Value subView = subViewOp.getResult();
@ -765,12 +765,12 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
// Transfer into `viewOrAlloc`.
Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return failure();
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return failure();
Value subView = subViewOp.getResult();

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIRMemRef
MemRefDialect.cpp
MemRefOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
DEPENDS
MLIRMemRefOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRDialect
MLIRIR
)

View File

@ -0,0 +1,39 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::memref;
//===----------------------------------------------------------------------===//
// MemRefDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct MemRefInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
void mlir::memref::MemRefDialect::initialize() {
addOperations<DmaStartOp, DmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
>();
addInterfaces<MemRefInlinerInterface>();
}

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRSCF
MLIREDSC
MLIRIR
MLIRLoopLikeInterface
MLIRMemRef
MLIRSideEffectInterfaces
MLIRStandard
)

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/PatternMatch.h"
@ -568,7 +569,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
/// %t0 = ... : tensor_type
/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
/// ...
/// // %m is either tensor_to_memref(%bb00) or defined above the loop
/// // %m is either buffer_cast(%bb00) or defined above the loop
/// %m... : memref_type
/// ... // uses of %m with potential inplace updates
/// %new_tensor = tensor_load %m : memref_type
@ -578,7 +579,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
/// ```
///
/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
/// `%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load`
/// `%m = buffer_cast %bb0` op that feeds into the yielded `tensor_load`
/// op.
///
/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
@ -590,7 +591,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
///
/// The canonicalization rewrites the pattern as:
/// ```
/// // %m is either a tensor_to_memref or defined above
/// // %m is either a buffer_cast or defined above
/// %m... : memref_type
/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
/// ... // uses of %m with potential inplace updates
@ -601,7 +602,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
///
/// A later bbArg canonicalization will further rewrite as:
/// ```
/// // %m is either a tensor_to_memref or defined above
/// // %m is either a buffer_cast or defined above
/// %m... : memref_type
/// scf.for ... { // no iter_args
/// ... // uses of %m with potential inplace updates
@ -622,19 +623,18 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
Value yieldVal = yieldOp->getOperand(idx);
auto tensorLoadOp = yieldVal.getDefiningOp<TensorLoadOp>();
auto tensorLoadOp = yieldVal.getDefiningOp<memref::TensorLoadOp>();
bool isTensor = bbArg.getType().isa<TensorType>();
TensorToMemrefOp tensorToMemRefOp;
// Either bbArg has no use or it has a single tensor_to_memref use.
memref::BufferCastOp bufferCastOp;
// Either bbArg has no use or it has a single buffer_cast use.
if (bbArg.hasOneUse())
tensorToMemRefOp =
dyn_cast<TensorToMemrefOp>(*bbArg.getUsers().begin());
if (!isTensor || !tensorLoadOp ||
(!bbArg.use_empty() && !tensorToMemRefOp))
bufferCastOp =
dyn_cast<memref::BufferCastOp>(*bbArg.getUsers().begin());
if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !bufferCastOp))
continue;
// If tensorToMemRefOp is present, it must feed into the `tensorLoadOp`.
if (tensorToMemRefOp && tensorLoadOp.memref() != tensorToMemRefOp)
// If bufferCastOp is present, it must feed into the `tensorLoadOp`.
if (bufferCastOp && tensorLoadOp.memref() != bufferCastOp)
continue;
// TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
// must be before `tensorLoadOp` in the block so that the lastWrite
@ -644,18 +644,18 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
if (tensorLoadOp->getNextNode() != yieldOp)
continue;
// Clone the optional tensorToMemRefOp before forOp.
if (tensorToMemRefOp) {
// Clone the optional bufferCastOp before forOp.
if (bufferCastOp) {
rewriter.setInsertionPoint(forOp);
rewriter.replaceOpWithNewOp<TensorToMemrefOp>(
tensorToMemRefOp, tensorToMemRefOp.memref().getType(),
tensorToMemRefOp.tensor());
rewriter.replaceOpWithNewOp<memref::BufferCastOp>(
bufferCastOp, bufferCastOp.memref().getType(),
bufferCastOp.tensor());
}
// Clone the tensorLoad after forOp.
rewriter.setInsertionPointAfter(forOp);
Value newTensorLoad =
rewriter.create<TensorLoadOp>(loc, tensorLoadOp.memref());
rewriter.create<memref::TensorLoadOp>(loc, tensorLoadOp.memref());
Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
replacements.insert(std::make_pair(forOpResult, newTensorLoad));

View File

@ -8,6 +8,7 @@
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"

Some files were not shown because too many files have changed in this diff Show More