forked from OSchip/llvm-project
[fir] Add affine demotion pass
Add affine demotion pass. Affine dialect's default lowering for loads and stores is different from fir as it uses the `memref` type. The `memref` type is not compatible with the Fortran runtime. Therefore, conversion of memory operations back to `fir.load` and `fir.store` with `!fir.ref<?>` types is required. This patch is part of the upstreaming effort from fir-dev branch. Co-authored-by: Jean Perier <jperier@nvidia.com> Co-authored-by: V Donaldson <vdonaldson@nvidia.com> Co-authored-by: Rajan Walia <walrajan@gmail.com> Co-authored-by: Sourabh Singh Tomar <SourabhSingh.Tomar@amd.com> Co-authored-by: Valentin Clement <clementval@gmail.com> Reviewed By: schweitz Differential Revision: https://reviews.llvm.org/D111257
This commit is contained in:
parent
b6426d5211
commit
80c27abb2f
|
@ -22,8 +22,13 @@ class Region;
|
|||
|
||||
namespace fir {
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Passes defined in Passes.td
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::unique_ptr<mlir::Pass> createAffineDemotionPass();
|
||||
std::unique_ptr<mlir::Pass> createExternalNameConversionPass();
|
||||
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
|
||||
|
||||
/// Support for inlining on FIR.
|
||||
bool canLegallyInline(mlir::Operation *op, mlir::Region *reg,
|
||||
|
|
|
@ -41,6 +41,20 @@ def AffineDialectPromotion : FunctionPass<"promote-to-affine"> {
|
|||
];
|
||||
}
|
||||
|
||||
def AffineDialectDemotion : FunctionPass<"demote-affine"> {
|
||||
let summary = "Converts `affine.{load,store}` back to fir operations";
|
||||
let description = [{
|
||||
Affine dialect's default lowering for loads and stores is different from
|
||||
fir as it uses the `memref` type. The `memref` type is not compatible with
|
||||
the Fortran runtime. Therefore, conversion of memory operations back to
|
||||
`fir.load` and `fir.store` with `!fir.ref<?>` types is required.
|
||||
}];
|
||||
let constructor = "::fir::createAffineDemotionPass()";
|
||||
let dependentDialects = [
|
||||
"fir::FIROpsDialect", "mlir::StandardOpsDialect", "mlir::AffineDialect"
|
||||
];
|
||||
}
|
||||
|
||||
def ExternalNameConversion : Pass<"external-name-interop", "mlir::ModuleOp"> {
|
||||
let summary = "Convert name for external interoperability";
|
||||
let description = [{
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
//===-- AffineDemotion.cpp -----------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
||||
#include "flang/Optimizer/Dialect/FIROps.h"
|
||||
#include "flang/Optimizer/Dialect/FIRType.h"
|
||||
#include "flang/Optimizer/Transforms/Passes.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "flang-affine-demotion"
|
||||
|
||||
using namespace fir;
|
||||
|
||||
namespace {
|
||||
|
||||
class AffineLoadConversion : public OpRewritePattern<mlir::AffineLoadOp> {
|
||||
public:
|
||||
using OpRewritePattern<mlir::AffineLoadOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::AffineLoadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> indices(op.getMapOperands());
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
|
||||
if (!maybeExpandedMap)
|
||||
return failure();
|
||||
|
||||
auto coorOp = rewriter.create<fir::CoordinateOp>(
|
||||
op.getLoc(), fir::ReferenceType::get(op.getResult().getType()),
|
||||
op.getMemRef(), *maybeExpandedMap);
|
||||
|
||||
rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class AffineStoreConversion : public OpRewritePattern<mlir::AffineStoreOp> {
|
||||
public:
|
||||
using OpRewritePattern<mlir::AffineStoreOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::AffineStoreOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> indices(op.getMapOperands());
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
|
||||
if (!maybeExpandedMap)
|
||||
return failure();
|
||||
|
||||
auto coorOp = rewriter.create<fir::CoordinateOp>(
|
||||
op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()),
|
||||
op.getMemRef(), *maybeExpandedMap);
|
||||
rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValueToStore(),
|
||||
coorOp.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(fir::ConvertOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (op.res().getType().isa<mlir::MemRefType>()) {
|
||||
// due to index calculation moving to affine maps we still need to
|
||||
// add converts for sequence types this has a side effect of losing
|
||||
// some information about arrays with known dimensions by creating:
|
||||
// fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
|
||||
// !fir.ref<!fir.array<?xi32>>
|
||||
if (auto refTy = op.value().getType().dyn_cast<fir::ReferenceType>())
|
||||
if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) {
|
||||
fir::SequenceType::Shape flatShape = {
|
||||
fir::SequenceType::getUnknownExtent()};
|
||||
auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy());
|
||||
auto flatTy = fir::ReferenceType::get(flatArrTy);
|
||||
rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, op.value());
|
||||
return success();
|
||||
}
|
||||
rewriter.startRootUpdate(op->getParentOp());
|
||||
op.getResult().replaceAllUsesWith(op.value());
|
||||
rewriter.finalizeRootUpdate(op->getParentOp());
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
mlir::Type convertMemRef(mlir::MemRefType type) {
|
||||
return fir::SequenceType::get(
|
||||
SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()),
|
||||
type.getElementType());
|
||||
}
|
||||
|
||||
class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(memref::AllocOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()),
|
||||
op.memref());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class AffineDialectDemotion
|
||||
: public AffineDialectDemotionBase<AffineDialectDemotion> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto *context = &getContext();
|
||||
auto function = getFunction();
|
||||
LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
|
||||
function.print(llvm::dbgs()););
|
||||
|
||||
mlir::OwningRewritePatternList patterns(context);
|
||||
patterns.insert<ConvertConversion>(context);
|
||||
patterns.insert<AffineLoadConversion>(context);
|
||||
patterns.insert<AffineStoreConversion>(context);
|
||||
patterns.insert<StdAllocConversion>(context);
|
||||
mlir::ConversionTarget target(*context);
|
||||
target.addIllegalOp<memref::AllocOp>();
|
||||
target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) {
|
||||
if (op.res().getType().isa<mlir::MemRefType>())
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
|
||||
mlir::StandardOpsDialect>();
|
||||
|
||||
if (mlir::failed(mlir::applyPartialConversion(function, target,
|
||||
std::move(patterns)))) {
|
||||
mlir::emitError(mlir::UnknownLoc::get(context),
|
||||
"error in converting affine dialect\n");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() {
|
||||
return std::make_unique<AffineDialectDemotion>();
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
add_flang_library(FIRTransforms
|
||||
AffinePromotion.cpp
|
||||
AffineDemotion.cpp
|
||||
Inliner.cpp
|
||||
ExternalNameConversion.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
// Test affine demotion pass
|
||||
|
||||
// RUN: fir-opt --split-input-file --demote-affine %s | FileCheck %s
|
||||
|
||||
#map0 = affine_map<()[s0, s1] -> (s1 - s0 + 1)>
|
||||
#map1 = affine_map<()[s0] -> (s0 + 1)>
|
||||
#map2 = affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>
|
||||
module {
|
||||
func @calc(%arg0: !fir.ref<!fir.array<?xf32>>, %arg1: !fir.ref<!fir.array<?xf32>>, %arg2: !fir.ref<!fir.array<?xf32>>) {
|
||||
%c1 = constant 1 : index
|
||||
%c100 = constant 100 : index
|
||||
%0 = fir.shape %c100 : (index) -> !fir.shape<1>
|
||||
%1 = affine.apply #map0()[%c1, %c100]
|
||||
%2 = fir.alloca !fir.array<?xf32>, %1
|
||||
%3 = fir.convert %arg0 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
|
||||
%4 = fir.convert %arg1 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
|
||||
%5 = fir.convert %2 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
|
||||
affine.for %arg3 = %c1 to #map1()[%c100] {
|
||||
%7 = affine.apply #map2(%arg3)[%c1, %c100, %c1]
|
||||
%8 = affine.load %3[%7] : memref<?xf32>
|
||||
%9 = affine.load %4[%7] : memref<?xf32>
|
||||
%10 = addf %8, %9 : f32
|
||||
affine.store %10, %5[%7] : memref<?xf32>
|
||||
}
|
||||
%6 = fir.convert %arg2 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
|
||||
affine.for %arg3 = %c1 to #map1()[%c100] {
|
||||
%7 = affine.apply #map2(%arg3)[%c1, %c100, %c1]
|
||||
%8 = affine.load %5[%7] : memref<?xf32>
|
||||
%9 = affine.load %4[%7] : memref<?xf32>
|
||||
%10 = mulf %8, %9 : f32
|
||||
affine.store %10, %6[%7] : memref<?xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: func @calc(%[[VAL_0:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_1:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_2:.*]]: !fir.ref<!fir.array<?xf32>>) {
|
||||
// CHECK: %[[VAL_3:.*]] = constant 1 : index
|
||||
// CHECK: %[[VAL_4:.*]] = constant 100 : index
|
||||
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
|
||||
// CHECK: %[[VAL_6:.*]] = constant 100 : index
|
||||
// CHECK: %[[VAL_7:.*]] = fir.alloca !fir.array<?xf32>, %[[VAL_6]]
|
||||
// CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
|
||||
// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
|
||||
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_7]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
|
||||
// CHECK: affine.for %[[VAL_11:.*]] = 1 to 101 {
|
||||
// CHECK: %[[VAL_12:.*]] = affine.apply #map(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
|
||||
// CHECK: %[[VAL_13:.*]] = fir.coordinate_of %[[VAL_8]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
|
||||
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<f32>
|
||||
// CHECK: %[[VAL_15:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
|
||||
// CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<f32>
|
||||
// CHECK: %[[VAL_17:.*]] = addf %[[VAL_14]], %[[VAL_16]] : f32
|
||||
// CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_10]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
|
||||
// CHECK: fir.store %[[VAL_17]] to %[[VAL_18]] : !fir.ref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_19:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
|
||||
// CHECK: affine.for %[[VAL_20:.*]] = 1 to 101 {
|
||||
// CHECK: %[[VAL_21:.*]] = affine.apply #map(%[[VAL_20]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
|
||||
// CHECK: %[[VAL_22:.*]] = fir.coordinate_of %[[VAL_10]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
|
||||
// CHECK: %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref<f32>
|
||||
// CHECK: %[[VAL_24:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
|
||||
// CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_24]] : !fir.ref<f32>
|
||||
// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32
|
||||
// CHECK: %[[VAL_27:.*]] = fir.coordinate_of %[[VAL_19]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
|
||||
// CHECK: fir.store %[[VAL_26]] to %[[VAL_27]] : !fir.ref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: return
|
||||
// CHECK: }
|
Loading…
Reference in New Issue