forked from OSchip/llvm-project
[mlir] Move ComposeSubView+ExpandOps from Standard to MemRef
These transformations already operate on memref operations (as part of splitting up the standard dialect). Now that the operations have moved, it's time for these transformations to move as well. Differential Revision: https://reviews.llvm.org/D118285
This commit is contained in:
parent
1372d53639
commit
7d0426dd95
|
@ -1,4 +1,4 @@
|
|||
//===- ComposeSubView.h - Combining composed subview ops --------*- C++ -*-===//
|
||||
//===- ComposeSubView.h - Combining composed memref ops ---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -10,19 +10,20 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
|
||||
#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
|
||||
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_
|
||||
#define MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Forward declarations.
|
||||
class MLIRContext;
|
||||
class RewritePatternSet;
|
||||
using OwningRewritePatternList = RewritePatternSet;
|
||||
|
||||
namespace memref {
|
||||
|
||||
void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context);
|
||||
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
|
||||
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_
|
|
@ -18,6 +18,7 @@
|
|||
namespace mlir {
|
||||
|
||||
class AffineDialect;
|
||||
class StandardOpsDialect;
|
||||
namespace tensor {
|
||||
class TensorDialect;
|
||||
} // namespace tensor
|
||||
|
@ -31,6 +32,9 @@ namespace memref {
|
|||
// Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Collects a set of patterns to rewrite ops within the memref dialect.
|
||||
void populateExpandOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns for folding memref.subview ops into consumer load/store ops
|
||||
/// into `patterns`.
|
||||
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
|
||||
|
@ -51,6 +55,11 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
|
|||
// Passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Creates an instance of the ExpandOps pass that legalizes memref dialect ops
|
||||
/// to be convertible to LLVM. For example, `memref.reshape` gets converted to
|
||||
/// `memref_reinterpret_cast`.
|
||||
std::unique_ptr<Pass> createExpandOpsPass();
|
||||
|
||||
/// Creates an operation pass to fold memref.subview ops into consumer
|
||||
/// load/store ops into `patterns`.
|
||||
std::unique_ptr<Pass> createFoldSubViewOpsPass();
|
||||
|
|
|
@ -11,6 +11,12 @@
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ExpandOps : Pass<"memref-expand", "FuncOp"> {
|
||||
let summary = "Legalize memref operations to be convertible to LLVM.";
|
||||
let constructor = "mlir::memref::createExpandOpsPass()";
|
||||
let dependentDialects = ["StandardOpsDialect"];
|
||||
}
|
||||
|
||||
def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
|
||||
let summary = "Fold memref.subview ops into consumer load/store ops";
|
||||
let description = [{
|
||||
|
|
|
@ -45,16 +45,6 @@ void populateTensorConstantBufferizePatterns(
|
|||
/// Creates an instance of tensor constant bufferization pass.
|
||||
std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
|
||||
|
||||
/// Creates an instance of the StdExpand pass that legalizes Std
|
||||
/// dialect ops to be convertible to LLVM. For example,
|
||||
/// `std.arith.ceildivsi` gets transformed to a number of std operations,
|
||||
/// which can be lowered to LLVM; `memref.reshape` gets converted to
|
||||
/// `memref_reinterpret_cast`.
|
||||
std::unique_ptr<Pass> createStdExpandOpsPass();
|
||||
|
||||
/// Collects a set of patterns to rewrite ops within the Std dialect.
|
||||
void populateStdExpandOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -18,11 +18,6 @@ def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
|
|||
"memref::MemRefDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
def StdExpandOps : Pass<"std-expand", "FuncOp"> {
|
||||
let summary = "Legalize std operations to be convertible to LLVM.";
|
||||
let constructor = "mlir::createStdExpandOpsPass()";
|
||||
}
|
||||
|
||||
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
|
||||
let summary = "Bufferize func/call/return ops";
|
||||
let description = [{
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRMemRefTransforms
|
||||
ComposeSubView.cpp
|
||||
ExpandOps.cpp
|
||||
FoldSubViewOps.cpp
|
||||
NormalizeMemRefs.cpp
|
||||
ResolveShapedTypeResultDims.cpp
|
||||
|
|
|
@ -11,8 +11,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
@ -21,7 +20,7 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -128,9 +127,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
|||
|
||||
} // namespace
|
||||
|
||||
void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *context) {
|
||||
void mlir::memref::populateComposeSubViewPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<ComposeSubViewOpPattern>(context);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
|
@ -17,8 +17,8 @@
|
|||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
@ -120,13 +120,13 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
|
||||
struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext &ctx = getContext();
|
||||
|
||||
RewritePatternSet patterns(&ctx);
|
||||
populateStdExpandOpsPatterns(patterns);
|
||||
ConversionTarget target(getContext());
|
||||
memref::populateExpandOpsPatterns(patterns);
|
||||
ConversionTarget target(ctx);
|
||||
|
||||
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect>();
|
||||
|
@ -146,11 +146,11 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
|
|||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
|
||||
void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {
|
||||
return std::make_unique<StdExpandOpsPass>();
|
||||
std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
|
||||
return std::make_unique<ExpandOpsPass>();
|
||||
}
|
|
@ -14,6 +14,7 @@
|
|||
namespace mlir {
|
||||
|
||||
class AffineDialect;
|
||||
class StandardOpsDialect;
|
||||
|
||||
// Forward declaration from Dialect.h
|
||||
template <typename ConcreteDialect>
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRStandardOpsTransforms
|
||||
Bufferize.cpp
|
||||
ComposeSubView.cpp
|
||||
DecomposeCallGraphTypes.cpp
|
||||
ExpandOps.cpp
|
||||
FuncBufferize.cpp
|
||||
FuncConversions.cpp
|
||||
TensorConstantBufferize.cpp
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt -std-expand %s -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw_to_generic
|
||||
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
|
|
@ -3,6 +3,7 @@ add_subdirectory(DLTI)
|
|||
add_subdirectory(GPU)
|
||||
add_subdirectory(Linalg)
|
||||
add_subdirectory(Math)
|
||||
add_subdirectory(MemRef)
|
||||
add_subdirectory(SCF)
|
||||
add_subdirectory(Shape)
|
||||
add_subdirectory(SPIRV)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRMemRefTestPasses
|
||||
TestComposeSubView.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRPass
|
||||
MLIRMemRefTransforms
|
||||
MLIRTestDialect
|
||||
)
|
||||
|
||||
target_include_directories(MLIRMemRefTestPasses
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../Test
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../Test
|
||||
)
|
|
@ -11,7 +11,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
|
@ -35,7 +35,7 @@ void TestComposeSubViewPass::getDependentDialects(
|
|||
|
||||
void TestComposeSubViewPass::runOnOperation() {
|
||||
OwningRewritePatternList patterns(&getContext());
|
||||
populateComposeSubViewPatterns(patterns, &getContext());
|
||||
memref::populateComposeSubViewPatterns(patterns, &getContext());
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
} // namespace
|
|
@ -1,7 +1,6 @@
|
|||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRStandardOpsTestPasses
|
||||
TestDecomposeCallGraphTypes.cpp
|
||||
TestComposeSubView.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -std-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \
|
||||
// RUN: mlir-opt %s -convert-scf-to-std -memref-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \
|
||||
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
|
||||
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
|
||||
// RUN: | FileCheck %s
|
||||
|
|
|
@ -18,6 +18,7 @@ if(MLIR_INCLUDE_TESTS)
|
|||
MLIRGPUTestPasses
|
||||
MLIRLinalgTestPasses
|
||||
MLIRMathTestPasses
|
||||
MLIRMemRefTestPasses
|
||||
MLIRSCFTestPasses
|
||||
MLIRShapeTestPasses
|
||||
MLIRSPIRVTestPasses
|
||||
|
|
Loading…
Reference in New Issue