[flang] Avoid losing type parameter information

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D127738

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Valentin Clement 2022-06-14 15:57:24 +02:00
parent 0c66a4ce0a
commit 3260d42398
No known key found for this signature in database
GPG Key ID: 086D54783C928776
1 changed files with 48 additions and 40 deletions

View File

@ -667,11 +667,11 @@ amendingAccess(llvm::ArrayRef<mlir::Operation *> mentions) {
return {};
}
// Are either of types of conflicts present?
inline bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
llvm::ArrayRef<mlir::Operation *> accesses,
// Are any conflicts present? The conflicts detected here are described above.
static bool conflictDetected(llvm::ArrayRef<mlir::Operation *> reach,
llvm::ArrayRef<mlir::Operation *> mentions,
ArrayMergeStoreOp st) {
return conflictOnLoad(reach, st) || conflictOnMerge(accesses);
return conflictOnLoad(reach, st) || conflictOnMerge(mentions);
}
// Assume that any call to a function that uses host-associations will be
@ -871,21 +871,33 @@ static mlir::Type toRefType(mlir::Type ty) {
return fir::ReferenceType::get(ty);
}
static mlir::Value
genCoorOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type eleTy,
mlir::Type resTy, mlir::Value alloc, mlir::Value shape,
mlir::Value slice, mlir::ValueRange indices,
mlir::ValueRange typeparams, bool skipOrig = false) {
static llvm::SmallVector<mlir::Value>
getTypeParamsIfRawData(mlir::Location loc, FirOpBuilder &builder,
ArrayLoadOp arrLoad, mlir::Type ty) {
if (ty.isa<BoxType>())
return {};
return fir::factory::getTypeParams(loc, builder, arrLoad);
}
static mlir::Value genCoorOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type eleTy,
mlir::Type resTy, mlir::Value alloc,
mlir::Value shape, mlir::Value slice,
mlir::ValueRange indices, ArrayLoadOp load,
bool skipOrig = false) {
llvm::SmallVector<mlir::Value> originated;
if (skipOrig)
originated.assign(indices.begin(), indices.end());
else
originated = fir::factory::originateIndices(loc, rewriter, alloc.getType(),
shape, indices);
auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(alloc.getType());
assert(seqTy && seqTy.isa<fir::SequenceType>());
const auto dimension = seqTy.cast<fir::SequenceType>().getDimension();
mlir::Value result = rewriter.create<fir::ArrayCoorOp>(
originated = factory::originateIndices(loc, rewriter, alloc.getType(),
shape, indices);
auto seqTy = dyn_cast_ptrOrBoxEleTy(alloc.getType());
assert(seqTy && seqTy.isa<SequenceType>());
const auto dimension = seqTy.cast<SequenceType>().getDimension();
auto module = load->getParentOfType<mlir::ModuleOp>();
FirOpBuilder builder(rewriter, getKindMapping(module));
auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType());
mlir::Value result = rewriter.create<ArrayCoorOp>(
loc, eleTy, alloc, shape, slice,
llvm::ArrayRef<mlir::Value>{originated}.take_front(dimension),
typeparams);
@ -946,20 +958,19 @@ void genArrayCopy(mlir::Location loc, mlir::PatternRewriter &rewriter,
}
// Reverse the indices so they are in column-major order.
std::reverse(indices.begin(), indices.end());
auto typeparams = arrLoad.getTypeparams();
auto module = arrLoad->getParentOfType<mlir::ModuleOp>();
FirOpBuilder builder(rewriter, getKindMapping(module));
auto fromAddr = rewriter.create<ArrayCoorOp>(
loc, getEleTy(src.getType()), src, shapeOp,
CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
factory::originateIndices(loc, rewriter, src.getType(), shapeOp, indices),
typeparams);
getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
auto toAddr = rewriter.create<ArrayCoorOp>(
loc, getEleTy(dst.getType()), dst, shapeOp,
!CopyIn && copyUsingSlice ? sliceOp : mlir::Value{},
factory::originateIndices(loc, rewriter, dst.getType(), shapeOp, indices),
typeparams);
getTypeParamsIfRawData(loc, builder, arrLoad, src.getType()));
auto eleTy = unwrapSequenceType(unwrapPassByRefType(dst.getType()));
auto module = toAddr->getParentOfType<mlir::ModuleOp>();
FirOpBuilder builder(rewriter, getKindMapping(module));
// Copy from (to) object to (from) temp copy of same object.
if (auto charTy = eleTy.dyn_cast<CharacterType>()) {
auto len = getCharacterLen(loc, builder, arrLoad, charTy);
@ -1093,11 +1104,10 @@ public:
load);
// Generate the reference for the access.
rewriter.setInsertionPoint(op);
auto coor =
genCoorOp(rewriter, loc, getEleTy(load.getType()), eleTy, allocmem,
shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
access.getIndices(), load.getTypeparams(),
access->hasAttr(factory::attrFortranArrayOffsets()));
auto coor = genCoorOp(
rewriter, loc, getEleTy(load.getType()), eleTy, allocmem, shapeOp,
copyUsingSlice ? mlir::Value{} : load.getSlice(), access.getIndices(),
load, access->hasAttr(factory::attrFortranArrayOffsets()));
// Copy out.
auto *storeOp = useMap.lookup(loadOp);
auto store = mlir::cast<ArrayMergeStoreOp>(storeOp);
@ -1145,7 +1155,7 @@ public:
auto coor = genCoorOp(
rewriter, loc, getEleTy(load.getType()), lhsEltRefType, allocmem,
shapeOp, copyUsingSlice ? mlir::Value{} : load.getSlice(),
update.getIndices(), load.getTypeparams(),
update.getIndices(), load,
update->hasAttr(factory::attrFortranArrayOffsets()));
assignElement(coor);
auto *storeOp = useMap.lookup(loadOp);
@ -1163,10 +1173,10 @@ public:
LLVM_DEBUG(llvm::outs() << "No, conflict wasn't found\n");
rewriter.setInsertionPoint(op);
auto coorTy = getEleTy(load.getType());
auto coor = genCoorOp(rewriter, loc, coorTy, lhsEltRefType,
load.getMemref(), load.getShape(), load.getSlice(),
update.getIndices(), load.getTypeparams(),
update->hasAttr(factory::attrFortranArrayOffsets()));
auto coor =
genCoorOp(rewriter, loc, coorTy, lhsEltRefType, load.getMemref(),
load.getShape(), load.getSlice(), update.getIndices(), load,
update->hasAttr(factory::attrFortranArrayOffsets()));
assignElement(coor);
return {coor, load.getResult()};
}
@ -1240,11 +1250,10 @@ public:
rewriter.setInsertionPoint(op);
auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
auto loc = fetch.getLoc();
auto coor =
genCoorOp(rewriter, loc, getEleTy(load.getType()),
toRefType(fetch.getType()), load.getMemref(), load.getShape(),
load.getSlice(), fetch.getIndices(), load.getTypeparams(),
fetch->hasAttr(factory::attrFortranArrayOffsets()));
auto coor = genCoorOp(
rewriter, loc, getEleTy(load.getType()), toRefType(fetch.getType()),
load.getMemref(), load.getShape(), load.getSlice(), fetch.getIndices(),
load, fetch->hasAttr(factory::attrFortranArrayOffsets()));
if (isa_ref_type(fetch.getType()))
rewriter.replaceOp(fetch, coor);
else
@ -1280,11 +1289,10 @@ public:
}
rewriter.setInsertionPoint(op);
auto load = mlir::cast<ArrayLoadOp>(useMap.lookup(op));
auto coor = genCoorOp(rewriter, loc, getEleTy(load.getType()),
toRefType(access.getType()), load.getMemref(),
load.getShape(), load.getSlice(), access.getIndices(),
load.getTypeparams(),
access->hasAttr(factory::attrFortranArrayOffsets()));
auto coor = genCoorOp(
rewriter, loc, getEleTy(load.getType()), toRefType(access.getType()),
load.getMemref(), load.getShape(), load.getSlice(), access.getIndices(),
load, access->hasAttr(factory::attrFortranArrayOffsets()));
rewriter.replaceOp(access, coor);
return mlir::success();
}