forked from OSchip/llvm-project
Introduce memref replacement/rewrite support: to replace an existing memref
with a new one (of a potentially different rank/shape) with an optional index remapping. - introduce Utils::replaceAllMemRefUsesWith - use this for DMA double buffering (This CL also adds a few temporary utilities / code that will be done away with once: 1) abstract DMA op's are added 2) memref deferencing side-effect / trait is available on op's 3) b/117159533 is resolved (memref index computation slices). PiperOrigin-RevId: 215831373
This commit is contained in:
parent
b55b407601
commit
6cfdb756b1
|
@ -343,15 +343,21 @@ public:
|
|||
return MLFuncBuilder(forStmt, forStmt->end());
|
||||
}
|
||||
|
||||
/// Get the current insertion point of the builder.
|
||||
/// Returns the current insertion point of the builder.
|
||||
StmtBlock::iterator getInsertionPoint() const { return insertPoint; }
|
||||
|
||||
/// Get the current block of the builder.
|
||||
/// Returns the current block of the builder.
|
||||
StmtBlock *getBlock() const { return block; }
|
||||
|
||||
/// Create an operation given the fields represented as an OperationState.
|
||||
/// Creates an operation given the fields represented as an OperationState.
|
||||
OperationStmt *createOperation(const OperationState &state);
|
||||
|
||||
/// Creates an operation given the fields.
|
||||
OperationStmt *createOperation(Location *location, Identifier name,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Type *> types,
|
||||
ArrayRef<NamedAttribute> attrs);
|
||||
|
||||
/// Create operation of specific op type at the current insertion point.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpPointer<OpTy> create(Location *location, Args... args) {
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
//===- Utils.h - General transformation utilities ---------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This header file defines prototypes for various transformation utilities for
|
||||
// memref's and non-loop IR structures. These are not passes by themselves but
|
||||
// are used either by passes, optimization sequences, or in turn by other
|
||||
// transformation utilities.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TRANSFORMS_UTILS_H
|
||||
#define MLIR_TRANSFORMS_UTILS_H
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineMap;
|
||||
class MLValue;
|
||||
class SSAValue;
|
||||
|
||||
/// Replace all uses of oldMemRef with newMemRef while optionally remapping the
|
||||
/// old memref's indices using the supplied affine map and adding any additional
|
||||
/// indices. The new memref could be of a different shape or rank. Returns true
|
||||
/// on success and false if the replacement is not possible (whenever a memref
|
||||
/// is used as an operand in a non-deferencing scenario).
|
||||
/// Additional indices are added at the start.
|
||||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
|
||||
// extended to add additional indices at any position.
|
||||
bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
|
||||
llvm::ArrayRef<SSAValue *> extraIndices,
|
||||
AffineMap *indexRemap = nullptr);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_UTILS_H
|
|
@ -312,6 +312,18 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
|
|||
return op;
|
||||
}
|
||||
|
||||
/// Create an operation given the fields.
|
||||
OperationStmt *MLFuncBuilder::createOperation(Location *location,
|
||||
Identifier name,
|
||||
ArrayRef<MLValue *> operands,
|
||||
ArrayRef<Type *> types,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
auto *op = OperationStmt::create(location, name, operands, types, attrs,
|
||||
getContext());
|
||||
block->getStatements().insert(insertPoint, op);
|
||||
return op;
|
||||
}
|
||||
|
||||
ForStmt *MLFuncBuilder::createFor(Location *location,
|
||||
ArrayRef<MLValue *> lbOperands,
|
||||
AffineMap *lbMap,
|
||||
|
|
|
@ -21,10 +21,13 @@
|
|||
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/MLFunction.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardOps.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -43,27 +46,237 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
|
|||
return new PipelineDataTransfer();
|
||||
}
|
||||
|
||||
// For testing purposes, this just runs on the first statement of the MLFunction
|
||||
// if that statement is a for stmt, and shifts the second half of its body by
|
||||
// one.
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
|
||||
// op traits for it are added. TODO(b/117228571)
|
||||
static bool isDmaStartStmt(const OperationStmt &stmt) {
|
||||
return stmt.getName().strref().contains("dma.in.start") ||
|
||||
stmt.getName().strref().contains("dma.out.start");
|
||||
}
|
||||
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
||||
// added. TODO(b/117228571)
|
||||
static bool isDmaFinishStmt(const OperationStmt &stmt) {
|
||||
return stmt.getName().strref().contains("dma.finish");
|
||||
}
|
||||
|
||||
/// Given a DMA start operation, returns the operand position of either the
|
||||
/// source or destination memref depending on the one that is at the higher
|
||||
/// level of the memory hierarchy.
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
||||
// added. TODO(b/117228571)
|
||||
static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
|
||||
assert(isDmaStartStmt(dmaStartStmt));
|
||||
unsigned srcDmaPos = 0;
|
||||
unsigned destDmaPos =
|
||||
cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
|
||||
|
||||
if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
|
||||
->getMemorySpace() >
|
||||
cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
|
||||
->getMemorySpace())
|
||||
return srcDmaPos;
|
||||
return destDmaPos;
|
||||
}
|
||||
|
||||
// Returns the position of the tag memref operand given a DMA statement.
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
||||
// added. TODO(b/117228571)
|
||||
unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
|
||||
assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
|
||||
if (isDmaStartStmt(dmaStmt)) {
|
||||
// Second to last operand.
|
||||
return dmaStmt.getNumOperands() - 2;
|
||||
}
|
||||
// First operand for a dma finish statement.
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Doubles the buffer of the supplied memref.
|
||||
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
||||
MLFuncBuilder bInner(forStmt, forStmt->begin());
|
||||
bInner.setInsertionPoint(forStmt, forStmt->begin());
|
||||
|
||||
// Doubles the shape with a leading dimension extent of 2.
|
||||
auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
|
||||
// Add the leading dimension in the shape for the double buffer.
|
||||
ArrayRef<int> shape = origMemRefType->getShape();
|
||||
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
|
||||
shapeSizes.insert(shapeSizes.begin(), 2);
|
||||
|
||||
auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
|
||||
return newMemRefType;
|
||||
};
|
||||
|
||||
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
|
||||
|
||||
// Create and place the alloc at the top level.
|
||||
auto *func = forStmt->getFunction();
|
||||
MLFuncBuilder topBuilder(func, func->begin());
|
||||
auto *newMemRef = cast<MLValue>(
|
||||
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
|
||||
->getResult());
|
||||
|
||||
auto d0 = bInner.getDimExpr(0);
|
||||
auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
|
||||
auto ivModTwoOp =
|
||||
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0)))
|
||||
return false;
|
||||
// We don't need ivMod2Op any more - this is cloned by
|
||||
// replaceAllMemRefUsesWith wherever the memref replacement happens. Once
|
||||
// b/117159533 is addressed, we'll eventually only need to pass
|
||||
// ivModTwoOp->getResult(0) to replaceAllMemRefUsesWith.
|
||||
cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
|
||||
return true;
|
||||
}
|
||||
|
||||
// For testing purposes, this just runs on the first for statement of an
|
||||
// MLFunction at the top level.
|
||||
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
|
||||
// the other TODOs listed inside are dealt with.
|
||||
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
|
||||
if (f->empty())
|
||||
return PassResult::Success;
|
||||
auto *forStmt = dyn_cast<ForStmt>(&f->front());
|
||||
|
||||
ForStmt *forStmt = nullptr;
|
||||
for (auto &stmt : *f) {
|
||||
if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!forStmt)
|
||||
return PassResult::Failure;
|
||||
return PassResult::Success;
|
||||
|
||||
unsigned numStmts = forStmt->getStatements().size();
|
||||
|
||||
if (numStmts == 0)
|
||||
return PassResult::Success;
|
||||
|
||||
std::vector<uint64_t> delays(numStmts);
|
||||
for (unsigned i = 0; i < numStmts; i++)
|
||||
delays[i] = (i < numStmts / 2) ? 0 : 1;
|
||||
SmallVector<OperationStmt *, 4> dmaStartStmts;
|
||||
SmallVector<OperationStmt *, 4> dmaFinishStmts;
|
||||
for (auto &stmt : *forStmt) {
|
||||
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
||||
if (!opStmt)
|
||||
continue;
|
||||
if (isDmaStartStmt(*opStmt)) {
|
||||
dmaStartStmts.push_back(opStmt);
|
||||
} else if (isDmaFinishStmt(*opStmt)) {
|
||||
dmaFinishStmts.push_back(opStmt);
|
||||
}
|
||||
}
|
||||
|
||||
if (!checkDominancePreservationOnShift(*forStmt, delays))
|
||||
// TODO(bondhugula,andydavis): match tag memref's (requires memory-based
|
||||
// subscript check utilities). Assume for now that start/finish are matched in
|
||||
// the order they appear.
|
||||
if (dmaStartStmts.size() != dmaFinishStmts.size())
|
||||
return PassResult::Failure;
|
||||
|
||||
// Double the buffers for the higher memory space memref's.
|
||||
// TODO(bondhugula): assuming we don't have multiple DMA starts for the same
|
||||
// memref.
|
||||
// TODO(bondhugula): check whether double-buffering is even necessary.
|
||||
// TODO(bondhugula): make this work with different layouts: assuming here that
|
||||
// the dimension we are adding here for the double buffering is the outermost
|
||||
// dimension.
|
||||
// Identify memref's to replace by scanning through all DMA start statements.
|
||||
// A DMA start statement has two memref's - the one from the higher level of
|
||||
// memory hierarchy is the one to double buffer.
|
||||
for (auto *dmaStartStmt : dmaStartStmts) {
|
||||
MLValue *oldMemRef = cast<MLValue>(
|
||||
dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
|
||||
if (!doubleBuffer(oldMemRef, forStmt))
|
||||
return PassResult::Failure;
|
||||
}
|
||||
|
||||
// Double the buffers for tag memref's.
|
||||
for (auto *dmaFinishStmt : dmaFinishStmts) {
|
||||
MLValue *oldTagMemRef = cast<MLValue>(
|
||||
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
|
||||
if (!doubleBuffer(oldTagMemRef, forStmt))
|
||||
return PassResult::Failure;
|
||||
}
|
||||
|
||||
// Collect all compute ops.
|
||||
std::vector<const Statement *> computeOps;
|
||||
computeOps.reserve(forStmt->getStatements().size());
|
||||
// Store delay for statement for later lookup for AffineApplyOp's.
|
||||
DenseMap<const Statement *, unsigned> opDelayMap;
|
||||
for (const auto &stmt : *forStmt) {
|
||||
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
||||
if (!opStmt) {
|
||||
// All for and if stmt's are treated as pure compute operations.
|
||||
// TODO(bondhugula): check whether such statements do not have any DMAs
|
||||
// nested within.
|
||||
opDelayMap[&stmt] = 1;
|
||||
} else if (isDmaStartStmt(*opStmt)) {
|
||||
// DMA starts are not shifted.
|
||||
opDelayMap[&stmt] = 0;
|
||||
} else if (isDmaFinishStmt(*opStmt)) {
|
||||
// DMA finish op shifted by one.
|
||||
opDelayMap[&stmt] = 1;
|
||||
} else if (!opStmt->is<AffineApplyOp>()) {
|
||||
// Compute op shifted by one.
|
||||
opDelayMap[&stmt] = 1;
|
||||
computeOps.push_back(&stmt);
|
||||
}
|
||||
// Shifts for affine apply op's determined later.
|
||||
}
|
||||
|
||||
// Get the ancestor of a 'stmt' that lies in forStmt's block.
|
||||
auto getAncestorInForBlock =
|
||||
[&](const Statement *stmt, const StmtBlock &block) -> const Statement * {
|
||||
// Traverse up the statement hierarchy starting from the owner of operand to
|
||||
// find the ancestor statement that resides in the block of 'forStmt'.
|
||||
while (stmt != nullptr && stmt->getBlock() != &block) {
|
||||
stmt = stmt->getParentStmt();
|
||||
}
|
||||
return stmt;
|
||||
};
|
||||
|
||||
// Determine delays for affine apply op's: look up delay from its consumer op.
|
||||
// This code will be thrown away once we have a way to obtain indices through
|
||||
// a composed affine_apply op. See TODO(b/117159533). Such a composed
|
||||
// affine_apply will be used exclusively by a given memref deferencing op.
|
||||
for (const auto &stmt : *forStmt) {
|
||||
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
||||
// Skip statements that aren't affine apply ops.
|
||||
if (!opStmt || !opStmt->is<AffineApplyOp>())
|
||||
continue;
|
||||
// Traverse uses of each result of the affine apply op.
|
||||
for (auto *res : opStmt->getResults()) {
|
||||
for (auto &use : res->getUses()) {
|
||||
auto *ancestorInForBlock =
|
||||
getAncestorInForBlock(use.getOwner(), *forStmt);
|
||||
assert(ancestorInForBlock &&
|
||||
"traversing parent should reach forStmt block");
|
||||
auto *opCheck = dyn_cast<OperationStmt>(ancestorInForBlock);
|
||||
if (!opCheck || opCheck->is<AffineApplyOp>())
|
||||
continue;
|
||||
assert(opDelayMap.find(ancestorInForBlock) != opDelayMap.end());
|
||||
if (opDelayMap.find(&stmt) != opDelayMap.end()) {
|
||||
// This is where we enforce all uses of this affine_apply to have
|
||||
// the same shifts - so that we know what shift to use for the
|
||||
// affine_apply to preserve semantics.
|
||||
assert(opDelayMap[&stmt] == opDelayMap[ancestorInForBlock]);
|
||||
} else {
|
||||
// Obtain delay from its consumer.
|
||||
opDelayMap[&stmt] = opDelayMap[ancestorInForBlock];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get delays stored in map.
|
||||
std::vector<uint64_t> delays(forStmt->getStatements().size());
|
||||
unsigned s = 0;
|
||||
for (const auto &stmt : *forStmt) {
|
||||
delays[s++] = opDelayMap[&stmt];
|
||||
}
|
||||
|
||||
if (!checkDominancePreservationOnShift(*forStmt, delays)) {
|
||||
// Violates SSA dominance.
|
||||
return PassResult::Failure;
|
||||
}
|
||||
|
||||
if (stmtBodySkew(forStmt, delays))
|
||||
return PassResult::Failure;
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements miscellaneous transformation routines for non-loop IR
|
||||
// structures.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardOps.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Return true if this operation dereferences one or more memref's.
|
||||
// Temporary utility: will be replaced when this is modeled through
|
||||
// side-effects/op traits. TODO(b/117228571)
|
||||
static bool isMemRefDereferencingOp(const Operation &op) {
|
||||
if (op.is<LoadOp>() || op.is<StoreOp>() ||
|
||||
op.getName().strref().contains("dma.in.start") ||
|
||||
op.getName().strref().contains("dma.out.start") ||
|
||||
op.getName().strref().contains("dma.finish")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
|
||||
/// old memref's indices to the new memref using the supplied affine map
|
||||
/// and adding any additional indices. The new memref could be of a different
|
||||
/// shape or rank, but of the same elemental type. Additional indices are added
|
||||
/// at the start for now.
|
||||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
|
||||
// extended to add additional indices at any position.
|
||||
bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
|
||||
ArrayRef<SSAValue *> extraIndices,
|
||||
AffineMap *indexRemap) {
|
||||
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
|
||||
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
|
||||
if (indexRemap) {
|
||||
assert(indexRemap->getNumInputs() == oldMemRefRank);
|
||||
assert(indexRemap->getNumResults() + extraIndices.size() == newMemRefRank);
|
||||
} else {
|
||||
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
|
||||
}
|
||||
|
||||
// Assert same elemental type.
|
||||
assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
|
||||
cast<MemRefType>(newMemRef->getType())->getElementType());
|
||||
|
||||
// Check if memref was used in a non-deferencing context.
|
||||
for (const StmtOperand &use : oldMemRef->getUses()) {
|
||||
auto *opStmt = cast<OperationStmt>(use.getOwner());
|
||||
// Failure: memref used in a non-deferencing op (potentially escapes); no
|
||||
// replacement in these cases.
|
||||
if (!isMemRefDereferencingOp(*opStmt))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Walk all uses of old memref. Statement using the memref gets replaced.
|
||||
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
|
||||
StmtOperand &use = *(it++);
|
||||
auto *opStmt = cast<OperationStmt>(use.getOwner());
|
||||
assert(isMemRefDereferencingOp(*opStmt) &&
|
||||
"memref deferencing op expected");
|
||||
|
||||
auto getMemRefOperandPos = [&]() -> unsigned {
|
||||
unsigned i;
|
||||
for (i = 0; i < opStmt->getNumOperands(); i++) {
|
||||
if (opStmt->getOperand(i) == oldMemRef)
|
||||
break;
|
||||
}
|
||||
assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
|
||||
return i;
|
||||
};
|
||||
unsigned memRefOperandPos = getMemRefOperandPos();
|
||||
|
||||
// Construct the new operation statement using this memref.
|
||||
SmallVector<MLValue *, 8> operands;
|
||||
operands.reserve(opStmt->getNumOperands() + extraIndices.size());
|
||||
// Insert the non-memref operands.
|
||||
operands.insert(operands.end(), opStmt->operand_begin(),
|
||||
opStmt->operand_begin() + memRefOperandPos);
|
||||
operands.push_back(newMemRef);
|
||||
|
||||
MLFuncBuilder builder(opStmt);
|
||||
// Normally, we could just use extraIndices as operands, but we will
|
||||
// clone it so that each op gets its own "private" index. See b/117159533.
|
||||
for (auto *extraIndex : extraIndices) {
|
||||
OperationStmt::OperandMapTy operandMap;
|
||||
// TODO(mlir-team): An operation/SSA value should provide a method to
|
||||
// return the position of an SSA result in its defining
|
||||
// operation.
|
||||
assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
|
||||
"single result op's expected to generate these indices");
|
||||
// TODO: actually check if this is a result of an affine_apply op.
|
||||
assert((cast<MLValue>(extraIndex)->isValidDim() ||
|
||||
cast<MLValue>(extraIndex)->isValidSymbol()) &&
|
||||
"invalid memory op index");
|
||||
auto *clonedExtraIndex =
|
||||
cast<OperationStmt>(
|
||||
builder.clone(*extraIndex->getDefiningStmt(), operandMap))
|
||||
->getResult(0);
|
||||
operands.push_back(cast<MLValue>(clonedExtraIndex));
|
||||
}
|
||||
|
||||
// Construct new indices. The indices of a memref come right after it, i.e.,
|
||||
// at position memRefOperandPos + 1.
|
||||
SmallVector<SSAValue *, 4> indices(
|
||||
opStmt->operand_begin() + memRefOperandPos + 1,
|
||||
opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
|
||||
if (indexRemap) {
|
||||
auto remapOp =
|
||||
builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
|
||||
// Remapped indices.
|
||||
for (auto *index : remapOp->getOperation()->getResults())
|
||||
operands.push_back(cast<MLValue>(index));
|
||||
} else {
|
||||
// No remapping specified.
|
||||
for (auto *index : indices)
|
||||
operands.push_back(cast<MLValue>(index));
|
||||
}
|
||||
|
||||
// Insert the remaining operands unmodified.
|
||||
operands.insert(operands.end(),
|
||||
opStmt->operand_begin() + memRefOperandPos + 1 +
|
||||
oldMemRefRank,
|
||||
opStmt->operand_end());
|
||||
|
||||
// Result types don't change. Both memref's are of the same elemental type.
|
||||
SmallVector<Type *, 8> resultTypes;
|
||||
resultTypes.reserve(opStmt->getNumResults());
|
||||
for (const auto *result : opStmt->getResults())
|
||||
resultTypes.push_back(result->getType());
|
||||
|
||||
// Create the new operation.
|
||||
auto *repOp =
|
||||
builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
|
||||
resultTypes, opStmt->getAttrs());
|
||||
// Replace old memref's deferencing op's uses.
|
||||
unsigned r = 0;
|
||||
for (auto *res : opStmt->getResults()) {
|
||||
res->replaceAllUsesWith(repOp->getResult(r++));
|
||||
}
|
||||
opStmt->eraseFromBlock();
|
||||
}
|
||||
return true;
|
||||
}
|
|
@ -1,79 +1,66 @@
|
|||
// RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: mlfunc @loop_nest_simple() {
|
||||
// CHECK: %c8 = constant 8 : affineint
|
||||
// CHECK-NEXT: %c0 = constant 0 : affineint
|
||||
// CHECK-NEXT: %0 = "foo"(%c0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: for %i0 = 1 to 7 {
|
||||
// CHECK-NEXT: %1 = "foo"(%i0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %2 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: %3 = "bar"(%2) : (affineint) -> affineint
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %4 = affine_apply #map0(%c8)
|
||||
// CHECK-NEXT: %5 = "bar"(%4) : (affineint) -> affineint
|
||||
// CHECK-NEXT: return
|
||||
mlfunc @loop_nest_simple() {
|
||||
for %i = 0 to 7 {
|
||||
%y = "foo"(%i) : (affineint) -> affineint
|
||||
%x = "bar"(%i) : (affineint) -> affineint
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @loop_nest_dma() {
|
||||
// CHECK: %c8 = constant 8 : affineint
|
||||
// CHECK-NEXT: %c0 = constant 0 : affineint
|
||||
// CHECK-NEXT: %0 = affine_apply #map1(%c0)
|
||||
// CHECK-NEXT: %1 = "dma.enqueue"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %2 = "dma.enqueue"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: for %i0 = 1 to 7 {
|
||||
// CHECK-NEXT: %3 = affine_apply #map1(%i0)
|
||||
// CHECK-NEXT: %4 = "dma.enqueue"(%3) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %5 = "dma.enqueue"(%3) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %6 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: %7 = affine_apply #map1(%6)
|
||||
// CHECK-NEXT: %8 = "dma.wait"(%7) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %9 = "compute1"(%7) : (affineint) -> affineint
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %10 = affine_apply #map0(%c8)
|
||||
// CHECK-NEXT: %11 = affine_apply #map1(%10)
|
||||
// CHECK-NEXT: %12 = "dma.wait"(%11) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %13 = "compute1"(%11) : (affineint) -> affineint
|
||||
// CHECK-NEXT: return
|
||||
// CHECK: #map0 = (d0) -> (d0 mod 2)
|
||||
// CHECK-NEXT: #map1 = (d0) -> (d0 - 1)
|
||||
// CHECK-NEXT: mlfunc @loop_nest_dma() {
|
||||
// CHECK-NEXT: %c8 = constant 8 : affineint
|
||||
// CHECK-NEXT: %c0 = constant 0 : affineint
|
||||
// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
|
||||
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32>
|
||||
// CHECK-NEXT: %2 = alloc() : memref<256xf32, (d0) -> (d0)>
|
||||
// CHECK-NEXT: %3 = alloc() : memref<32xf32, (d0) -> (d0), 1>
|
||||
// CHECK-NEXT: %4 = alloc() : memref<1xf32>
|
||||
// CHECK-NEXT: %c0_0 = constant 0 : affineint
|
||||
// CHECK-NEXT: %c128 = constant 128 : affineint
|
||||
// CHECK-NEXT: %5 = affine_apply #map0(%c0)
|
||||
// CHECK-NEXT: %6 = affine_apply #map0(%c0)
|
||||
// CHECK-NEXT: "dma.in.start"(%2, %c0, %1, %5, %c0, %c128, %0, %6, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> ()
|
||||
// CHECK-NEXT: for %i0 = 1 to 7 {
|
||||
// CHECK-NEXT: %7 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: %8 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: "dma.in.start"(%2, %i0, %1, %7, %i0, %c128, %0, %8, %c0_0) : (memref<256xf32, (d0) -> (d0)>, affineint, memref<2x32xf32>, affineint, affineint, affineint, memref<2x1xf32>, affineint, affineint) -> ()
|
||||
// CHECK-NEXT: %9 = affine_apply #map1(%i0)
|
||||
// CHECK-NEXT: %10 = affine_apply #map0(%9)
|
||||
// CHECK-NEXT: %11 = "dma.finish"(%0, %10, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint
|
||||
// CHECK-NEXT: %12 = affine_apply #map0(%9)
|
||||
// CHECK-NEXT: %13 = load %1[%12, %9] : memref<2x32xf32>
|
||||
// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32
|
||||
// CHECK-NEXT: %15 = affine_apply #map0(%9)
|
||||
// CHECK-NEXT: store %14, %1[%15, %9] : memref<2x32xf32>
|
||||
// CHECK-NEXT: for %i1 = 0 to 127 {
|
||||
// CHECK-NEXT: "do_more_compute"(%9, %i1) : (affineint, affineint) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %16 = affine_apply #map1(%c8)
|
||||
// CHECK-NEXT: %17 = affine_apply #map0(%16)
|
||||
// CHECK-NEXT: %18 = "dma.finish"(%0, %17, %c0_0) : (memref<2x1xf32>, affineint, affineint) -> affineint
|
||||
// CHECK-NEXT: %19 = affine_apply #map0(%16)
|
||||
// CHECK-NEXT: %20 = load %1[%19, %16] : memref<2x32xf32>
|
||||
// CHECK-NEXT: %21 = "compute"(%20) : (f32) -> f32
|
||||
// CHECK-NEXT: %22 = affine_apply #map0(%16)
|
||||
// CHECK-NEXT: store %21, %1[%22, %16] : memref<2x32xf32>
|
||||
// CHECK-NEXT: for %i2 = 0 to 127 {
|
||||
// CHECK-NEXT: "do_more_compute"(%16, %i2) : (affineint, affineint) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
mlfunc @loop_nest_dma() {
|
||||
for %i = 0 to 7 {
|
||||
%pingpong = affine_apply (d0) -> (d0 mod 2) (%i)
|
||||
"dma.enqueue"(%pingpong) : (affineint) -> affineint
|
||||
"dma.enqueue"(%pingpong) : (affineint) -> affineint
|
||||
%pongping = affine_apply (d0) -> (d0 mod 2) (%i)
|
||||
"dma.wait"(%pongping) : (affineint) -> affineint
|
||||
"compute1"(%pongping) : (affineint) -> affineint
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @loop_nest_bound_map(%arg0 : affineint) {
|
||||
// CHECK: %0 = affine_apply #map2()[%arg0]
|
||||
// CHECK-NEXT: %1 = "foo"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %2 = "bar"(%0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: for %i0 = #map3()[%arg0] to #map4()[%arg0] {
|
||||
// CHECK-NEXT: %3 = "foo"(%i0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %4 = "bar"(%i0) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %5 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: %6 = "foo_bar"(%5) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %7 = "bar_foo"(%5) : (affineint) -> affineint
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %8 = affine_apply #map5()[%arg0]
|
||||
// CHECK-NEXT: %9 = affine_apply #map0(%8)
|
||||
// CHECK-NEXT: %10 = "foo_bar"(%9) : (affineint) -> affineint
|
||||
// CHECK-NEXT: %11 = "bar_foo"(%9) : (affineint) -> affineint
|
||||
// CHECK-NEXT: return
|
||||
mlfunc @loop_nest_bound_map(%N : affineint) {
|
||||
for %i = %N to ()[s0] -> (s0 + 7)()[%N] {
|
||||
"foo"(%i) : (affineint) -> affineint
|
||||
"bar"(%i) : (affineint) -> affineint
|
||||
"foo_bar"(%i) : (affineint) -> (affineint)
|
||||
"bar_foo"(%i) : (affineint) -> (affineint)
|
||||
%A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
|
||||
%Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
|
||||
|
||||
%tag = alloc() : memref<1 x f32>
|
||||
|
||||
%zero = constant 0 : affineint
|
||||
%size = constant 128 : affineint
|
||||
|
||||
for %i = 0 to 7 {
|
||||
"dma.in.start"(%A, %i, %Ah, %i, %size, %tag, %zero) : (memref<256 x f32, (d0)->(d0), 0>, affineint, memref<32 x f32, (d0)->(d0), 1>, affineint, affineint, memref<1 x f32>, affineint) -> ()
|
||||
"dma.finish"(%tag, %zero) : (memref<1 x f32>, affineint) -> affineint
|
||||
%v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
|
||||
%r = "compute"(%v) : (f32) -> (f32)
|
||||
store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
|
||||
for %j = 0 to 127 {
|
||||
"do_more_compute"(%i, %j) : (affineint, affineint) -> ()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue