forked from OSchip/llvm-project
[MLIR] Introduce inter-procedural memref layout normalization
-- Introduces a pass that normalizes the affine layout maps to the identity layout map both within and across functions by rewriting function arguments and call operands where necessary. -- Memref normalization is now implemented entirely in the module pass '-normalize-memrefs' and the limited intra-procedural version has been removed from '-simplify-affine-structures'. -- Run using -normalize-memrefs. -- Return ops are not handled and would be handled in the subsequent revisions. Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com> Differential Revision: https://reviews.llvm.org/D84490
This commit is contained in:
parent
e12db3ed99
commit
76d07503f0
|
@ -24,7 +24,9 @@ class AffineForOp;
|
|||
class FuncOp;
|
||||
class ModuleOp;
|
||||
class Pass;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
template <typename T>
|
||||
class OperationPass;
|
||||
|
||||
/// Creates an instance of the BufferPlacement pass.
|
||||
std::unique_ptr<Pass> createBufferPlacementPass();
|
||||
|
@ -89,6 +91,10 @@ std::unique_ptr<Pass> createSCCPPass();
|
|||
/// Creates a pass which delete symbol operations that are unreachable. This
|
||||
/// pass may *only* be scheduled on an operation that defines a SymbolTable.
|
||||
std::unique_ptr<Pass> createSymbolDCEPass();
|
||||
|
||||
/// Creates an interprocedural pass to normalize memrefs to have a trivial
|
||||
/// (identity) layout map.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_PASSES_H
|
||||
|
|
|
@ -309,6 +309,11 @@ def MemRefDataFlowOpt : FunctionPass<"memref-dataflow-opt"> {
|
|||
let constructor = "mlir::createMemRefDataFlowOptPass()";
|
||||
}
|
||||
|
||||
def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
|
||||
let summary = "Normalize memrefs";
|
||||
let constructor = "mlir::createNormalizeMemRefsPass()";
|
||||
}
|
||||
|
||||
def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> {
|
||||
let summary = "Collapse parallel loops to use less induction variables";
|
||||
let constructor = "mlir::createParallelLoopCollapsingPass()";
|
||||
|
@ -405,5 +410,4 @@ def SymbolDCE : Pass<"symbol-dce"> {
|
|||
}];
|
||||
let constructor = "mlir::createSymbolDCEPass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_TRANSFORMS_PASSES
|
||||
|
|
|
@ -45,10 +45,19 @@ class OpBuilder;
|
|||
/// operations that are dominated by the former; similarly, `postDomInstFilter`
|
||||
/// restricts replacement to only those operations that are postdominated by it.
|
||||
///
|
||||
/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing
|
||||
/// uses of a memref without any requirement for access index rewrites. The
|
||||
/// default value of this flag variable is false.
|
||||
///
|
||||
/// 'replaceInDeallocOp', if set, lets DeallocOp, a non-dereferencing user, to
|
||||
/// also be a candidate for replacement. The default value of this flag is
|
||||
/// false.
|
||||
///
|
||||
/// Returns true on success and false if the replacement is not possible,
|
||||
/// whenever a memref is used as an operand in a non-dereferencing context,
|
||||
/// except for dealloc's on the memref which are left untouched. See comments at
|
||||
/// function definition for an example.
|
||||
/// whenever a memref is used as an operand in a non-dereferencing context and
|
||||
/// 'allowNonDereferencingOps' is false, except for dealloc's on the memref
|
||||
/// which are left untouched. See comments at function definition for an
|
||||
/// example.
|
||||
//
|
||||
// Ex: to replace load %A[%i, %j] with load %Abuf[%t mod 2, %ii - %i, %j]:
|
||||
// The SSA value corresponding to '%t mod 2' should be in 'extraIndices', and
|
||||
|
@ -57,28 +66,38 @@ class OpBuilder;
|
|||
// extra operands, note that 'indexRemap' would just be applied to existing
|
||||
// indices (%i, %j).
|
||||
// TODO: allow extraIndices to be added at any position.
|
||||
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
||||
ArrayRef<Value> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value> extraOperands = {},
|
||||
ArrayRef<Value> symbolOperands = {},
|
||||
Operation *domInstFilter = nullptr,
|
||||
Operation *postDomInstFilter = nullptr);
|
||||
LogicalResult replaceAllMemRefUsesWith(
|
||||
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(), ArrayRef<Value> extraOperands = {},
|
||||
ArrayRef<Value> symbolOperands = {}, Operation *domInstFilter = nullptr,
|
||||
Operation *postDomInstFilter = nullptr,
|
||||
bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false);
|
||||
|
||||
/// Performs the same replacement as the other version above but only for the
|
||||
/// dereferencing uses of `oldMemRef` in `op`.
|
||||
/// dereferencing uses of `oldMemRef` in `op`, except in cases where
|
||||
/// 'allowNonDereferencingOps' is set to true where we replace the
|
||||
/// non-dereferencing uses as well.
|
||||
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
||||
Operation *op,
|
||||
ArrayRef<Value> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap(),
|
||||
ArrayRef<Value> extraOperands = {},
|
||||
ArrayRef<Value> symbolOperands = {});
|
||||
ArrayRef<Value> symbolOperands = {},
|
||||
bool allowNonDereferencingOps = false);
|
||||
|
||||
/// Rewrites the memref defined by this alloc op to have an identity layout map
|
||||
/// and updates all its indexing uses. Returns failure if any of its uses
|
||||
/// escape (while leaving the IR in a valid state).
|
||||
LogicalResult normalizeMemRef(AllocOp op);
|
||||
|
||||
/// Uses the old memref type map layout and computes the new memref type to have
|
||||
/// a new shape and a layout map, where the old layout map has been normalized
|
||||
/// to an identity layout map. It returns the old memref in case no
|
||||
/// normalization was needed or a failure occurs while transforming the old map
|
||||
/// layout to an identity layout map.
|
||||
MemRefType normalizeMemRefType(MemRefType memrefType, OpBuilder builder,
|
||||
unsigned numSymbolicOperands);
|
||||
|
||||
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
|
||||
/// its results equal to the number of operands, as a composition
|
||||
/// of all other AffineApplyOps reachable from input parameter 'operands'. If
|
||||
|
|
|
@ -96,13 +96,4 @@ void SimplifyAffineStructures::runOnFunction() {
|
|||
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
|
||||
applyOpPatternsAndFold(op, patterns);
|
||||
});
|
||||
|
||||
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
|
||||
// alloc ops first and then process since normalizeMemRef replaces/erases ops
|
||||
// during memref rewriting.
|
||||
SmallVector<AllocOp, 4> allocOps;
|
||||
func.walk([&](AllocOp op) { allocOps.push_back(op); });
|
||||
for (auto allocOp : allocOps) {
|
||||
normalizeMemRef(allocOp);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ add_mlir_library(MLIRTransforms
|
|||
LoopFusion.cpp
|
||||
LoopInvariantCodeMotion.cpp
|
||||
MemRefDataFlowOpt.cpp
|
||||
NormalizeMemRefs.cpp
|
||||
OpStats.cpp
|
||||
ParallelLoopCollapsing.cpp
|
||||
PipelineDataTransfer.cpp
|
||||
|
|
|
@ -0,0 +1,218 @@
|
|||
//===- NormalizeMemRefs.cpp -----------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements an interprocedural pass to normalize memrefs to have
|
||||
// identity layout maps.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
#define DEBUG_TYPE "normalize-memrefs"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// All memrefs passed across functions with non-trivial layout maps are
|
||||
/// converted to ones with trivial identity layout ones.
|
||||
|
||||
// Input :-
|
||||
// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
|
||||
// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
|
||||
// (memref<16xf64, #tile>) {
|
||||
// affine.for %arg3 = 0 to 16 {
|
||||
// %a = affine.load %A[%arg3] : memref<16xf64, #tile>
|
||||
// %p = mulf %a, %a : f64
|
||||
// affine.store %p, %A[%arg3] : memref<16xf64, #tile>
|
||||
// }
|
||||
// %c = alloc() : memref<16xf64, #tile>
|
||||
// %d = affine.load %c[0] : memref<16xf64, #tile>
|
||||
// return %A: memref<16xf64, #tile>
|
||||
// }
|
||||
|
||||
// Output :-
|
||||
// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
|
||||
// -> memref<4x4xf64> {
|
||||
// affine.for %arg3 = 0 to 16 {
|
||||
// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
|
||||
// %3 = mulf %2, %2 : f64
|
||||
// affine.store %3, %arg0[%arg3 floordiv 4, %arg3 mod 4] : memref<4x4xf64>
|
||||
// }
|
||||
// %0 = alloc() : memref<16xf64, #map0>
|
||||
// %1 = affine.load %0[0] : memref<16xf64, #map0>
|
||||
// return %arg0 : memref<4x4xf64>
|
||||
// }
|
||||
|
||||
struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
|
||||
void runOnOperation() override;
|
||||
void runOnFunction(FuncOp funcOp);
|
||||
bool areMemRefsNormalizable(FuncOp funcOp);
|
||||
void updateFunctionSignature(FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
|
||||
return std::make_unique<NormalizeMemRefs>();
|
||||
}
|
||||
|
||||
void NormalizeMemRefs::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
// We traverse each function within the module in order to normalize the
|
||||
// memref type arguments.
|
||||
// TODO: Handle external functions.
|
||||
moduleOp.walk([&](FuncOp funcOp) {
|
||||
if (areMemRefsNormalizable(funcOp))
|
||||
runOnFunction(funcOp);
|
||||
});
|
||||
}
|
||||
|
||||
// Return true if this operation dereferences one or more memref's.
|
||||
// TODO: Temporary utility, will be replaced when this is modeled through
|
||||
// side-effects/op traits.
|
||||
static bool isMemRefDereferencingOp(Operation &op) {
|
||||
return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
|
||||
AffineDmaWaitOp>(op);
|
||||
}
|
||||
|
||||
// Check whether all the uses of oldMemRef are either dereferencing uses or the
|
||||
// op is of type : DeallocOp, CallOp. Only if these constraints are satisfied
|
||||
// will the value become a candidate for replacement.
|
||||
static bool isMemRefNormalizable(Value::user_range opUsers) {
|
||||
if (llvm::any_of(opUsers, [](Operation *op) {
|
||||
if (isMemRefDereferencingOp(*op))
|
||||
return false;
|
||||
return !isa<DeallocOp, CallOp>(*op);
|
||||
}))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether all the uses of AllocOps, CallOps and function arguments of a
|
||||
// function are either of dereferencing type or of type: DeallocOp, CallOp. Only
|
||||
// if these constraints are satisfied will the function become a candidate for
|
||||
// normalization.
|
||||
bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
|
||||
if (funcOp
|
||||
.walk([&](AllocOp allocOp) -> WalkResult {
|
||||
Value oldMemRef = allocOp.getResult();
|
||||
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
})
|
||||
.wasInterrupted())
|
||||
return false;
|
||||
|
||||
if (funcOp
|
||||
.walk([&](CallOp callOp) -> WalkResult {
|
||||
for (unsigned resIndex :
|
||||
llvm::seq<unsigned>(0, callOp.getNumResults())) {
|
||||
Value oldMemRef = callOp.getResult(resIndex);
|
||||
if (oldMemRef.getType().isa<MemRefType>())
|
||||
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
})
|
||||
.wasInterrupted())
|
||||
return false;
|
||||
|
||||
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
|
||||
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
|
||||
if (oldMemRef.getType().isa<MemRefType>())
|
||||
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Fetch the updated argument list and result of the function and update the
|
||||
// function signature.
|
||||
void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp) {
|
||||
FunctionType functionType = funcOp.getType();
|
||||
SmallVector<Type, 8> argTypes;
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
|
||||
for (const auto &arg : llvm::enumerate(funcOp.getArguments()))
|
||||
argTypes.push_back(arg.value().getType());
|
||||
|
||||
resultTypes = llvm::to_vector<4>(functionType.getResults());
|
||||
// We create a new function type and modify the function signature with this
|
||||
// new type.
|
||||
FunctionType newFuncType = FunctionType::get(/*inputs=*/argTypes,
|
||||
/*results=*/resultTypes,
|
||||
/*context=*/&getContext());
|
||||
|
||||
// TODO: Handle ReturnOps to update function results the caller site.
|
||||
funcOp.setType(newFuncType);
|
||||
}
|
||||
|
||||
void NormalizeMemRefs::runOnFunction(FuncOp funcOp) {
|
||||
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
|
||||
// alloc ops first and then process since normalizeMemRef replaces/erases ops
|
||||
// during memref rewriting.
|
||||
SmallVector<AllocOp, 4> allocOps;
|
||||
funcOp.walk([&](AllocOp op) { allocOps.push_back(op); });
|
||||
for (AllocOp allocOp : allocOps)
|
||||
normalizeMemRef(allocOp);
|
||||
|
||||
// We use this OpBuilder to create new memref layout later.
|
||||
OpBuilder b(funcOp);
|
||||
|
||||
// Walk over each argument of a function to perform memref normalization (if
|
||||
// any).
|
||||
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
|
||||
Type argType = funcOp.getArgument(argIndex).getType();
|
||||
MemRefType memrefType = argType.dyn_cast<MemRefType>();
|
||||
// Check whether argument is of MemRef type. Any other argument type can
|
||||
// simply be part of the final function signature.
|
||||
if (!memrefType)
|
||||
continue;
|
||||
// Fetch a new memref type after normalizing the old memref to have an
|
||||
// identity map layout.
|
||||
MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
|
||||
/*numSymbolicOperands=*/0);
|
||||
if (newMemRefType == memrefType) {
|
||||
// Either memrefType already had an identity map or the map couldn't be
|
||||
// transformed to an identity map.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Insert a new temporary argument with the new memref type.
|
||||
BlockArgument newMemRef =
|
||||
funcOp.front().insertArgument(argIndex, newMemRefType);
|
||||
BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
|
||||
AffineMap layoutMap = memrefType.getAffineMaps().front();
|
||||
// Replace all uses of the old memref.
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
|
||||
/*extraIndices=*/{},
|
||||
/*indexRemap=*/layoutMap,
|
||||
/*extraOperands=*/{},
|
||||
/*symbolOperands=*/{},
|
||||
/*domInstFilter=*/nullptr,
|
||||
/*postDomInstFilter=*/nullptr,
|
||||
/*allowNonDereferencingOps=*/true,
|
||||
/*handleDeallocOp=*/true))) {
|
||||
// If it failed (due to escapes for example), bail out. Removing the
|
||||
// temporary argument inserted previously.
|
||||
funcOp.front().eraseArgument(argIndex);
|
||||
continue;
|
||||
}
|
||||
|
||||
// All uses for the argument with old memref type were replaced
|
||||
// successfully. So we remove the old argument now.
|
||||
funcOp.front().eraseArgument(argIndex + 1);
|
||||
}
|
||||
|
||||
updateFunctionSignature(funcOp);
|
||||
}
|
|
@ -48,7 +48,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
ArrayRef<Value> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value> extraOperands,
|
||||
ArrayRef<Value> symbolOperands) {
|
||||
ArrayRef<Value> symbolOperands,
|
||||
bool allowNonDereferencingOps) {
|
||||
unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
|
||||
|
@ -67,11 +68,6 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
|
||||
newMemRef.getType().cast<MemRefType>().getElementType());
|
||||
|
||||
if (!isMemRefDereferencingOp(*op))
|
||||
// Failure: memref used in a non-dereferencing context (potentially
|
||||
// escapes); no replacement in these cases.
|
||||
return failure();
|
||||
|
||||
SmallVector<unsigned, 2> usePositions;
|
||||
for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
|
||||
if (opEntry.value() == oldMemRef)
|
||||
|
@ -91,6 +87,18 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
unsigned memRefOperandPos = usePositions.front();
|
||||
|
||||
OpBuilder builder(op);
|
||||
// The following checks if op is dereferencing memref and performs the access
|
||||
// index rewrites.
|
||||
if (!isMemRefDereferencingOp(*op)) {
|
||||
if (!allowNonDereferencingOps)
|
||||
// Failure: memref used in a non-dereferencing context (potentially
|
||||
// escapes); no replacement in these cases unless allowNonDereferencingOps
|
||||
// is set.
|
||||
return failure();
|
||||
op->setOperand(memRefOperandPos, newMemRef);
|
||||
return success();
|
||||
}
|
||||
// Perform index rewrites for the dereferencing op and then replace the op
|
||||
NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
|
||||
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
|
||||
unsigned oldMapNumInputs = oldMap.getNumInputs();
|
||||
|
@ -112,7 +120,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
affineApplyOps.push_back(afOp);
|
||||
}
|
||||
} else {
|
||||
oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
|
||||
oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
|
||||
}
|
||||
|
||||
// Construct new indices as a remap of the old ones if a remapping has been
|
||||
|
@ -141,14 +149,14 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
}
|
||||
} else {
|
||||
// No remapping specified.
|
||||
remapOutputs.append(remapOperands.begin(), remapOperands.end());
|
||||
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> newMapOperands;
|
||||
newMapOperands.reserve(newMemRefRank);
|
||||
|
||||
// Prepend 'extraIndices' in 'newMapOperands'.
|
||||
for (auto extraIndex : extraIndices) {
|
||||
for (Value extraIndex : extraIndices) {
|
||||
assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
|
||||
"single result op's expected to generate these indices");
|
||||
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
|
||||
|
@ -167,12 +175,12 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
newMap = simplifyAffineMap(newMap);
|
||||
canonicalizeMapAndOperands(&newMap, &newMapOperands);
|
||||
// Remove any affine.apply's that became dead as a result of composition.
|
||||
for (auto value : affineApplyOps)
|
||||
for (Value value : affineApplyOps)
|
||||
if (value.use_empty())
|
||||
value.getDefiningOp()->erase();
|
||||
|
||||
// Construct the new operation using this memref.
|
||||
OperationState state(op->getLoc(), op->getName());
|
||||
// Construct the new operation using this memref.
|
||||
state.operands.reserve(op->getNumOperands() + extraIndices.size());
|
||||
// Insert the non-memref operands.
|
||||
state.operands.append(op->operand_begin(),
|
||||
|
@ -196,11 +204,10 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
// Add attribute for 'newMap', other Attributes do not change.
|
||||
auto newMapAttr = AffineMapAttr::get(newMap);
|
||||
for (auto namedAttr : op->getAttrs()) {
|
||||
if (namedAttr.first == oldMapAttrPair.first) {
|
||||
if (namedAttr.first == oldMapAttrPair.first)
|
||||
state.attributes.push_back({namedAttr.first, newMapAttr});
|
||||
} else {
|
||||
else
|
||||
state.attributes.push_back(namedAttr);
|
||||
}
|
||||
}
|
||||
|
||||
// Create the new operation.
|
||||
|
@ -211,13 +218,12 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
||||
ArrayRef<Value> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value> extraOperands,
|
||||
ArrayRef<Value> symbolOperands,
|
||||
Operation *domInstFilter,
|
||||
Operation *postDomInstFilter) {
|
||||
LogicalResult mlir::replaceAllMemRefUsesWith(
|
||||
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
|
||||
AffineMap indexRemap, ArrayRef<Value> extraOperands,
|
||||
ArrayRef<Value> symbolOperands, Operation *domInstFilter,
|
||||
Operation *postDomInstFilter, bool allowNonDereferencingOps,
|
||||
bool replaceInDeallocOp) {
|
||||
unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
|
||||
|
@ -261,16 +267,21 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
|
||||
// Skip dealloc's - no replacement is necessary, and a memref replacement
|
||||
// at other uses doesn't hurt these dealloc's.
|
||||
if (isa<DeallocOp>(op))
|
||||
if (isa<DeallocOp>(op) && !replaceInDeallocOp)
|
||||
continue;
|
||||
|
||||
// Check if the memref was used in a non-dereferencing context. It is fine
|
||||
// for the memref to be used in a non-dereferencing way outside of the
|
||||
// region where this replacement is happening.
|
||||
if (!isMemRefDereferencingOp(*op))
|
||||
// Failure: memref used in a non-dereferencing op (potentially escapes);
|
||||
// no replacement in these cases.
|
||||
return failure();
|
||||
if (!isMemRefDereferencingOp(*op)) {
|
||||
// Currently we support the following non-dereferencing types to be a
|
||||
// candidate for replacement: Dealloc and CallOp.
|
||||
// TODO: Add support for other kinds of ops.
|
||||
if (!allowNonDereferencingOps)
|
||||
return failure();
|
||||
if (!(isa<DeallocOp, CallOp>(*op)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// We'll first collect and then replace --- since replacement erases the op
|
||||
// that has the use, and that op could be postDomFilter or domFilter itself!
|
||||
|
@ -278,9 +289,9 @@ LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
|
|||
}
|
||||
|
||||
for (auto *op : opsToReplace) {
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
|
||||
indexRemap, extraOperands,
|
||||
symbolOperands)))
|
||||
if (failed(replaceAllMemRefUsesWith(
|
||||
oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
|
||||
symbolOperands, allowNonDereferencingOps)))
|
||||
llvm_unreachable("memref replacement guaranteed to succeed here");
|
||||
}
|
||||
|
||||
|
@ -385,76 +396,32 @@ void mlir::createAffineComputationSlice(
|
|||
// TODO: Currently works for static memrefs with a single layout map.
|
||||
LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
|
||||
MemRefType memrefType = allocOp.getType();
|
||||
unsigned rank = memrefType.getRank();
|
||||
if (rank == 0)
|
||||
return success();
|
||||
|
||||
auto layoutMaps = memrefType.getAffineMaps();
|
||||
OpBuilder b(allocOp);
|
||||
if (layoutMaps.size() != 1)
|
||||
return failure();
|
||||
|
||||
AffineMap layoutMap = layoutMaps.front();
|
||||
|
||||
// Nothing to do for identity layout maps.
|
||||
if (layoutMap == b.getMultiDimIdentityMap(rank))
|
||||
return success();
|
||||
|
||||
// We don't do any checks for one-to-one'ness; we assume that it is
|
||||
// one-to-one.
|
||||
|
||||
// TODO: Only for static memref's for now.
|
||||
if (memrefType.getNumDynamicDims() > 0)
|
||||
return failure();
|
||||
|
||||
// We have a single map that is not an identity map. Create a new memref with
|
||||
// the right shape and an identity layout map.
|
||||
auto shape = memrefType.getShape();
|
||||
FlatAffineConstraints fac(rank, allocOp.getNumSymbolicOperands());
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
fac.addConstantLowerBound(d, 0);
|
||||
fac.addConstantUpperBound(d, shape[d] - 1);
|
||||
}
|
||||
|
||||
// We compose this map with the original index (logical) space to derive the
|
||||
// upper bounds for the new index space.
|
||||
unsigned newRank = layoutMap.getNumResults();
|
||||
if (failed(fac.composeMatchingMap(layoutMap)))
|
||||
// TODO: semi-affine maps.
|
||||
return failure();
|
||||
|
||||
// Project out the old data dimensions.
|
||||
fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
|
||||
SmallVector<int64_t, 4> newShape(newRank);
|
||||
for (unsigned d = 0; d < newRank; ++d) {
|
||||
// The lower bound for the shape is always zero.
|
||||
auto ubConst = fac.getConstantUpperBound(d);
|
||||
// For a static memref and an affine map with no symbols, this is always
|
||||
// bounded.
|
||||
assert(ubConst.hasValue() && "should always have an upper bound");
|
||||
if (ubConst.getValue() < 0)
|
||||
// This is due to an invalid map that maps to a negative space.
|
||||
return failure();
|
||||
newShape[d] = ubConst.getValue() + 1;
|
||||
}
|
||||
|
||||
auto oldMemRef = allocOp.getResult();
|
||||
SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
|
||||
|
||||
// Fetch a new memref type after normalizing the old memref to have an
|
||||
// identity map layout.
|
||||
MemRefType newMemRefType =
|
||||
MemRefType::Builder(memrefType)
|
||||
.setShape(newShape)
|
||||
.setAffineMaps(b.getMultiDimIdentityMap(newRank));
|
||||
normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands());
|
||||
if (newMemRefType == memrefType)
|
||||
// Either memrefType already had an identity map or the map couldn't be
|
||||
// transformed to an identity map.
|
||||
return failure();
|
||||
|
||||
auto newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType, llvm::None,
|
||||
allocOp.alignmentAttr());
|
||||
Value oldMemRef = allocOp.getResult();
|
||||
|
||||
SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
|
||||
AllocOp newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType,
|
||||
llvm::None, allocOp.alignmentAttr());
|
||||
AffineMap layoutMap = memrefType.getAffineMaps().front();
|
||||
// Replace all uses of the old memref.
|
||||
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
|
||||
/*extraIndices=*/{},
|
||||
/*indexRemap=*/layoutMap,
|
||||
/*extraOperands=*/{},
|
||||
/*symbolOperands=*/symbolOperands))) {
|
||||
/*symbolOperands=*/symbolOperands,
|
||||
/*domInstFilter=*/nullptr,
|
||||
/*postDomInstFilter=*/nullptr,
|
||||
/*allowDereferencingOps=*/true))) {
|
||||
// If it failed (due to escapes for example), bail out.
|
||||
newAlloc.erase();
|
||||
return failure();
|
||||
|
@ -467,3 +434,64 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) {
|
|||
allocOp.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
|
||||
unsigned numSymbolicOperands) {
|
||||
unsigned rank = memrefType.getRank();
|
||||
if (rank == 0)
|
||||
return memrefType;
|
||||
|
||||
ArrayRef<AffineMap> layoutMaps = memrefType.getAffineMaps();
|
||||
if (layoutMaps.empty() ||
|
||||
layoutMaps.front() == b.getMultiDimIdentityMap(rank)) {
|
||||
// Either no maps is associated with this memref or this memref has
|
||||
// a trivial (identity) map.
|
||||
return memrefType;
|
||||
}
|
||||
|
||||
// We don't do any checks for one-to-one'ness; we assume that it is
|
||||
// one-to-one.
|
||||
|
||||
// TODO: Only for static memref's for now.
|
||||
if (memrefType.getNumDynamicDims() > 0)
|
||||
return memrefType;
|
||||
|
||||
// We have a single map that is not an identity map. Create a new memref
|
||||
// with the right shape and an identity layout map.
|
||||
ArrayRef<int64_t> shape = memrefType.getShape();
|
||||
// FlatAffineConstraint may later on use symbolicOperands.
|
||||
FlatAffineConstraints fac(rank, numSymbolicOperands);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
fac.addConstantLowerBound(d, 0);
|
||||
fac.addConstantUpperBound(d, shape[d] - 1);
|
||||
}
|
||||
// We compose this map with the original index (logical) space to derive
|
||||
// the upper bounds for the new index space.
|
||||
AffineMap layoutMap = layoutMaps.front();
|
||||
unsigned newRank = layoutMap.getNumResults();
|
||||
if (failed(fac.composeMatchingMap(layoutMap)))
|
||||
return memrefType;
|
||||
// TODO: Handle semi-affine maps.
|
||||
// Project out the old data dimensions.
|
||||
fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
|
||||
SmallVector<int64_t, 4> newShape(newRank);
|
||||
for (unsigned d = 0; d < newRank; ++d) {
|
||||
// The lower bound for the shape is always zero.
|
||||
auto ubConst = fac.getConstantUpperBound(d);
|
||||
// For a static memref and an affine map with no symbols, this is
|
||||
// always bounded.
|
||||
assert(ubConst.hasValue() && "should always have an upper bound");
|
||||
if (ubConst.getValue() < 0)
|
||||
// This is due to an invalid map that maps to a negative space.
|
||||
return memrefType;
|
||||
newShape[d] = ubConst.getValue() + 1;
|
||||
}
|
||||
|
||||
// Create the new memref type after trivializing the old layout map.
|
||||
MemRefType newMemRefType =
|
||||
MemRefType::Builder(memrefType)
|
||||
.setShape(newShape)
|
||||
.setAffineMaps(b.getMultiDimIdentityMap(newRank));
|
||||
|
||||
return newMemRefType;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -simplify-affine-structures %s | FileCheck %s
|
||||
// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s
|
||||
|
||||
// This file tests whether the memref type having non-trivial map layouts
|
||||
// are normalized to trivial (identity) layouts.
|
||||
|
||||
// CHECK-LABEL: func @permute()
|
||||
func @permute() {
|
||||
|
@ -150,3 +153,61 @@ func @alignment() {
|
|||
// CHECK-NEXT: alloc() {alignment = 32 : i64} : memref<256x64x128xf32>
|
||||
return
|
||||
}
|
||||
|
||||
#tile = affine_map < (i)->(i floordiv 4, i mod 4) >
|
||||
|
||||
// Following test cases check the inter-procedural memref normalization.
|
||||
|
||||
// Test case 1: Check normalization for multiple memrefs in a function argument list.
|
||||
// CHECK-LABEL: func @multiple_argument_type
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>, %[[D:arg[0-9]+]]: memref<24xf64>) -> f64
|
||||
func @multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>, %D: memref<24xf64>) -> f64 {
|
||||
%a = affine.load %A[0] : memref<16xf64, #tile>
|
||||
%p = mulf %a, %a : f64
|
||||
affine.store %p, %A[10] : memref<16xf64, #tile>
|
||||
call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
|
||||
return %B : f64
|
||||
}
|
||||
|
||||
// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
|
||||
// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
|
||||
// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64>
|
||||
// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
|
||||
// CHECK: return %[[B]] : f64
|
||||
|
||||
// Test case 2: Check normalization for single memref argument in a function.
|
||||
// CHECK-LABEL: func @single_argument_type
|
||||
// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>)
|
||||
func @single_argument_type(%C : memref<8xf64, #tile>) {
|
||||
%a = alloc(): memref<8xf64, #tile>
|
||||
%b = alloc(): memref<16xf64, #tile>
|
||||
%d = constant 23.0 : f64
|
||||
%e = alloc(): memref<24xf64>
|
||||
call @single_argument_type(%a): (memref<8xf64, #tile>) -> ()
|
||||
call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
|
||||
call @multiple_argument_type(%b, %d, %a, %e): (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) -> f64
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
|
||||
// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
|
||||
// CHECK: %cst = constant 2.300000e+01 : f64
|
||||
// CHECK: %[[e:[0-9]+]] = alloc() : memref<24xf64>
|
||||
// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> ()
|
||||
// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
|
||||
// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64
|
||||
|
||||
// Test case 3: Check function returning any other type except memref.
|
||||
// CHECK-LABEL: func @non_memref_ret
|
||||
// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> i1
|
||||
func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
|
||||
%d = constant 1 : i1
|
||||
return %d : i1
|
||||
}
|
||||
|
||||
// Test case 4: No normalization should take place because the function is returning the memref.
|
||||
// CHECK-LABEL: func @memref_used_in_return
|
||||
// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) -> memref<8xf64, #map{{[0-9]+}}>
|
||||
func @memref_used_in_return(%A: memref<8xf64, #tile>) -> (memref<8xf64, #tile>) {
|
||||
return %A : memref<8xf64, #tile>
|
||||
}
|
Loading…
Reference in New Issue