forked from OSchip/llvm-project
[MLIR] Create memref dialect and move dialect-specific ops from std.
Create the memref dialect and move dialect-specific ops from std dialect to this dialect. Moved ops: AllocOp -> MemRef_AllocOp AllocaOp -> MemRef_AllocaOp AssumeAlignmentOp -> MemRef_AssumeAlignmentOp DeallocOp -> MemRef_DeallocOp DimOp -> MemRef_DimOp MemRefCastOp -> MemRef_CastOp MemRefReinterpretCastOp -> MemRef_ReinterpretCastOp GetGlobalMemRefOp -> MemRef_GetGlobalOp GlobalMemRefOp -> MemRef_GlobalOp LoadOp -> MemRef_LoadOp PrefetchOp -> MemRef_PrefetchOp ReshapeOp -> MemRef_ReshapeOp StoreOp -> MemRef_StoreOp SubViewOp -> MemRef_SubViewOp TransposeOp -> MemRef_TransposeOp TensorLoadOp -> MemRef_TensorLoadOp TensorStoreOp -> MemRef_TensorStoreOp TensorToMemRefOp -> MemRef_BufferCastOp ViewOp -> MemRef_ViewOp The roadmap to split the memref dialect from std is discussed here: https://llvm.discourse.group/t/rfc-split-the-memref-dialect-from-std/2667 Differential Revision: https://reviews.llvm.org/D98041
This commit is contained in:
parent
a88371490d
commit
e2310704d8
|
@ -779,8 +779,8 @@ the deallocation of the source value.
|
|||
## Known Limitations
|
||||
|
||||
BufferDeallocation introduces additional copies using allocations from the
|
||||
“std” dialect (“std.alloc”). Analogous, all deallocations use the “std”
|
||||
dialect-free operation “std.dealloc”. The actual copy process is realized using
|
||||
“linalg.copy”. Furthermore, buffers are essentially immutable after their
|
||||
creation in a block. Another limitations are known in the case using
|
||||
unstructered control flow.
|
||||
“memref” dialect (“memref.alloc”). Analogous, all deallocations use the
|
||||
“memref” dialect-free operation “memref.dealloc”. The actual copy process is
|
||||
realized using “linalg.copy”. Furthermore, buffers are essentially immutable
|
||||
after their creation in a block. Another limitations are known in the case
|
||||
using unstructered control flow.
|
||||
|
|
|
@ -190,8 +190,8 @@ One convenient utility provided by the MLIR bufferization infrastructure is the
|
|||
`BufferizeTypeConverter`, which comes pre-loaded with the necessary conversions
|
||||
and materializations between `tensor` and `memref`.
|
||||
|
||||
In this case, the `StandardOpsDialect` is marked as legal, so the `tensor_load`
|
||||
and `tensor_to_memref` ops, which are inserted automatically by the dialect
|
||||
In this case, the `MemRefOpsDialect` is marked as legal, so the `tensor_load`
|
||||
and `buffer_cast` ops, which are inserted automatically by the dialect
|
||||
conversion framework as materializations, are legal. There is a helper
|
||||
`populateBufferizeMaterializationLegality`
|
||||
([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L53))
|
||||
|
@ -247,7 +247,7 @@ from the program.
|
|||
|
||||
The easiest way to write a finalizing bufferize pass is to not write one at all!
|
||||
MLIR provides a pass `finalizing-bufferize` which eliminates the `tensor_load` /
|
||||
`tensor_to_memref` materialization ops inserted by partial bufferization passes
|
||||
`buffer_cast` materialization ops inserted by partial bufferization passes
|
||||
and emits an error if that is not sufficient to remove all tensors from the
|
||||
program.
|
||||
|
||||
|
@ -268,7 +268,7 @@ recommended in new code. A helper,
|
|||
`populateEliminateBufferizeMaterializationsPatterns`
|
||||
([code](https://github.com/llvm/llvm-project/blob/a0b65a7bcd6065688189b3d678c42ed6af9603db/mlir/include/mlir/Transforms/Bufferize.h#L58))
|
||||
is available for such passes to provide patterns that eliminate `tensor_load`
|
||||
and `tensor_to_memref`.
|
||||
and `buffer_cast`.
|
||||
|
||||
## Changes since [the talk](#the-talk)
|
||||
|
||||
|
|
|
@ -406,9 +406,9 @@ into a form that will resemble:
|
|||
#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
|
||||
|
||||
func @example(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
%0 = memref_cast %arg0 : memref<?x?xf32> to memref<?x?xf32, #map0>
|
||||
%1 = memref_cast %arg1 : memref<?x?xf32> to memref<?x?xf32, #map0>
|
||||
%2 = memref_cast %arg2 : memref<?x?xf32> to memref<?x?xf32, #map0>
|
||||
%0 = memref.cast %arg0 : memref<?x?xf32> to memref<?x?xf32, #map0>
|
||||
%1 = memref.cast %arg1 : memref<?x?xf32> to memref<?x?xf32, #map0>
|
||||
%2 = memref.cast %arg2 : memref<?x?xf32> to memref<?x?xf32, #map0>
|
||||
call @pointwise_add(%0, %1, %2) : (memref<?x?xf32, #map0>, memref<?x?xf32, #map0>, memref<?x?xf32, #map0>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -518,9 +518,9 @@ A set of ops that manipulate metadata but do not move memory. These ops take
|
|||
generally alias the operand `view`. At the moment the existing ops are:
|
||||
|
||||
```
|
||||
* `std.view`,
|
||||
* `memref.view`,
|
||||
* `std.subview`,
|
||||
* `std.transpose`.
|
||||
* `memref.transpose`.
|
||||
* `linalg.range`,
|
||||
* `linalg.slice`,
|
||||
* `linalg.reshape`,
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# 'memref' Dialect
|
||||
|
||||
This dialect provides documentation for operations within the MemRef dialect.
|
||||
|
||||
**Please post an RFC on the [forum](https://llvm.discourse.group/c/mlir/31)
|
||||
before adding or changing any operation in this dialect.**
|
||||
|
||||
[TOC]
|
||||
|
||||
## Operations
|
||||
|
||||
[include "Dialects/MemRefOps.md"]
|
||||
|
||||
### 'dma_start' operation
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= `dma_start` ssa-use`[`ssa-use-list`]` `,`
|
||||
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
|
||||
ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)?
|
||||
`:` memref-type `,` memref-type `,` memref-type
|
||||
```
|
||||
|
||||
Starts a non-blocking DMA operation that transfers data from a source memref to
|
||||
a destination memref. The operands include the source and destination memref's
|
||||
each followed by its indices, size of the data transfer in terms of the number
|
||||
of elements (of the elemental type of the memref), a tag memref with its
|
||||
indices, and optionally two additional arguments corresponding to the stride (in
|
||||
terms of number of elements) and the number of elements to transfer per stride.
|
||||
The tag location is used by a dma_wait operation to check for completion. The
|
||||
indices of the source memref, destination memref, and the tag memref have the
|
||||
same restrictions as any load/store operation in an affine context (whenever DMA
|
||||
operations appear in an affine context). See
|
||||
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
|
||||
in affine contexts. This allows powerful static analysis and transformations in
|
||||
the presence of such DMAs including rescheduling, pipelining / overlap with
|
||||
computation, and checking for matching start/end operations. The source and
|
||||
destination memref need not be of the same dimensionality, but need to have the
|
||||
same elemental type.
|
||||
|
||||
For example, a `dma_start` operation that transfers 32 vector elements from a
|
||||
memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would be
|
||||
specified as shown below.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%size = constant 32 : index
|
||||
%tag = alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
|
||||
%idx = constant 0 : index
|
||||
dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] :
|
||||
memref<40 x 8 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 0>,
|
||||
memref<2 x 4 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 2>,
|
||||
memref<1 x i32>, affine_map<(d0) -> (d0)>, 4>
|
||||
```
|
||||
|
||||
### 'dma_wait' operation
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= `dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type
|
||||
```
|
||||
|
||||
Blocks until the completion of a DMA operation associated with the tag element
|
||||
specified with a tag memref and its indices. The operands include the tag memref
|
||||
followed by its indices and the number of elements associated with the DMA being
|
||||
waited on. The indices of the tag memref have the same restrictions as
|
||||
load/store indices.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
|
||||
```
|
|
@ -13,67 +13,3 @@ before adding or changing any operation in this dialect.**
|
|||
## Operations
|
||||
|
||||
[include "Dialects/StandardOps.md"]
|
||||
|
||||
### 'dma_start' operation
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= `dma_start` ssa-use`[`ssa-use-list`]` `,`
|
||||
ssa-use`[`ssa-use-list`]` `,` ssa-use `,`
|
||||
ssa-use`[`ssa-use-list`]` (`,` ssa-use `,` ssa-use)?
|
||||
`:` memref-type `,` memref-type `,` memref-type
|
||||
```
|
||||
|
||||
Starts a non-blocking DMA operation that transfers data from a source memref to
|
||||
a destination memref. The operands include the source and destination memref's
|
||||
each followed by its indices, size of the data transfer in terms of the number
|
||||
of elements (of the elemental type of the memref), a tag memref with its
|
||||
indices, and optionally two additional arguments corresponding to the stride (in
|
||||
terms of number of elements) and the number of elements to transfer per stride.
|
||||
The tag location is used by a dma_wait operation to check for completion. The
|
||||
indices of the source memref, destination memref, and the tag memref have the
|
||||
same restrictions as any load/store operation in an affine context (whenever DMA
|
||||
operations appear in an affine context). See
|
||||
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
|
||||
in affine contexts. This allows powerful static analysis and transformations in
|
||||
the presence of such DMAs including rescheduling, pipelining / overlap with
|
||||
computation, and checking for matching start/end operations. The source and
|
||||
destination memref need not be of the same dimensionality, but need to have the
|
||||
same elemental type.
|
||||
|
||||
For example, a `dma_start` operation that transfers 32 vector elements from a
|
||||
memref `%src` at location `[%i, %j]` to memref `%dst` at `[%k, %l]` would be
|
||||
specified as shown below.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%size = constant 32 : index
|
||||
%tag = alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
|
||||
%idx = constant 0 : index
|
||||
dma_start %src[%i, %j], %dst[%k, %l], %size, %tag[%idx] :
|
||||
memref<40 x 8 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 0>,
|
||||
memref<2 x 4 x vector<16xf32>, affine_map<(d0, d1) -> (d0, d1)>, 2>,
|
||||
memref<1 x i32>, affine_map<(d0) -> (d0)>, 4>
|
||||
```
|
||||
|
||||
### 'dma_wait' operation
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= `dma_wait` ssa-use`[`ssa-use-list`]` `,` ssa-use `:` memref-type
|
||||
```
|
||||
|
||||
Blocks until the completion of a DMA operation associated with the tag element
|
||||
specified with a tag memref and its indices. The operands include the tag memref
|
||||
followed by its indices and the number of elements associated with the DMA being
|
||||
waited on. The indices of the tag memref have the same restrictions as
|
||||
load/store indices.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
dma_wait %tag[%idx], %size : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
|
||||
```
|
||||
|
|
|
@ -200,7 +200,7 @@ for.
|
|||
### The `OpPointer` and `ConstOpPointer` Classes
|
||||
|
||||
The "typed operation" classes for registered operations (e.g. like `DimOp` for
|
||||
the "std.dim" operation in standard ops) contain a pointer to an operation and
|
||||
the "memref.dim" operation in memref ops) contain a pointer to an operation and
|
||||
provide typed APIs for processing it.
|
||||
|
||||
However, this is a problem for our current `const` design - `const DimOp` means
|
||||
|
|
|
@ -211,7 +211,7 @@ are nested inside of other operations that themselves have this trait.
|
|||
This trait is carried by region holding operations that define a new scope for
|
||||
automatic allocation. Such allocations are automatically freed when control is
|
||||
transferred back from the regions of such operations. As an example, allocations
|
||||
performed by [`std.alloca`](Dialects/Standard.md#stdalloca-allocaop) are
|
||||
performed by [`memref.alloca`](Dialects/MemRef.md#memrefalloca-allocaop) are
|
||||
automatically freed when control leaves the region of its closest surrounding op
|
||||
that has the trait AutomaticAllocationScope.
|
||||
|
||||
|
|
|
@ -50,8 +50,9 @@ framework, we need to provide two things (and an optional third):
|
|||
## Conversion Target
|
||||
|
||||
For our purposes, we want to convert the compute-intensive `Toy` operations into
|
||||
a combination of operations from the `Affine` `Standard` dialects for further
|
||||
optimization. To start off the lowering, we first define our conversion target:
|
||||
a combination of operations from the `Affine`, `MemRef` and `Standard` dialects
|
||||
for further optimization. To start off the lowering, we first define our
|
||||
conversion target:
|
||||
|
||||
```c++
|
||||
void ToyToAffineLoweringPass::runOnFunction() {
|
||||
|
@ -61,8 +62,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
|
|||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine` and `Standard` dialects.
|
||||
target.addLegalDialect<mlir::AffineDialect, mlir::StandardOpsDialect>();
|
||||
// `Affine`, `MemRef` and `Standard` dialects.
|
||||
target.addLegalDialect<mlir::AffineDialect, mlir::memref::MemRefDialect,
|
||||
mlir::StandardOpsDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a partial lowering of Toy operations to a combination of
|
||||
// affine loops and standard operations. This lowering expects that all calls
|
||||
// have been inlined, and all shapes have been resolved.
|
||||
// affine loops, memref operations and standard operations. This lowering
|
||||
// expects that all calls have been inlined, and all shapes have been resolved.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
|||
#include "toy/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -36,7 +37,7 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
|||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto alloc = rewriter.create<AllocOp>(loc, type);
|
||||
auto alloc = rewriter.create<memref::AllocOp>(loc, type);
|
||||
|
||||
// Make sure to allocate at the beginning of the block.
|
||||
auto *parentBlock = alloc->getBlock();
|
||||
|
@ -44,7 +45,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
|
||||
// Make sure to deallocate this alloc at the end of the block. This is fine
|
||||
// as toy functions have no control flow.
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
|
||||
dealloc->moveBefore(&parentBlock->back());
|
||||
return alloc;
|
||||
}
|
||||
|
@ -152,8 +153,8 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
|
||||
if (!valueShape.empty()) {
|
||||
for (auto i : llvm::seq<int64_t>(
|
||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||
} else {
|
||||
// This is the case of a tensor of rank 0.
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
|
@ -257,7 +258,7 @@ namespace {
|
|||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, StandardOpsDialect>();
|
||||
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
|
||||
}
|
||||
void runOnFunction() final;
|
||||
};
|
||||
|
@ -283,8 +284,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
|
|||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine` and `Standard` dialects.
|
||||
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
|
||||
// `Affine`, `MemRef` and `Standard` dialects.
|
||||
target.addLegalDialect<AffineDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a partial lowering of Toy operations to a combination of
|
||||
// affine loops and standard operations. This lowering expects that all calls
|
||||
// have been inlined, and all shapes have been resolved.
|
||||
// affine loops, memref operations and standard operations. This lowering
|
||||
// expects that all calls have been inlined, and all shapes have been resolved.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
|||
#include "toy/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -36,7 +37,7 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
|||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto alloc = rewriter.create<AllocOp>(loc, type);
|
||||
auto alloc = rewriter.create<memref::AllocOp>(loc, type);
|
||||
|
||||
// Make sure to allocate at the beginning of the block.
|
||||
auto *parentBlock = alloc->getBlock();
|
||||
|
@ -44,7 +45,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
|
||||
// Make sure to deallocate this alloc at the end of the block. This is fine
|
||||
// as toy functions have no control flow.
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
|
||||
dealloc->moveBefore(&parentBlock->back());
|
||||
return alloc;
|
||||
}
|
||||
|
@ -152,8 +153,8 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
|
||||
if (!valueShape.empty()) {
|
||||
for (auto i : llvm::seq<int64_t>(
|
||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||
} else {
|
||||
// This is the case of a tensor of rank 0.
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
|
@ -256,7 +257,7 @@ namespace {
|
|||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, StandardOpsDialect>();
|
||||
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
|
||||
}
|
||||
void runOnFunction() final;
|
||||
};
|
||||
|
@ -282,8 +283,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
|
|||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine` and `Standard` dialects.
|
||||
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
|
||||
// `Affine`, `MemRef` and `Standard` dialects.
|
||||
target.addLegalDialect<AffineDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -91,7 +92,8 @@ public:
|
|||
|
||||
// Generate a call to printf for the current element of the loop.
|
||||
auto printOp = cast<toy::PrintOp>(op);
|
||||
auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs);
|
||||
auto elementLoad =
|
||||
rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
|
||||
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
|
||||
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a partial lowering of Toy operations to a combination of
|
||||
// affine loops and standard operations. This lowering expects that all calls
|
||||
// have been inlined, and all shapes have been resolved.
|
||||
// affine loops, memref operations and standard operations. This lowering
|
||||
// expects that all calls have been inlined, and all shapes have been resolved.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
|||
#include "toy/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -36,7 +37,7 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
|||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto alloc = rewriter.create<AllocOp>(loc, type);
|
||||
auto alloc = rewriter.create<memref::AllocOp>(loc, type);
|
||||
|
||||
// Make sure to allocate at the beginning of the block.
|
||||
auto *parentBlock = alloc->getBlock();
|
||||
|
@ -44,7 +45,7 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
|
||||
// Make sure to deallocate this alloc at the end of the block. This is fine
|
||||
// as toy functions have no control flow.
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
|
||||
dealloc->moveBefore(&parentBlock->back());
|
||||
return alloc;
|
||||
}
|
||||
|
@ -152,8 +153,8 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
|
||||
if (!valueShape.empty()) {
|
||||
for (auto i : llvm::seq<int64_t>(
|
||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||
0, *std::max_element(valueShape.begin(), valueShape.end())))
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
|
||||
} else {
|
||||
// This is the case of a tensor of rank 0.
|
||||
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
|
@ -257,7 +258,7 @@ namespace {
|
|||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, StandardOpsDialect>();
|
||||
registry.insert<AffineDialect, memref::MemRefDialect, StandardOpsDialect>();
|
||||
}
|
||||
void runOnFunction() final;
|
||||
};
|
||||
|
@ -283,8 +284,9 @@ void ToyToAffineLoweringPass::runOnFunction() {
|
|||
|
||||
// We define the specific operations, or dialects, that are legal targets for
|
||||
// this lowering. In our case, we are lowering to a combination of the
|
||||
// `Affine` and `Standard` dialects.
|
||||
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
|
||||
// `Affine`, `MemRef` and `Standard` dialects.
|
||||
target.addLegalDialect<AffineDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect>();
|
||||
|
||||
// We also define the Toy dialect as Illegal so that the conversion will fail
|
||||
// if any of these operations are *not* converted. Given that we actually want
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -91,7 +92,8 @@ public:
|
|||
|
||||
// Generate a call to printf for the current element of the loop.
|
||||
auto printOp = cast<toy::PrintOp>(op);
|
||||
auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs);
|
||||
auto elementLoad =
|
||||
rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
|
||||
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
|
||||
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ def LowerHostCodeToLLVM : Pass<"lower-host-to-llvm", "ModuleOp"> {
|
|||
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
|
||||
let summary = "Generate NVVM operations for gpu operations";
|
||||
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
|
||||
let dependentDialects = ["NVVM::NVVMDialect"];
|
||||
let dependentDialects = ["NVVM::NVVMDialect", "memref::MemRefDialect"];
|
||||
let options = [
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
|
@ -210,7 +210,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
|
|||
let summary = "Convert the operations from the linalg dialect into the "
|
||||
"Standard dialect";
|
||||
let constructor = "mlir::createConvertLinalgToStandardPass()";
|
||||
let dependentDialects = ["StandardOpsDialect"];
|
||||
let dependentDialects = ["memref::MemRefDialect", "StandardOpsDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -316,7 +316,11 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
|
|||
let summary = "Convert operations from the shape dialect into the standard "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertShapeToStandardPass()";
|
||||
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
|
||||
let dependentDialects = [
|
||||
"memref::MemRefDialect",
|
||||
"StandardOpsDialect",
|
||||
"scf::SCFDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
|
||||
|
@ -474,7 +478,11 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
|
|||
let summary = "Lower the operations from the vector dialect into the SCF "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertVectorToSCFPass()";
|
||||
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
|
||||
let dependentDialects = [
|
||||
"AffineDialect",
|
||||
"memref::MemRefDialect",
|
||||
"scf::SCFDialect"
|
||||
];
|
||||
let options = [
|
||||
Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
|
||||
"Perform full unrolling when converting vector transfers to SCF">,
|
||||
|
|
|
@ -72,7 +72,8 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|||
|
||||
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
|
||||
/// stdlib malloc/free is used by default for allocating memrefs allocated with
|
||||
/// std.alloc, while LLVM's alloca is used for those allocated with std.alloca.
|
||||
/// memref.alloc, while LLVM's alloca is used for those allocated with
|
||||
/// memref.alloca.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createLowerToLLVMPass(const LowerToLLVMOptions &options =
|
||||
LowerToLLVMOptions::getDefaultOptions());
|
||||
|
|
|
@ -18,6 +18,7 @@ include "mlir/Pass/PassBase.td"
|
|||
def AffineDataCopyGeneration : FunctionPass<"affine-data-copy-generate"> {
|
||||
let summary = "Generate explicit copying for affine memory operations";
|
||||
let constructor = "mlir::createAffineDataCopyGenerationPass()";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
let options = [
|
||||
Option<"fastMemoryCapacity", "fast-mem-capacity", "uint64_t",
|
||||
/*default=*/"std::numeric_limits<uint64_t>::max()",
|
||||
|
|
|
@ -9,6 +9,7 @@ add_subdirectory(GPU)
|
|||
add_subdirectory(Math)
|
||||
add_subdirectory(Linalg)
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(MemRef)
|
||||
add_subdirectory(OpenACC)
|
||||
add_subdirectory(OpenMP)
|
||||
add_subdirectory(PDL)
|
||||
|
|
|
@ -480,7 +480,7 @@ def GPU_LaunchOp : GPU_Op<"launch">,
|
|||
%num_bx : index, %num_by : index, %num_bz : index,
|
||||
%num_tx : index, %num_ty : index, %num_tz : index)
|
||||
"some_op"(%bx, %tx) : (index, index) -> ()
|
||||
%3 = "std.load"(%val1, %bx) : (memref<?xf32, 1>, index) -> f32
|
||||
%3 = "memref.load"(%val1, %bx) : (memref<?xf32, 1>, index) -> f32
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -812,7 +812,7 @@ def GPU_AllocOp : GPU_Op<"alloc", [
|
|||
let summary = "GPU memory allocation operation.";
|
||||
let description = [{
|
||||
The `gpu.alloc` operation allocates a region of memory on the GPU. It is
|
||||
similar to the `std.alloc` op, but supports asynchronous GPU execution.
|
||||
similar to the `memref.alloc` op, but supports asynchronous GPU execution.
|
||||
|
||||
The op does not execute before all async dependencies have finished
|
||||
executing.
|
||||
|
@ -850,7 +850,7 @@ def GPU_DeallocOp : GPU_Op<"dealloc", [GPU_AsyncOpInterface]> {
|
|||
let description = [{
|
||||
The `gpu.dealloc` operation frees the region of memory referenced by a
|
||||
memref which was originally created by the `gpu.alloc` operation. It is
|
||||
similar to the `std.dealloc` op, but supports asynchronous GPU execution.
|
||||
similar to the `memref.dealloc` op, but supports asynchronous GPU execution.
|
||||
|
||||
The op does not execute before all async dependencies have finished
|
||||
executing.
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
@ -35,30 +36,25 @@ struct FoldedValueBuilder {
|
|||
};
|
||||
|
||||
using folded_math_tanh = FoldedValueBuilder<math::TanhOp>;
|
||||
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
|
||||
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
|
||||
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
|
||||
using folded_std_dim = FoldedValueBuilder<DimOp>;
|
||||
using folded_memref_alloc = FoldedValueBuilder<memref::AllocOp>;
|
||||
using folded_memref_cast = FoldedValueBuilder<memref::CastOp>;
|
||||
using folded_memref_dim = FoldedValueBuilder<memref::DimOp>;
|
||||
using folded_memref_load = FoldedValueBuilder<memref::LoadOp>;
|
||||
using folded_memref_sub_view = FoldedValueBuilder<memref::SubViewOp>;
|
||||
using folded_memref_tensor_load = FoldedValueBuilder<memref::TensorLoadOp>;
|
||||
using folded_memref_view = FoldedValueBuilder<memref::ViewOp>;
|
||||
using folded_std_muli = FoldedValueBuilder<MulIOp>;
|
||||
using folded_std_addi = FoldedValueBuilder<AddIOp>;
|
||||
using folded_std_addf = FoldedValueBuilder<AddFOp>;
|
||||
using folded_std_alloc = FoldedValueBuilder<AllocOp>;
|
||||
using folded_std_constant = FoldedValueBuilder<ConstantOp>;
|
||||
using folded_std_constant_float = FoldedValueBuilder<ConstantFloatOp>;
|
||||
using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
|
||||
using folded_std_constant_int = FoldedValueBuilder<ConstantIntOp>;
|
||||
using folded_std_dim = FoldedValueBuilder<DimOp>;
|
||||
using folded_std_index_cast = FoldedValueBuilder<IndexCastOp>;
|
||||
using folded_std_muli = FoldedValueBuilder<MulIOp>;
|
||||
using folded_std_mulf = FoldedValueBuilder<MulFOp>;
|
||||
using folded_std_memref_cast = FoldedValueBuilder<MemRefCastOp>;
|
||||
using folded_std_select = FoldedValueBuilder<SelectOp>;
|
||||
using folded_std_load = FoldedValueBuilder<LoadOp>;
|
||||
using folded_std_subi = FoldedValueBuilder<SubIOp>;
|
||||
using folded_std_sub_view = FoldedValueBuilder<SubViewOp>;
|
||||
using folded_std_tensor_load = FoldedValueBuilder<TensorLoadOp>;
|
||||
using folded_std_view = FoldedValueBuilder<ViewOp>;
|
||||
using folded_std_zero_extendi = FoldedValueBuilder<ZeroExtendIOp>;
|
||||
using folded_std_sign_extendi = FoldedValueBuilder<SignExtendIOp>;
|
||||
using folded_tensor_extract = FoldedValueBuilder<tensor::ExtractOp>;
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
//
|
||||
// The other operations form the bridge between the opaque pointer and
|
||||
// the actual storage of pointers, indices, and values. These operations
|
||||
// resemble 'tensor_to_memref' in the sense that they map tensors to
|
||||
// resemble 'buffer_cast' in the sense that they map tensors to
|
||||
// their bufferized memrefs, but they lower into actual calls since
|
||||
// sparse storage does not bufferize into a single memrefs, as dense
|
||||
// tensors do, but into a hierarchical storage scheme where pointers
|
||||
|
@ -74,9 +74,9 @@ def Linalg_SparseTensorToPointersMemRefOp :
|
|||
let description = [{
|
||||
Returns the pointers array of the sparse storage scheme at the
|
||||
given dimension for the given tensor. This is similar to the
|
||||
`tensor_to_memref` operation in the sense that it provides a bridge
|
||||
`buffer_cast` operation in the sense that it provides a bridge
|
||||
between a tensor world view and a bufferized world view. Unlike the
|
||||
`tensor_to_memref` operation, however, this sparse operation actually
|
||||
`buffer_cast` operation, however, this sparse operation actually
|
||||
lowers into a call into a support library to obtain access to the
|
||||
pointers array.
|
||||
|
||||
|
@ -98,9 +98,9 @@ def Linalg_SparseTensorToIndicesMemRefOp :
|
|||
let description = [{
|
||||
Returns the indices array of the sparse storage scheme at the
|
||||
given dimension for the given tensor. This is similar to the
|
||||
`tensor_to_memref` operation in the sense that it provides a bridge
|
||||
`buffer_cast` operation in the sense that it provides a bridge
|
||||
between a tensor world view and a bufferized world view. Unlike the
|
||||
`tensor_to_memref` operation, however, this sparse operation actually
|
||||
`buffer_cast` operation, however, this sparse operation actually
|
||||
lowers into a call into a support library to obtain access to the
|
||||
indices array.
|
||||
|
||||
|
@ -122,9 +122,9 @@ def Linalg_SparseTensorToValuesMemRefOp :
|
|||
let description = [{
|
||||
Returns the values array of the sparse storage scheme for the given
|
||||
tensor, independent of the actual dimension. This is similar to the
|
||||
`tensor_to_memref` operation in the sense that it provides a bridge
|
||||
`buffer_cast` operation in the sense that it provides a bridge
|
||||
between a tensor world view and a bufferized world view. Unlike the
|
||||
`tensor_to_memref` operation, however, this sparse operation actually
|
||||
`buffer_cast` operation, however, this sparse operation actually
|
||||
lowers into a call into a support library to obtain access to the
|
||||
values array.
|
||||
|
||||
|
|
|
@ -34,11 +34,11 @@ createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca);
|
|||
std::unique_ptr<OperationPass<FuncOp>> createLinalgPromotionPass();
|
||||
|
||||
/// Create a pass to convert Linalg operations to scf.for loops and
|
||||
/// std.load/std.store accesses.
|
||||
/// memref.load/memref.store accesses.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToLoopsPass();
|
||||
|
||||
/// Create a pass to convert Linalg operations to scf.parallel loops and
|
||||
/// std.load/std.store accesses.
|
||||
/// memref.load/memref.store accesses.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToParallelLoopsPass();
|
||||
|
||||
/// Create a pass to convert Linalg operations to affine.for loops and
|
||||
|
|
|
@ -19,7 +19,7 @@ def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
|
|||
This pass only converts ops that operate on ranked tensors.
|
||||
}];
|
||||
let constructor = "mlir::createConvertElementwiseToLinalgPass()";
|
||||
let dependentDialects = ["linalg::LinalgDialect"];
|
||||
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
|
||||
|
@ -70,13 +70,21 @@ def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
|
|||
"interchange vector",
|
||||
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
|
||||
];
|
||||
let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"];
|
||||
let dependentDialects = [
|
||||
"linalg::LinalgDialect",
|
||||
"scf::SCFDialect",
|
||||
"AffineDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgBufferize : Pass<"linalg-bufferize", "FuncOp"> {
|
||||
let summary = "Bufferize the linalg dialect";
|
||||
let constructor = "mlir::createLinalgBufferizePass()";
|
||||
let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
|
||||
let dependentDialects = [
|
||||
"linalg::LinalgDialect",
|
||||
"AffineDialect",
|
||||
"memref::MemRefDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgLowerToParallelLoops
|
||||
|
@ -90,7 +98,12 @@ def LinalgLowerToParallelLoops
|
|||
"interchange vector",
|
||||
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
|
||||
];
|
||||
let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
|
||||
let dependentDialects = [
|
||||
"AffineDialect",
|
||||
"linalg::LinalgDialect",
|
||||
"memref::MemRefDialect",
|
||||
"scf::SCFDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
|
||||
|
@ -109,7 +122,10 @@ def LinalgTiling : FunctionPass<"linalg-tile"> {
|
|||
let summary = "Tile operations in the linalg dialect";
|
||||
let constructor = "mlir::createLinalgTilingPass()";
|
||||
let dependentDialects = [
|
||||
"AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"
|
||||
"AffineDialect",
|
||||
"linalg::LinalgDialect",
|
||||
"memref::MemRefDialect",
|
||||
"scf::SCFDialect"
|
||||
];
|
||||
let options = [
|
||||
ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
|
||||
|
@ -127,7 +143,12 @@ def LinalgTilingToParallelLoops
|
|||
"Test generation of dynamic promoted buffers",
|
||||
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
|
||||
];
|
||||
let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
|
||||
let dependentDialects = [
|
||||
"AffineDialect",
|
||||
"linalg::LinalgDialect",
|
||||
"memref::MemRefDialect",
|
||||
"scf::SCFDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> {
|
||||
|
|
|
@ -147,8 +147,8 @@ LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
|
|||
/// dimension. If that is not possible, contains the dynamic size of the
|
||||
/// subview. The call back should return the buffer to use.
|
||||
using AllocBufferCallbackFn = std::function<Optional<Value>(
|
||||
OpBuilder &b, SubViewOp subView, ArrayRef<Value> boundingSubViewSize,
|
||||
OperationFolder *folder)>;
|
||||
OpBuilder &b, memref::SubViewOp subView,
|
||||
ArrayRef<Value> boundingSubViewSize, OperationFolder *folder)>;
|
||||
|
||||
/// Callback function type used to deallocate the buffers used to hold the
|
||||
/// promoted subview.
|
||||
|
@ -244,7 +244,7 @@ struct PromotionInfo {
|
|||
Value partialLocalView;
|
||||
};
|
||||
Optional<PromotionInfo>
|
||||
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView,
|
||||
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
|
||||
AllocBufferCallbackFn allocationFn,
|
||||
OperationFolder *folder = nullptr);
|
||||
|
||||
|
@ -818,7 +818,7 @@ struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
|
|||
/// Match and rewrite for the pattern:
|
||||
/// ```
|
||||
/// %alloc = ...
|
||||
/// [optional] %view = std.view %alloc ...
|
||||
/// [optional] %view = memref.view %alloc ...
|
||||
/// %subView = subview %allocOrView ...
|
||||
/// [optional] linalg.fill(%allocOrView, %cst) ...
|
||||
/// ...
|
||||
|
@ -828,7 +828,7 @@ struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
|
|||
/// into
|
||||
/// ```
|
||||
/// [unchanged] %alloc = ...
|
||||
/// [unchanged] [optional] %view = std.view %alloc ...
|
||||
/// [unchanged] [optional] %view = memref.view %alloc ...
|
||||
/// [unchanged] [unchanged] %subView = subview %allocOrView ...
|
||||
/// ...
|
||||
/// vector.transfer_read %in[...], %cst ...
|
||||
|
@ -849,7 +849,7 @@ struct LinalgCopyVTRForwardingPattern
|
|||
/// Match and rewrite for the pattern:
|
||||
/// ```
|
||||
/// %alloc = ...
|
||||
/// [optional] %view = std.view %alloc ...
|
||||
/// [optional] %view = memref.view %alloc ...
|
||||
/// %subView = subview %allocOrView...
|
||||
/// ...
|
||||
/// vector.transfer_write %..., %allocOrView[...]
|
||||
|
@ -858,7 +858,7 @@ struct LinalgCopyVTRForwardingPattern
|
|||
/// into
|
||||
/// ```
|
||||
/// [unchanged] %alloc = ...
|
||||
/// [unchanged] [optional] %view = std.view %alloc ...
|
||||
/// [unchanged] [optional] %view = memref.view %alloc ...
|
||||
/// [unchanged] %subView = subview %allocOrView...
|
||||
/// ...
|
||||
/// vector.transfer_write %..., %out[...]
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
@ -21,7 +22,7 @@
|
|||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
using mlir::edsc::intrinsics::AffineIndexedValue;
|
||||
using mlir::edsc::intrinsics::StdIndexedValue;
|
||||
using mlir::edsc::intrinsics::MemRefIndexedValue;
|
||||
|
||||
namespace mlir {
|
||||
class AffineExpr;
|
||||
|
@ -213,7 +214,7 @@ template <typename LoopTy>
|
|||
struct GenerateLoopNest {
|
||||
using IndexedValueTy =
|
||||
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
|
||||
AffineIndexedValue, StdIndexedValue>::type;
|
||||
AffineIndexedValue, MemRefIndexedValue>::type;
|
||||
|
||||
static void
|
||||
doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1,37 @@
|
|||
//===- Intrinsics.h - MLIR EDSC Intrinsics for MemRefOps --------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_DIALECT_MEMREF_EDSC_INTRINSICS_H_
|
||||
#define MLIR_DIALECT_MEMREF_EDSC_INTRINSICS_H_
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/EDSC/Builders.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace edsc {
|
||||
namespace intrinsics {
|
||||
|
||||
using memref_alloc = ValueBuilder<memref::AllocOp>;
|
||||
using memref_alloca = ValueBuilder<memref::AllocaOp>;
|
||||
using memref_cast = ValueBuilder<memref::CastOp>;
|
||||
using memref_dealloc = OperationBuilder<memref::DeallocOp>;
|
||||
using memref_dim = ValueBuilder<memref::DimOp>;
|
||||
using memref_load = ValueBuilder<memref::LoadOp>;
|
||||
using memref_store = OperationBuilder<memref::StoreOp>;
|
||||
using memref_sub_view = ValueBuilder<memref::SubViewOp>;
|
||||
using memref_tensor_load = ValueBuilder<memref::TensorLoadOp>;
|
||||
using memref_tensor_store = OperationBuilder<memref::TensorStoreOp>;
|
||||
using memref_view = ValueBuilder<memref::ViewOp>;
|
||||
|
||||
/// Provide an index notation around memref_load and memref_store.
|
||||
using MemRefIndexedValue =
|
||||
TemplatedIndexedValue<intrinsics::memref_load, intrinsics::memref_store>;
|
||||
} // namespace intrinsics
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MEMREF_EDSC_INTRINSICS_H_
|
|
@ -0,0 +1,2 @@
|
|||
add_mlir_dialect(MemRefOps memref)
|
||||
add_mlir_doc(MemRefOps -gen-dialect-doc MemRefOps Dialects/)
|
|
@ -0,0 +1,239 @@
|
|||
//===- MemRef.h - MemRef dialect --------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_
|
||||
#define MLIR_DIALECT_MEMREF_IR_MEMREF_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
raw_ostream &operator<<(raw_ostream &os, Range &range);
|
||||
|
||||
/// Return the list of Range (i.e. offset, size, stride). Each Range
|
||||
/// entry contains either the dynamic value or a ConstantIndexOp constructed
|
||||
/// with `b` at location `loc`.
|
||||
SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
|
||||
OpBuilder &b, Location loc);
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRef Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRefOpsDialect.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRef Dialect Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace memref {
|
||||
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
|
||||
// source memref to a destination memref. The source and destination memref need
|
||||
// not be of the same dimensionality, but need to have the same elemental type.
|
||||
// The operands include the source and destination memref's each followed by its
|
||||
// indices, size of the data transfer in terms of the number of elements (of the
|
||||
// elemental type of the memref), a tag memref with its indices, and optionally
|
||||
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
|
||||
// location is used by a DmaWaitOp to check for completion. The indices of the
|
||||
// source memref, destination memref, and the tag memref have the same
|
||||
// restrictions as any load/store. The optional stride arguments should be of
|
||||
// 'index' type, and specify a stride for the slower memory space (memory space
|
||||
// with a lower memory space id), transferring chunks of
|
||||
// number_of_elements_per_stride every stride until %num_elements are
|
||||
// transferred. Either both or no stride arguments should be specified.
|
||||
//
|
||||
// For example, a DmaStartOp operation that transfers 256 elements of a memref
|
||||
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
|
||||
// 1 at indices [%k, %l], would be specified as follows:
|
||||
//
|
||||
// %num_elements = constant 256
|
||||
// %idx = constant 0 : index
|
||||
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
|
||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
|
||||
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
|
||||
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
|
||||
// memref<1 x i32>, (d0) -> (d0), 2>
|
||||
//
|
||||
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
|
||||
// transfer %num_elt_per_stride elements every %stride elements apart from
|
||||
// memory space 0 until %num_elements are transferred.
|
||||
//
|
||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
|
||||
// %num_elt_per_stride :
|
||||
//
|
||||
// TODO: add additional operands to allow source and destination striding, and
|
||||
// multiple stride levels.
|
||||
// TODO: Consider replacing src/dst memref indices with view memrefs.
|
||||
class DmaStartOp
|
||||
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
|
||||
ValueRange srcIndices, Value destMemRef,
|
||||
ValueRange destIndices, Value numElements, Value tagMemRef,
|
||||
ValueRange tagIndices, Value stride = nullptr,
|
||||
Value elementsPerStride = nullptr);
|
||||
|
||||
// Returns the source MemRefType for this DMA operation.
|
||||
Value getSrcMemRef() { return getOperand(0); }
|
||||
// Returns the rank (number of indices) of the source MemRefType.
|
||||
unsigned getSrcMemRefRank() {
|
||||
return getSrcMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
// Returns the source memref indices for this DMA operation.
|
||||
operand_range getSrcIndices() {
|
||||
return {(*this)->operand_begin() + 1,
|
||||
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
|
||||
}
|
||||
|
||||
// Returns the destination MemRefType for this DMA operations.
|
||||
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
|
||||
// Returns the rank (number of indices) of the destination MemRefType.
|
||||
unsigned getDstMemRefRank() {
|
||||
return getDstMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
unsigned getSrcMemorySpace() {
|
||||
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
||||
}
|
||||
unsigned getDstMemorySpace() {
|
||||
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
||||
}
|
||||
|
||||
// Returns the destination memref indices for this DMA operation.
|
||||
operand_range getDstIndices() {
|
||||
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
|
||||
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
|
||||
getDstMemRefRank()};
|
||||
}
|
||||
|
||||
// Returns the number of elements being transferred by this DMA operation.
|
||||
Value getNumElements() {
|
||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
|
||||
}
|
||||
|
||||
// Returns the Tag MemRef for this DMA operation.
|
||||
Value getTagMemRef() {
|
||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
|
||||
}
|
||||
// Returns the rank (number of indices) of the tag MemRefType.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
operand_range getTagIndices() {
|
||||
unsigned tagIndexStartPos =
|
||||
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
|
||||
return {(*this)->operand_begin() + tagIndexStartPos,
|
||||
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a faster memory space to a slower one.
|
||||
bool isDestMemorySpaceFaster() {
|
||||
return (getSrcMemorySpace() < getDstMemorySpace());
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a slower memory space to a faster one.
|
||||
bool isSrcMemorySpaceFaster() {
|
||||
// Assumes that a lower number is for a slower memory space.
|
||||
return (getDstMemorySpace() < getSrcMemorySpace());
|
||||
}
|
||||
|
||||
/// Given a DMA start operation, returns the operand position of either the
|
||||
/// source or destination memref depending on the one that is at the higher
|
||||
/// level of the memory hierarchy. Asserts failure if neither is true.
|
||||
unsigned getFasterMemPos() {
|
||||
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
|
||||
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
|
||||
}
|
||||
|
||||
static StringRef getOperationName() { return "memref.dma_start"; }
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
|
||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
|
||||
bool isStrided() {
|
||||
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
|
||||
1 + 1 + getTagMemRefRank();
|
||||
}
|
||||
|
||||
Value getStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1 - 1);
|
||||
}
|
||||
|
||||
Value getNumElementsPerStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1);
|
||||
}
|
||||
};
|
||||
|
||||
// DmaWaitOp blocks until the completion of a DMA operation associated with the
|
||||
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
|
||||
// with the same restrictions as any load/store index. %num_elements is the
|
||||
// number of elements associated with the DMA operation. For example:
|
||||
//
|
||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
|
||||
// memref<2048 x f32>, (d0) -> (d0), 0>,
|
||||
// memref<256 x f32>, (d0) -> (d0), 1>
|
||||
// memref<1 x i32>, (d0) -> (d0), 2>
|
||||
// ...
|
||||
// ...
|
||||
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
|
||||
//
|
||||
class DmaWaitOp
|
||||
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
|
||||
ValueRange tagIndices, Value numElements);
|
||||
|
||||
static StringRef getOperationName() { return "memref.dma_wait"; }
|
||||
|
||||
// Returns the Tag MemRef associated with the DMA operation being waited on.
|
||||
Value getTagMemRef() { return getOperand(0); }
|
||||
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
operand_range getTagIndices() {
|
||||
return {(*this)->operand_begin() + 1,
|
||||
(*this)->operand_begin() + 1 + getTagMemRefRank()};
|
||||
}
|
||||
|
||||
// Returns the rank (number of indices) of the tag memref.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
// Returns the number of elements transferred in the associated DMA operation.
|
||||
Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
|
||||
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
LogicalResult verify();
|
||||
};
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_
|
|
@ -0,0 +1,25 @@
|
|||
//===- MemRefBase.td - Base definitions for memref dialect -*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MEMREF_BASE
|
||||
#define MEMREF_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def MemRef_Dialect : Dialect {
|
||||
let name = "memref";
|
||||
let cppNamespace = "::mlir::memref";
|
||||
let description = [{
|
||||
The `memref` dialect is intended to hold core memref creation and
|
||||
manipulation ops, which are not strongly associated with any particular
|
||||
other dialect or domain abstraction.
|
||||
}];
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
#endif // MEMREF_BASE
|
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ include "mlir/Pass/PassBase.td"
|
|||
def SCFBufferize : FunctionPass<"scf-bufferize"> {
|
||||
let summary = "Bufferize the scf dialect.";
|
||||
let constructor = "mlir::createSCFBufferizePass()";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def SCFForLoopSpecialization
|
||||
|
|
|
@ -25,5 +25,6 @@ def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
|
|||
def ShapeBufferize : FunctionPass<"shape-bufferize"> {
|
||||
let summary = "Bufferize the shape dialect.";
|
||||
let constructor = "mlir::createShapeBufferizePass()";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
|
||||
|
|
|
@ -17,35 +17,24 @@ namespace intrinsics {
|
|||
|
||||
using std_addi = ValueBuilder<AddIOp>;
|
||||
using std_addf = ValueBuilder<AddFOp>;
|
||||
using std_alloc = ValueBuilder<AllocOp>;
|
||||
using std_alloca = ValueBuilder<AllocaOp>;
|
||||
using std_call = OperationBuilder<CallOp>;
|
||||
using std_constant = ValueBuilder<ConstantOp>;
|
||||
using std_constant_float = ValueBuilder<ConstantFloatOp>;
|
||||
using std_constant_index = ValueBuilder<ConstantIndexOp>;
|
||||
using std_constant_int = ValueBuilder<ConstantIntOp>;
|
||||
using std_dealloc = OperationBuilder<DeallocOp>;
|
||||
using std_divis = ValueBuilder<SignedDivIOp>;
|
||||
using std_diviu = ValueBuilder<UnsignedDivIOp>;
|
||||
using std_dim = ValueBuilder<DimOp>;
|
||||
using std_fpext = ValueBuilder<FPExtOp>;
|
||||
using std_fptrunc = ValueBuilder<FPTruncOp>;
|
||||
using std_index_cast = ValueBuilder<IndexCastOp>;
|
||||
using std_muli = ValueBuilder<MulIOp>;
|
||||
using std_mulf = ValueBuilder<MulFOp>;
|
||||
using std_memref_cast = ValueBuilder<MemRefCastOp>;
|
||||
using std_ret = OperationBuilder<ReturnOp>;
|
||||
using std_select = ValueBuilder<SelectOp>;
|
||||
using std_load = ValueBuilder<LoadOp>;
|
||||
using std_sign_extendi = ValueBuilder<SignExtendIOp>;
|
||||
using std_splat = ValueBuilder<SplatOp>;
|
||||
using std_store = OperationBuilder<StoreOp>;
|
||||
using std_subf = ValueBuilder<SubFOp>;
|
||||
using std_subi = ValueBuilder<SubIOp>;
|
||||
using std_sub_view = ValueBuilder<SubViewOp>;
|
||||
using std_tensor_load = ValueBuilder<TensorLoadOp>;
|
||||
using std_tensor_store = OperationBuilder<TensorStoreOp>;
|
||||
using std_view = ValueBuilder<ViewOp>;
|
||||
using std_zero_extendi = ValueBuilder<ZeroExtendIOp>;
|
||||
using tensor_extract = ValueBuilder<tensor::ExtractOp>;
|
||||
|
||||
|
@ -77,10 +66,6 @@ BranchOp std_br(Block *block, ValueRange operands);
|
|||
/// or to `falseBranch` and `falseOperand` if `cond` evaluates to `false`.
|
||||
CondBranchOp std_cond_br(Value cond, Block *trueBranch, ValueRange trueOperands,
|
||||
Block *falseBranch, ValueRange falseOperands);
|
||||
|
||||
/// Provide an index notation around sdt_load and std_store.
|
||||
using StdIndexedValue =
|
||||
TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
|
||||
} // namespace intrinsics
|
||||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
|
|
@ -34,8 +34,6 @@ class Builder;
|
|||
class FuncOp;
|
||||
class OpBuilder;
|
||||
|
||||
raw_ostream &operator<<(raw_ostream &os, Range &range);
|
||||
|
||||
/// Return the list of Range (i.e. offset, size, stride). Each Range
|
||||
/// entry contains either the dynamic value or a ConstantIndexOp constructed
|
||||
/// with `b` at location `loc`.
|
||||
|
@ -110,200 +108,6 @@ public:
|
|||
static bool classof(Operation *op);
|
||||
};
|
||||
|
||||
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
|
||||
// source memref to a destination memref. The source and destination memref need
|
||||
// not be of the same dimensionality, but need to have the same elemental type.
|
||||
// The operands include the source and destination memref's each followed by its
|
||||
// indices, size of the data transfer in terms of the number of elements (of the
|
||||
// elemental type of the memref), a tag memref with its indices, and optionally
|
||||
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
|
||||
// location is used by a DmaWaitOp to check for completion. The indices of the
|
||||
// source memref, destination memref, and the tag memref have the same
|
||||
// restrictions as any load/store. The optional stride arguments should be of
|
||||
// 'index' type, and specify a stride for the slower memory space (memory space
|
||||
// with a lower memory space id), transferring chunks of
|
||||
// number_of_elements_per_stride every stride until %num_elements are
|
||||
// transferred. Either both or no stride arguments should be specified.
|
||||
//
|
||||
// For example, a DmaStartOp operation that transfers 256 elements of a memref
|
||||
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
|
||||
// 1 at indices [%k, %l], would be specified as follows:
|
||||
//
|
||||
// %num_elements = constant 256
|
||||
// %idx = constant 0 : index
|
||||
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
|
||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
|
||||
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
|
||||
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
|
||||
// memref<1 x i32>, (d0) -> (d0), 2>
|
||||
//
|
||||
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
|
||||
// transfer %num_elt_per_stride elements every %stride elements apart from
|
||||
// memory space 0 until %num_elements are transferred.
|
||||
//
|
||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
|
||||
// %num_elt_per_stride :
|
||||
//
|
||||
// TODO: add additional operands to allow source and destination striding, and
|
||||
// multiple stride levels.
|
||||
// TODO: Consider replacing src/dst memref indices with view memrefs.
|
||||
class DmaStartOp
|
||||
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
|
||||
ValueRange srcIndices, Value destMemRef,
|
||||
ValueRange destIndices, Value numElements, Value tagMemRef,
|
||||
ValueRange tagIndices, Value stride = nullptr,
|
||||
Value elementsPerStride = nullptr);
|
||||
|
||||
// Returns the source MemRefType for this DMA operation.
|
||||
Value getSrcMemRef() { return getOperand(0); }
|
||||
// Returns the rank (number of indices) of the source MemRefType.
|
||||
unsigned getSrcMemRefRank() {
|
||||
return getSrcMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
// Returns the source memref indices for this DMA operation.
|
||||
operand_range getSrcIndices() {
|
||||
return {(*this)->operand_begin() + 1,
|
||||
(*this)->operand_begin() + 1 + getSrcMemRefRank()};
|
||||
}
|
||||
|
||||
// Returns the destination MemRefType for this DMA operations.
|
||||
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
|
||||
// Returns the rank (number of indices) of the destination MemRefType.
|
||||
unsigned getDstMemRefRank() {
|
||||
return getDstMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
unsigned getSrcMemorySpace() {
|
||||
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
||||
}
|
||||
unsigned getDstMemorySpace() {
|
||||
return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
|
||||
}
|
||||
|
||||
// Returns the destination memref indices for this DMA operation.
|
||||
operand_range getDstIndices() {
|
||||
return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
|
||||
(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
|
||||
getDstMemRefRank()};
|
||||
}
|
||||
|
||||
// Returns the number of elements being transferred by this DMA operation.
|
||||
Value getNumElements() {
|
||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
|
||||
}
|
||||
|
||||
// Returns the Tag MemRef for this DMA operation.
|
||||
Value getTagMemRef() {
|
||||
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
|
||||
}
|
||||
// Returns the rank (number of indices) of the tag MemRefType.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
operand_range getTagIndices() {
|
||||
unsigned tagIndexStartPos =
|
||||
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
|
||||
return {(*this)->operand_begin() + tagIndexStartPos,
|
||||
(*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a faster memory space to a slower one.
|
||||
bool isDestMemorySpaceFaster() {
|
||||
return (getSrcMemorySpace() < getDstMemorySpace());
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a slower memory space to a faster one.
|
||||
bool isSrcMemorySpaceFaster() {
|
||||
// Assumes that a lower number is for a slower memory space.
|
||||
return (getDstMemorySpace() < getSrcMemorySpace());
|
||||
}
|
||||
|
||||
/// Given a DMA start operation, returns the operand position of either the
|
||||
/// source or destination memref depending on the one that is at the higher
|
||||
/// level of the memory hierarchy. Asserts failure if neither is true.
|
||||
unsigned getFasterMemPos() {
|
||||
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
|
||||
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
|
||||
}
|
||||
|
||||
static StringRef getOperationName() { return "std.dma_start"; }
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
|
||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
|
||||
bool isStrided() {
|
||||
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
|
||||
1 + 1 + getTagMemRefRank();
|
||||
}
|
||||
|
||||
Value getStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1 - 1);
|
||||
}
|
||||
|
||||
Value getNumElementsPerStride() {
|
||||
if (!isStrided())
|
||||
return nullptr;
|
||||
return getOperand(getNumOperands() - 1);
|
||||
}
|
||||
};
|
||||
|
||||
// DmaWaitOp blocks until the completion of a DMA operation associated with the
|
||||
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
|
||||
// with the same restrictions as any load/store index. %num_elements is the
|
||||
// number of elements associated with the DMA operation. For example:
|
||||
//
|
||||
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
|
||||
// memref<2048 x f32>, (d0) -> (d0), 0>,
|
||||
// memref<256 x f32>, (d0) -> (d0), 1>
|
||||
// memref<1 x i32>, (d0) -> (d0), 2>
|
||||
// ...
|
||||
// ...
|
||||
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
|
||||
//
|
||||
class DmaWaitOp
|
||||
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
|
||||
ValueRange tagIndices, Value numElements);
|
||||
|
||||
static StringRef getOperationName() { return "std.dma_wait"; }
|
||||
|
||||
// Returns the Tag MemRef associated with the DMA operation being waited on.
|
||||
Value getTagMemRef() { return getOperand(0); }
|
||||
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
operand_range getTagIndices() {
|
||||
return {(*this)->operand_begin() + 1,
|
||||
(*this)->operand_begin() + 1 + getTagMemRefRank()};
|
||||
}
|
||||
|
||||
// Returns the rank (number of indices) of the tag memref.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
// Returns the number of elements transferred in the associated DMA operation.
|
||||
Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
|
||||
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
|
||||
/// `originalShape` with some `1` entries erased, return the set of indices
|
||||
/// that specifies which of the entries of `originalShape` are dropped to obtain
|
||||
|
@ -316,45 +120,6 @@ llvm::Optional<llvm::SmallDenseSet<unsigned>>
|
|||
computeRankReductionMask(ArrayRef<int64_t> originalShape,
|
||||
ArrayRef<int64_t> reducedShape);
|
||||
|
||||
/// Determines whether MemRefCastOp casts to a more dynamic version of the
|
||||
/// source memref. This is useful to to fold a memref_cast into a consuming op
|
||||
/// and implement canonicalization patterns for ops in different dialects that
|
||||
/// may consume the results of memref_cast operations. Such foldable memref_cast
|
||||
/// operations are typically inserted as `view` and `subview` ops and are
|
||||
/// canonicalized, to preserve the type compatibility of their uses.
|
||||
///
|
||||
/// Returns true when all conditions are met:
|
||||
/// 1. source and result are ranked memrefs with strided semantics and same
|
||||
/// element type and rank.
|
||||
/// 2. each of the source's size, offset or stride has more static information
|
||||
/// than the corresponding result's size, offset or stride.
|
||||
///
|
||||
/// Example 1:
|
||||
/// ```mlir
|
||||
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
|
||||
/// %2 = consumer %1 ... : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// may fold into:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %2 = consumer %0 ... : memref<8x16xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// Example 2:
|
||||
/// ```
|
||||
/// %1 = memref_cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// to memref<?x?xf32>
|
||||
/// consumer %1 : memref<?x?xf32> ...
|
||||
/// ```
|
||||
///
|
||||
/// may fold into:
|
||||
///
|
||||
/// ```
|
||||
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
|
||||
/// ```
|
||||
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
|
||||
|
||||
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
||||
/// comparison predicates.
|
||||
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -37,7 +37,7 @@ std::unique_ptr<Pass> createTensorConstantBufferizePass();
|
|||
/// Creates an instance of the StdExpand pass that legalizes Std
|
||||
/// dialect ops to be convertible to LLVM. For example,
|
||||
/// `std.ceildivi_signed` gets transformed to a number of std operations,
|
||||
/// which can be lowered to LLVM; `memref_reshape` gets converted to
|
||||
/// which can be lowered to LLVM; `memref.reshape` gets converted to
|
||||
/// `memref_reinterpret_cast`.
|
||||
std::unique_ptr<Pass> createStdExpandOpsPass();
|
||||
|
||||
|
|
|
@ -44,9 +44,10 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
|
|||
implement the `ReturnLike` trait are not rewritten in general, as they
|
||||
require that the corresponding parent operation is also rewritten.
|
||||
Finally, this pass fails for unknown terminators, as we cannot decide
|
||||
whether they need rewriting.
|
||||
whether they need rewriting.
|
||||
}];
|
||||
let constructor = "mlir::createFuncBufferizePass()";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
|
||||
|
@ -54,12 +55,13 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
|
|||
let description = [{
|
||||
This pass bufferizes tensor constants.
|
||||
|
||||
This pass needs to be a module pass because it inserts std.global_memref
|
||||
This pass needs to be a module pass because it inserts memref.global
|
||||
ops into the module, which cannot be done safely from a function pass due to
|
||||
multi-threading. Most other bufferization passes can run in parallel at
|
||||
function granularity.
|
||||
}];
|
||||
let constructor = "mlir::createTensorConstantBufferizePass()";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
#ifndef MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H
|
||||
#define MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -27,6 +30,51 @@ class OpBuilder;
|
|||
/// constructing the necessary DimOp operators.
|
||||
SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b);
|
||||
|
||||
/// Matches a ConstantIndexOp.
|
||||
detail::op_matcher<ConstantIndexOp> matchConstantIndex();
|
||||
|
||||
/// Detects the `values` produced by a ConstantIndexOp and places the new
|
||||
/// constant in place of the corresponding sentinel value.
|
||||
void canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> &values,
|
||||
function_ref<bool(int64_t)> isDynamic);
|
||||
|
||||
void getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
|
||||
llvm::SmallDenseSet<unsigned> &dimsToProject);
|
||||
|
||||
/// Pattern to rewrite a subview op with constant arguments.
|
||||
template <typename OpType, typename CastOpFunc>
|
||||
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
|
||||
: public OpRewritePattern<OpType> {
|
||||
public:
|
||||
using OpRewritePattern<OpType>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpType op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// No constant operand, just return;
|
||||
if (llvm::none_of(op.getOperands(), [](Value operand) {
|
||||
return matchPattern(operand, matchConstantIndex());
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// At least one of offsets/sizes/strides is a new constant.
|
||||
// Form the new list of operands and constant attributes from the existing.
|
||||
SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
|
||||
SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
|
||||
SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
|
||||
canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
|
||||
canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
|
||||
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
|
||||
|
||||
// Create the new op in canonical form.
|
||||
auto newOp = rewriter.create<OpType>(op.getLoc(), op.source(), mixedOffsets,
|
||||
mixedSizes, mixedStrides);
|
||||
CastOpFunc func;
|
||||
func(rewriter, op, newOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARDOPS_UTILS_UTILS_H
|
||||
|
|
|
@ -180,11 +180,11 @@ private:
|
|||
/// ```
|
||||
/// %1:3 = scf.if (%inBounds) {
|
||||
/// // fastpath, direct cast
|
||||
/// memref_cast %A: memref<A...> to compatibleMemRefType
|
||||
/// memref.cast %A: memref<A...> to compatibleMemRefType
|
||||
/// scf.yield %view : compatibleMemRefType, index, index
|
||||
/// } else {
|
||||
/// // slowpath, masked vector.transfer or linalg.copy.
|
||||
/// memref_cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
|
||||
/// scf.yield %4 : compatibleMemRefType, index, index
|
||||
// }
|
||||
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
|
||||
|
|
|
@ -1133,7 +1133,7 @@ public:
|
|||
/// A trait of region holding operations that define a new scope for automatic
|
||||
/// allocations, i.e., allocations that are freed when control is transferred
|
||||
/// back from the operation's region. Any operations performing such allocations
|
||||
/// (for eg. std.alloca) will have their allocations automatically freed at
|
||||
/// (for eg. memref.alloca) will have their allocations automatically freed at
|
||||
/// their closest enclosing operation with this trait.
|
||||
template <typename ConcreteType>
|
||||
class AutomaticAllocationScope
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
|
@ -60,6 +61,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
|||
LLVM::LLVMArmSVEDialect,
|
||||
linalg::LinalgDialect,
|
||||
math::MathDialect,
|
||||
memref::MemRefDialect,
|
||||
scf::SCFDialect,
|
||||
omp::OpenMPDialect,
|
||||
pdl::PDLDialect,
|
||||
|
|
|
@ -54,7 +54,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
|
|||
|
||||
/// Populate patterns to eliminate bufferize materializations.
|
||||
///
|
||||
/// In particular, these are the tensor_load/tensor_to_memref ops.
|
||||
/// In particular, these are the tensor_load/buffer_cast ops.
|
||||
void populateEliminateBufferizeMaterializationsPatterns(
|
||||
MLIRContext *context, BufferizeTypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
|
|
@ -54,7 +54,7 @@ std::unique_ptr<Pass>
|
|||
createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
|
||||
|
||||
/// Creates a pass that finalizes a partial bufferization by removing remaining
|
||||
/// tensor_load and tensor_to_memref operations.
|
||||
/// tensor_load and buffer_cast operations.
|
||||
std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
|
||||
|
||||
/// Creates a pass that converts memref function results to out-params.
|
||||
|
|
|
@ -352,7 +352,7 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp">
|
|||
works for static shaped memrefs.
|
||||
}];
|
||||
let constructor = "mlir::createBufferResultsToOutParamsPass()";
|
||||
let dependentDialects = ["linalg::LinalgDialect"];
|
||||
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def Canonicalizer : Pass<"canonicalize"> {
|
||||
|
@ -363,6 +363,7 @@ def Canonicalizer : Pass<"canonicalize"> {
|
|||
details.
|
||||
}];
|
||||
let constructor = "mlir::createCanonicalizerPass()";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def CopyRemoval : FunctionPass<"copy-removal"> {
|
||||
|
@ -406,11 +407,11 @@ def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
|
|||
let summary = "Finalize a partial bufferization";
|
||||
let description = [{
|
||||
A bufferize pass that finalizes a partial bufferization by removing
|
||||
remaining `tensor_load` and `tensor_to_memref` operations.
|
||||
remaining `memref.tensor_load` and `memref.buffer_cast` operations.
|
||||
|
||||
The removal of those operations is only possible if the operations only
|
||||
exist in pairs, i.e., all uses of `tensor_load` operations are
|
||||
`tensor_to_memref` operations.
|
||||
exist in pairs, i.e., all uses of `memref.tensor_load` operations are
|
||||
`memref.buffer_cast` operations.
|
||||
|
||||
This pass will fail if not all operations can be removed or if any operation
|
||||
with tensor typed operands remains.
|
||||
|
@ -535,7 +536,7 @@ def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
|
|||
contained in the op. Operations marked with the [MemRefsNormalizable]
|
||||
(https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
|
||||
expected to be normalizable. Supported operations include affine
|
||||
operations, std.alloc, std.dealloc, and std.return.
|
||||
operations, memref.alloc, memref.dealloc, and std.return.
|
||||
|
||||
Given an appropriate layout map specified in the code, this transformation
|
||||
can express tiled or linearized access to multi-dimensional data
|
||||
|
|
|
@ -28,6 +28,10 @@ class AffineForOp;
|
|||
class Location;
|
||||
class OpBuilder;
|
||||
|
||||
namespace memref {
|
||||
class AllocOp;
|
||||
} // end namespace memref
|
||||
|
||||
/// Replaces all "dereferencing" uses of `oldMemRef` with `newMemRef` while
|
||||
/// optionally remapping the old memref's indices using the supplied affine map,
|
||||
/// `indexRemap`. The new memref could be of a different shape or rank.
|
||||
|
@ -88,7 +92,7 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
/// Rewrites the memref defined by this alloc op to have an identity layout map
|
||||
/// and updates all its indexing uses. Returns failure if any of its uses
|
||||
/// escape (while leaving the IR in a valid state).
|
||||
LogicalResult normalizeMemRef(AllocOp op);
|
||||
LogicalResult normalizeMemRef(memref::AllocOp *op);
|
||||
|
||||
/// Uses the old memref type map layout and computes the new memref type to have
|
||||
/// a new shape and a layout map, where the old layout map has been normalized
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
|
@ -44,7 +45,8 @@ public:
|
|||
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
|
||||
loc(loc) {}
|
||||
|
||||
template <typename OpTy> Value buildBinaryExpr(AffineBinaryOpExpr expr) {
|
||||
template <typename OpTy>
|
||||
Value buildBinaryExpr(AffineBinaryOpExpr expr) {
|
||||
auto lhs = visit(expr.getLHS());
|
||||
auto rhs = visit(expr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
|
@ -563,8 +565,8 @@ public:
|
|||
};
|
||||
|
||||
/// Apply the affine map from an 'affine.load' operation to its operands, and
|
||||
/// feed the results to a newly created 'std.load' operation (which replaces the
|
||||
/// original 'affine.load').
|
||||
/// feed the results to a newly created 'memref.load' operation (which replaces
|
||||
/// the original 'affine.load').
|
||||
class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
|
||||
public:
|
||||
using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
|
||||
|
@ -579,14 +581,14 @@ public:
|
|||
return failure();
|
||||
|
||||
// Build vector.load memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, op.getMemRef(),
|
||||
*resultOperands);
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
|
||||
*resultOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Apply the affine map from an 'affine.prefetch' operation to its operands,
|
||||
/// and feed the results to a newly created 'std.prefetch' operation (which
|
||||
/// and feed the results to a newly created 'memref.prefetch' operation (which
|
||||
/// replaces the original 'affine.prefetch').
|
||||
class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
|
||||
public:
|
||||
|
@ -601,16 +603,16 @@ public:
|
|||
if (!resultOperands)
|
||||
return failure();
|
||||
|
||||
// Build std.prefetch memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<PrefetchOp>(op, op.memref(), *resultOperands,
|
||||
op.isWrite(), op.localityHint(),
|
||||
op.isDataCache());
|
||||
// Build memref.prefetch memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
|
||||
op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(),
|
||||
op.isDataCache());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Apply the affine map from an 'affine.store' operation to its operands, and
|
||||
/// feed the results to a newly created 'std.store' operation (which replaces
|
||||
/// feed the results to a newly created 'memref.store' operation (which replaces
|
||||
/// the original 'affine.store').
|
||||
class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
|
||||
public:
|
||||
|
@ -625,8 +627,8 @@ public:
|
|||
if (!maybeExpandedMap)
|
||||
return failure();
|
||||
|
||||
// Build std.store valueToStore, memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<mlir::StoreOp>(
|
||||
// Build memref.store valueToStore, memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
|
||||
return success();
|
||||
}
|
||||
|
@ -634,7 +636,8 @@ public:
|
|||
|
||||
/// Apply the affine maps from an 'affine.dma_start' operation to each of their
|
||||
/// respective map operands, and feed the results to a newly created
|
||||
/// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
|
||||
/// 'memref.dma_start' operation (which replaces the original
|
||||
/// 'affine.dma_start').
|
||||
class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
|
||||
public:
|
||||
using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
|
||||
|
@ -663,8 +666,8 @@ public:
|
|||
if (!maybeExpandedTagMap)
|
||||
return failure();
|
||||
|
||||
// Build std.dma_start operation with affine map results.
|
||||
rewriter.replaceOpWithNewOp<DmaStartOp>(
|
||||
// Build memref.dma_start operation with affine map results.
|
||||
rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
|
||||
op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
|
||||
*maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
|
||||
*maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
|
||||
|
@ -673,7 +676,7 @@ public:
|
|||
};
|
||||
|
||||
/// Apply the affine map from an 'affine.dma_wait' operation tag memref,
|
||||
/// and feed the results to a newly created 'std.dma_wait' operation (which
|
||||
/// and feed the results to a newly created 'memref.dma_wait' operation (which
|
||||
/// replaces the original 'affine.dma_wait').
|
||||
class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
|
||||
public:
|
||||
|
@ -688,8 +691,8 @@ public:
|
|||
if (!maybeExpandedTagMap)
|
||||
return failure();
|
||||
|
||||
// Build std.dma_wait operation with affine map results.
|
||||
rewriter.replaceOpWithNewOp<DmaWaitOp>(
|
||||
// Build memref.dma_wait operation with affine map results.
|
||||
rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
|
||||
op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
|
||||
return success();
|
||||
}
|
||||
|
@ -777,8 +780,8 @@ class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
|
|||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
populateAffineToVectorConversionPatterns(patterns, &getContext());
|
||||
ConversionTarget target(getContext());
|
||||
target
|
||||
.addLegalDialect<scf::SCFDialect, StandardOpsDialect, VectorDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect, scf::SCFDialect,
|
||||
StandardOpsDialect, VectorDialect>();
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRAffineToStandard
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRMemRef
|
||||
MLIRSCF
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
|
|
|
@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms
|
|||
MLIRGPU
|
||||
MLIRGPUToGPURuntimeTransforms
|
||||
MLIRLLVMIR
|
||||
MLIRMemRef
|
||||
MLIRNVVMIR
|
||||
MLIRPass
|
||||
MLIRStandardToLLVM
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
|
|
@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRLinalgToStandard
|
|||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRTransforms
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
||||
|
@ -93,7 +94,7 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
|
|||
continue;
|
||||
}
|
||||
Value cast =
|
||||
b.create<MemRefCastOp>(loc, eraseStridedLayout(memrefType), op);
|
||||
b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
|
||||
res.push_back(cast);
|
||||
}
|
||||
return res;
|
||||
|
@ -143,12 +144,12 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
|
|||
// If either inputPerm or outputPerm are non-identities, insert transposes.
|
||||
auto inputPerm = op.inputPermutation();
|
||||
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
||||
in = rewriter.create<TransposeOp>(op.getLoc(), in,
|
||||
AffineMapAttr::get(*inputPerm));
|
||||
in = rewriter.create<memref::TransposeOp>(op.getLoc(), in,
|
||||
AffineMapAttr::get(*inputPerm));
|
||||
auto outputPerm = op.outputPermutation();
|
||||
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
||||
out = rewriter.create<TransposeOp>(op.getLoc(), out,
|
||||
AffineMapAttr::get(*outputPerm));
|
||||
out = rewriter.create<memref::TransposeOp>(op.getLoc(), out,
|
||||
AffineMapAttr::get(*outputPerm));
|
||||
|
||||
// If nothing was transposed, fail and let the conversion kick in.
|
||||
if (in == op.input() && out == op.output())
|
||||
|
@ -213,7 +214,8 @@ struct ConvertLinalgToStandardPass
|
|||
void ConvertLinalgToStandardPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect>();
|
||||
target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
|
||||
StandardOpsDialect>();
|
||||
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
|
||||
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
|
||||
OwningRewritePatternList patterns;
|
||||
|
|
|
@ -38,6 +38,10 @@ namespace NVVM {
|
|||
class NVVMDialect;
|
||||
} // end namespace NVVM
|
||||
|
||||
namespace memref {
|
||||
class MemRefDialect;
|
||||
} // end namespace memref
|
||||
|
||||
namespace omp {
|
||||
class OpenMPDialect;
|
||||
} // end namespace omp
|
||||
|
|
|
@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRSCFToGPU
|
|||
MLIRGPU
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
MLIRSupport
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/ParallelLoopMapper.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -647,6 +648,7 @@ void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
|
|||
}
|
||||
|
||||
void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addDynamicallyLegalOp<scf::ParallelOp>([](scf::ParallelOp parallelOp) {
|
||||
return !parallelOp->getAttr(gpu::getMappingAttrName());
|
||||
});
|
||||
|
|
|
@ -19,6 +19,7 @@ add_mlir_conversion_library(MLIRShapeToStandard
|
|||
LINK_LIBS PUBLIC
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRShape
|
||||
MLIRTensor
|
||||
MLIRPass
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
@ -139,7 +140,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
|
|||
// dimension in the tensor.
|
||||
SmallVector<Value> ranks, rankDiffs;
|
||||
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
|
||||
return lb.create<DimOp>(v, zero);
|
||||
return lb.create<memref::DimOp>(v, zero);
|
||||
}));
|
||||
|
||||
// Find the maximum rank
|
||||
|
@ -252,7 +253,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
|
|||
// dimension in the tensor.
|
||||
SmallVector<Value> ranks, rankDiffs;
|
||||
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
|
||||
return lb.create<DimOp>(v, zero);
|
||||
return lb.create<memref::DimOp>(v, zero);
|
||||
}));
|
||||
|
||||
// Find the maximum rank
|
||||
|
@ -344,8 +345,8 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
|
|||
// circumvents the necessity to materialize the shape in memory.
|
||||
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
|
||||
if (shapeOfOp.arg().getType().isa<ShapedType>()) {
|
||||
rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
|
||||
transformed.dim());
|
||||
rewriter.replaceOpWithNewOp<memref::DimOp>(op, shapeOfOp.arg(),
|
||||
transformed.dim());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@ -375,7 +376,7 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
|
|||
return failure();
|
||||
|
||||
shape::RankOp::Adaptor transformed(operands);
|
||||
rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
|
||||
rewriter.replaceOpWithNewOp<memref::DimOp>(op, transformed.shape(), 0);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -404,7 +405,8 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
|
|||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
|
||||
Value rank =
|
||||
rewriter.create<memref::DimOp>(loc, indexTy, transformed.shape(), zero);
|
||||
|
||||
auto loop = rewriter.create<scf::ForOp>(
|
||||
loc, zero, rank, one, op.initVals(),
|
||||
|
@ -490,11 +492,12 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
|
|||
Type indexTy = rewriter.getIndexType();
|
||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value firstShape = transformed.shapes().front();
|
||||
Value firstRank = rewriter.create<DimOp>(loc, indexTy, firstShape, zero);
|
||||
Value firstRank =
|
||||
rewriter.create<memref::DimOp>(loc, indexTy, firstShape, zero);
|
||||
Value result = nullptr;
|
||||
// Generate a linear sequence of compares, all with firstShape as lhs.
|
||||
for (Value shape : transformed.shapes().drop_front(1)) {
|
||||
Value rank = rewriter.create<DimOp>(loc, indexTy, shape, zero);
|
||||
Value rank = rewriter.create<memref::DimOp>(loc, indexTy, shape, zero);
|
||||
Value eqRank =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, firstRank, rank);
|
||||
auto same = rewriter.create<IfOp>(
|
||||
|
@ -559,7 +562,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
|
|||
int64_t rank = rankedTensorTy.getRank();
|
||||
for (int64_t i = 0; i < rank; i++) {
|
||||
if (rankedTensorTy.isDynamicDim(i)) {
|
||||
Value extent = rewriter.create<DimOp>(loc, tensor, i);
|
||||
Value extent = rewriter.create<memref::DimOp>(loc, tensor, i);
|
||||
extentValues.push_back(extent);
|
||||
} else {
|
||||
Value extent =
|
||||
|
@ -583,7 +586,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
|
|||
op, getExtentTensorType(ctx), ValueRange{rank},
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value dim = args.front();
|
||||
Value extent = b.create<DimOp>(loc, tensor, dim);
|
||||
Value extent = b.create<memref::DimOp>(loc, tensor, dim);
|
||||
b.create<tensor::YieldOp>(loc, extent);
|
||||
});
|
||||
|
||||
|
@ -613,7 +616,7 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
|
|||
SplitAtOp::Adaptor transformed(op);
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
Value zero = b.create<ConstantIndexOp>(0);
|
||||
Value rank = b.create<DimOp>(transformed.operand(), zero);
|
||||
Value rank = b.create<memref::DimOp>(transformed.operand(), zero);
|
||||
|
||||
// index < 0 ? index + rank : index
|
||||
Value originalIndex = transformed.index();
|
||||
|
@ -670,8 +673,8 @@ void ConvertShapeToStandardPass::runOnOperation() {
|
|||
// Setup target legality.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target
|
||||
.addLegalDialect<StandardOpsDialect, SCFDialect, tensor::TensorDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
|
||||
|
||||
// Setup conversion patterns.
|
||||
|
|
|
@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRStandardToLLVM
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRLLVMIR
|
||||
MLIRMath
|
||||
MLIRMemRef
|
||||
MLIRTransforms
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -1864,13 +1865,13 @@ private:
|
|||
|
||||
struct AllocOpLowering : public AllocLikeOpLowering {
|
||||
AllocOpLowering(LLVMTypeConverter &converter)
|
||||
: AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
|
||||
: AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
|
||||
|
||||
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value sizeBytes,
|
||||
Operation *op) const override {
|
||||
// Heap allocations.
|
||||
AllocOp allocOp = cast<AllocOp>(op);
|
||||
memref::AllocOp allocOp = cast<memref::AllocOp>(op);
|
||||
MemRefType memRefType = allocOp.getType();
|
||||
|
||||
Value alignment;
|
||||
|
@ -1917,7 +1918,7 @@ struct AllocOpLowering : public AllocLikeOpLowering {
|
|||
|
||||
struct AlignedAllocOpLowering : public AllocLikeOpLowering {
|
||||
AlignedAllocOpLowering(LLVMTypeConverter &converter)
|
||||
: AllocLikeOpLowering(AllocOp::getOperationName(), converter) {}
|
||||
: AllocLikeOpLowering(memref::AllocOp::getOperationName(), converter) {}
|
||||
|
||||
/// Returns the memref's element size in bytes.
|
||||
// TODO: there are other places where this is used. Expose publicly?
|
||||
|
@ -1950,7 +1951,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
|
|||
/// Returns the alignment to be used for the allocation call itself.
|
||||
/// aligned_alloc requires the allocation size to be a power of two, and the
|
||||
/// allocation size to be a multiple of alignment,
|
||||
int64_t getAllocationAlignment(AllocOp allocOp) const {
|
||||
int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
|
||||
if (Optional<uint64_t> alignment = allocOp.alignment())
|
||||
return *alignment;
|
||||
|
||||
|
@ -1966,7 +1967,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
|
|||
Location loc, Value sizeBytes,
|
||||
Operation *op) const override {
|
||||
// Heap allocations.
|
||||
AllocOp allocOp = cast<AllocOp>(op);
|
||||
memref::AllocOp allocOp = cast<memref::AllocOp>(op);
|
||||
MemRefType memRefType = allocOp.getType();
|
||||
int64_t alignment = getAllocationAlignment(allocOp);
|
||||
Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
|
||||
|
@ -1997,7 +1998,7 @@ constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
|
|||
|
||||
struct AllocaOpLowering : public AllocLikeOpLowering {
|
||||
AllocaOpLowering(LLVMTypeConverter &converter)
|
||||
: AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {}
|
||||
: AllocLikeOpLowering(memref::AllocaOp::getOperationName(), converter) {}
|
||||
|
||||
/// Allocates the underlying buffer using the right call. `allocatedBytePtr`
|
||||
/// is set to null for stack allocations. `accessAlignment` is set if
|
||||
|
@ -2008,7 +2009,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering {
|
|||
|
||||
// With alloca, one gets a pointer to the element type right away.
|
||||
// For stack allocations.
|
||||
auto allocaOp = cast<AllocaOp>(op);
|
||||
auto allocaOp = cast<memref::AllocaOp>(op);
|
||||
auto elementPtrType = this->getElementPtrType(allocaOp.getType());
|
||||
|
||||
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
|
||||
|
@ -2180,17 +2181,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
|
|||
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
|
||||
// The memref descriptor being an SSA value, there is no need to clean it up
|
||||
// in any way.
|
||||
struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
|
||||
using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
|
||||
struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
|
||||
using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
explicit DeallocOpLowering(LLVMTypeConverter &converter)
|
||||
: ConvertOpToLLVMPattern<DeallocOp>(converter) {}
|
||||
: ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(DeallocOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
assert(operands.size() == 1 && "dealloc takes one operand");
|
||||
DeallocOp::Adaptor transformed(operands);
|
||||
memref::DeallocOp::Adaptor transformed(operands);
|
||||
|
||||
// Insert the `free` declaration if it is not already present.
|
||||
auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
|
||||
|
@ -2209,7 +2210,7 @@ static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
|
|||
LLVMTypeConverter &typeConverter) {
|
||||
// LLVM type for a global memref will be a multi-dimension array. For
|
||||
// declarations or uninitialized global memrefs, we can potentially flatten
|
||||
// this to a 1D array. However, for global_memref's with an initial value,
|
||||
// this to a 1D array. However, for memref.global's with an initial value,
|
||||
// we do not intend to flatten the ElementsAttribute when going from std ->
|
||||
// LLVM dialect, so the LLVM type needs to me a multi-dimension array.
|
||||
Type elementType = unwrap(typeConverter.convertType(type.getElementType()));
|
||||
|
@ -2221,11 +2222,12 @@ static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
|
|||
}
|
||||
|
||||
/// GlobalMemrefOp is lowered to a LLVM Global Variable.
|
||||
struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
|
||||
using ConvertOpToLLVMPattern<GlobalMemrefOp>::ConvertOpToLLVMPattern;
|
||||
struct GlobalMemrefOpLowering
|
||||
: public ConvertOpToLLVMPattern<memref::GlobalOp> {
|
||||
using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(GlobalMemrefOp global, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefType type = global.type().cast<MemRefType>();
|
||||
if (!isConvertibleAndHasIdentityMaps(type))
|
||||
|
@ -2259,14 +2261,15 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
|
|||
/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
|
||||
struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
|
||||
GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
|
||||
: AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {}
|
||||
: AllocLikeOpLowering(memref::GetGlobalOp::getOperationName(),
|
||||
converter) {}
|
||||
|
||||
/// Buffer "allocation" for get_global_memref op is getting the address of
|
||||
/// Buffer "allocation" for memref.get_global op is getting the address of
|
||||
/// the global variable referenced.
|
||||
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value sizeBytes,
|
||||
Operation *op) const override {
|
||||
auto getGlobalOp = cast<GetGlobalMemrefOp>(op);
|
||||
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
|
||||
MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
|
||||
unsigned memSpace = type.getMemorySpaceAsInt();
|
||||
|
||||
|
@ -2285,7 +2288,7 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
|
|||
createIndexConstant(rewriter, loc, 0));
|
||||
auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
|
||||
|
||||
// We do not expect the memref obtained using `get_global_memref` to be
|
||||
// We do not expect the memref obtained using `memref.get_global` to be
|
||||
// ever deallocated. Set the allocated pointer to be known bad value to
|
||||
// help debug if that ever happens.
|
||||
auto intPtrType = getIntPtrType(memSpace);
|
||||
|
@ -2354,17 +2357,17 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
|
||||
using ConvertOpToLLVMPattern<MemRefCastOp>::ConvertOpToLLVMPattern;
|
||||
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
||||
using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult match(MemRefCastOp memRefCastOp) const override {
|
||||
LogicalResult match(memref::CastOp memRefCastOp) const override {
|
||||
Type srcType = memRefCastOp.getOperand().getType();
|
||||
Type dstType = memRefCastOp.getType();
|
||||
|
||||
// MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used
|
||||
// for type erasure. For now they must preserve underlying element type and
|
||||
// require source and result type to have the same rank. Therefore, perform
|
||||
// a sanity check that the underlying structs are the same. Once op
|
||||
// memref::CastOp reduce to bitcast in the ranked MemRef case and can be
|
||||
// used for type erasure. For now they must preserve underlying element type
|
||||
// and require source and result type to have the same rank. Therefore,
|
||||
// perform a sanity check that the underlying structs are the same. Once op
|
||||
// semantics are relaxed we can revisit.
|
||||
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
|
||||
return success(typeConverter->convertType(srcType) ==
|
||||
|
@ -2381,9 +2384,9 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
|
|||
: failure();
|
||||
}
|
||||
|
||||
void rewrite(MemRefCastOp memRefCastOp, ArrayRef<Value> operands,
|
||||
void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefCastOp::Adaptor transformed(operands);
|
||||
memref::CastOp::Adaptor transformed(operands);
|
||||
|
||||
auto srcType = memRefCastOp.getOperand().getType();
|
||||
auto dstType = memRefCastOp.getType();
|
||||
|
@ -2486,14 +2489,15 @@ static void extractPointersAndOffset(Location loc,
|
|||
}
|
||||
|
||||
struct MemRefReinterpretCastOpLowering
|
||||
: public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> {
|
||||
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern;
|
||||
: public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefReinterpretCastOp::Adaptor adaptor(operands,
|
||||
castOp->getAttrDictionary());
|
||||
memref::ReinterpretCastOp::Adaptor adaptor(operands,
|
||||
castOp->getAttrDictionary());
|
||||
Type srcType = castOp.source().getType();
|
||||
|
||||
Value descriptor;
|
||||
|
@ -2505,11 +2509,10 @@ struct MemRefReinterpretCastOpLowering
|
|||
}
|
||||
|
||||
private:
|
||||
LogicalResult
|
||||
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
|
||||
Type srcType, MemRefReinterpretCastOp castOp,
|
||||
MemRefReinterpretCastOp::Adaptor adaptor,
|
||||
Value *descriptor) const {
|
||||
LogicalResult convertSourceMemRefToDescriptor(
|
||||
ConversionPatternRewriter &rewriter, Type srcType,
|
||||
memref::ReinterpretCastOp castOp,
|
||||
memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
|
||||
MemRefType targetMemRefType =
|
||||
castOp.getResult().getType().cast<MemRefType>();
|
||||
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
|
||||
|
@ -2555,14 +2558,14 @@ private:
|
|||
};
|
||||
|
||||
struct MemRefReshapeOpLowering
|
||||
: public ConvertOpToLLVMPattern<MemRefReshapeOp> {
|
||||
using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern;
|
||||
: public ConvertOpToLLVMPattern<memref::ReshapeOp> {
|
||||
using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *op = reshapeOp.getOperation();
|
||||
MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
|
||||
memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
|
||||
Type srcType = reshapeOp.source().getType();
|
||||
|
||||
Value descriptor;
|
||||
|
@ -2576,8 +2579,8 @@ struct MemRefReshapeOpLowering
|
|||
private:
|
||||
LogicalResult
|
||||
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
|
||||
Type srcType, MemRefReshapeOp reshapeOp,
|
||||
MemRefReshapeOp::Adaptor adaptor,
|
||||
Type srcType, memref::ReshapeOp reshapeOp,
|
||||
memref::ReshapeOp::Adaptor adaptor,
|
||||
Value *descriptor) const {
|
||||
// Conversion for statically-known shape args is performed via
|
||||
// `memref_reinterpret_cast`.
|
||||
|
@ -2722,11 +2725,11 @@ struct DialectCastOpLowering
|
|||
|
||||
// A `dim` is converted to a constant for static sizes and to an access to the
|
||||
// size stored in the memref descriptor for dynamic sizes.
|
||||
struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
|
||||
using ConvertOpToLLVMPattern<DimOp>::ConvertOpToLLVMPattern;
|
||||
struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
|
||||
using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(DimOp dimOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type operandType = dimOp.memrefOrTensor().getType();
|
||||
if (operandType.isa<UnrankedMemRefType>()) {
|
||||
|
@ -2744,11 +2747,11 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
|
|||
}
|
||||
|
||||
private:
|
||||
Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp,
|
||||
Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = dimOp.getLoc();
|
||||
DimOp::Adaptor transformed(operands);
|
||||
memref::DimOp::Adaptor transformed(operands);
|
||||
|
||||
auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
|
||||
auto scalarMemRefType =
|
||||
|
@ -2785,11 +2788,11 @@ private:
|
|||
return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
|
||||
}
|
||||
|
||||
Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp,
|
||||
Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = dimOp.getLoc();
|
||||
DimOp::Adaptor transformed(operands);
|
||||
memref::DimOp::Adaptor transformed(operands);
|
||||
// Take advantage if index is constant.
|
||||
MemRefType memRefType = operandType.cast<MemRefType>();
|
||||
if (Optional<int64_t> index = dimOp.getConstantIndex()) {
|
||||
|
@ -2833,7 +2836,7 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
|
|||
};
|
||||
|
||||
// Common base for load and store operations on MemRefs. Restricts the match
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// to supported MemRef types. Provides functionality to emit code accessing a
|
||||
// specific element of the underlying data buffer.
|
||||
template <typename Derived>
|
||||
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
||||
|
@ -2849,13 +2852,13 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
|||
|
||||
// Load operation is lowered to obtaining a pointer to the indexed element
|
||||
// and loading it.
|
||||
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
||||
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LoadOp::Adaptor transformed(operands);
|
||||
memref::LoadOp::Adaptor transformed(operands);
|
||||
auto type = loadOp.getMemRefType();
|
||||
|
||||
Value dataPtr =
|
||||
|
@ -2868,14 +2871,14 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
|||
|
||||
// Store operation is lowered to obtaining a pointer to the indexed element,
|
||||
// and storing the given value to it.
|
||||
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
||||
struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(StoreOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto type = op.getMemRefType();
|
||||
StoreOp::Adaptor transformed(operands);
|
||||
memref::StoreOp::Adaptor transformed(operands);
|
||||
|
||||
Value dataPtr =
|
||||
getStridedElementPtr(op.getLoc(), type, transformed.memref(),
|
||||
|
@ -2888,13 +2891,13 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
|||
|
||||
// The prefetch operation is lowered in a way similar to the load operation
|
||||
// except that the llvm.prefetch operation is used for replacement.
|
||||
struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
|
||||
struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PrefetchOp prefetchOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
PrefetchOp::Adaptor transformed(operands);
|
||||
memref::PrefetchOp::Adaptor transformed(operands);
|
||||
auto type = prefetchOp.getMemRefType();
|
||||
auto loc = prefetchOp.getLoc();
|
||||
|
||||
|
@ -3221,11 +3224,11 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
|||
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
|
||||
/// and stride.
|
||||
/// The subview op is replaced by the descriptor.
|
||||
struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
|
||||
using ConvertOpToLLVMPattern<SubViewOp>::ConvertOpToLLVMPattern;
|
||||
struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
|
||||
using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SubViewOp subViewOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = subViewOp.getLoc();
|
||||
|
||||
|
@ -3234,7 +3237,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
|
|||
typeConverter->convertType(sourceMemRefType.getElementType());
|
||||
|
||||
auto viewMemRefType = subViewOp.getType();
|
||||
auto inferredType = SubViewOp::inferResultType(
|
||||
auto inferredType = memref::SubViewOp::inferResultType(
|
||||
subViewOp.getSourceType(),
|
||||
extractFromI64ArrayAttr(subViewOp.static_offsets()),
|
||||
extractFromI64ArrayAttr(subViewOp.static_sizes()),
|
||||
|
@ -3335,7 +3338,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
|
|||
if (static_cast<unsigned>(i) >= mixedSizes.size()) {
|
||||
size = rewriter.create<LLVM::DialectCastOp>(
|
||||
loc, llvmIndexType,
|
||||
rewriter.create<DimOp>(loc, subViewOp.source(), i));
|
||||
rewriter.create<memref::DimOp>(loc, subViewOp.source(), i));
|
||||
stride = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
|
||||
} else {
|
||||
|
@ -3376,15 +3379,15 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
|
|||
/// and stride. Size and stride are permutations of the original values.
|
||||
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
||||
/// The transpose op is replaced by the alloca'ed pointer.
|
||||
class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> {
|
||||
class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<TransposeOp>::ConvertOpToLLVMPattern;
|
||||
using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(TransposeOp transposeOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = transposeOp.getLoc();
|
||||
TransposeOpAdaptor adaptor(operands);
|
||||
memref::TransposeOpAdaptor adaptor(operands);
|
||||
MemRefDescriptor viewMemRef(adaptor.in());
|
||||
|
||||
// No permutation, early exit.
|
||||
|
@ -3424,8 +3427,8 @@ public:
|
|||
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
|
||||
/// and stride.
|
||||
/// The view op is replaced by the descriptor.
|
||||
struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
||||
using ConvertOpToLLVMPattern<ViewOp>::ConvertOpToLLVMPattern;
|
||||
struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
||||
using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Build and return the value for the idx^th shape dimension, either by
|
||||
// returning the constant shape dimension or counting the proper dynamic size.
|
||||
|
@ -3461,10 +3464,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
|||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ViewOp viewOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = viewOp.getLoc();
|
||||
ViewOpAdaptor adaptor(operands);
|
||||
memref::ViewOpAdaptor adaptor(operands);
|
||||
|
||||
auto viewMemRefType = viewOp.getType();
|
||||
auto targetElementTy =
|
||||
|
@ -3540,13 +3543,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
|
|||
};
|
||||
|
||||
struct AssumeAlignmentOpLowering
|
||||
: public ConvertOpToLLVMPattern<AssumeAlignmentOp> {
|
||||
using ConvertOpToLLVMPattern<AssumeAlignmentOp>::ConvertOpToLLVMPattern;
|
||||
: public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AssumeAlignmentOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
AssumeAlignmentOp::Adaptor transformed(operands);
|
||||
memref::AssumeAlignmentOp::Adaptor transformed(operands);
|
||||
Value memref = transformed.memref();
|
||||
unsigned alignment = op.alignment();
|
||||
auto loc = op.getLoc();
|
||||
|
|
|
@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRStandardToSPIRV
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRMath
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRSPIRV
|
||||
MLIRSPIRVConversion
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
|
@ -23,11 +24,11 @@
|
|||
using namespace mlir;
|
||||
|
||||
/// Helpers to access the memref operand for each op.
|
||||
static Value getMemRefOperand(LoadOp op) { return op.memref(); }
|
||||
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
|
||||
|
||||
static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
|
||||
|
||||
static Value getMemRefOperand(StoreOp op) { return op.memref(); }
|
||||
static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
|
||||
|
||||
static Value getMemRefOperand(vector::TransferWriteOp op) {
|
||||
return op.source();
|
||||
|
@ -44,7 +45,7 @@ public:
|
|||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
void replaceOp(OpTy loadOp, SubViewOp subViewOp,
|
||||
void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices,
|
||||
PatternRewriter &rewriter) const;
|
||||
};
|
||||
|
@ -59,23 +60,22 @@ public:
|
|||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
|
||||
void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices,
|
||||
PatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
|
||||
SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices,
|
||||
PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
|
||||
sourceIndices);
|
||||
void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
|
||||
memref::LoadOp loadOp, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
|
||||
sourceIndices);
|
||||
}
|
||||
|
||||
template <>
|
||||
void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
|
||||
vector::TransferReadOp loadOp, SubViewOp subViewOp,
|
||||
vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
|
||||
loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
|
||||
|
@ -83,16 +83,16 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
|
|||
}
|
||||
|
||||
template <>
|
||||
void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
|
||||
StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
|
||||
PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
|
||||
subViewOp.source(), sourceIndices);
|
||||
void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
|
||||
memref::StoreOp storeOp, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
|
||||
}
|
||||
|
||||
template <>
|
||||
void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
|
||||
vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
|
||||
vector::TransferWriteOp tranferWriteOp, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
|
||||
|
@ -120,7 +120,7 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
|
|||
/// memref<12x42xf32>
|
||||
static LogicalResult
|
||||
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
|
||||
SubViewOp subViewOp, ValueRange indices,
|
||||
memref::SubViewOp subViewOp, ValueRange indices,
|
||||
SmallVectorImpl<Value> &sourceIndices) {
|
||||
// TODO: Aborting when the offsets are static. There might be a way to fold
|
||||
// the subview op with load even if the offsets have been canonicalized
|
||||
|
@ -152,7 +152,8 @@ template <typename OpTy>
|
|||
LogicalResult
|
||||
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
|
||||
auto subViewOp =
|
||||
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
|
||||
if (!subViewOp) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -174,7 +175,7 @@ LogicalResult
|
|||
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp =
|
||||
getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
|
||||
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
|
||||
if (!subViewOp) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -193,9 +194,9 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
|
|||
|
||||
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
|
||||
MLIRContext *context, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
|
||||
patterns.insert<LoadOpOfSubViewFolder<memref::LoadOp>,
|
||||
LoadOpOfSubViewFolder<vector::TransferReadOp>,
|
||||
StoreOpOfSubViewFolder<StoreOp>,
|
||||
StoreOpOfSubViewFolder<memref::StoreOp>,
|
||||
StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
|
@ -237,12 +238,12 @@ namespace {
|
|||
/// to Workgroup memory when the size is constant. Note that this pattern needs
|
||||
/// to be applied in a pass that runs at least at spv.module scope since it wil
|
||||
/// ladd global variables into the spv.module.
|
||||
class AllocOpPattern final : public OpConversionPattern<AllocOp> {
|
||||
class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
|
||||
public:
|
||||
using OpConversionPattern<AllocOp>::OpConversionPattern;
|
||||
using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::AllocOp operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefType allocType = operation.getType();
|
||||
if (!isAllocationSupported(allocType))
|
||||
|
@ -278,12 +279,12 @@ public:
|
|||
|
||||
/// Removed a deallocation if it is a supported allocation. Currently only
|
||||
/// removes deallocation if the memory space is workgroup memory.
|
||||
class DeallocOpPattern final : public OpConversionPattern<DeallocOp> {
|
||||
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
|
||||
public:
|
||||
using OpConversionPattern<DeallocOp>::OpConversionPattern;
|
||||
using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::DeallocOp operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
|
||||
if (!isAllocationSupported(deallocType))
|
||||
|
@ -430,23 +431,23 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.load to spv.Load.
|
||||
class IntLoadOpPattern final : public OpConversionPattern<LoadOp> {
|
||||
/// Converts memref.load to spv.Load.
|
||||
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
|
||||
public:
|
||||
using OpConversionPattern<LoadOp>::OpConversionPattern;
|
||||
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.load to spv.Load.
|
||||
class LoadOpPattern final : public OpConversionPattern<LoadOp> {
|
||||
/// Converts memref.load to spv.Load.
|
||||
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
|
||||
public:
|
||||
using OpConversionPattern<LoadOp>::OpConversionPattern;
|
||||
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
|
@ -469,23 +470,23 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.store to spv.Store on integers.
|
||||
class IntStoreOpPattern final : public OpConversionPattern<StoreOp> {
|
||||
/// Converts memref.store to spv.Store on integers.
|
||||
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
|
||||
public:
|
||||
using OpConversionPattern<StoreOp>::OpConversionPattern;
|
||||
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Converts std.store to spv.Store.
|
||||
class StoreOpPattern final : public OpConversionPattern<StoreOp> {
|
||||
/// Converts memref.store to spv.Store.
|
||||
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
|
||||
public:
|
||||
using OpConversionPattern<StoreOp>::OpConversionPattern;
|
||||
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::StoreOp storeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
|
@ -975,9 +976,10 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
LoadOpAdaptor loadOperands(operands);
|
||||
memref::LoadOpAdaptor loadOperands(operands);
|
||||
auto loc = loadOp.getLoc();
|
||||
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
||||
if (!memrefType.getElementType().isSignlessInteger())
|
||||
|
@ -1051,9 +1053,9 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
|||
}
|
||||
|
||||
LogicalResult
|
||||
LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
LoadOpAdaptor loadOperands(operands);
|
||||
memref::LoadOpAdaptor loadOperands(operands);
|
||||
auto memrefType = loadOp.memref().getType().cast<MemRefType>();
|
||||
if (memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
|
@ -1101,9 +1103,10 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
||||
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
StoreOpAdaptor storeOperands(operands);
|
||||
memref::StoreOpAdaptor storeOperands(operands);
|
||||
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
|
||||
if (!memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
|
@ -1180,9 +1183,10 @@ IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
|||
}
|
||||
|
||||
LogicalResult
|
||||
StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
||||
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
StoreOpAdaptor storeOperands(operands);
|
||||
memref::StoreOpAdaptor storeOperands(operands);
|
||||
auto memrefType = storeOp.memref().getType().cast<MemRefType>();
|
||||
if (memrefType.getElementType().isSignlessInteger())
|
||||
return failure();
|
||||
|
|
|
@ -20,6 +20,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
|
|||
MLIRArmSVEToLLVM
|
||||
MLIRLLVMArmSVE
|
||||
MLIRLLVMIR
|
||||
MLIRMemRef
|
||||
MLIRStandardToLLVM
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRTransforms
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -1262,7 +1263,7 @@ public:
|
|||
unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
|
||||
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
|
||||
Value off = xferOp.indices()[lastIndex];
|
||||
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
|
||||
Value dim = rewriter.create<memref::DimOp>(loc, xferOp.source(), lastIndex);
|
||||
Value mask = buildVectorComparison(
|
||||
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
@ -39,6 +40,7 @@ struct LowerVectorToLLVMPass
|
|||
// Override explicitly to allow conditional dialect dependence.
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
if (enableArmNeon)
|
||||
registry.insert<arm_neon::ArmNeonDialect>();
|
||||
if (enableArmSVE)
|
||||
|
@ -72,6 +74,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
|||
// Architecture specific augmentations.
|
||||
LLVMConversionTarget target(getContext());
|
||||
target.addLegalOp<LLVM::DialectCastOp>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalOp<UnrealizedConversionCastOp>();
|
||||
if (enableArmNeon) {
|
||||
|
|
|
@ -11,5 +11,6 @@ add_mlir_conversion_library(MLIRVectorToSCF
|
|||
MLIREDSC
|
||||
MLIRAffineEDSC
|
||||
MLIRLLVMIR
|
||||
MLIRMemRef
|
||||
MLIRTransforms
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
|
@ -252,7 +253,7 @@ static Value setAllocAtFunctionEntry(MemRefType memRefMinorVectorType,
|
|||
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
|
||||
assert(scope && "Expected op to be inside automatic allocation scope");
|
||||
b.setInsertionPointToStart(&scope->getRegion(0).front());
|
||||
Value res = std_alloca(memRefMinorVectorType);
|
||||
Value res = memref_alloca(memRefMinorVectorType);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -314,7 +315,7 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
|
|||
return {vector};
|
||||
}
|
||||
// 3.b. Otherwise, just go through the temporary `alloc`.
|
||||
std_store(vector, alloc, majorIvs);
|
||||
memref_store(vector, alloc, majorIvs);
|
||||
return {};
|
||||
},
|
||||
[&]() -> scf::ValueVector {
|
||||
|
@ -326,7 +327,7 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
|
|||
return {vector};
|
||||
}
|
||||
// 3.d. Otherwise, just go through the temporary `alloc`.
|
||||
std_store(vector, alloc, majorIvs);
|
||||
memref_store(vector, alloc, majorIvs);
|
||||
return {};
|
||||
});
|
||||
|
||||
|
@ -341,14 +342,15 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
|
|||
result = vector_insert(loaded1D, result, majorIvs);
|
||||
// 5.b. Otherwise, just go through the temporary `alloc`.
|
||||
else
|
||||
std_store(loaded1D, alloc, majorIvs);
|
||||
memref_store(loaded1D, alloc, majorIvs);
|
||||
}
|
||||
});
|
||||
|
||||
assert((!options.unroll ^ (bool)result) &&
|
||||
"Expected resulting Value iff unroll");
|
||||
if (!result)
|
||||
result = std_load(vector_type_cast(MemRefType::get({}, vectorType), alloc));
|
||||
result =
|
||||
memref_load(vector_type_cast(MemRefType::get({}, vectorType), alloc));
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
return success();
|
||||
|
@ -359,8 +361,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
|
|||
Value alloc;
|
||||
if (!options.unroll) {
|
||||
alloc = setAllocAtFunctionEntry(memRefMinorVectorType, op);
|
||||
std_store(xferOp.vector(),
|
||||
vector_type_cast(MemRefType::get({}, vectorType), alloc));
|
||||
memref_store(xferOp.vector(),
|
||||
vector_type_cast(MemRefType::get({}, vectorType), alloc));
|
||||
}
|
||||
|
||||
emitLoops([&](ValueRange majorIvs, ValueRange leadingOffsets,
|
||||
|
@ -379,7 +381,7 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
|
|||
if (options.unroll)
|
||||
result = vector_extract(xferOp.vector(), majorIvs);
|
||||
else
|
||||
result = std_load(alloc, majorIvs);
|
||||
result = memref_load(alloc, majorIvs);
|
||||
auto map =
|
||||
getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
|
||||
ArrayAttr masked;
|
||||
|
@ -560,7 +562,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
// Conservative lowering to scalar load / stores.
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(rewriter, transfer.getLoc());
|
||||
StdIndexedValue remote(transfer.source());
|
||||
MemRefIndexedValue remote(transfer.source());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
|
||||
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
|
||||
int coalescedIdx = computeCoalescedIndex(transfer);
|
||||
|
@ -579,7 +581,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
// 2. Emit alloc-copy-load-dealloc.
|
||||
MLIRContext *ctx = op->getContext();
|
||||
Value tmp = setAllocAtFunctionEntry(tmpMemRefType(transfer), transfer);
|
||||
StdIndexedValue local(tmp);
|
||||
MemRefIndexedValue local(tmp);
|
||||
loopNestBuilder(lbs, ubs, steps, [&](ValueRange loopIvs) {
|
||||
auto ivsStorage = llvm::to_vector<8>(loopIvs);
|
||||
// Swap the ivs which will reorder memory accesses.
|
||||
|
@ -601,7 +603,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
rewriter, cast<VectorTransferOpInterface>(transfer.getOperation()), ivs,
|
||||
memRefBoundsCapture, loadValue, loadPadding);
|
||||
});
|
||||
Value vectorValue = std_load(vector_type_cast(tmp));
|
||||
Value vectorValue = memref_load(vector_type_cast(tmp));
|
||||
|
||||
// 3. Propagate.
|
||||
rewriter.replaceOp(op, vectorValue);
|
||||
|
@ -646,7 +648,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
|
||||
// 1. Setup all the captures.
|
||||
ScopedContext scope(rewriter, transfer.getLoc());
|
||||
StdIndexedValue remote(transfer.source());
|
||||
MemRefIndexedValue remote(transfer.source());
|
||||
MemRefBoundsCapture memRefBoundsCapture(transfer.source());
|
||||
Value vectorValue(transfer.vector());
|
||||
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
|
||||
|
@ -665,9 +667,9 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
|
||||
// 2. Emit alloc-store-copy-dealloc.
|
||||
Value tmp = setAllocAtFunctionEntry(tmpMemRefType(transfer), transfer);
|
||||
StdIndexedValue local(tmp);
|
||||
MemRefIndexedValue local(tmp);
|
||||
Value vec = vector_type_cast(tmp);
|
||||
std_store(vectorValue, vec);
|
||||
memref_store(vectorValue, vec);
|
||||
loopNestBuilder(lbs, ubs, steps, [&](ValueRange loopIvs) {
|
||||
auto ivsStorage = llvm::to_vector<8>(loopIvs);
|
||||
// Swap the ivsStorage which will reorder memory accesses.
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -64,7 +65,7 @@ remainsLegalAfterInline(Value value, Region *src, Region *dest,
|
|||
// op won't be top-level anymore after inlining.
|
||||
Attribute operandCst;
|
||||
return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
|
||||
value.getDefiningOp<DimOp>();
|
||||
value.getDefiningOp<memref::DimOp>();
|
||||
}
|
||||
|
||||
/// Checks if all values known to be legal affine dimensions or symbols in `src`
|
||||
|
@ -295,7 +296,7 @@ bool mlir::isValidDim(Value value, Region *region) {
|
|||
return applyOp.isValidDim(region);
|
||||
// The dim op is okay if its operand memref/tensor is defined at the top
|
||||
// level.
|
||||
if (auto dimOp = dyn_cast<DimOp>(op))
|
||||
if (auto dimOp = dyn_cast<memref::DimOp>(op))
|
||||
return isTopLevelValue(dimOp.memrefOrTensor());
|
||||
return false;
|
||||
}
|
||||
|
@ -317,9 +318,8 @@ static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
|
|||
}
|
||||
|
||||
/// Returns true if the result of the dim op is a valid symbol for `region`.
|
||||
static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
|
||||
// The dim op is okay if its operand memref/tensor is defined at the top
|
||||
// level.
|
||||
static bool isDimOpValidSymbol(memref::DimOp dimOp, Region *region) {
|
||||
// The dim op is okay if its operand memref is defined at the top level.
|
||||
if (isTopLevelValue(dimOp.memrefOrTensor()))
|
||||
return true;
|
||||
|
||||
|
@ -328,14 +328,14 @@ static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
|
|||
if (dimOp.memrefOrTensor().isa<BlockArgument>())
|
||||
return false;
|
||||
|
||||
// The dim op is also okay if its operand memref/tensor is a view/subview
|
||||
// whose corresponding size is a valid symbol.
|
||||
// The dim op is also okay if its operand memref is a view/subview whose
|
||||
// corresponding size is a valid symbol.
|
||||
Optional<int64_t> index = dimOp.getConstantIndex();
|
||||
assert(index.hasValue() &&
|
||||
"expect only `dim` operations with a constant index");
|
||||
int64_t i = index.getValue();
|
||||
return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
|
||||
.Case<ViewOp, SubViewOp, AllocOp>(
|
||||
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
|
||||
[&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
|
||||
.Default([](Operation *) { return false; });
|
||||
}
|
||||
|
@ -404,7 +404,7 @@ bool mlir::isValidSymbol(Value value, Region *region) {
|
|||
return applyOp.isValidSymbol(region);
|
||||
|
||||
// Dim op results could be valid symbols at any level.
|
||||
if (auto dimOp = dyn_cast<DimOp>(defOp))
|
||||
if (auto dimOp = dyn_cast<memref::DimOp>(defOp))
|
||||
return isDimOpValidSymbol(dimOp, region);
|
||||
|
||||
// Check for values dominating `region`'s parent op.
|
||||
|
@ -915,12 +915,12 @@ void AffineApplyOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This is a common class used for patterns of the form
|
||||
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
|
||||
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
|
||||
/// into the root operation directly.
|
||||
static LogicalResult foldMemRefCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto cast = operand.get().getDefiningOp<MemRefCastOp>();
|
||||
auto cast = operand.get().getDefiningOp<memref::CastOp>();
|
||||
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
|
||||
operand.set(cast.getOperand());
|
||||
folded = true;
|
||||
|
@ -2254,7 +2254,8 @@ LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
// AffineMinMaxOpBase
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
template <typename T>
|
||||
static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
// Verify that operand count matches affine map dimension and symbol count.
|
||||
if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
|
||||
return op.emitOpError(
|
||||
|
@ -2262,7 +2263,8 @@ template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
|
||||
template <typename T>
|
||||
static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
|
||||
p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName());
|
||||
auto operands = op.getOperands();
|
||||
unsigned numDims = op.map().getNumDims();
|
||||
|
|
|
@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAffine
|
|||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLoopLikeInterface
|
||||
MLIRMemRef
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRStandard
|
||||
)
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
|
|
@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
|
|||
MLIRAffineUtils
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRStandard
|
||||
|
|
|
@ -19,6 +19,11 @@ void registerDialect(DialectRegistry ®istry);
|
|||
namespace linalg {
|
||||
class LinalgDialect;
|
||||
} // end namespace linalg
|
||||
|
||||
namespace memref {
|
||||
class MemRefDialect;
|
||||
} // end namespace memref
|
||||
|
||||
namespace vector {
|
||||
class VectorDialect;
|
||||
} // end namespace vector
|
||||
|
|
|
@ -9,6 +9,7 @@ add_subdirectory(GPU)
|
|||
add_subdirectory(Linalg)
|
||||
add_subdirectory(LLVMIR)
|
||||
add_subdirectory(Math)
|
||||
add_subdirectory(MemRef)
|
||||
add_subdirectory(OpenACC)
|
||||
add_subdirectory(OpenMP)
|
||||
add_subdirectory(PDL)
|
||||
|
|
|
@ -35,6 +35,7 @@ add_mlir_dialect_library(MLIRGPU
|
|||
MLIRAsync
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRLLVMIR
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
MLIRSCF
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -107,7 +108,7 @@ struct GpuAllReduceRewriter {
|
|||
createPredicatedBlock(isFirstLane, [&] {
|
||||
Value subgroupId = getDivideBySubgroupSize(invocationIdx);
|
||||
Value index = create<IndexCastOp>(indexType, subgroupId);
|
||||
create<StoreOp>(subgroupReduce, buffer, index);
|
||||
create<memref::StoreOp>(subgroupReduce, buffer, index);
|
||||
});
|
||||
create<gpu::BarrierOp>();
|
||||
|
||||
|
@ -124,27 +125,29 @@ struct GpuAllReduceRewriter {
|
|||
Value zero = create<ConstantIndexOp>(0);
|
||||
createPredicatedBlock(isValidSubgroup, [&] {
|
||||
Value index = create<IndexCastOp>(indexType, invocationIdx);
|
||||
Value value = create<LoadOp>(valueType, buffer, index);
|
||||
Value value = create<memref::LoadOp>(valueType, buffer, index);
|
||||
Value result =
|
||||
createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
|
||||
create<StoreOp>(result, buffer, zero);
|
||||
create<memref::StoreOp>(result, buffer, zero);
|
||||
});
|
||||
|
||||
// Synchronize workgroup and load result from workgroup memory.
|
||||
create<gpu::BarrierOp>();
|
||||
Value result = create<LoadOp>(valueType, buffer, zero);
|
||||
Value result = create<memref::LoadOp>(valueType, buffer, zero);
|
||||
|
||||
rewriter.replaceOp(reduceOp, result);
|
||||
}
|
||||
|
||||
private:
|
||||
// Shortcut to create an op from rewriter using loc as the first argument.
|
||||
template <typename T, typename... Args> T create(Args... args) {
|
||||
template <typename T, typename... Args>
|
||||
T create(Args... args) {
|
||||
return rewriter.create<T>(loc, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Creates dimension op of type T, with the result casted to int32.
|
||||
template <typename T> Value getDimOp(StringRef dimension) {
|
||||
template <typename T>
|
||||
Value getDimOp(StringRef dimension) {
|
||||
Value dim = create<T>(indexType, rewriter.getStringAttr(dimension));
|
||||
return create<IndexCastOp>(int32Type, dim);
|
||||
}
|
||||
|
@ -236,7 +239,8 @@ private:
|
|||
}
|
||||
|
||||
/// Returns an accumulator factory that creates an op of type T.
|
||||
template <typename T> AccumulatorFactory getFactory() {
|
||||
template <typename T>
|
||||
AccumulatorFactory getFactory() {
|
||||
return [&](Value lhs, Value rhs) {
|
||||
return create<T>(lhs.getType(), lhs, rhs);
|
||||
};
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/GPU/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -58,7 +59,7 @@ static void injectGpuIndexOperations(Location loc, Region &launchFuncOpBody,
|
|||
/// operations may not have side-effects, as otherwise sinking (and hence
|
||||
/// duplicating them) is not legal.
|
||||
static bool isSinkingBeneficiary(Operation *op) {
|
||||
return isa<ConstantOp, DimOp, SelectOp, CmpIOp>(op);
|
||||
return isa<ConstantOp, memref::DimOp, SelectOp, CmpIOp>(op);
|
||||
}
|
||||
|
||||
/// For a given operation `op`, computes whether it is beneficial to sink the
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "mlir/Dialect/GPU/MemoryPromotion.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -82,7 +83,7 @@ static void insertCopyLoops(OpBuilder &builder, Location loc,
|
|||
loopNestBuilder(lbs, ubs, steps, [&](ValueRange loopIvs) {
|
||||
ivs.assign(loopIvs.begin(), loopIvs.end());
|
||||
auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank);
|
||||
StdIndexedValue fromHandle(from), toHandle(to);
|
||||
MemRefIndexedValue fromHandle(from), toHandle(to);
|
||||
toHandle(activeIvs) = fromHandle(activeIvs);
|
||||
});
|
||||
|
||||
|
|
|
@ -7,5 +7,6 @@ add_mlir_dialect_library(MLIRLinalgAnalysis
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRMemRef
|
||||
MLIRStandard
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
|
@ -48,7 +49,7 @@ Value Aliases::find(Value v) {
|
|||
// the aliasing further.
|
||||
if (isa<RegionBranchOpInterface>(defOp))
|
||||
return v;
|
||||
if (isa<TensorToMemrefOp>(defOp))
|
||||
if (isa<memref::BufferCastOp>(defOp))
|
||||
return v;
|
||||
|
||||
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
|
||||
|
|
|
@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRLinalgEDSC
|
|||
MLIRAffineEDSC
|
||||
MLIRLinalg
|
||||
MLIRMath
|
||||
MLIRMemRef
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
)
|
||||
|
|
|
@ -19,5 +19,6 @@ add_mlir_dialect_library(MLIRLinalg
|
|||
MLIRSideEffectInterfaces
|
||||
MLIRViewLikeInterface
|
||||
MLIRStandard
|
||||
MLIRMemRef
|
||||
MLIRTensor
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
@ -187,7 +188,7 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
|
|||
for (Value v : getShapedOperands()) {
|
||||
ShapedType t = v.getType().template cast<ShapedType>();
|
||||
for (unsigned i = 0, e = t.getRank(); i < e; ++i)
|
||||
res.push_back(b.create<DimOp>(loc, v, i));
|
||||
res.push_back(b.create<memref::DimOp>(loc, v, i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -109,12 +110,12 @@ static void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
|||
/// ```
|
||||
/// someop(memrefcast) -> someop
|
||||
/// ```
|
||||
/// It folds the source of the memref_cast into the root operation directly.
|
||||
/// It folds the source of the memref.cast into the root operation directly.
|
||||
static LogicalResult foldMemRefCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
|
||||
if (castOp && canFoldIntoConsumerOp(castOp)) {
|
||||
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
|
||||
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
||||
operand.set(castOp.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
|
@ -776,10 +777,10 @@ struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
|
|||
/// - A constant value if the size is static along the dimension.
|
||||
/// - The dynamic value that defines the size of the result of
|
||||
/// `linalg.init_tensor` op.
|
||||
struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
struct ReplaceDimOfInitTensorOp : public OpRewritePattern<memref::DimOp> {
|
||||
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
|
||||
if (!initTensorOp)
|
||||
|
@ -986,7 +987,7 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
|
|||
assert(rankedTensorType.hasStaticShape());
|
||||
int rank = rankedTensorType.getRank();
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dimOp = builder.createOrFold<DimOp>(loc, source, i);
|
||||
auto dimOp = builder.createOrFold<memref::DimOp>(loc, source, i);
|
||||
auto resultDimSize = builder.createOrFold<ConstantIndexOp>(
|
||||
loc, rankedTensorType.getDimSize(i));
|
||||
auto highValue = builder.createOrFold<SubIOp>(loc, resultDimSize, dimOp);
|
||||
|
@ -1292,7 +1293,7 @@ getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
|
|||
AffineExpr expr;
|
||||
SmallVector<Value, 2> dynamicDims;
|
||||
for (auto dim : llvm::seq(startPos, endPos + 1)) {
|
||||
dynamicDims.push_back(builder.create<DimOp>(loc, src, dim));
|
||||
dynamicDims.push_back(builder.create<memref::DimOp>(loc, src, dim));
|
||||
AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
|
||||
expr = (expr ? expr * currExpr : currExpr);
|
||||
}
|
||||
|
@ -1361,7 +1362,7 @@ static Value getExpandedOutputDimFromInputShape(
|
|||
"dimensions");
|
||||
linearizedStaticDim *= d.value();
|
||||
}
|
||||
Value sourceDim = builder.create<DimOp>(loc, src, sourceDimPos);
|
||||
Value sourceDim = builder.create<memref::DimOp>(loc, src, sourceDimPos);
|
||||
return applyMapToValues(
|
||||
builder, loc,
|
||||
AffineMap::get(
|
||||
|
@ -1637,9 +1638,9 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
|
|||
};
|
||||
|
||||
/// Canonicalize dim ops that use the output shape with dim of the input.
|
||||
struct ReplaceDimOfReshapeOpResult : OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
|
||||
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dimValue = dimOp.memrefOrTensor();
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
|
@ -2445,24 +2446,25 @@ struct FoldTensorCastOp : public RewritePattern {
|
|||
}
|
||||
};
|
||||
|
||||
/// Replaces std.dim operations that use the result of a LinalgOp (on tensors)
|
||||
/// with std.dim operations that use one of the arguments. For example,
|
||||
/// Replaces memref.dim operations that use the result of a LinalgOp (on
|
||||
/// tensors) with memref.dim operations that use one of the arguments. For
|
||||
/// example,
|
||||
///
|
||||
/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
|
||||
/// %1 = dim %0, %c0
|
||||
/// %1 = memref.dim %0, %c0
|
||||
///
|
||||
/// with
|
||||
///
|
||||
/// %1 = dim %arg0, %c0
|
||||
/// %1 = memref.dim %arg0, %c0
|
||||
///
|
||||
/// where possible. With this the result of the `linalg.matmul` is not used in
|
||||
/// dim operations. If the value produced is replaced with another value (say by
|
||||
/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
|
||||
/// used in a dim op that would prevent the DCE of this op.
|
||||
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<memref::DimOp> {
|
||||
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value dimValue = dimOp.memrefOrTensor();
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
|
@ -2479,7 +2481,7 @@ struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
|
|||
if (!operandDimValue) {
|
||||
// Its always possible to replace using the corresponding `outs`
|
||||
// parameter.
|
||||
operandDimValue = rewriter.create<DimOp>(
|
||||
operandDimValue = rewriter.create<memref::DimOp>(
|
||||
dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
|
||||
}
|
||||
rewriter.replaceOp(dimOp, *operandDimValue);
|
||||
|
|
|
@ -25,8 +25,8 @@ using namespace ::mlir::linalg;
|
|||
|
||||
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
auto alloc =
|
||||
b.create<AllocOp>(loc, memrefType, getDynOperands(loc, memref, b));
|
||||
auto alloc = b.create<memref::AllocOp>(loc, memrefType,
|
||||
getDynOperands(loc, memref, b));
|
||||
b.create<linalg::CopyOp>(loc, memref, alloc);
|
||||
return alloc;
|
||||
}
|
||||
|
@ -60,17 +60,17 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
|
|||
continue;
|
||||
}
|
||||
|
||||
if (auto alloc = resultTensor.getDefiningOp<AllocOp>()) {
|
||||
if (auto alloc = resultTensor.getDefiningOp<memref::AllocOp>()) {
|
||||
resultBuffers.push_back(resultTensor);
|
||||
continue;
|
||||
}
|
||||
// Allocate buffers for statically-shaped results.
|
||||
if (memrefType.hasStaticShape()) {
|
||||
resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
|
||||
resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
|
||||
continue;
|
||||
}
|
||||
|
||||
resultBuffers.push_back(b.create<AllocOp>(
|
||||
resultBuffers.push_back(b.create<memref::AllocOp>(
|
||||
loc, memrefType, getDynOperands(loc, resultTensor, b)));
|
||||
}
|
||||
return success();
|
||||
|
@ -148,7 +148,7 @@ public:
|
|||
matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<AllocOp>(
|
||||
rewriter.replaceOpWithNewOp<memref::AllocOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
|
||||
adaptor.sizes());
|
||||
return success();
|
||||
|
@ -231,9 +231,9 @@ public:
|
|||
// op.sizes() capture exactly the dynamic alloc operands matching the
|
||||
// subviewMemRefType thanks to subview/subtensor canonicalization and
|
||||
// verification.
|
||||
Value alloc =
|
||||
rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
|
||||
Value subView = rewriter.create<SubViewOp>(
|
||||
Value alloc = rewriter.create<memref::AllocOp>(
|
||||
op.getLoc(), subviewMemRefType, op.sizes());
|
||||
Value subView = rewriter.create<memref::SubViewOp>(
|
||||
op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
|
||||
op.getMixedStrides());
|
||||
rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
|
||||
|
@ -243,8 +243,8 @@ public:
|
|||
};
|
||||
|
||||
/// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
|
||||
/// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
|
||||
/// tensor_to_memref and tensor_load are inserted automatically by the
|
||||
/// %t` to an buffer_cast + subview + copy + tensor_load pattern.
|
||||
/// buffer_cast and tensor_load are inserted automatically by the
|
||||
/// conversion infra:
|
||||
/// ```
|
||||
/// %sv = subview %dest [offsets][sizes][strides]
|
||||
|
@ -273,7 +273,7 @@ public:
|
|||
assert(destMemRef.getType().isa<MemRefType>());
|
||||
|
||||
// Take a subview to copy the small memref.
|
||||
Value subview = rewriter.create<SubViewOp>(
|
||||
Value subview = rewriter.create<memref::SubViewOp>(
|
||||
op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
|
||||
op.getMixedStrides());
|
||||
// Copy the small memref.
|
||||
|
@ -295,7 +295,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
|
|||
|
||||
// Mark all Standard operations legal.
|
||||
target.addLegalDialect<AffineDialect, math::MathDialect,
|
||||
StandardOpsDialect>();
|
||||
memref::MemRefDialect, StandardOpsDialect>();
|
||||
target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
|
||||
|
||||
// Mark all Linalg operations illegal as long as they work on tensors.
|
||||
|
|
|
@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRAnalysis
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRLinalgAnalysis
|
||||
MLIRLinalgEDSC
|
||||
MLIRLinalg
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -104,11 +106,12 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
|
|||
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
|
||||
inferShapeComponents(map, loopRanges, offsets, sizes, strides);
|
||||
Value shape = en.value();
|
||||
Value sub = shape.getType().isa<MemRefType>()
|
||||
? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
|
||||
.getResult()
|
||||
: b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
|
||||
.getResult();
|
||||
Value sub =
|
||||
shape.getType().isa<MemRefType>()
|
||||
? b.create<memref::SubViewOp>(loc, shape, offsets, sizes, strides)
|
||||
.getResult()
|
||||
: b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
|
||||
.getResult();
|
||||
clonedShapes.push_back(sub);
|
||||
}
|
||||
// Append the other operands.
|
||||
|
@ -177,8 +180,8 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
|||
// `ViewInterface`. The interface needs a `getOrCreateRanges` method which
|
||||
// currently returns a `linalg.range`. The fix here is to move this op to
|
||||
// `std` dialect and add the method to `ViewInterface`.
|
||||
if (fromSubViewOpOnly &&
|
||||
!isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
|
||||
if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
|
||||
en.value().getDefiningOp()))
|
||||
continue;
|
||||
|
||||
unsigned idx = en.index();
|
||||
|
@ -227,9 +230,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
|||
<< "existing LoopRange: " << loopRanges[i] << "\n");
|
||||
else {
|
||||
auto shapeDim = getShapeDefiningLoopRange(producer, i);
|
||||
loopRanges[i] = Range{std_constant_index(0),
|
||||
std_dim(shapeDim.shape, shapeDim.dimension),
|
||||
std_constant_index(1)};
|
||||
Value dim = memref_dim(shapeDim.shape, shapeDim.dimension);
|
||||
loopRanges[i] = Range{std_constant_index(0), dim, std_constant_index(1)};
|
||||
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
|
||||
}
|
||||
}
|
||||
|
@ -242,7 +244,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
|||
static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
|
||||
Value shapedOperand, unsigned dim) {
|
||||
Operation *shapeProducingOp = shapedOperand.getDefiningOp();
|
||||
if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
|
||||
if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
|
||||
return subViewOp.getOrCreateRanges(b, loc)[dim];
|
||||
if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
|
||||
return subTensorOp.getOrCreateRanges(b, loc)[dim];
|
||||
|
@ -425,7 +427,7 @@ mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
|
|||
|
||||
// Must be a subview or a slice to guarantee there are loops we can fuse
|
||||
// into.
|
||||
auto subView = consumerOpOperand.get().getDefiningOp<SubViewOp>();
|
||||
auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
|
||||
if (!subView) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
|
||||
return llvm::None;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -200,7 +201,7 @@ Value getPaddedInput(Value input, ArrayRef<Value> indices,
|
|||
conds.push_back(leftOutOfBound);
|
||||
else
|
||||
conds.push_back(conds.back() || leftOutOfBound);
|
||||
Value rightBound = std_dim(input, idx);
|
||||
Value rightBound = memref_dim(input, idx);
|
||||
conds.push_back(conds.back() || (sge(dim, rightBound)));
|
||||
|
||||
// When padding is involved, the indices will only be shifted to negative,
|
||||
|
@ -307,12 +308,12 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
|
|||
IndexedValueType F(convOp.filter()), O(convOp.output());
|
||||
|
||||
// Emit scalar form. Padded conv involves an affine.max in the memory access
|
||||
// which is not allowed by affine.load. Override to use an StdIndexedValue
|
||||
// which is not allowed by affine.load. Override to use an MemRefIndexedValue
|
||||
// when there is non-zero padding.
|
||||
if (hasPadding(convOp)) {
|
||||
Type type = convOp.input().getType().cast<MemRefType>().getElementType();
|
||||
Value padValue = std_constant(type, getPadValueAttr<ConvOp>(type));
|
||||
Value paddedInput = getPaddedInput<StdIndexedValue>(
|
||||
Value paddedInput = getPaddedInput<MemRefIndexedValue>(
|
||||
convOp.input(), imIdx,
|
||||
/* Only need to pad the window dimensions */
|
||||
{0, static_cast<int>(imIdx.size()) - 1}, padValue);
|
||||
|
@ -338,9 +339,9 @@ static Value getPoolingInput(PoolingOp op, ArrayRef<Value> inputIndices) {
|
|||
Type type =
|
||||
op.input().getType().template cast<MemRefType>().getElementType();
|
||||
Value padValue = std_constant(type, getPadValueAttr<PoolingOp>(type));
|
||||
return getPaddedInput<StdIndexedValue>(op.input(), inputIndices,
|
||||
/*Pad every dimension*/ {},
|
||||
padValue);
|
||||
return getPaddedInput<MemRefIndexedValue>(op.input(), inputIndices,
|
||||
/*Pad every dimension*/ {},
|
||||
padValue);
|
||||
}
|
||||
IndexedValueType input(op.input());
|
||||
return input(inputIndices);
|
||||
|
@ -546,7 +547,7 @@ static void lowerLinalgToLoopsImpl(FuncOp funcOp,
|
|||
MLIRContext *context = funcOp.getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
|
||||
DimOp::getCanonicalizationPatterns(patterns, context);
|
||||
memref::DimOp::getCanonicalizationPatterns(patterns, context);
|
||||
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.insert<FoldAffineOp>(context);
|
||||
// Just apply the patterns greedily.
|
||||
|
@ -593,12 +594,18 @@ struct FoldAffineOp : public RewritePattern {
|
|||
|
||||
struct LowerToAffineLoops
|
||||
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect>();
|
||||
}
|
||||
void runOnFunction() override {
|
||||
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector);
|
||||
}
|
||||
};
|
||||
|
||||
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect, scf::SCFDialect>();
|
||||
}
|
||||
void runOnFunction() override {
|
||||
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector);
|
||||
}
|
||||
|
|
|
@ -26,6 +26,10 @@ namespace scf {
|
|||
class SCFDialect;
|
||||
} // end namespace scf
|
||||
|
||||
namespace memref {
|
||||
class MemRefDialect;
|
||||
} // end namespace memref
|
||||
|
||||
namespace vector {
|
||||
class VectorDialect;
|
||||
} // end namespace vector
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -38,9 +39,9 @@ using llvm::MapVector;
|
|||
|
||||
using folded_affine_min = FoldedValueBuilder<AffineMinOp>;
|
||||
using folded_linalg_range = FoldedValueBuilder<linalg::RangeOp>;
|
||||
using folded_std_dim = FoldedValueBuilder<DimOp>;
|
||||
using folded_std_subview = FoldedValueBuilder<SubViewOp>;
|
||||
using folded_std_view = FoldedValueBuilder<ViewOp>;
|
||||
using folded_memref_dim = FoldedValueBuilder<memref::DimOp>;
|
||||
using folded_memref_subview = FoldedValueBuilder<memref::SubViewOp>;
|
||||
using folded_memref_view = FoldedValueBuilder<memref::ViewOp>;
|
||||
|
||||
#define DEBUG_TYPE "linalg-promotion"
|
||||
|
||||
|
@ -59,22 +60,22 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
|
|||
if (!dynamicBuffers)
|
||||
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
|
||||
return options.useAlloca
|
||||
? std_alloca(MemRefType::get(width * cst.getValue(),
|
||||
IntegerType::get(ctx, 8)),
|
||||
ValueRange{}, alignment_attr)
|
||||
? memref_alloca(MemRefType::get(width * cst.getValue(),
|
||||
IntegerType::get(ctx, 8)),
|
||||
ValueRange{}, alignment_attr)
|
||||
.value
|
||||
: std_alloc(MemRefType::get(width * cst.getValue(),
|
||||
IntegerType::get(ctx, 8)),
|
||||
ValueRange{}, alignment_attr)
|
||||
: memref_alloc(MemRefType::get(width * cst.getValue(),
|
||||
IntegerType::get(ctx, 8)),
|
||||
ValueRange{}, alignment_attr)
|
||||
.value;
|
||||
Value mul =
|
||||
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
|
||||
return options.useAlloca
|
||||
? std_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||
alignment_attr)
|
||||
? memref_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||
alignment_attr)
|
||||
.value
|
||||
: std_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||
alignment_attr)
|
||||
: memref_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||
alignment_attr)
|
||||
.value;
|
||||
}
|
||||
|
||||
|
@ -82,10 +83,12 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
|
|||
/// no call back to do so is provided. The default is to allocate a
|
||||
/// memref<..xi8> and return a view to get a memref type of shape
|
||||
/// boundingSubViewSize.
|
||||
static Optional<Value> defaultAllocBufferCallBack(
|
||||
const LinalgPromotionOptions &options, OpBuilder &builder,
|
||||
SubViewOp subView, ArrayRef<Value> boundingSubViewSize, bool dynamicBuffers,
|
||||
Optional<unsigned> alignment, OperationFolder *folder) {
|
||||
static Optional<Value>
|
||||
defaultAllocBufferCallBack(const LinalgPromotionOptions &options,
|
||||
OpBuilder &builder, memref::SubViewOp subView,
|
||||
ArrayRef<Value> boundingSubViewSize,
|
||||
bool dynamicBuffers, Optional<unsigned> alignment,
|
||||
OperationFolder *folder) {
|
||||
ShapedType viewType = subView.getType();
|
||||
int64_t rank = viewType.getRank();
|
||||
(void)rank;
|
||||
|
@ -100,7 +103,7 @@ static Optional<Value> defaultAllocBufferCallBack(
|
|||
dynamicBuffers, folder, alignment);
|
||||
SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
|
||||
ShapedType::kDynamicSize);
|
||||
Value view = folded_std_view(
|
||||
Value view = folded_memref_view(
|
||||
folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
|
||||
zero, boundingSubViewSize);
|
||||
return view;
|
||||
|
@ -112,10 +115,10 @@ static Optional<Value> defaultAllocBufferCallBack(
|
|||
static LogicalResult
|
||||
defaultDeallocBufferCallBack(const LinalgPromotionOptions &options,
|
||||
OpBuilder &b, Value fullLocalView) {
|
||||
auto viewOp = fullLocalView.getDefiningOp<ViewOp>();
|
||||
auto viewOp = fullLocalView.getDefiningOp<memref::ViewOp>();
|
||||
assert(viewOp && "expected full local view to be a ViewOp");
|
||||
if (!options.useAlloca)
|
||||
std_dealloc(viewOp.source());
|
||||
memref_dealloc(viewOp.source());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -161,21 +164,21 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
|
|||
if (options.operandsToPromote && !options.operandsToPromote->count(idx))
|
||||
continue;
|
||||
auto *op = linalgOp.getShapedOperand(idx).getDefiningOp();
|
||||
if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
|
||||
if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
|
||||
subViews[idx] = sv;
|
||||
useFullTileBuffers[sv] = vUseFullTileBuffers[idx];
|
||||
}
|
||||
}
|
||||
|
||||
allocationFn =
|
||||
(options.allocationFn ? *(options.allocationFn)
|
||||
: [&](OpBuilder &builder, SubViewOp subViewOp,
|
||||
ArrayRef<Value> boundingSubViewSize,
|
||||
OperationFolder *folder) -> Optional<Value> {
|
||||
return defaultAllocBufferCallBack(options, builder, subViewOp,
|
||||
boundingSubViewSize, dynamicBuffers,
|
||||
alignment, folder);
|
||||
});
|
||||
allocationFn = (options.allocationFn
|
||||
? *(options.allocationFn)
|
||||
: [&](OpBuilder &builder, memref::SubViewOp subViewOp,
|
||||
ArrayRef<Value> boundingSubViewSize,
|
||||
OperationFolder *folder) -> Optional<Value> {
|
||||
return defaultAllocBufferCallBack(options, builder, subViewOp,
|
||||
boundingSubViewSize, dynamicBuffers,
|
||||
alignment, folder);
|
||||
});
|
||||
deallocationFn =
|
||||
(options.deallocationFn
|
||||
? *(options.deallocationFn)
|
||||
|
@ -209,7 +212,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
|
|||
// boundary tiles. For now this is done with an unconditional `fill` op followed
|
||||
// by a partial `copy` op.
|
||||
Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
|
||||
OpBuilder &b, Location loc, SubViewOp subView,
|
||||
OpBuilder &b, Location loc, memref::SubViewOp subView,
|
||||
AllocBufferCallbackFn allocationFn, OperationFolder *folder) {
|
||||
ScopedContext scopedContext(b, loc);
|
||||
auto viewType = subView.getType();
|
||||
|
@ -227,7 +230,8 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
|
|||
(!sizeAttr) ? rangeValue.size : b.create<ConstantOp>(loc, sizeAttr);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
|
||||
fullSizes.push_back(size);
|
||||
partialSizes.push_back(folded_std_dim(folder, subView, en.index()).value);
|
||||
partialSizes.push_back(
|
||||
folded_memref_dim(folder, subView, en.index()).value);
|
||||
}
|
||||
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
|
||||
// If a callback is not specified, then use the default implementation for
|
||||
|
@ -238,7 +242,7 @@ Optional<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
|
|||
SmallVector<OpFoldResult, 4> zeros(fullSizes.size(), b.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult, 4> ones(fullSizes.size(), b.getIndexAttr(1));
|
||||
auto partialLocalView =
|
||||
folded_std_subview(folder, *fullLocalView, zeros, partialSizes, ones);
|
||||
folded_memref_subview(folder, *fullLocalView, zeros, partialSizes, ones);
|
||||
return PromotionInfo{*fullLocalView, partialLocalView};
|
||||
}
|
||||
|
||||
|
@ -253,7 +257,8 @@ promoteSubViews(OpBuilder &b, Location loc,
|
|||
MapVector<unsigned, PromotionInfo> promotionInfoMap;
|
||||
|
||||
for (auto v : options.subViews) {
|
||||
SubViewOp subView = cast<SubViewOp>(v.second.getDefiningOp());
|
||||
memref::SubViewOp subView =
|
||||
cast<memref::SubViewOp>(v.second.getDefiningOp());
|
||||
Optional<PromotionInfo> promotionInfo = promoteSubviewAsNewBuffer(
|
||||
b, loc, subView, options.allocationFn, folder);
|
||||
if (!promotionInfo)
|
||||
|
@ -277,8 +282,9 @@ promoteSubViews(OpBuilder &b, Location loc,
|
|||
auto info = promotionInfoMap.find(v.first);
|
||||
if (info == promotionInfoMap.end())
|
||||
continue;
|
||||
if (failed(options.copyInFn(b, cast<SubViewOp>(v.second.getDefiningOp()),
|
||||
info->second.partialLocalView)))
|
||||
if (failed(options.copyInFn(
|
||||
b, cast<memref::SubViewOp>(v.second.getDefiningOp()),
|
||||
info->second.partialLocalView)))
|
||||
return {};
|
||||
}
|
||||
return promotionInfoMap;
|
||||
|
@ -353,7 +359,7 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
|
|||
return failure();
|
||||
// Check that at least one of the requested operands is indeed a subview.
|
||||
for (auto en : llvm::enumerate(linOp.getShapedOperands())) {
|
||||
auto sv = isa_and_nonnull<SubViewOp>(en.value().getDefiningOp());
|
||||
auto sv = isa_and_nonnull<memref::SubViewOp>(en.value().getDefiningOp());
|
||||
if (sv) {
|
||||
if (!options.operandsToPromote.hasValue() ||
|
||||
options.operandsToPromote->count(en.index()))
|
||||
|
|
|
@ -44,11 +44,11 @@ class TensorFromPointerConverter
|
|||
};
|
||||
|
||||
/// Sparse conversion rule for dimension accesses.
|
||||
class TensorToDimSizeConverter : public OpConversionPattern<DimOp> {
|
||||
class TensorToDimSizeConverter : public OpConversionPattern<memref::DimOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(DimOp op, ArrayRef<Value> operands,
|
||||
matchAndRewrite(memref::DimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
|
||||
return failure();
|
||||
|
|
|
@ -533,13 +533,13 @@ static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
|
|||
// positions for the output tensor. Currently this results in functional,
|
||||
// but slightly imprecise IR, so it is put under an experimental option.
|
||||
if (codegen.options.fastOutput)
|
||||
return rewriter.create<TensorToMemrefOp>(loc, denseTp, tensor);
|
||||
return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
|
||||
// By default, a new buffer is allocated which is initialized to the
|
||||
// tensor defined in the outs() clause. This is always correct but
|
||||
// introduces a dense initialization component that may negatively
|
||||
// impact the running complexity of the sparse kernel.
|
||||
Value init = rewriter.create<TensorToMemrefOp>(loc, denseTp, tensor);
|
||||
Value alloc = rewriter.create<AllocOp>(loc, denseTp, args);
|
||||
Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
|
||||
Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
|
||||
rewriter.create<linalg::CopyOp>(loc, init, alloc);
|
||||
return alloc;
|
||||
}
|
||||
|
@ -585,8 +585,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
|
|||
}
|
||||
// Find lower and upper bound in current dimension.
|
||||
Value up;
|
||||
if (shape[d] == TensorType::kDynamicSize) {
|
||||
up = rewriter.create<DimOp>(loc, tensor, d);
|
||||
if (shape[d] == MemRefType::kDynamicSize) {
|
||||
up = rewriter.create<memref::DimOp>(loc, tensor, d);
|
||||
args.push_back(up);
|
||||
} else {
|
||||
up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
|
||||
|
@ -600,7 +600,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
|
|||
auto denseTp = MemRefType::get(shape, tensorType.getElementType());
|
||||
if (t < numInputs)
|
||||
codegen.buffers[t] =
|
||||
rewriter.create<TensorToMemrefOp>(loc, denseTp, tensor);
|
||||
rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
|
||||
else
|
||||
codegen.buffers[t] =
|
||||
genOutputBuffer(codegen, rewriter, op, denseTp, args);
|
||||
|
@ -716,7 +716,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
|
|||
Value ptr = codegen.buffers[tensor];
|
||||
if (codegen.curVecLength > 1)
|
||||
return genVectorLoad(codegen, rewriter, ptr, args);
|
||||
return rewriter.create<LoadOp>(loc, ptr, args);
|
||||
return rewriter.create<memref::LoadOp>(loc, ptr, args);
|
||||
}
|
||||
|
||||
/// Generates a store on a dense tensor.
|
||||
|
@ -744,7 +744,7 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
|
|||
if (codegen.curVecLength > 1)
|
||||
genVectorStore(codegen, rewriter, rhs, ptr, args);
|
||||
else
|
||||
rewriter.create<StoreOp>(loc, rhs, ptr, args);
|
||||
rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
|
||||
}
|
||||
|
||||
/// Generates a pointer/index load from the sparse storage scheme.
|
||||
|
@ -752,7 +752,7 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
|
|||
Value ptr, Value s) {
|
||||
if (codegen.curVecLength > 1)
|
||||
return genVectorLoad(codegen, rewriter, ptr, {s});
|
||||
Value load = rewriter.create<LoadOp>(loc, ptr, s);
|
||||
Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
|
||||
return load.getType().isa<IndexType>()
|
||||
? load
|
||||
: rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
|
||||
|
@ -1345,8 +1345,8 @@ public:
|
|||
CodeGen codegen(options, numTensors, numLoops);
|
||||
genBuffers(merger, codegen, rewriter, op);
|
||||
genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
|
||||
Value result =
|
||||
rewriter.create<TensorLoadOp>(op.getLoc(), codegen.buffers.back());
|
||||
Value result = rewriter.create<memref::TensorLoadOp>(
|
||||
op.getLoc(), codegen.buffers.back());
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/EDSC/Builders.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -34,7 +36,6 @@ using namespace mlir::edsc::intrinsics;
|
|||
using namespace mlir::linalg;
|
||||
using namespace mlir::scf;
|
||||
|
||||
|
||||
#define DEBUG_TYPE "linalg-tiling"
|
||||
|
||||
static bool isZero(Value v) {
|
||||
|
@ -144,9 +145,9 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
|
|||
// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
|
||||
// scf.for %k = %c0 to operand_dim_0 step %c10 {
|
||||
// scf.for %l = %c0 to operand_dim_1 step %c25 {
|
||||
// %4 = std.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
// %5 = std.subview %result[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1]
|
||||
// : memref<50x100xf32> to memref<?x?xf32, #strided>
|
||||
// linalg.indexed_generic pointwise_2d_trait %4, %5 {
|
||||
// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
|
||||
|
@ -262,7 +263,7 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||
for (unsigned r = 0; r < rank; ++r) {
|
||||
if (!isTiled(map.getSubMap({r}), tileSizes)) {
|
||||
offsets.push_back(b.getIndexAttr(0));
|
||||
sizes.push_back(std_dim(shapedOp, r).value);
|
||||
sizes.push_back(memref_dim(shapedOp, r).value);
|
||||
strides.push_back(b.getIndexAttr(1));
|
||||
continue;
|
||||
}
|
||||
|
@ -290,7 +291,7 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||
getAffineDimExpr(/*position=*/1, b.getContext()) -
|
||||
getAffineDimExpr(/*position=*/2, b.getContext())},
|
||||
b.getContext());
|
||||
auto d = std_dim(shapedOp, r);
|
||||
Value d = memref_dim(shapedOp, r);
|
||||
SmallVector<Value, 4> operands{size, d, offset};
|
||||
fullyComposeAffineMapAndOperands(&minMap, &operands);
|
||||
size = affine_min(b.getIndexType(), minMap, operands);
|
||||
|
@ -302,7 +303,7 @@ makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||
|
||||
if (shapedType.isa<MemRefType>())
|
||||
res.push_back(
|
||||
b.create<SubViewOp>(loc, shapedOp, offsets, sizes, strides));
|
||||
b.create<memref::SubViewOp>(loc, shapedOp, offsets, sizes, strides));
|
||||
else
|
||||
res.push_back(
|
||||
b.create<SubTensorOp>(loc, shapedOp, offsets, sizes, strides));
|
||||
|
@ -474,7 +475,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
|
|||
|
||||
if (!options.tileSizeComputationFunction)
|
||||
return llvm::None;
|
||||
|
||||
|
||||
// Enforce the convention that "tiling by zero" skips tiling a particular
|
||||
// dimension. This convention is significantly simpler to handle instead of
|
||||
// adjusting affine maps to account for missing dimensions.
|
||||
|
@ -564,9 +565,9 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
|
|||
scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
SubViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
ViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
CanonicalizationPatternList<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
||||
|
|
|
@ -212,7 +212,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
|
|||
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
|
||||
auto sizes = llvm::to_vector<4>(llvm::map_range(
|
||||
llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
|
||||
auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d);
|
||||
auto dimOp = rewriter.create<memref::DimOp>(loc, std::get<0>(it), d);
|
||||
newUsersOfOpToPad.insert(dimOp);
|
||||
return dimOp.getResult();
|
||||
}));
|
||||
|
|
|
@ -85,7 +85,7 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
|
|||
}
|
||||
|
||||
/// Build a vector.transfer_read from `source` at indices set to all `0`.
|
||||
/// If source has rank zero, build an std.load.
|
||||
/// If source has rank zero, build an memref.load.
|
||||
/// Return the produced value.
|
||||
static Value buildVectorRead(OpBuilder &builder, Value source) {
|
||||
edsc::ScopedContext scope(builder);
|
||||
|
@ -94,11 +94,11 @@ static Value buildVectorRead(OpBuilder &builder, Value source) {
|
|||
SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
|
||||
return vector_transfer_read(vectorType, source, indices);
|
||||
}
|
||||
return std_load(source);
|
||||
return memref_load(source);
|
||||
}
|
||||
|
||||
/// Build a vector.transfer_write of `value` into `dest` at indices set to all
|
||||
/// `0`. If `dest` has null rank, build an std.store.
|
||||
/// `0`. If `dest` has null rank, build an memref.store.
|
||||
/// Return the produced value or null if no value is produced.
|
||||
static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
|
||||
edsc::ScopedContext scope(builder);
|
||||
|
@ -110,7 +110,7 @@ static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
|
|||
value = vector_broadcast(vectorType, value);
|
||||
write = vector_transfer_write(value, dest, indices);
|
||||
} else {
|
||||
write = std_store(value, dest);
|
||||
write = memref_store(value, dest);
|
||||
}
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
|
||||
if (!write->getResults().empty())
|
||||
|
@ -544,7 +544,7 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
|
|||
rewriter.getAffineMapArrayAttr(indexingMaps),
|
||||
rewriter.getStrArrayAttr(iteratorTypes));
|
||||
|
||||
rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
|
||||
rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
@ -667,12 +667,12 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
|||
}
|
||||
|
||||
/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
|
||||
static SubViewOp getSubViewUseIfUnique(Value v) {
|
||||
SubViewOp subViewOp;
|
||||
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
|
||||
memref::SubViewOp subViewOp;
|
||||
for (auto &u : v.getUses()) {
|
||||
if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
|
||||
if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
|
||||
if (subViewOp)
|
||||
return SubViewOp();
|
||||
return memref::SubViewOp();
|
||||
subViewOp = newSubViewOp;
|
||||
}
|
||||
}
|
||||
|
@ -686,14 +686,14 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
|
||||
// Transfer into `view`.
|
||||
Value viewOrAlloc = xferOp.source();
|
||||
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<AllocOp>())
|
||||
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
||||
return failure();
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
|
||||
|
||||
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
||||
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
||||
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
||||
if (!subViewOp)
|
||||
return failure();
|
||||
Value subView = subViewOp.getResult();
|
||||
|
@ -765,12 +765,12 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|||
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
|
||||
// Transfer into `viewOrAlloc`.
|
||||
Value viewOrAlloc = xferOp.source();
|
||||
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<AllocOp>())
|
||||
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
||||
return failure();
|
||||
|
||||
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
||||
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
||||
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
||||
if (!subViewOp)
|
||||
return failure();
|
||||
Value subView = subViewOp.getResult();
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(IR)
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_dialect_library(MLIRMemRef
|
||||
MemRefDialect.cpp
|
||||
MemRefOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
|
||||
|
||||
DEPENDS
|
||||
MLIRMemRefOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDialect
|
||||
MLIRIR
|
||||
)
|
|
@ -0,0 +1,39 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::memref;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefDialect Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct MemRefInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void mlir::memref::MemRefDialect::initialize() {
|
||||
addOperations<DmaStartOp, DmaWaitOp,
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
|
||||
>();
|
||||
addInterfaces<MemRefInlinerInterface>();
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRSCF
|
|||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLoopLikeInterface
|
||||
MLIRMemRef
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRStandard
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -568,7 +569,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
|
|||
/// %t0 = ... : tensor_type
|
||||
/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
|
||||
/// ...
|
||||
/// // %m is either tensor_to_memref(%bb00) or defined above the loop
|
||||
/// // %m is either buffer_cast(%bb00) or defined above the loop
|
||||
/// %m... : memref_type
|
||||
/// ... // uses of %m with potential inplace updates
|
||||
/// %new_tensor = tensor_load %m : memref_type
|
||||
|
@ -578,7 +579,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
|
|||
/// ```
|
||||
///
|
||||
/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
|
||||
/// `%m = tensor_to_memref %bb0` op that feeds into the yielded `tensor_load`
|
||||
/// `%m = buffer_cast %bb0` op that feeds into the yielded `tensor_load`
|
||||
/// op.
|
||||
///
|
||||
/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
|
||||
|
@ -590,7 +591,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
|
|||
///
|
||||
/// The canonicalization rewrites the pattern as:
|
||||
/// ```
|
||||
/// // %m is either a tensor_to_memref or defined above
|
||||
/// // %m is either a buffer_cast or defined above
|
||||
/// %m... : memref_type
|
||||
/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
|
||||
/// ... // uses of %m with potential inplace updates
|
||||
|
@ -601,7 +602,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
|
|||
///
|
||||
/// A later bbArg canonicalization will further rewrite as:
|
||||
/// ```
|
||||
/// // %m is either a tensor_to_memref or defined above
|
||||
/// // %m is either a buffer_cast or defined above
|
||||
/// %m... : memref_type
|
||||
/// scf.for ... { // no iter_args
|
||||
/// ... // uses of %m with potential inplace updates
|
||||
|
@ -622,19 +623,18 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
|
|||
unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
|
||||
Value yieldVal = yieldOp->getOperand(idx);
|
||||
auto tensorLoadOp = yieldVal.getDefiningOp<TensorLoadOp>();
|
||||
auto tensorLoadOp = yieldVal.getDefiningOp<memref::TensorLoadOp>();
|
||||
bool isTensor = bbArg.getType().isa<TensorType>();
|
||||
|
||||
TensorToMemrefOp tensorToMemRefOp;
|
||||
// Either bbArg has no use or it has a single tensor_to_memref use.
|
||||
memref::BufferCastOp bufferCastOp;
|
||||
// Either bbArg has no use or it has a single buffer_cast use.
|
||||
if (bbArg.hasOneUse())
|
||||
tensorToMemRefOp =
|
||||
dyn_cast<TensorToMemrefOp>(*bbArg.getUsers().begin());
|
||||
if (!isTensor || !tensorLoadOp ||
|
||||
(!bbArg.use_empty() && !tensorToMemRefOp))
|
||||
bufferCastOp =
|
||||
dyn_cast<memref::BufferCastOp>(*bbArg.getUsers().begin());
|
||||
if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !bufferCastOp))
|
||||
continue;
|
||||
// If tensorToMemRefOp is present, it must feed into the `tensorLoadOp`.
|
||||
if (tensorToMemRefOp && tensorLoadOp.memref() != tensorToMemRefOp)
|
||||
// If bufferCastOp is present, it must feed into the `tensorLoadOp`.
|
||||
if (bufferCastOp && tensorLoadOp.memref() != bufferCastOp)
|
||||
continue;
|
||||
// TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
|
||||
// must be before `tensorLoadOp` in the block so that the lastWrite
|
||||
|
@ -644,18 +644,18 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
|
|||
if (tensorLoadOp->getNextNode() != yieldOp)
|
||||
continue;
|
||||
|
||||
// Clone the optional tensorToMemRefOp before forOp.
|
||||
if (tensorToMemRefOp) {
|
||||
// Clone the optional bufferCastOp before forOp.
|
||||
if (bufferCastOp) {
|
||||
rewriter.setInsertionPoint(forOp);
|
||||
rewriter.replaceOpWithNewOp<TensorToMemrefOp>(
|
||||
tensorToMemRefOp, tensorToMemRefOp.memref().getType(),
|
||||
tensorToMemRefOp.tensor());
|
||||
rewriter.replaceOpWithNewOp<memref::BufferCastOp>(
|
||||
bufferCastOp, bufferCastOp.memref().getType(),
|
||||
bufferCastOp.tensor());
|
||||
}
|
||||
|
||||
// Clone the tensorLoad after forOp.
|
||||
rewriter.setInsertionPointAfter(forOp);
|
||||
Value newTensorLoad =
|
||||
rewriter.create<TensorLoadOp>(loc, tensorLoadOp.memref());
|
||||
rewriter.create<memref::TensorLoadOp>(loc, tensorLoadOp.memref());
|
||||
Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
|
||||
replacements.insert(std::make_pair(forOpResult, newTensorLoad));
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue