diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h similarity index 71% rename from mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h rename to mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h index 7a5ae3e8417b..20aa1c02db17 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ComposeSubView.h @@ -1,4 +1,4 @@ -//===- ComposeSubView.h - Combining composed subview ops --------*- C++ -*-===// +//===- ComposeSubView.h - Combining composed memref ops ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,19 +10,20 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ -#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_ +#define MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_ namespace mlir { - -// Forward declarations. class MLIRContext; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; +namespace memref { + void populateComposeSubViewPatterns(OwningRewritePatternList &patterns, MLIRContext *context); +} // namespace memref } // namespace mlir -#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_ +#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_COMPOSESUBVIEW_H_ diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h index 23d12508b65c..6e9f6fb10a66 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -18,6 +18,7 @@ namespace mlir { class AffineDialect; +class StandardOpsDialect; namespace tensor { class TensorDialect; } // namespace tensor @@ -31,6 +32,9 @@ namespace memref { // Patterns //===----------------------------------------------------------------------===// +/// Collects a set of patterns to rewrite ops within the memref dialect. +void populateExpandOpsPatterns(RewritePatternSet &patterns); + /// Appends patterns for folding memref.subview ops into consumer load/store ops /// into `patterns`. void populateFoldSubViewOpPatterns(RewritePatternSet &patterns); @@ -51,6 +55,11 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); // Passes //===----------------------------------------------------------------------===// +/// Creates an instance of the ExpandOps pass that legalizes memref dialect ops +/// to be convertible to LLVM. For example, `memref.reshape` gets converted to +/// `memref_reinterpret_cast`. +std::unique_ptr createExpandOpsPass(); + /// Creates an operation pass to fold memref.subview ops into consumer /// load/store ops into `patterns`. std::unique_ptr createFoldSubViewOpsPass(); diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index d67746b9c603..0f2e3a91a255 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -11,6 +11,12 @@ include "mlir/Pass/PassBase.td" +def ExpandOps : Pass<"memref-expand", "FuncOp"> { + let summary = "Legalize memref operations to be convertible to LLVM."; + let constructor = "mlir::memref::createExpandOpsPass()"; + let dependentDialects = ["StandardOpsDialect"]; +} + def FoldSubViewOps : Pass<"fold-memref-subview-ops"> { let summary = "Fold memref.subview ops into consumer load/store ops"; let description = [{ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h index b47303c70250..3f723133b2a6 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -45,16 +45,6 @@ void populateTensorConstantBufferizePatterns( /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(unsigned alignment = 0); -/// Creates an instance of the StdExpand pass that legalizes Std -/// dialect ops to be convertible to LLVM. For example, -/// `std.arith.ceildivsi` gets transformed to a number of std operations, -/// which can be lowered to LLVM; `memref.reshape` gets converted to -/// `memref_reinterpret_cast`. -std::unique_ptr createStdExpandOpsPass(); - -/// Collects a set of patterns to rewrite ops within the Std dialect. -void populateStdExpandOpsPatterns(RewritePatternSet &patterns); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td index 8ac2b64a8006..339c1b1194cc 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -18,11 +18,6 @@ def StdBufferize : Pass<"std-bufferize", "FuncOp"> { "memref::MemRefDialect", "scf::SCFDialect"]; } -def StdExpandOps : Pass<"std-expand", "FuncOp"> { - let summary = "Legalize std operations to be convertible to LLVM."; - let constructor = "mlir::createStdExpandOpsPass()"; -} - def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index 319f9bbb95a3..99bc552548f9 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,4 +1,6 @@ add_mlir_dialect_library(MLIRMemRefTransforms + ComposeSubView.cpp + ExpandOps.cpp FoldSubViewOps.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp similarity index 96% rename from mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp rename to mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index cabaa614fa2a..29ba5060d167 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -11,8 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h" - +#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" @@ -21,7 +20,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -namespace mlir { +using namespace mlir; namespace { @@ -128,9 +127,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern { } // namespace -void populateComposeSubViewPatterns(OwningRewritePatternList &patterns, - MLIRContext *context) { +void mlir::memref::populateComposeSubViewPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } - -} // namespace mlir diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp similarity index 93% rename from mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp rename to mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index e5cf08da4904..293fb58d4e70 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -17,8 +17,8 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -120,13 +120,13 @@ public: } }; -struct StdExpandOpsPass : public StdExpandOpsBase { +struct ExpandOpsPass : public ExpandOpsBase { void runOnOperation() override { MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); - populateStdExpandOpsPatterns(patterns); - ConversionTarget target(getContext()); + memref::populateExpandOpsPatterns(patterns); + ConversionTarget target(ctx); target.addLegalDialect(); @@ -146,11 +146,11 @@ struct StdExpandOpsPass : public StdExpandOpsBase { } // namespace -void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) { +void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) { patterns.add( patterns.getContext()); } -std::unique_ptr mlir::createStdExpandOpsPass() { - return std::make_unique(); +std::unique_ptr mlir::memref::createExpandOpsPass() { + return std::make_unique(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h index d15631526817..d1e5baa798fd 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h @@ -14,6 +14,7 @@ namespace mlir { class AffineDialect; +class StandardOpsDialect; // Forward declaration from Dialect.h template diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt index 82e25923840d..f8082601b48b 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,8 +1,6 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms Bufferize.cpp - ComposeSubView.cpp DecomposeCallGraphTypes.cpp - ExpandOps.cpp FuncBufferize.cpp FuncConversions.cpp TensorConstantBufferize.cpp diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir similarity index 96% rename from mlir/test/Dialect/Standard/expand-ops.mlir rename to mlir/test/Dialect/MemRef/expand-ops.mlir index 2a1c367ff80f..bcf83042184f 100644 --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/MemRef/expand-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -std-expand %s -split-input-file | FileCheck %s +// RUN: mlir-opt -memref-expand %s -split-input-file | FileCheck %s // CHECK-LABEL: func @atomic_rmw_to_generic // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index d8219057f939..d02dc61045fa 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(DLTI) add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(Math) +add_subdirectory(MemRef) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SPIRV) diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt new file mode 100644 index 000000000000..c43dec48bdab --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -0,0 +1,17 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRMemRefTestPasses + TestComposeSubView.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRPass + MLIRMemRefTransforms + MLIRTestDialect + ) + +target_include_directories(MLIRMemRefTestPasses + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../Test + ${CMAKE_CURRENT_BINARY_DIR}/../Test + ) diff --git a/mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp similarity index 92% rename from mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp rename to mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp index 1638ee3debe0..20add4cc94c8 100644 --- a/mlir/test/lib/Dialect/StandardOps/TestComposeSubView.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h" +#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -35,7 +35,7 @@ void TestComposeSubViewPass::getDependentDialects( void TestComposeSubViewPass::runOnOperation() { OwningRewritePatternList patterns(&getContext()); - populateComposeSubViewPatterns(patterns, &getContext()); + memref::populateComposeSubViewPatterns(patterns, &getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } } // namespace diff --git a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt index 22d5818e3413..b85de09e10bf 100644 --- a/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/test/lib/Dialect/StandardOps/CMakeLists.txt @@ -1,7 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRStandardOpsTestPasses TestDecomposeCallGraphTypes.cpp - TestComposeSubView.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/mlir-cpu-runner/memref-reshape.mlir b/mlir/test/mlir-cpu-runner/memref-reshape.mlir index 6d0397399cca..e74d6219a1f3 100644 --- a/mlir/test/mlir-cpu-runner/memref-reshape.mlir +++ b/mlir/test/mlir-cpu-runner/memref-reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-scf-to-std -std-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \ +// RUN: mlir-opt %s -convert-scf-to-std -memref-expand -convert-arith-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 7a16a497b6f8..c03d6403a74e 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -18,6 +18,7 @@ if(MLIR_INCLUDE_TESTS) MLIRGPUTestPasses MLIRLinalgTestPasses MLIRMathTestPasses + MLIRMemRefTestPasses MLIRSCFTestPasses MLIRShapeTestPasses MLIRSPIRVTestPasses