[flang] Support multidimensional reductions in SimplifyIntrinsicsPass.

Create simplified functions for each rank with "x<rank>" suffix
that implement multidimensional reductions. To enable this I had to fix
an issue with taking incorrect box shape in cases of sliced embox/rebox.

Differential Revision: https://reviews.llvm.org/D133820
This commit is contained in:
Slava Zakharin 2022-09-13 09:41:22 -07:00
parent 2b138567e0
commit 8bd76ac151
2 changed files with 267 additions and 80 deletions

View File

@ -61,7 +61,7 @@ class SimplifyIntrinsicsPass
using FunctionBodyGeneratorTy =
llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
using GenReductionBodyTy = llvm::function_ref<void(
fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
public:
/// Generate a new function implementing a simplified version
@ -110,10 +110,11 @@ using InitValGeneratorTy = llvm::function_ref<mlir::Value(
/// the reduction value
/// \p genBody is called to fill in the actual reduciton operation
/// for example add for SUM, MAX for MAXVAL, etc.
/// \p rank is the rank of the input argument.
static void genReductionLoop(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp,
InitValGeneratorTy initVal,
BodyOpGeneratorTy genBody) {
BodyOpGeneratorTy genBody, unsigned rank) {
auto loc = mlir::UnknownLoc::get(builder.getContext());
mlir::Type elementType = funcOp.getResultTypes()[0];
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
@ -125,59 +126,98 @@ static void genReductionLoop(fir::FirOpBuilder &builder,
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
fir::SequenceType::Shape flatShape(rank,
fir::SequenceType::getUnknownExtent());
mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
auto dims =
builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
mlir::Value len = dims.getResult(1);
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
mlir::Value step = one;
// We use C indexing here, so len-1 as loopcount
mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
mlir::Value init = initVal(builder, loc, elementType);
auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
/*unordered=*/false,
/*finalCountValue=*/false, init);
mlir::Value reductionVal = loop.getRegionIterArgs()[0];
// Begin loop code
mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loop.getBody());
llvm::SmallVector<mlir::Value, 15> bounds;
assert(rank > 0 && "rank cannot be zero");
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
// Compute all the upper bounds before the loop nest.
// It is not strictly necessary for performance, since the loop nest
// does not have any store operations and any LICM optimization
// should be able to optimize the redundancy.
for (unsigned i = 0; i < rank; ++i) {
mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
auto dims =
builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
mlir::Value len = dims.getResult(1);
// We use C indexing here, so len-1 as loopcount
mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
bounds.push_back(loopCount);
}
// Create a loop nest consisting of DoLoopOp operations.
// Collect the loops' induction variables into indices array,
// which will be used in the innermost loop to load the input
// array's element.
// The loops are generated such that the innermost loop processes
// the 0 dimension.
llvm::SmallVector<mlir::Value, 15> indices;
for (unsigned i = rank; 0 < i; --i) {
mlir::Value step = one;
mlir::Value loopCount = bounds[i - 1];
auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
/*unordered=*/false,
/*finalCountValue=*/false, init);
init = loop.getRegionIterArgs()[0];
indices.push_back(loop.getInductionVar());
// Set insertion point to the loop body so that the next loop
// is inserted inside the current one.
builder.setInsertionPointToStart(loop.getBody());
}
// Reverse the indices such that they are ordered as:
// <dim-0-idx, dim-1-idx, ...>
std::reverse(indices.begin(), indices.end());
// We are in the innermost loop: generate the reduction body.
mlir::Type eleRefTy = builder.getRefType(elementType);
mlir::Value index = loop.getInductionVar();
mlir::Value addr =
builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
builder.create<fir::ResultOp>(loc, reductionVal);
// End of loop.
builder.restoreInsertionPoint(loopEndPt);
// Unwind the loop nest and insert ResultOp on each level
// to return the updated value of the reduction to the enclosing
// loops.
for (unsigned i = 0; i < rank; ++i) {
auto result = builder.create<fir::ResultOp>(loc, reductionVal);
// Proceed to the outer loop.
auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
reductionVal = loop.getResult(0);
// Set insertion point after the loop operation that we have
// just processed.
builder.setInsertionPointAfter(loop.getOperation());
}
mlir::Value resultVal = loop.getResult(0);
builder.create<mlir::func::ReturnOp>(loc, resultVal);
// End of loop nest. The insertion point is after the outermost loop.
// Return the reduction value from the function.
builder.create<mlir::func::ReturnOp>(loc, reductionVal);
}
/// Generate function body of the simplified version of RTNAME(Sum)
/// with signature provided by \p funcOp. The caller is responsible
/// for saving/restoring the original insertion point of \p builder.
/// \p funcOp is expected to be empty on entry to this function.
/// \p rank specifies the rank of the input argument.
static void genRuntimeSumBody(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp) {
// function RTNAME(Sum)<T>_simplified(arr)
mlir::func::FuncOp &funcOp, unsigned rank) {
// function RTNAME(Sum)<T>x<rank>_simplified(arr)
// T, dimension(:) :: arr
// T sum = 0
// integer iter
// do iter = 0, extent(arr)
// sum = sum + arr[iter]
// end do
// RTNAME(Sum)<T>_simplified = sum
// end function RTNAME(Sum)<T>_simplified
// RTNAME(Sum)<T>x<rank>_simplified = sum
// end function RTNAME(Sum)<T>x<rank>_simplified
auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@ -200,11 +240,11 @@ static void genRuntimeSumBody(fir::FirOpBuilder &builder,
return {};
};
genReductionLoop(builder, funcOp, zero, genBodyOp);
genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
}
static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp) {
mlir::func::FuncOp &funcOp, unsigned rank) {
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@ -228,7 +268,7 @@ static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
llvm_unreachable("unsupported type");
return {};
};
genReductionLoop(builder, funcOp, init, genBodyOp);
genReductionLoop(builder, funcOp, init, genBodyOp, rank);
}
/// Generate function type for the simplified version of RTNAME(DotProduct)
@ -410,21 +450,31 @@ static bool isZero(mlir::Value val) {
return false;
}
static mlir::Value findShape(mlir::Value val) {
static mlir::Value findBoxDef(mlir::Value val) {
if (auto op = expectConvertOp(val)) {
assert(op->getOperands().size() != 0);
if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
op->getOperand(0).getDefiningOp()))
return box.getShape();
return box.getResult();
if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
op->getOperand(0).getDefiningOp()))
return box.getResult();
}
return {};
}
static unsigned getDimCount(mlir::Value val) {
if (mlir::Value shapeVal = findShape(val)) {
mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
return fir::getRankOfShapeType(resType);
}
// In order to find the dimensions count, we look for EmboxOp/ReboxOp
// and take the count from its *result* type. Note that in case
// of sliced emboxing the operand and the result of EmboxOp/ReboxOp
// have different types.
// Actually, we can take the box type from the operand of
// the first ConvertOp that has non-opaque box type that we meet
// going through the ConvertOp chain.
if (mlir::Value emboxVal = findBoxDef(val))
if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
return seqTy.getDimension();
return 0;
}
@ -455,7 +505,6 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
const fir::KindMapping &kindMap,
GenReductionBodyTy genBodyFunc) {
mlir::SymbolRefAttr callee = call.getCalleeAttr();
mlir::StringRef funcName = callee.getLeafReference().getValue();
mlir::Operation::operand_range args = call.getArgs();
// args[1] and args[2] are source filename and line number, ignored.
const mlir::Value &dim = args[3];
@ -464,7 +513,7 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
// detail in the runtime library.
bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
unsigned rank = getDimCount(args[0]);
if (dimAndMaskAbsent && rank == 1) {
if (dimAndMaskAbsent && rank > 0) {
mlir::Location loc = call.getLoc();
fir::FirOpBuilder builder(call, kindMap);
@ -483,8 +532,17 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
return genNoneBoxType(builder, resultType);
};
auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
mlir::func::FuncOp &funcOp) {
genBodyFunc(builder, funcOp, rank);
};
// Mangle the function name with the rank value as "x<rank>".
std::string funcName =
(mlir::Twine{callee.getLeafReference().getValue(), "x"} +
mlir::Twine{rank})
.str();
mlir::func::FuncOp newFunc =
getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
auto newCall =
builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
call->replaceAllUsesWith(newCall.getResults());

View File

@ -34,20 +34,21 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10xi32>>) -> !fir.box<none>
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: return %{{.*}} : i32
// CHECK: }
// CHECK: func.func private @_FortranASumInteger4(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32 attributes {fir.runtime}
// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified(
// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM:.*]] = %[[CI32_0]]) -> (i32) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
@ -59,7 +60,7 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// -----
// Call to SUM with 2D I32 arrays is not replaced.
// Call to SUM with 2D I32 arrays is replaced.
module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} {
func.func @sum_2d_array_int(%arg0: !fir.ref<!fir.array<10x10xi32>> {fir.bindc_name = "a"}) -> i32 {
%c10 = arith.constant 10 : index
@ -88,9 +89,39 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
}
// CHECK-LABEL: func.func @sum_2d_array_int({{.*}} !fir.ref<!fir.array<10x10xi32>> {fir.bindc_name = "a"}) -> i32 {
// CHECK-NOT: fir.call @_FortranASumInteger4_simplified({{.*}})
// CHECK: fir.call @_FortranASumInteger4({{.*}}) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
// CHECK-NOT: fir.call @_FortranASumInteger4_simplified({{.*}})
// CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index, index) -> !fir.shape<2>
// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10x10xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<10x10xi32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10x10xi32>>) -> !fir.box<none>
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4x2_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: return %{{.*}} : i32
// CHECK: }
// CHECK: func.func private @_FortranASumInteger4(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32 attributes {fir.runtime}
// CHECK-LABEL: func.func private @_FortranASumInteger4x2_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?x?xi32>>
// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS_0:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT_0:.*]] = arith.subi %[[DIMS_0]]#1, %[[CINDEX_1]] : index
// CHECK: %[[DIMIDX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMS_1:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_1]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT_1:.*]] = arith.subi %[[DIMS_1]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES_1:.*]] = fir.do_loop %[[ITER_1:.*]] = %[[CINDEX_0]] to %[[EXTENT_1]] step %[[CINDEX_1]] iter_args(%[[SUM_1:.*]] = %[[CI32_0]]) -> (i32) {
// CHECK: %[[RES_0:.*]] = fir.do_loop %[[ITER_0:.*]] = %[[CINDEX_0]] to %[[EXTENT_0]] step %[[CINDEX_1]] iter_args(%[[SUM_0:.*]] = %[[SUM_1]]) -> (i32) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER_0]], %[[ITER_1]] : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
// CHECK: %[[NEW_SUM:.*]] = arith.addi %[[ITEM_VAL]], %[[SUM_0]] : i32
// CHECK: fir.result %[[NEW_SUM]] : i32
// CHECK: }
// CHECK: fir.result %[[RES_0]]
// CHECK: }
// CHECK: return %[[RES_1]] : i32
// CHECK: }
// -----
@ -129,19 +160,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
// CHECK-NOT: fir.call @_FortranASumReal8({{.*}})
// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
// CHECK-NOT: fir.call @_FortranASumReal8({{.*}})
// CHECK: return %{{.*}} : f64
// CHECK: }
// CHECK-LABEL: func.func private @_FortranASumReal8_simplified(
// CHECK-LABEL: func.func private @_FortranASumReal8x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf64>>
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f64) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f64>
@ -188,19 +220,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[A_BOX_F32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F32]] : (!fir.box<!fir.array<10xf32>>) -> !fir.box<none>
// CHECK-NOT: fir.call @_FortranASumReal4({{.*}})
// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f32
// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f32
// CHECK-NOT: fir.call @_FortranASumReal4({{.*}})
// CHECK: return %{{.*}} : f32
// CHECK: }
// CHECK-LABEL: func.func private @_FortranASumReal4_simplified(
// CHECK-LABEL: func.func private @_FortranASumReal4x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_F32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf32>>
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f32) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F32]], %[[ITER]] : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f32>
@ -243,9 +276,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
}
// CHECK-LABEL: func.func @sum_1d_complex(%{{.*}}: !fir.ref<!fir.array<10x!fir.complex<4>>> {fir.bindc_name = "a"}) -> !fir.complex<4> {
// CHECK-NOT: fir.call @_FortranACppSumComplex4_simplified({{.*}})
// CHECK-NOT: fir.call @_FortranACppSumComplex4x1_simplified({{.*}})
// CHECK: fir.call @_FortranACppSumComplex4({{.*}}) : (!fir.ref<complex<f32>>, !fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> none
// CHECK-NOT: fir.call @_FortranACppSumComplex4_simplified({{.*}})
// CHECK-NOT: fir.call @_FortranACppSumComplex4x1_simplified({{.*}})
// -----
@ -298,20 +331,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK-LABEL: func.func @sum_1d_calla(%{{.*}}) -> i32 {
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: fir.call @_FortranASumInteger4_simplified(%{{.*}})
// CHECK: fir.call @_FortranASumInteger4x1_simplified(%{{.*}})
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: }
// CHECK-LABEL: func.func @sum_1d_callb(%{{.*}}) -> i32 {
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: fir.call @_FortranASumInteger4_simplified(%{{.*}})
// CHECK: fir.call @_FortranASumInteger4x1_simplified(%{{.*}})
// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}})
// CHECK: }
// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified({{.*}}) -> i32 {{.*}} {
// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified({{.*}}) -> i32 {{.*}} {
// CHECK: return %{{.*}} : i32
// CHECK: }
// CHECK-NOT: func.func private @_FortranASumInteger4_simplified({{.*}})
// CHECK-NOT: func.func private @_FortranASumInteger4x1_simplified({{.*}})
// -----
@ -354,14 +387,14 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[SLICE:.*]] = fir.slice %{{.*}}, %{{.*}}, %[[CINDEX_2]] : (index, index, index) -> !fir.slice<1>
// CHECK: %[[A_BOX_I32:.*]] = fir.embox %{{.*}}(%[[SHAPE]]) {{\[}}%[[SLICE]]] : (!fir.ref<!fir.array<20xi32>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
// CHECK: %{{.*}} = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK: %{{.*}} = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK: return %{{.*}} : i32
// CHECK: }
// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %{{.*}} : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %{{.*}} to %[[EXTENT]] step %[[CINDEX_1]] iter_args({{.*}}) -> (i32) {
// CHECK: %{{.*}} = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
@ -792,18 +825,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box<!fir.array<10xi32>>) -> !fir.box<none>
// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> i32
// CHECK: return %{{.*}} : i32
// CHECK: }
// CHECK-LABEL: func.func private @_FortranAMaxvalInteger4_simplified(
// CHECK-LABEL: func.func private @_FortranAMaxvalInteger4x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[CI32_MININT:.*]] = arith.constant -2147483648 : i32
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX:.*]] = %[[CI32_MININT]]) -> (i32) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<i32>
@ -849,18 +883,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: %[[SHAPE:.*]] = fir.shape %[[CINDEX_10]] : (index) -> !fir.shape<1>
// CHECK: %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box<!fir.array<10xf64>>) -> !fir.box<none>
// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box<none>) -> f64
// CHECK: return %{{.*}} : f64
// CHECK: }
// CHECK-LABEL: func.func private @_FortranAMaxvalReal8_simplified(
// CHECK-LABEL: func.func private @_FortranAMaxvalReal8x1_simplified(
// CHECK-SAME: %[[ARR:.*]]: !fir.box<none>) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index
// CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf64>>
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[NEG_DBL_MAX:.*]] = arith.constant -1.7976931348623157E+308 : f64
// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index
// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index
// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX]] = %[[NEG_DBL_MAX]]) -> (f64) {
// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref<f64>
@ -869,3 +904,97 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ
// CHECK: }
// CHECK: return %[[RES]] : f64
// CHECK: }
// -----
// SUM reduction of sliced explicit-shape array is replaced with
// 2D simplified implementation.
func.func @sum_sliced_embox_i64(%arg0: !fir.ref<!fir.array<10x10x10xi64>> {fir.bindc_name = "a"}) -> f32 {
%c10 = arith.constant 10 : index
%c10_0 = arith.constant 10 : index
%c10_1 = arith.constant 10 : index
%0 = fir.alloca f32 {bindc_name = "sum_sliced_embox_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_embox_i64"}
%1 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_i64"}
%c1 = arith.constant 1 : index
%c1_i64 = arith.constant 1 : i64
%2 = fir.convert %c1_i64 : (i64) -> index
%3 = arith.addi %c1, %c10 : index
%4 = arith.subi %3, %c1 : index
%c1_i64_2 = arith.constant 1 : i64
%5 = fir.convert %c1_i64_2 : (i64) -> index
%6 = arith.addi %c1, %c10_0 : index
%7 = arith.subi %6, %c1 : index
%c1_i64_3 = arith.constant 1 : i64
%8 = fir.undefined index
%9 = fir.shape %c10, %c10_0, %c10_1 : (index, index, index) -> !fir.shape<3>
%10 = fir.slice %c1, %4, %2, %c1, %7, %5, %c1_i64_3, %8, %8 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3>
%11 = fir.embox %arg0(%9) [%10] : (!fir.ref<!fir.array<10x10x10xi64>>, !fir.shape<3>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xi64>>
%12 = fir.absent !fir.box<i1>
%c0 = arith.constant 0 : index
%13 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref<!fir.char<1,11>>
%c3_i32 = arith.constant 3 : i32
%14 = fir.convert %11 : (!fir.box<!fir.array<?x?xi64>>) -> !fir.box<none>
%15 = fir.convert %13 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
%16 = fir.convert %c0 : (index) -> i32
%17 = fir.convert %12 : (!fir.box<i1>) -> !fir.box<none>
%18 = fir.call @_FortranASumInteger8(%14, %15, %c3_i32, %16, %17) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64
fir.store %18 to %1 : !fir.ref<i64>
%19 = fir.load %0 : !fir.ref<f32>
return %19 : f32
}
func.func private @_FortranASumInteger8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64 attributes {fir.runtime}
fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
%0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11>
fir.has_value %0 : !fir.char<1,11>
}
// CHECK-NOT: call{{.*}}_FortranASumInteger8(
// CHECK: call @_FortranASumInteger8x2_simplified(
// CHECK-NOT: call{{.*}}_FortranASumInteger8(
// -----
// SUM reduction of sliced assumed-shape array is replaced with
// 2D simplified implementation.
func.func @_QPsum_sliced_rebox_i64(%arg0: !fir.box<!fir.array<?x?x?xi64>> {fir.bindc_name = "a"}) -> f32 {
%0 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_i64"}
%1 = fir.alloca f32 {bindc_name = "sum_sliced_rebox_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_rebox_i64"}
%c1 = arith.constant 1 : index
%c1_i64 = arith.constant 1 : i64
%2 = fir.convert %c1_i64 : (i64) -> index
%c0 = arith.constant 0 : index
%3:3 = fir.box_dims %arg0, %c0 : (!fir.box<!fir.array<?x?x?xi64>>, index) -> (index, index, index)
%4 = arith.addi %c1, %3#1 : index
%5 = arith.subi %4, %c1 : index
%c1_i64_0 = arith.constant 1 : i64
%6 = fir.convert %c1_i64_0 : (i64) -> index
%c1_1 = arith.constant 1 : index
%7:3 = fir.box_dims %arg0, %c1_1 : (!fir.box<!fir.array<?x?x?xi64>>, index) -> (index, index, index)
%8 = arith.addi %c1, %7#1 : index
%9 = arith.subi %8, %c1 : index
%c1_i64_2 = arith.constant 1 : i64
%10 = fir.undefined index
%11 = fir.slice %c1, %5, %2, %c1, %9, %6, %c1_i64_2, %10, %10 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3>
%12 = fir.rebox %arg0 [%11] : (!fir.box<!fir.array<?x?x?xi64>>, !fir.slice<3>) -> !fir.box<!fir.array<?x?xi64>>
%13 = fir.absent !fir.box<i1>
%c0_3 = arith.constant 0 : index
%14 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref<!fir.char<1,11>>
%c8_i32 = arith.constant 8 : i32
%15 = fir.convert %12 : (!fir.box<!fir.array<?x?xi64>>) -> !fir.box<none>
%16 = fir.convert %14 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
%17 = fir.convert %c0_3 : (index) -> i32
%18 = fir.convert %13 : (!fir.box<i1>) -> !fir.box<none>
%19 = fir.call @_FortranASumInteger8(%15, %16, %c8_i32, %17, %18) : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64
fir.store %19 to %0 : !fir.ref<i64>
%20 = fir.load %1 : !fir.ref<f32>
return %20 : f32
}
func.func private @_FortranASumInteger8(!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i64 attributes {fir.runtime}
fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> {
%0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11>
fir.has_value %0 : !fir.char<1,11>
}
// CHECK-NOT: call{{.*}}_FortranASumInteger8(
// CHECK: call @_FortranASumInteger8x2_simplified(
// CHECK-NOT: call{{.*}}_FortranASumInteger8(