forked from OSchip/llvm-project
523 lines
23 KiB
C++
523 lines
23 KiB
C++
//===- NormalizeMemRefs.cpp -----------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements an interprocedural pass to normalize memrefs to have
|
|
// identity layout maps.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "mlir/Transforms/Utils.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "normalize-memrefs"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
/// All memrefs passed across functions with non-trivial layout maps are
|
|
/// converted to ones with trivial identity layout ones.
|
|
/// If all the memref types/uses in a function are normalizable, we treat
|
|
/// such functions as normalizable. Also, if a normalizable function is known
|
|
/// to call a non-normalizable function, we treat that function as
|
|
/// non-normalizable as well. We assume external functions to be normalizable.
|
|
struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
|
|
void runOnOperation() override;
|
|
void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
|
|
bool areMemRefsNormalizable(FuncOp funcOp);
|
|
void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
|
|
void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
|
|
DenseSet<FuncOp> &normalizableFuncs);
|
|
Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp);
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
|
|
return std::make_unique<NormalizeMemRefs>();
|
|
}
|
|
|
|
void NormalizeMemRefs::runOnOperation() {
|
|
LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
|
|
ModuleOp moduleOp = getOperation();
|
|
// We maintain all normalizable FuncOps in a DenseSet. It is initialized
|
|
// with all the functions within a module and then functions which are not
|
|
// normalizable are removed from this set.
|
|
// TODO: Change this to work on FuncLikeOp once there is an operation
|
|
// interface for it.
|
|
DenseSet<FuncOp> normalizableFuncs;
|
|
// Initialize `normalizableFuncs` with all the functions within a module.
|
|
moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
|
|
|
|
// Traverse through all the functions applying a filter which determines
|
|
// whether that function is normalizable or not. All callers/callees of
|
|
// a non-normalizable function will also become non-normalizable even if
|
|
// they aren't passing any or specific non-normalizable memrefs. So,
|
|
// functions which calls or get called by a non-normalizable becomes non-
|
|
// normalizable functions themselves.
|
|
moduleOp.walk([&](FuncOp funcOp) {
|
|
if (normalizableFuncs.contains(funcOp)) {
|
|
if (!areMemRefsNormalizable(funcOp)) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "@" << funcOp.getName()
|
|
<< " contains ops that cannot normalize MemRefs\n");
|
|
// Since this function is not normalizable, we set all the caller
|
|
// functions and the callees of this function as not normalizable.
|
|
// TODO: Drop this conservative assumption in the future.
|
|
setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
|
|
normalizableFuncs);
|
|
}
|
|
}
|
|
});
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
|
|
<< " functions\n");
|
|
// Those functions which can be normalized are subjected to normalization.
|
|
for (FuncOp &funcOp : normalizableFuncs)
|
|
normalizeFuncOpMemRefs(funcOp, moduleOp);
|
|
}
|
|
|
|
/// Check whether all the uses of oldMemRef are either dereferencing uses or the
|
|
/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
|
|
/// are satisfied will the value become a candidate for replacement.
|
|
/// TODO: Extend this for DimOps.
|
|
static bool isMemRefNormalizable(Value::user_range opUsers) {
|
|
if (llvm::any_of(opUsers, [](Operation *op) {
|
|
if (op->hasTrait<OpTrait::MemRefsNormalizable>())
|
|
return false;
|
|
return true;
|
|
}))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
/// Set all the calling functions and the callees of the function as not
|
|
/// normalizable.
|
|
void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
|
|
FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
|
|
if (!normalizableFuncs.contains(funcOp))
|
|
return;
|
|
|
|
LLVM_DEBUG(
|
|
llvm::dbgs() << "@" << funcOp.getName()
|
|
<< " calls or is called by non-normalizable function\n");
|
|
normalizableFuncs.erase(funcOp);
|
|
// Caller of the function.
|
|
Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
|
|
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
|
|
// TODO: Extend this for ops that are FunctionLike. This would require
|
|
// creating an OpInterface for FunctionLike ops.
|
|
FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
|
|
for (FuncOp &funcOp : normalizableFuncs) {
|
|
if (parentFuncOp == funcOp) {
|
|
setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
|
|
normalizableFuncs);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Functions called by this function.
|
|
funcOp.walk([&](CallOp callOp) {
|
|
StringAttr callee = callOp.getCalleeAttr().getAttr();
|
|
for (FuncOp &funcOp : normalizableFuncs) {
|
|
// We compare FuncOp and callee's name.
|
|
if (callee == funcOp.getNameAttr()) {
|
|
setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
|
|
normalizableFuncs);
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Check whether all the uses of AllocOps, CallOps and function arguments of a
|
|
/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
|
|
/// or ReturnOp. Only if these constraints are satisfied will the function
|
|
/// become a candidate for normalization. We follow a conservative approach here
|
|
/// wherein even if the non-normalizable memref is not a part of the function's
|
|
/// argument or return type, we still label the entire function as
|
|
/// non-normalizable. We assume external functions to be normalizable.
|
|
bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
|
|
// We assume external functions to be normalizable.
|
|
if (funcOp.isExternal())
|
|
return true;
|
|
|
|
if (funcOp
|
|
.walk([&](memref::AllocOp allocOp) -> WalkResult {
|
|
Value oldMemRef = allocOp.getResult();
|
|
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
|
return WalkResult::interrupt();
|
|
return WalkResult::advance();
|
|
})
|
|
.wasInterrupted())
|
|
return false;
|
|
|
|
if (funcOp
|
|
.walk([&](CallOp callOp) -> WalkResult {
|
|
for (unsigned resIndex :
|
|
llvm::seq<unsigned>(0, callOp.getNumResults())) {
|
|
Value oldMemRef = callOp.getResult(resIndex);
|
|
if (oldMemRef.getType().isa<MemRefType>())
|
|
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
|
return WalkResult::interrupt();
|
|
}
|
|
return WalkResult::advance();
|
|
})
|
|
.wasInterrupted())
|
|
return false;
|
|
|
|
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
|
|
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
|
|
if (oldMemRef.getType().isa<MemRefType>())
|
|
if (!isMemRefNormalizable(oldMemRef.getUsers()))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Fetch the updated argument list and result of the function and update the
|
|
/// function signature. This updates the function's return type at the caller
|
|
/// site and in case the return type is a normalized memref then it updates
|
|
/// the calling function's signature.
|
|
/// TODO: An update to the calling function signature is required only if the
|
|
/// returned value is in turn used in ReturnOp of the calling function.
|
|
void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
|
|
ModuleOp moduleOp) {
|
|
FunctionType functionType = funcOp.getType();
|
|
SmallVector<Type, 4> resultTypes;
|
|
FunctionType newFuncType;
|
|
resultTypes = llvm::to_vector<4>(functionType.getResults());
|
|
|
|
// External function's signature was already updated in
|
|
// 'normalizeFuncOpMemRefs()'.
|
|
if (!funcOp.isExternal()) {
|
|
SmallVector<Type, 8> argTypes;
|
|
for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
|
|
argTypes.push_back(argEn.value().getType());
|
|
|
|
// Traverse ReturnOps to check if an update to the return type in the
|
|
// function signature is required.
|
|
funcOp.walk([&](ReturnOp returnOp) {
|
|
for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
|
|
Type opType = operandEn.value().getType();
|
|
MemRefType memrefType = opType.dyn_cast<MemRefType>();
|
|
// If type is not memref or if the memref type is same as that in
|
|
// function's return signature then no update is required.
|
|
if (!memrefType || memrefType == resultTypes[operandEn.index()])
|
|
continue;
|
|
// Update function's return type signature.
|
|
// Return type gets normalized either as a result of function argument
|
|
// normalization, AllocOp normalization or an update made at CallOp.
|
|
// There can be many call flows inside a function and an update to a
|
|
// specific ReturnOp has not yet been made. So we check that the result
|
|
// memref type is normalized.
|
|
// TODO: When selective normalization is implemented, handle multiple
|
|
// results case where some are normalized, some aren't.
|
|
if (memrefType.getLayout().isIdentity())
|
|
resultTypes[operandEn.index()] = memrefType;
|
|
}
|
|
});
|
|
|
|
// We create a new function type and modify the function signature with this
|
|
// new type.
|
|
newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
|
|
/*results=*/resultTypes);
|
|
}
|
|
|
|
// Since we update the function signature, it might affect the result types at
|
|
// the caller site. Since this result might even be used by the caller
|
|
// function in ReturnOps, the caller function's signature will also change.
|
|
// Hence we record the caller function in 'funcOpsToUpdate' to update their
|
|
// signature as well.
|
|
llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
|
|
// We iterate over all symbolic uses of the function and update the return
|
|
// type at the caller site.
|
|
Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
|
|
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
|
|
Operation *userOp = symbolUse.getUser();
|
|
OpBuilder builder(userOp);
|
|
// When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
|
|
// that the non-CallOp has no memrefs to be replaced.
|
|
// TODO: Handle cases where a non-CallOp symbol use of a function deals with
|
|
// memrefs.
|
|
auto callOp = dyn_cast<CallOp>(userOp);
|
|
if (!callOp)
|
|
continue;
|
|
Operation *newCallOp =
|
|
builder.create<CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
|
|
resultTypes, userOp->getOperands());
|
|
bool replacingMemRefUsesFailed = false;
|
|
bool returnTypeChanged = false;
|
|
for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
|
|
OpResult oldResult = userOp->getResult(resIndex);
|
|
OpResult newResult = newCallOp->getResult(resIndex);
|
|
// This condition ensures that if the result is not of type memref or if
|
|
// the resulting memref was already having a trivial map layout then we
|
|
// need not perform any use replacement here.
|
|
if (oldResult.getType() == newResult.getType())
|
|
continue;
|
|
AffineMap layoutMap =
|
|
oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
|
|
if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
|
|
/*extraIndices=*/{},
|
|
/*indexRemap=*/layoutMap,
|
|
/*extraOperands=*/{},
|
|
/*symbolOperands=*/{},
|
|
/*domInstFilter=*/nullptr,
|
|
/*postDomInstFilter=*/nullptr,
|
|
/*allowDereferencingOps=*/true,
|
|
/*replaceInDeallocOp=*/true))) {
|
|
// If it failed (due to escapes for example), bail out.
|
|
// It should never hit this part of the code because it is called by
|
|
// only those functions which are normalizable.
|
|
newCallOp->erase();
|
|
replacingMemRefUsesFailed = true;
|
|
break;
|
|
}
|
|
returnTypeChanged = true;
|
|
}
|
|
if (replacingMemRefUsesFailed)
|
|
continue;
|
|
// Replace all uses for other non-memref result types.
|
|
userOp->replaceAllUsesWith(newCallOp);
|
|
userOp->erase();
|
|
if (returnTypeChanged) {
|
|
// Since the return type changed it might lead to a change in function's
|
|
// signature.
|
|
// TODO: If funcOp doesn't return any memref type then no need to update
|
|
// signature.
|
|
// TODO: Further optimization - Check if the memref is indeed part of
|
|
// ReturnOp at the parentFuncOp and only then updation of signature is
|
|
// required.
|
|
// TODO: Extend this for ops that are FunctionLike. This would require
|
|
// creating an OpInterface for FunctionLike ops.
|
|
FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
|
|
funcOpsToUpdate.insert(parentFuncOp);
|
|
}
|
|
}
|
|
// Because external function's signature is already updated in
|
|
// 'normalizeFuncOpMemRefs()', we don't need to update it here again.
|
|
if (!funcOp.isExternal())
|
|
funcOp.setType(newFuncType);
|
|
|
|
// Updating the signature type of those functions which call the current
|
|
// function. Only if the return type of the current function has a normalized
|
|
// memref will the caller function become a candidate for signature update.
|
|
for (FuncOp parentFuncOp : funcOpsToUpdate)
|
|
updateFunctionSignature(parentFuncOp, moduleOp);
|
|
}
|
|
|
|
/// Normalizes the memrefs within a function which includes those arising as a
|
|
/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
|
|
/// is used to help update function's signature after normalization.
|
|
void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
|
|
ModuleOp moduleOp) {
|
|
// Turn memrefs' non-identity layouts maps into ones with identity. Collect
|
|
// alloc ops first and then process since normalizeMemRef replaces/erases ops
|
|
// during memref rewriting.
|
|
SmallVector<memref::AllocOp, 4> allocOps;
|
|
funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
|
|
for (memref::AllocOp allocOp : allocOps)
|
|
(void)normalizeMemRef(&allocOp);
|
|
|
|
// We use this OpBuilder to create new memref layout later.
|
|
OpBuilder b(funcOp);
|
|
|
|
FunctionType functionType = funcOp.getType();
|
|
SmallVector<Type, 8> inputTypes;
|
|
// Walk over each argument of a function to perform memref normalization (if
|
|
for (unsigned argIndex :
|
|
llvm::seq<unsigned>(0, functionType.getNumInputs())) {
|
|
Type argType = functionType.getInput(argIndex);
|
|
MemRefType memrefType = argType.dyn_cast<MemRefType>();
|
|
// Check whether argument is of MemRef type. Any other argument type can
|
|
// simply be part of the final function signature.
|
|
if (!memrefType) {
|
|
inputTypes.push_back(argType);
|
|
continue;
|
|
}
|
|
// Fetch a new memref type after normalizing the old memref to have an
|
|
// identity map layout.
|
|
MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
|
|
/*numSymbolicOperands=*/0);
|
|
if (newMemRefType == memrefType || funcOp.isExternal()) {
|
|
// Either memrefType already had an identity map or the map couldn't be
|
|
// transformed to an identity map.
|
|
inputTypes.push_back(newMemRefType);
|
|
continue;
|
|
}
|
|
|
|
// Insert a new temporary argument with the new memref type.
|
|
BlockArgument newMemRef =
|
|
funcOp.front().insertArgument(argIndex, newMemRefType);
|
|
BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
|
|
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
|
|
// Replace all uses of the old memref.
|
|
if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
|
|
/*extraIndices=*/{},
|
|
/*indexRemap=*/layoutMap,
|
|
/*extraOperands=*/{},
|
|
/*symbolOperands=*/{},
|
|
/*domInstFilter=*/nullptr,
|
|
/*postDomInstFilter=*/nullptr,
|
|
/*allowNonDereferencingOps=*/true,
|
|
/*replaceInDeallocOp=*/true))) {
|
|
// If it failed (due to escapes for example), bail out. Removing the
|
|
// temporary argument inserted previously.
|
|
funcOp.front().eraseArgument(argIndex);
|
|
continue;
|
|
}
|
|
|
|
// All uses for the argument with old memref type were replaced
|
|
// successfully. So we remove the old argument now.
|
|
funcOp.front().eraseArgument(argIndex + 1);
|
|
}
|
|
|
|
// Walk over normalizable operations to normalize memrefs of the operation
|
|
// results. When `op` has memrefs with affine map in the operation results,
|
|
// new operation containin normalized memrefs is created. Then, the memrefs
|
|
// are replaced. `CallOp` is skipped here because it is handled in
|
|
// `updateFunctionSignature()`.
|
|
funcOp.walk([&](Operation *op) {
|
|
if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
|
|
op->getNumResults() > 0 && !isa<CallOp>(op) && !funcOp.isExternal()) {
|
|
// Create newOp containing normalized memref in the operation result.
|
|
Operation *newOp = createOpResultsNormalized(funcOp, op);
|
|
// When all of the operation results have no memrefs or memrefs without
|
|
// affine map, `newOp` is the same with `op` and following process is
|
|
// skipped.
|
|
if (op != newOp) {
|
|
bool replacingMemRefUsesFailed = false;
|
|
for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
|
|
// Replace all uses of the old memrefs.
|
|
Value oldMemRef = op->getResult(resIndex);
|
|
Value newMemRef = newOp->getResult(resIndex);
|
|
MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
|
|
// Check whether the operation result is MemRef type.
|
|
if (!oldMemRefType)
|
|
continue;
|
|
MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
|
|
if (oldMemRefType == newMemRefType)
|
|
continue;
|
|
// TODO: Assume single layout map. Multiple maps not supported.
|
|
AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
|
|
if (failed(replaceAllMemRefUsesWith(oldMemRef,
|
|
/*newMemRef=*/newMemRef,
|
|
/*extraIndices=*/{},
|
|
/*indexRemap=*/layoutMap,
|
|
/*extraOperands=*/{},
|
|
/*symbolOperands=*/{},
|
|
/*domInstFilter=*/nullptr,
|
|
/*postDomInstFilter=*/nullptr,
|
|
/*allowDereferencingOps=*/true,
|
|
/*replaceInDeallocOp=*/true))) {
|
|
newOp->erase();
|
|
replacingMemRefUsesFailed = true;
|
|
continue;
|
|
}
|
|
}
|
|
if (!replacingMemRefUsesFailed) {
|
|
// Replace other ops with new op and delete the old op when the
|
|
// replacement succeeded.
|
|
op->replaceAllUsesWith(newOp);
|
|
op->erase();
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// In a normal function, memrefs in the return type signature gets normalized
|
|
// as a result of normalization of functions arguments, AllocOps or CallOps'
|
|
// result types. Since an external function doesn't have a body, memrefs in
|
|
// the return type signature can only get normalized by iterating over the
|
|
// individual return types.
|
|
if (funcOp.isExternal()) {
|
|
SmallVector<Type, 4> resultTypes;
|
|
for (unsigned resIndex :
|
|
llvm::seq<unsigned>(0, functionType.getNumResults())) {
|
|
Type resType = functionType.getResult(resIndex);
|
|
MemRefType memrefType = resType.dyn_cast<MemRefType>();
|
|
// Check whether result is of MemRef type. Any other argument type can
|
|
// simply be part of the final function signature.
|
|
if (!memrefType) {
|
|
resultTypes.push_back(resType);
|
|
continue;
|
|
}
|
|
// Computing a new memref type after normalizing the old memref to have an
|
|
// identity map layout.
|
|
MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
|
|
/*numSymbolicOperands=*/0);
|
|
resultTypes.push_back(newMemRefType);
|
|
}
|
|
|
|
FunctionType newFuncType =
|
|
FunctionType::get(&getContext(), /*inputs=*/inputTypes,
|
|
/*results=*/resultTypes);
|
|
// Setting the new function signature for this external function.
|
|
funcOp.setType(newFuncType);
|
|
}
|
|
updateFunctionSignature(funcOp, moduleOp);
|
|
}
|
|
|
|
/// Create an operation containing normalized memrefs in the operation results.
|
|
/// When the results of `oldOp` have memrefs with affine map, the memrefs are
|
|
/// normalized, and new operation containing them in the operation results is
|
|
/// returned. If all of the results of `oldOp` have no memrefs or memrefs
|
|
/// without affine map, `oldOp` is returned without modification.
|
|
Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
|
|
Operation *oldOp) {
|
|
// Prepare OperationState to create newOp containing normalized memref in
|
|
// the operation results.
|
|
OperationState result(oldOp->getLoc(), oldOp->getName());
|
|
result.addOperands(oldOp->getOperands());
|
|
result.addAttributes(oldOp->getAttrs());
|
|
// Add normalized MemRefType to the OperationState.
|
|
SmallVector<Type, 4> resultTypes;
|
|
OpBuilder b(funcOp);
|
|
bool resultTypeNormalized = false;
|
|
for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
|
|
auto resultType = oldOp->getResult(resIndex).getType();
|
|
MemRefType memrefType = resultType.dyn_cast<MemRefType>();
|
|
// Check whether the operation result is MemRef type.
|
|
if (!memrefType) {
|
|
resultTypes.push_back(resultType);
|
|
continue;
|
|
}
|
|
// Fetch a new memref type after normalizing the old memref.
|
|
MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
|
|
/*numSymbolicOperands=*/0);
|
|
if (newMemRefType == memrefType) {
|
|
// Either memrefType already had an identity map or the map couldn't
|
|
// be transformed to an identity map.
|
|
resultTypes.push_back(memrefType);
|
|
continue;
|
|
}
|
|
resultTypes.push_back(newMemRefType);
|
|
resultTypeNormalized = true;
|
|
}
|
|
result.addTypes(resultTypes);
|
|
// When all of the results of `oldOp` have no memrefs or memrefs without
|
|
// affine map, `oldOp` is returned without modification.
|
|
if (resultTypeNormalized) {
|
|
OpBuilder bb(oldOp);
|
|
for (auto &oldRegion : oldOp->getRegions()) {
|
|
Region *newRegion = result.addRegion();
|
|
newRegion->takeBody(oldRegion);
|
|
}
|
|
return bb.createOperation(result);
|
|
} else
|
|
return oldOp;
|
|
}
|