Introduce memref replacement/rewrite support: to replace an existing memref

with a new one (of a potentially different rank/shape) with an optional index
remapping.

- introduce Utils::replaceAllMemRefUsesWith
- use this for DMA double buffering

(This CL also adds a few temporary utilities / code that will be done away with
once:
1) abstract DMA op's are added
2) memref deferencing side-effect / trait is available on op's
3) b/117159533 is resolved (memref index computation slices).
PiperOrigin-RevId: 215831373
This commit is contained in:
Uday Bondhugula 2018-10-04 17:15:30 -07:00 committed by jpienaar
parent b55b407601
commit 6cfdb756b1
6 changed files with 519 additions and 87 deletions

View File

@ -343,15 +343,21 @@ public:
return MLFuncBuilder(forStmt, forStmt->end());
}
/// Get the current insertion point of the builder.
/// Returns the current insertion point of the builder.
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
/// Get the current block of the builder.
/// Returns the current block of the builder.
StmtBlock *getBlock() const { return block; }
/// Create an operation given the fields represented as an OperationState.
/// Creates an operation given the fields represented as an OperationState.
OperationStmt *createOperation(const OperationState &state);
/// Creates an operation given the fields.
OperationStmt *createOperation(Location *location, Identifier name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> types,
ArrayRef<NamedAttribute> attrs);
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Location *location, Args... args) {

View File

@ -0,0 +1,49 @@
//===- Utils.h - General transformation utilities ---------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes for various transformation utilities for
// memref's and non-loop IR structures. These are not passes by themselves but
// are used either by passes, optimization sequences, or in turn by other
// transformation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_UTILS_H
#define MLIR_TRANSFORMS_UTILS_H
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class AffineMap;
class MLValue;
class SSAValue;
/// Replace all uses of oldMemRef with newMemRef while optionally remapping the
/// old memref's indices using the supplied affine map and adding any additional
/// indices. The new memref could be of a different shape or rank. Returns true
/// on success and false if the replacement is not possible (whenever a memref
/// is used as an operand in a non-deferencing scenario).
/// Additional indices are added at the start.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position.
bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
llvm::ArrayRef<SSAValue *> extraIndices,
AffineMap *indexRemap = nullptr);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_UTILS_H

View File

@ -312,6 +312,18 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
return op;
}
/// Create an operation given the fields.
OperationStmt *MLFuncBuilder::createOperation(Location *location,
Identifier name,
ArrayRef<MLValue *> operands,
ArrayRef<Type *> types,
ArrayRef<NamedAttribute> attrs) {
auto *op = OperationStmt::create(location, name, operands, types, attrs,
getContext());
block->getStatements().insert(insertPoint, op);
return op;
}
ForStmt *MLFuncBuilder::createFor(Location *location,
ArrayRef<MLValue *> lbOperands,
AffineMap *lbMap,

View File

@ -21,10 +21,13 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
@ -43,27 +46,237 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
// For testing purposes, this just runs on the first statement of the MLFunction
// if that statement is a for stmt, and shifts the second half of its body by
// one.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
// op traits for it are added. TODO(b/117228571)
static bool isDmaStartStmt(const OperationStmt &stmt) {
return stmt.getName().strref().contains("dma.in.start") ||
stmt.getName().strref().contains("dma.out.start");
}
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
static bool isDmaFinishStmt(const OperationStmt &stmt) {
return stmt.getName().strref().contains("dma.finish");
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
assert(isDmaStartStmt(dmaStartStmt));
unsigned srcDmaPos = 0;
unsigned destDmaPos =
cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
->getMemorySpace() >
cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
->getMemorySpace())
return srcDmaPos;
return destDmaPos;
}
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
if (isDmaStartStmt(dmaStmt)) {
// Second to last operand.
return dmaStmt.getNumOperands() - 2;
}
// First operand for a dma finish statement.
return 0;
}
/// Doubles the buffer of the supplied memref.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
// Add the leading dimension in the shape for the double buffer.
ArrayRef<int> shape = origMemRefType->getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2);
auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
return newMemRefType;
};
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
// Create and place the alloc at the top level.
auto *func = forStmt->getFunction();
MLFuncBuilder topBuilder(func, func->begin());
auto *newMemRef = cast<MLValue>(
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
->getResult());
auto d0 = bInner.getDimExpr(0);
auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0)))
return false;
// We don't need ivMod2Op any more - this is cloned by
// replaceAllMemRefUsesWith wherever the memref replacement happens. Once
// b/117159533 is addressed, we'll eventually only need to pass
// ivModTwoOp->getResult(0) to replaceAllMemRefUsesWith.
cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
return true;
}
// For testing purposes, this just runs on the first for statement of an
// MLFunction at the top level.
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
// the other TODOs listed inside are dealt with.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
if (f->empty())
return PassResult::Success;
auto *forStmt = dyn_cast<ForStmt>(&f->front());
ForStmt *forStmt = nullptr;
for (auto &stmt : *f) {
if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
break;
}
}
if (!forStmt)
return PassResult::Failure;
return PassResult::Success;
unsigned numStmts = forStmt->getStatements().size();
if (numStmts == 0)
return PassResult::Success;
std::vector<uint64_t> delays(numStmts);
for (unsigned i = 0; i < numStmts; i++)
delays[i] = (i < numStmts / 2) ? 0 : 1;
SmallVector<OperationStmt *, 4> dmaStartStmts;
SmallVector<OperationStmt *, 4> dmaFinishStmts;
for (auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
if (isDmaStartStmt(*opStmt)) {
dmaStartStmts.push_back(opStmt);
} else if (isDmaFinishStmt(*opStmt)) {
dmaFinishStmts.push_back(opStmt);
}
}
if (!checkDominancePreservationOnShift(*forStmt, delays))
// TODO(bondhugula,andydavis): match tag memref's (requires memory-based
// subscript check utilities). Assume for now that start/finish are matched in
// the order they appear.
if (dmaStartStmts.size() != dmaFinishStmts.size())
return PassResult::Failure;
// Double the buffers for the higher memory space memref's.
// TODO(bondhugula): assuming we don't have multiple DMA starts for the same
// memref.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
// Identify memref's to replace by scanning through all DMA start statements.
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
for (auto *dmaStartStmt : dmaStartStmts) {
MLValue *oldMemRef = cast<MLValue>(
dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
if (!doubleBuffer(oldMemRef, forStmt))
return PassResult::Failure;
}
// Double the buffers for tag memref's.
for (auto *dmaFinishStmt : dmaFinishStmts) {
MLValue *oldTagMemRef = cast<MLValue>(
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
if (!doubleBuffer(oldTagMemRef, forStmt))
return PassResult::Failure;
}
// Collect all compute ops.
std::vector<const Statement *> computeOps;
computeOps.reserve(forStmt->getStatements().size());
// Store delay for statement for later lookup for AffineApplyOp's.
DenseMap<const Statement *, unsigned> opDelayMap;
for (const auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt) {
// All for and if stmt's are treated as pure compute operations.
// TODO(bondhugula): check whether such statements do not have any DMAs
// nested within.
opDelayMap[&stmt] = 1;
} else if (isDmaStartStmt(*opStmt)) {
// DMA starts are not shifted.
opDelayMap[&stmt] = 0;
} else if (isDmaFinishStmt(*opStmt)) {
// DMA finish op shifted by one.
opDelayMap[&stmt] = 1;
} else if (!opStmt->is<AffineApplyOp>()) {
// Compute op shifted by one.
opDelayMap[&stmt] = 1;
computeOps.push_back(&stmt);
}
// Shifts for affine apply op's determined later.
}
// Get the ancestor of a 'stmt' that lies in forStmt's block.
auto getAncestorInForBlock =
[&](const Statement *stmt, const StmtBlock &block) -> const Statement * {
// Traverse up the statement hierarchy starting from the owner of operand to
// find the ancestor statement that resides in the block of 'forStmt'.
while (stmt != nullptr && stmt->getBlock() != &block) {
stmt = stmt->getParentStmt();
}
return stmt;
};
// Determine delays for affine apply op's: look up delay from its consumer op.
// This code will be thrown away once we have a way to obtain indices through
// a composed affine_apply op. See TODO(b/117159533). Such a composed
// affine_apply will be used exclusively by a given memref deferencing op.
for (const auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
// Skip statements that aren't affine apply ops.
if (!opStmt || !opStmt->is<AffineApplyOp>())
continue;
// Traverse uses of each result of the affine apply op.
for (auto *res : opStmt->getResults()) {
for (auto &use : res->getUses()) {
auto *ancestorInForBlock =
getAncestorInForBlock(use.getOwner(), *forStmt);
assert(ancestorInForBlock &&
"traversing parent should reach forStmt block");
auto *opCheck = dyn_cast<OperationStmt>(ancestorInForBlock);
if (!opCheck || opCheck->is<AffineApplyOp>())
continue;
assert(opDelayMap.find(ancestorInForBlock) != opDelayMap.end());
if (opDelayMap.find(&stmt) != opDelayMap.end()) {
// This is where we enforce all uses of this affine_apply to have
// the same shifts - so that we know what shift to use for the
// affine_apply to preserve semantics.
assert(opDelayMap[&stmt] == opDelayMap[ancestorInForBlock]);
} else {
// Obtain delay from its consumer.
opDelayMap[&stmt] = opDelayMap[ancestorInForBlock];
}
}
}
}
// Get delays stored in map.
std::vector<uint64_t> delays(forStmt->getStatements().size());
unsigned s = 0;
for (const auto &stmt : *forStmt) {
delays[s++] = opDelayMap[&stmt];
}
if (!checkDominancePreservationOnShift(*forStmt, delays)) {
// Violates SSA dominance.
return PassResult::Failure;
}
if (stmtBodySkew(forStmt, delays))
return PassResult::Failure;

View File

@ -0,0 +1,165 @@
//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements miscellaneous transformation routines for non-loop IR
// structures.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Utils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
/// Return true if this operation dereferences one or more memref's.
// Temporary utility: will be replaced when this is modeled through
// side-effects/op traits. TODO(b/117228571)
static bool isMemRefDereferencingOp(const Operation &op) {
if (op.is<LoadOp>() || op.is<StoreOp>() ||
op.getName().strref().contains("dma.in.start") ||
op.getName().strref().contains("dma.out.start") ||
op.getName().strref().contains("dma.finish")) {
return true;
}
return false;
}
/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
/// old memref's indices to the new memref using the supplied affine map
/// and adding any additional indices. The new memref could be of a different
/// shape or rank, but of the same elemental type. Additional indices are added
/// at the start for now.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position.
bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
ArrayRef<SSAValue *> extraIndices,
AffineMap *indexRemap) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
if (indexRemap) {
assert(indexRemap->getNumInputs() == oldMemRefRank);
assert(indexRemap->getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
}
// Assert same elemental type.
assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
cast<MemRefType>(newMemRef->getType())->getElementType());
// Check if memref was used in a non-deferencing context.
for (const StmtOperand &use : oldMemRef->getUses()) {
auto *opStmt = cast<OperationStmt>(use.getOwner());
// Failure: memref used in a non-deferencing op (potentially escapes); no
// replacement in these cases.
if (!isMemRefDereferencingOp(*opStmt))
return false;
}
// Walk all uses of old memref. Statement using the memref gets replaced.
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
StmtOperand &use = *(it++);
auto *opStmt = cast<OperationStmt>(use.getOwner());
assert(isMemRefDereferencingOp(*opStmt) &&
"memref deferencing op expected");
auto getMemRefOperandPos = [&]() -> unsigned {
unsigned i;
for (i = 0; i < opStmt->getNumOperands(); i++) {
if (opStmt->getOperand(i) == oldMemRef)
break;
}
assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
return i;
};
unsigned memRefOperandPos = getMemRefOperandPos();
// Construct the new operation statement using this memref.
SmallVector<MLValue *, 8> operands;
operands.reserve(opStmt->getNumOperands() + extraIndices.size());
// Insert the non-memref operands.
operands.insert(operands.end(), opStmt->operand_begin(),
opStmt->operand_begin() + memRefOperandPos);
operands.push_back(newMemRef);
MLFuncBuilder builder(opStmt);
// Normally, we could just use extraIndices as operands, but we will
// clone it so that each op gets its own "private" index. See b/117159533.
for (auto *extraIndex : extraIndices) {
OperationStmt::OperandMapTy operandMap;
// TODO(mlir-team): An operation/SSA value should provide a method to
// return the position of an SSA result in its defining
// operation.
assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
"single result op's expected to generate these indices");
// TODO: actually check if this is a result of an affine_apply op.
assert((cast<MLValue>(extraIndex)->isValidDim() ||
cast<MLValue>(extraIndex)->isValidSymbol()) &&
"invalid memory op index");
auto *clonedExtraIndex =
cast<OperationStmt>(
builder.clone(*extraIndex->getDefiningStmt(), operandMap))
->getResult(0);
operands.push_back(cast<MLValue>(clonedExtraIndex));
}
// Construct new indices. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1.
SmallVector<SSAValue *, 4> indices(
opStmt->operand_begin() + memRefOperandPos + 1,
opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
if (indexRemap) {
auto remapOp =
builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
// Remapped indices.
for (auto *index : remapOp->getOperation()->getResults())
operands.push_back(cast<MLValue>(index));
} else {
// No remapping specified.
for (auto *index : indices)
operands.push_back(cast<MLValue>(index));
}
// Insert the remaining operands unmodified.
operands.insert(operands.end(),
opStmt->operand_begin() + memRefOperandPos + 1 +
oldMemRefRank,
opStmt->operand_end());
// Result types don't change. Both memref's are of the same elemental type.
SmallVector<Type *, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (const auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
// Create the new operation.
auto *repOp =
builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
resultTypes, opStmt->getAttrs());
// Replace old memref's deferencing op's uses.
unsigned r = 0;
for (auto *res : opStmt->getResults()) {
res->replaceAllUsesWith(repOp->getResult(r++));
}
opStmt->eraseFromBlock();
}
return true;
}

View File

@ -1,79 +1,66 @@
// RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s
// CHECK-LABEL: mlfunc @loop_nest_simple() {
// CHECK: %c8 = constant 8 : affineint
// CHECK-NEXT: %c0 = constant 0 : affineint
// CHECK-NEXT: %0 = "foo"(%c0) : (affineint) -> affineint
// CHECK-NEXT: for %i0 = 1 to 7 {
// CHECK-NEXT: %1 = "foo"(%i0) : (affineint) -> affineint
// CHECK-NEXT: %2 = affine_apply #map0(%i0)
// CHECK-NEXT: %3 = "bar"(%2) : (affineint) -> affineint
// CHECK-NEXT: }
// CHECK-NEXT: %4 = affine_apply #map0(%c8)
// CHECK-NEXT: %5 = "bar"(%4) : (affineint) -> affineint
// CHECK-NEXT: return
mlfunc @loop_nest_simple() {
for %i = 0 to 7 {
%y = "foo"(%i) : (affineint) -> affineint
%x = "bar"(%i) : (affineint) -> affineint
}
return
}
// CHECK-LABEL: mlfunc @loop_nest_dma() {
// CHECK: %c8 = constant 8 : affineint
// CHECK-NEXT: %c0 = constant 0 : affineint
// CHECK-NEXT: %0 = affine_apply #map1(%c0)
// CHECK-NEXT: %1 = "dma.enqueue"(%0) : (affineint) -> affineint
// CHECK-NEXT: %2 = "dma.enqueue"(%0) : (affineint) -> affineint
// CHECK-NEXT: for %i0 = 1 to 7 {
// CHECK-NEXT: %3 = affine_apply #map1(%i0)
// CHECK-NEXT: %4 = "dma.enqueue"(%3) : (affineint) -> affineint
// CHECK-NEXT: %5 = "dma.enqueue"(%3) : (affineint) -> affineint
// CHECK-NEXT: %6 = affine_apply #map0(%i0)
// CHECK-NEXT: %7 = affine_apply #map1(%6)
// CHECK-NEXT: %8 = "dma.wait"(%7) : (affineint) -> affineint
// CHECK-NEXT: %9 = "compute1"(%7) : (affineint) -> affineint
// CHECK-NEXT: }
// CHECK-NEXT: %10 = affine_apply #map0(%c8)
// CHECK-NEXT: %11 = affine_apply #map1(%10)
// CHECK-NEXT: %12 = "dma.wait"(%11) : (affineint) -> affineint
// CHECK-NEXT: %13 = "compute1"(%11) : (affineint) -> affineint
// CHECK-NEXT: return
// CHECK: #map0 = (d0) -> (d0 mod 2)
// CHECK-NEXT: #map1 = (d0) -> (d0 - 1)
// CHECK-NEXT: mlfunc @loop_nest_dma() {
// CHECK-NEXT: %c8 = constant 8 : affineint
// CHECK-NEXT: %c0 = constant 0 : affineint
// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32>
// CHECK-NEXT: %2 = alloc() : memref<256xf32, (d0) -> (d0)>
// CHECK-NEXT: %3 = alloc() : memref<32xf32, (d0) -> (d0), 1>
// CHECK-NEXT: %4 = alloc() : memref<1xf32>
// CHECK-NEXT: %c0_0 = constant 0 : affineint
// CHECK-NEXT: %c128 = constant 128 : affineint
// CHECK-NEXT: %5 = affine_apply #map0(%c0)
// CHECK-NEXT: %6 = affine_apply #map0(%c0)
// CHECK-NEXT: "dma.in.start"(%2, %c0, %1, %5, %c0, %c128, %0, %6, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> ()
// CHECK-NEXT: for %i0 = 1 to 7 {
// CHECK-NEXT: %7 = affine_apply #map0(%i0)
// CHECK-NEXT: %8 = affine_apply #map0(%i0)
// CHECK-NEXT: "dma.in.start"(%2, %i0, %1, %7, %i0, %c128, %0, %8, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> ()
// CHECK-NEXT: %9 = affine_apply #map1(%i0)
// CHECK-NEXT: %10 = affine_apply #map0(%9)
// CHECK-NEXT: %11 = "dma.finish"(%0, %10, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint
// CHECK-NEXT: %12 = affine_apply #map0(%9)
// CHECK-NEXT: %13 = load %1[%12, %9] : memref<2x32xf32>
// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32
// CHECK-NEXT: %15 = affine_apply #map0(%9)
// CHECK-NEXT: store %14, %1[%15, %9] : memref<2x32xf32>
// CHECK-NEXT: for %i1 = 0 to 127 {
// CHECK-NEXT: "do_more_compute"(%9, %i1) : (affineint, affineint) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: %16 = affine_apply #map1(%c8)
// CHECK-NEXT: %17 = affine_apply #map0(%16)
// CHECK-NEXT: %18 = "dma.finish"(%0, %17, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint
// CHECK-NEXT: %19 = affine_apply #map0(%16)
// CHECK-NEXT: %20 = load %1[%19, %16] : memref<2x32xf32>
// CHECK-NEXT: %21 = "compute"(%20) : (f32) -> f32
// CHECK-NEXT: %22 = affine_apply #map0(%16)
// CHECK-NEXT: store %21, %1[%22, %16] : memref<2x32xf32>
// CHECK-NEXT: for %i2 = 0 to 127 {
// CHECK-NEXT: "do_more_compute"(%16, %i2) : (affineint, affineint) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
mlfunc @loop_nest_dma() {
for %i = 0 to 7 {
%pingpong = affine_apply (d0) -> (d0 mod 2) (%i)
"dma.enqueue"(%pingpong) : (affineint) -> affineint
"dma.enqueue"(%pingpong) : (affineint) -> affineint
%pongping = affine_apply (d0) -> (d0 mod 2) (%i)
"dma.wait"(%pongping) : (affineint) -> affineint
"compute1"(%pongping) : (affineint) -> affineint
}
return
}
// CHECK-LABEL: mlfunc @loop_nest_bound_map(%arg0 : affineint) {
// CHECK: %0 = affine_apply #map2()[%arg0]
// CHECK-NEXT: %1 = "foo"(%0) : (affineint) -> affineint
// CHECK-NEXT: %2 = "bar"(%0) : (affineint) -> affineint
// CHECK-NEXT: for %i0 = #map3()[%arg0] to #map4()[%arg0] {
// CHECK-NEXT: %3 = "foo"(%i0) : (affineint) -> affineint
// CHECK-NEXT: %4 = "bar"(%i0) : (affineint) -> affineint
// CHECK-NEXT: %5 = affine_apply #map0(%i0)
// CHECK-NEXT: %6 = "foo_bar"(%5) : (affineint) -> affineint
// CHECK-NEXT: %7 = "bar_foo"(%5) : (affineint) -> affineint
// CHECK-NEXT: }
// CHECK-NEXT: %8 = affine_apply #map5()[%arg0]
// CHECK-NEXT: %9 = affine_apply #map0(%8)
// CHECK-NEXT: %10 = "foo_bar"(%9) : (affineint) -> affineint
// CHECK-NEXT: %11 = "bar_foo"(%9) : (affineint) -> affineint
// CHECK-NEXT: return
mlfunc @loop_nest_bound_map(%N : affineint) {
for %i = %N to ()[s0] -> (s0 + 7)()[%N] {
"foo"(%i) : (affineint) -> affineint
"bar"(%i) : (affineint) -> affineint
"foo_bar"(%i) : (affineint) -> (affineint)
"bar_foo"(%i) : (affineint) -> (affineint)
%A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
%Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
%tag = alloc() : memref<1 x f32>
%zero = constant 0 : affineint
%size = constant 128 : affineint
for %i = 0 to 7 {
"dma.in.start"(%A, %i, %Ah, %i, %size, %tag, %zero) : (memref<256 x f32, (d0)->(d0), 0>, affineint, memref<32 x f32, (d0)->(d0), 1>, affineint, affineint, memref<1 x f32>, affineint) -> ()
"dma.finish"(%tag, %zero) : (memref<1 x f32>, affineint) -> affineint
%v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
%r = "compute"(%v) : (f32) -> (f32)
store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
for %j = 0 to 127 {
"do_more_compute"(%i, %j) : (affineint, affineint) -> ()
}
}
return
}