From 8bd76ac151534d2b9534ed919c0a7f4511002d84 Mon Sep 17 00:00:00 2001 From: Slava Zakharin Date: Tue, 13 Sep 2022 09:41:22 -0700 Subject: [PATCH] [flang] Support multidimensional reductions in SimplifyIntrinsicsPass. Create simplified functions for each rank with "x" suffix that implement multidimensional reductions. To enable this I had to fix an issue with taking incorrect box shape in cases of sliced embox/rebox. Differential Revision: https://reviews.llvm.org/D133820 --- .../Transforms/SimplifyIntrinsics.cpp | 142 ++++++++---- flang/test/Transforms/simplifyintrinsics.fir | 205 ++++++++++++++---- 2 files changed, 267 insertions(+), 80 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp index d23736ef8a68..5682fa281671 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -61,7 +61,7 @@ class SimplifyIntrinsicsPass using FunctionBodyGeneratorTy = llvm::function_ref; using GenReductionBodyTy = llvm::function_ref; + fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>; public: /// Generate a new function implementing a simplified version @@ -110,10 +110,11 @@ using InitValGeneratorTy = llvm::function_ref(loc, boxArrTy, arg); - auto dims = - builder.create(loc, idxTy, idxTy, idxTy, array, zeroIdx); - mlir::Value len = dims.getResult(1); - mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); - mlir::Value step = one; - - // We use C indexing here, so len-1 as loopcount - mlir::Value loopCount = builder.create(loc, len, one); mlir::Value init = initVal(builder, loc, elementType); - auto loop = builder.create(loc, zeroIdx, loopCount, step, - /*unordered=*/false, - /*finalCountValue=*/false, init); - mlir::Value reductionVal = loop.getRegionIterArgs()[0]; - // Begin loop code - mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(loop.getBody()); + llvm::SmallVector bounds; + assert(rank > 0 && "rank cannot be zero"); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + + // Compute all the upper bounds before the loop nest. + // It is not strictly necessary for performance, since the loop nest + // does not have any store operations and any LICM optimization + // should be able to optimize the redundancy. + for (unsigned i = 0; i < rank; ++i) { + mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); + auto dims = + builder.create(loc, idxTy, idxTy, idxTy, array, dimIdx); + mlir::Value len = dims.getResult(1); + // We use C indexing here, so len-1 as loopcount + mlir::Value loopCount = builder.create(loc, len, one); + bounds.push_back(loopCount); + } + + // Create a loop nest consisting of DoLoopOp operations. + // Collect the loops' induction variables into indices array, + // which will be used in the innermost loop to load the input + // array's element. + // The loops are generated such that the innermost loop processes + // the 0 dimension. + llvm::SmallVector indices; + for (unsigned i = rank; 0 < i; --i) { + mlir::Value step = one; + mlir::Value loopCount = bounds[i - 1]; + auto loop = builder.create(loc, zeroIdx, loopCount, step, + /*unordered=*/false, + /*finalCountValue=*/false, init); + init = loop.getRegionIterArgs()[0]; + indices.push_back(loop.getInductionVar()); + // Set insertion point to the loop body so that the next loop + // is inserted inside the current one. + builder.setInsertionPointToStart(loop.getBody()); + } + + // Reverse the indices such that they are ordered as: + // + std::reverse(indices.begin(), indices.end()); + + // We are in the innermost loop: generate the reduction body. mlir::Type eleRefTy = builder.getRefType(elementType); - mlir::Value index = loop.getInductionVar(); mlir::Value addr = - builder.create(loc, eleRefTy, array, index); + builder.create(loc, eleRefTy, array, indices); mlir::Value elem = builder.create(loc, addr); - reductionVal = genBody(builder, loc, elementType, elem, reductionVal); + mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init); - builder.create(loc, reductionVal); - // End of loop. - builder.restoreInsertionPoint(loopEndPt); + // Unwind the loop nest and insert ResultOp on each level + // to return the updated value of the reduction to the enclosing + // loops. + for (unsigned i = 0; i < rank; ++i) { + auto result = builder.create(loc, reductionVal); + // Proceed to the outer loop. + auto loop = mlir::cast(result->getParentOp()); + reductionVal = loop.getResult(0); + // Set insertion point after the loop operation that we have + // just processed. + builder.setInsertionPointAfter(loop.getOperation()); + } - mlir::Value resultVal = loop.getResult(0); - builder.create(loc, resultVal); + // End of loop nest. The insertion point is after the outermost loop. + // Return the reduction value from the function. + builder.create(loc, reductionVal); } /// Generate function body of the simplified version of RTNAME(Sum) /// with signature provided by \p funcOp. The caller is responsible /// for saving/restoring the original insertion point of \p builder. /// \p funcOp is expected to be empty on entry to this function. +/// \p rank specifies the rank of the input argument. static void genRuntimeSumBody(fir::FirOpBuilder &builder, - mlir::func::FuncOp &funcOp) { - // function RTNAME(Sum)_simplified(arr) + mlir::func::FuncOp &funcOp, unsigned rank) { + // function RTNAME(Sum)x_simplified(arr) // T, dimension(:) :: arr // T sum = 0 // integer iter // do iter = 0, extent(arr) // sum = sum + arr[iter] // end do - // RTNAME(Sum)_simplified = sum - // end function RTNAME(Sum)_simplified + // RTNAME(Sum)x_simplified = sum + // end function RTNAME(Sum)x_simplified auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType) { if (auto ty = elementType.dyn_cast()) { @@ -200,11 +240,11 @@ static void genRuntimeSumBody(fir::FirOpBuilder &builder, return {}; }; - genReductionLoop(builder, funcOp, zero, genBodyOp); + genReductionLoop(builder, funcOp, zero, genBodyOp, rank); } static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder, - mlir::func::FuncOp &funcOp) { + mlir::func::FuncOp &funcOp, unsigned rank) { auto init = [](fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType) { if (auto ty = elementType.dyn_cast()) { @@ -228,7 +268,7 @@ static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder, llvm_unreachable("unsupported type"); return {}; }; - genReductionLoop(builder, funcOp, init, genBodyOp); + genReductionLoop(builder, funcOp, init, genBodyOp, rank); } /// Generate function type for the simplified version of RTNAME(DotProduct) @@ -410,21 +450,31 @@ static bool isZero(mlir::Value val) { return false; } -static mlir::Value findShape(mlir::Value val) { +static mlir::Value findBoxDef(mlir::Value val) { if (auto op = expectConvertOp(val)) { assert(op->getOperands().size() != 0); if (auto box = mlir::dyn_cast_or_null( op->getOperand(0).getDefiningOp())) - return box.getShape(); + return box.getResult(); + if (auto box = mlir::dyn_cast_or_null( + op->getOperand(0).getDefiningOp())) + return box.getResult(); } return {}; } static unsigned getDimCount(mlir::Value val) { - if (mlir::Value shapeVal = findShape(val)) { - mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0]; - return fir::getRankOfShapeType(resType); - } + // In order to find the dimensions count, we look for EmboxOp/ReboxOp + // and take the count from its *result* type. Note that in case + // of sliced emboxing the operand and the result of EmboxOp/ReboxOp + // have different types. + // Actually, we can take the box type from the operand of + // the first ConvertOp that has non-opaque box type that we meet + // going through the ConvertOp chain. + if (mlir::Value emboxVal = findBoxDef(val)) + if (auto boxTy = emboxVal.getType().dyn_cast()) + if (auto seqTy = boxTy.getEleTy().dyn_cast()) + return seqTy.getDimension(); return 0; } @@ -455,7 +505,6 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap, GenReductionBodyTy genBodyFunc) { mlir::SymbolRefAttr callee = call.getCalleeAttr(); - mlir::StringRef funcName = callee.getLeafReference().getValue(); mlir::Operation::operand_range args = call.getArgs(); // args[1] and args[2] are source filename and line number, ignored. const mlir::Value &dim = args[3]; @@ -464,7 +513,7 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call, // detail in the runtime library. bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); unsigned rank = getDimCount(args[0]); - if (dimAndMaskAbsent && rank == 1) { + if (dimAndMaskAbsent && rank > 0) { mlir::Location loc = call.getLoc(); fir::FirOpBuilder builder(call, kindMap); @@ -483,8 +532,17 @@ void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call, auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { return genNoneBoxType(builder, resultType); }; + auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp) { + genBodyFunc(builder, funcOp, rank); + }; + // Mangle the function name with the rank value as "x". + std::string funcName = + (mlir::Twine{callee.getLeafReference().getValue(), "x"} + + mlir::Twine{rank}) + .str(); mlir::func::FuncOp newFunc = - getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc); + getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); auto newCall = builder.create(loc, newFunc, mlir::ValueRange{args[0]}); call->replaceAllUsesWith(newCall.getResults()); diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir index b5d24c578524..e3ac9c930d29 100644 --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -34,20 +34,21 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> // CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box>) -> !fir.box // CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) -// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 +// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 // CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) // CHECK: return %{{.*}} : i32 // CHECK: } // CHECK: func.func private @_FortranASumInteger4(!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 attributes {fir.runtime} -// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified( +// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified( // CHECK-SAME: %[[ARR:.*]]: !fir.box) -> i32 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) -// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM:.*]] = %[[CI32_0]]) -> (i32) { // CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref @@ -59,7 +60,7 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // ----- -// Call to SUM with 2D I32 arrays is not replaced. +// Call to SUM with 2D I32 arrays is replaced. module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} { func.func @sum_2d_array_int(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> i32 { %c10 = arith.constant 10 : index @@ -88,9 +89,39 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ } // CHECK-LABEL: func.func @sum_2d_array_int({{.*}} !fir.ref> {fir.bindc_name = "a"}) -> i32 { -// CHECK-NOT: fir.call @_FortranASumInteger4_simplified({{.*}}) -// CHECK: fir.call @_FortranASumInteger4({{.*}}) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 -// CHECK-NOT: fir.call @_FortranASumInteger4_simplified({{.*}}) +// CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index, index) -> !fir.shape<2> +// CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<2>) -> !fir.box> +// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box>) -> !fir.box +// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) +// CHECK: %[[RES:.*]] = fir.call @_FortranASumInteger4x2_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 +// CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) +// CHECK: return %{{.*}} : i32 +// CHECK: } +// CHECK: func.func private @_FortranASumInteger4(!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 attributes {fir.runtime} + +// CHECK-LABEL: func.func private @_FortranASumInteger4x2_simplified( +// CHECK-SAME: %[[ARR:.*]]: !fir.box) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index +// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> +// CHECK: %[[CI32_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS_0:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT_0:.*]] = arith.subi %[[DIMS_0]]#1, %[[CINDEX_1]] : index +// CHECK: %[[DIMIDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS_1:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_1]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT_1:.*]] = arith.subi %[[DIMS_1]]#1, %[[CINDEX_1]] : index +// CHECK: %[[RES_1:.*]] = fir.do_loop %[[ITER_1:.*]] = %[[CINDEX_0]] to %[[EXTENT_1]] step %[[CINDEX_1]] iter_args(%[[SUM_1:.*]] = %[[CI32_0]]) -> (i32) { +// CHECK: %[[RES_0:.*]] = fir.do_loop %[[ITER_0:.*]] = %[[CINDEX_0]] to %[[EXTENT_0]] step %[[CINDEX_1]] iter_args(%[[SUM_0:.*]] = %[[SUM_1]]) -> (i32) { +// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER_0]], %[[ITER_1]] : (!fir.box>, index, index) -> !fir.ref +// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref +// CHECK: %[[NEW_SUM:.*]] = arith.addi %[[ITEM_VAL]], %[[SUM_0]] : i32 +// CHECK: fir.result %[[NEW_SUM]] : i32 +// CHECK: } +// CHECK: fir.result %[[RES_0]] +// CHECK: } +// CHECK: return %[[RES_1]] : i32 +// CHECK: } // ----- @@ -129,19 +160,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK: %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> // CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box>) -> !fir.box // CHECK-NOT: fir.call @_FortranASumReal8({{.*}}) -// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f64 +// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f64 // CHECK-NOT: fir.call @_FortranASumReal8({{.*}}) // CHECK: return %{{.*}} : f64 // CHECK: } -// CHECK-LABEL: func.func private @_FortranASumReal8_simplified( +// CHECK-LABEL: func.func private @_FortranASumReal8x1_simplified( // CHECK-SAME: %[[ARR:.*]]: !fir.box) -> f64 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) -// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f64) { // CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref @@ -188,19 +220,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK: %[[A_BOX_F32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> // CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F32]] : (!fir.box>) -> !fir.box // CHECK-NOT: fir.call @_FortranASumReal4({{.*}}) -// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f32 +// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f32 // CHECK-NOT: fir.call @_FortranASumReal4({{.*}}) // CHECK: return %{{.*}} : f32 // CHECK: } -// CHECK-LABEL: func.func private @_FortranASumReal4_simplified( +// CHECK-LABEL: func.func private @_FortranASumReal4x1_simplified( // CHECK-SAME: %[[ARR:.*]]: !fir.box) -> f32 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[ARR_BOX_F32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) -// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f32) { // CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref @@ -243,9 +276,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ } // CHECK-LABEL: func.func @sum_1d_complex(%{{.*}}: !fir.ref>> {fir.bindc_name = "a"}) -> !fir.complex<4> { -// CHECK-NOT: fir.call @_FortranACppSumComplex4_simplified({{.*}}) +// CHECK-NOT: fir.call @_FortranACppSumComplex4x1_simplified({{.*}}) // CHECK: fir.call @_FortranACppSumComplex4({{.*}}) : (!fir.ref>, !fir.box, !fir.ref, i32, i32, !fir.box) -> none -// CHECK-NOT: fir.call @_FortranACppSumComplex4_simplified({{.*}}) +// CHECK-NOT: fir.call @_FortranACppSumComplex4x1_simplified({{.*}}) // ----- @@ -298,20 +331,20 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK-LABEL: func.func @sum_1d_calla(%{{.*}}) -> i32 { // CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) -// CHECK: fir.call @_FortranASumInteger4_simplified(%{{.*}}) +// CHECK: fir.call @_FortranASumInteger4x1_simplified(%{{.*}}) // CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) // CHECK: } // CHECK-LABEL: func.func @sum_1d_callb(%{{.*}}) -> i32 { // CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) -// CHECK: fir.call @_FortranASumInteger4_simplified(%{{.*}}) +// CHECK: fir.call @_FortranASumInteger4x1_simplified(%{{.*}}) // CHECK-NOT: fir.call @_FortranASumInteger4({{.*}}) // CHECK: } -// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified({{.*}}) -> i32 {{.*}} { +// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified({{.*}}) -> i32 {{.*}} { // CHECK: return %{{.*}} : i32 // CHECK: } -// CHECK-NOT: func.func private @_FortranASumInteger4_simplified({{.*}}) +// CHECK-NOT: func.func private @_FortranASumInteger4x1_simplified({{.*}}) // ----- @@ -354,14 +387,14 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK: %[[SLICE:.*]] = fir.slice %{{.*}}, %{{.*}}, %[[CINDEX_2]] : (index, index, index) -> !fir.slice<1> // CHECK: %[[A_BOX_I32:.*]] = fir.embox %{{.*}}(%[[SHAPE]]) {{\[}}%[[SLICE]]] : (!fir.ref>, !fir.shape<1>, !fir.slice<1>) -> !fir.box> // CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box>) -> !fir.box -// CHECK: %{{.*}} = fir.call @_FortranASumInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 +// CHECK: %{{.*}} = fir.call @_FortranASumInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 // CHECK: return %{{.*}} : i32 // CHECK: } -// CHECK-LABEL: func.func private @_FortranASumInteger4_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK-LABEL: func.func private @_FortranASumInteger4x1_simplified(%{{.*}}) -> i32 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %{{.*}} : (!fir.box) -> !fir.box> -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box>, index) -> (index, index, index) // CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %{{.*}} : (!fir.box>, index) -> (index, index, index) // CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %{{.*}} to %[[EXTENT]] step %[[CINDEX_1]] iter_args({{.*}}) -> (i32) { // CHECK: %{{.*}} = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref @@ -792,18 +825,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1> // CHECK: %[[A_BOX_I32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> // CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_I32]] : (!fir.box>) -> !fir.box -// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 +// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalInteger4x1_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> i32 // CHECK: return %{{.*}} : i32 // CHECK: } -// CHECK-LABEL: func.func private @_FortranAMaxvalInteger4_simplified( +// CHECK-LABEL: func.func private @_FortranAMaxvalInteger4x1_simplified( // CHECK-SAME: %[[ARR:.*]]: !fir.box) -> i32 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) -// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[CI32_MININT:.*]] = arith.constant -2147483648 : i32 +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX:.*]] = %[[CI32_MININT]]) -> (i32) { // CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref @@ -849,18 +883,19 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK: %[[SHAPE:.*]] = fir.shape %[[CINDEX_10]] : (index) -> !fir.shape<1> // CHECK: %[[A_BOX_F64:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> // CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F64]] : (!fir.box>) -> !fir.box -// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalReal8_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f64 +// CHECK: %[[RES:.*]] = fir.call @_FortranAMaxvalReal8x1_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f64 // CHECK: return %{{.*}} : f64 // CHECK: } -// CHECK-LABEL: func.func private @_FortranAMaxvalReal8_simplified( +// CHECK-LABEL: func.func private @_FortranAMaxvalReal8x1_simplified( // CHECK-SAME: %[[ARR:.*]]: !fir.box) -> f64 attributes {llvm.linkage = #llvm.linkage} { // CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[ARR_BOX_F64:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) -// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index -// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[NEG_DBL_MAX:.*]] = arith.constant -1.7976931348623157E+308 : f64 +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F64]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index // CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[MAX]] = %[[NEG_DBL_MAX]]) -> (f64) { // CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F64]], %[[ITER]] : (!fir.box>, index) -> !fir.ref // CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref @@ -869,3 +904,97 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.targ // CHECK: } // CHECK: return %[[RES]] : f64 // CHECK: } + +// ----- + +// SUM reduction of sliced explicit-shape array is replaced with +// 2D simplified implementation. +func.func @sum_sliced_embox_i64(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> f32 { + %c10 = arith.constant 10 : index + %c10_0 = arith.constant 10 : index + %c10_1 = arith.constant 10 : index + %0 = fir.alloca f32 {bindc_name = "sum_sliced_embox_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_embox_i64"} + %1 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_embox_i64Esum_sliced_i64"} + %c1 = arith.constant 1 : index + %c1_i64 = arith.constant 1 : i64 + %2 = fir.convert %c1_i64 : (i64) -> index + %3 = arith.addi %c1, %c10 : index + %4 = arith.subi %3, %c1 : index + %c1_i64_2 = arith.constant 1 : i64 + %5 = fir.convert %c1_i64_2 : (i64) -> index + %6 = arith.addi %c1, %c10_0 : index + %7 = arith.subi %6, %c1 : index + %c1_i64_3 = arith.constant 1 : i64 + %8 = fir.undefined index + %9 = fir.shape %c10, %c10_0, %c10_1 : (index, index, index) -> !fir.shape<3> + %10 = fir.slice %c1, %4, %2, %c1, %7, %5, %c1_i64_3, %8, %8 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3> + %11 = fir.embox %arg0(%9) [%10] : (!fir.ref>, !fir.shape<3>, !fir.slice<3>) -> !fir.box> + %12 = fir.absent !fir.box + %c0 = arith.constant 0 : index + %13 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref> + %c3_i32 = arith.constant 3 : i32 + %14 = fir.convert %11 : (!fir.box>) -> !fir.box + %15 = fir.convert %13 : (!fir.ref>) -> !fir.ref + %16 = fir.convert %c0 : (index) -> i32 + %17 = fir.convert %12 : (!fir.box) -> !fir.box + %18 = fir.call @_FortranASumInteger8(%14, %15, %c3_i32, %16, %17) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> i64 + fir.store %18 to %1 : !fir.ref + %19 = fir.load %0 : !fir.ref + return %19 : f32 +} +func.func private @_FortranASumInteger8(!fir.box, !fir.ref, i32, i32, !fir.box) -> i64 attributes {fir.runtime} +fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> { + %0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11> + fir.has_value %0 : !fir.char<1,11> +} + +// CHECK-NOT: call{{.*}}_FortranASumInteger8( +// CHECK: call @_FortranASumInteger8x2_simplified( +// CHECK-NOT: call{{.*}}_FortranASumInteger8( + +// ----- + +// SUM reduction of sliced assumed-shape array is replaced with +// 2D simplified implementation. +func.func @_QPsum_sliced_rebox_i64(%arg0: !fir.box> {fir.bindc_name = "a"}) -> f32 { + %0 = fir.alloca i64 {bindc_name = "sum_sliced_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_i64"} + %1 = fir.alloca f32 {bindc_name = "sum_sliced_rebox_i64", uniq_name = "_QFsum_sliced_rebox_i64Esum_sliced_rebox_i64"} + %c1 = arith.constant 1 : index + %c1_i64 = arith.constant 1 : i64 + %2 = fir.convert %c1_i64 : (i64) -> index + %c0 = arith.constant 0 : index + %3:3 = fir.box_dims %arg0, %c0 : (!fir.box>, index) -> (index, index, index) + %4 = arith.addi %c1, %3#1 : index + %5 = arith.subi %4, %c1 : index + %c1_i64_0 = arith.constant 1 : i64 + %6 = fir.convert %c1_i64_0 : (i64) -> index + %c1_1 = arith.constant 1 : index + %7:3 = fir.box_dims %arg0, %c1_1 : (!fir.box>, index) -> (index, index, index) + %8 = arith.addi %c1, %7#1 : index + %9 = arith.subi %8, %c1 : index + %c1_i64_2 = arith.constant 1 : i64 + %10 = fir.undefined index + %11 = fir.slice %c1, %5, %2, %c1, %9, %6, %c1_i64_2, %10, %10 : (index, index, index, index, index, index, i64, index, index) -> !fir.slice<3> + %12 = fir.rebox %arg0 [%11] : (!fir.box>, !fir.slice<3>) -> !fir.box> + %13 = fir.absent !fir.box + %c0_3 = arith.constant 0 : index + %14 = fir.address_of(@_QQcl.2E2F746573742E66393000) : !fir.ref> + %c8_i32 = arith.constant 8 : i32 + %15 = fir.convert %12 : (!fir.box>) -> !fir.box + %16 = fir.convert %14 : (!fir.ref>) -> !fir.ref + %17 = fir.convert %c0_3 : (index) -> i32 + %18 = fir.convert %13 : (!fir.box) -> !fir.box + %19 = fir.call @_FortranASumInteger8(%15, %16, %c8_i32, %17, %18) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> i64 + fir.store %19 to %0 : !fir.ref + %20 = fir.load %1 : !fir.ref + return %20 : f32 +} +func.func private @_FortranASumInteger8(!fir.box, !fir.ref, i32, i32, !fir.box) -> i64 attributes {fir.runtime} +fir.global linkonce @_QQcl.2E2F746573742E66393000 constant : !fir.char<1,11> { + %0 = fir.string_lit "./test.f90\00"(11) : !fir.char<1,11> + fir.has_value %0 : !fir.char<1,11> +} + +// CHECK-NOT: call{{.*}}_FortranASumInteger8( +// CHECK: call @_FortranASumInteger8x2_simplified( +// CHECK-NOT: call{{.*}}_FortranASumInteger8(