[mlir] Move memref.subview patterns to MemRef/Transforms/

These patterns have been used as a prerequisite step for lowering
to SPIR-V. But they don't involve SPIR-V dialect ops; they are
pure memref/vector op transformations. Given now we have a dedicated
MemRef dialect, moving them to Memref/Transforms/, which is a more
suitable place to host them, to allow used by others.

This commit just moves code around and renames patterns/passes
accordingly. CMakeLists.txt for existing MemRef libraries are
also improved along the way.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D100326
This commit is contained in:
Lei Zhang 2021-04-12 16:38:04 -04:00
parent 05df5c54e8
commit 0deeaaca39
18 changed files with 186 additions and 159 deletions

View File

@ -410,17 +410,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
// StandardToSPIRV
//===----------------------------------------------------------------------===//
def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
let summary = "Legalize standard ops for SPIR-V lowering";
let description = [{
The pass contains certain intra standard op conversions that are meant for
lowering to SPIR-V ops, e.g., folding subviews loads/stores to the original
loads/stores from/to the original memref.
}];
let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
}
def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
let summary = "Convert Standard dialect to SPIR-V dialect";
let constructor = "mlir::createConvertStandardToSPIRVPass()";

View File

@ -40,11 +40,6 @@ void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
RewritePatternSet &patterns);
/// Appends to a pattern list patterns to legalize ops that are not directly
/// lowered to SPIR-V.
void populateStdLegalizationPatternsForSPIRVLowering(
RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H

View File

@ -20,9 +20,6 @@ namespace mlir {
/// Creates a pass to convert standard ops to SPIR-V ops.
std::unique_ptr<OperationPass<ModuleOp>> createConvertStandardToSPIRVPass();
/// Creates a pass to legalize ops that are not directly lowered to SPIR-V.
std::unique_ptr<Pass> createLegalizeStdOpsForSPIRVLoweringPass();
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H

View File

@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name MemRef)
add_public_tablegen_target(MLIRMemRefPassIncGen)
add_dependencies(mlir-headers MLIRMemRefPassIncGen)
add_mlir_doc(Passes -gen-pass-doc MemRefPasses ./)

View File

@ -0,0 +1,47 @@
//===- Passes.h - MemRef Patterns and Passes --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header declares patterns and passes on MemRef operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace memref {
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
/// Appends patterns for folding memref.subview ops into consumer load/store ops
/// into `patterns`.
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
/// Creates an operation pass to fold memref.subview ops into consumer
/// load/store ops into `patterns`.
std::unique_ptr<Pass> createFoldSubViewOpsPass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H

View File

@ -0,0 +1,26 @@
//===-- Passes.td - MemRef transformation definition file --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
let summary = "Fold memref.subview ops into consumer load/store ops";
let description = [{
The pass folds loading/storing from/to subview ops to loading/storing
from/to the original memref.
}];
let constructor = "mlir::memref::createFoldSubViewOpsPass()";
let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"];
}
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES

View File

@ -20,6 +20,7 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@ -55,6 +56,7 @@ inline void registerAllPasses() {
registerGpuSerializeToHsacoPass();
registerLinalgPasses();
LLVM::registerLLVMPasses();
memref::registerMemRefPasses();
quant::registerQuantPasses();
registerSCFPasses();
registerShapePasses();

View File

@ -1,5 +1,4 @@
add_mlir_conversion_library(MLIRStandardToSPIRV
LegalizeStandardForSPIRV.cpp
StandardToSPIRV.cpp
StandardToSPIRVPass.cpp

View File

@ -1,23 +1,3 @@
add_mlir_dialect_library(MLIRMemRef
IR/MemRefDialect.cpp
IR/MemRefOps.cpp
Utils/MemRefUtils.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
DEPENDS
MLIRStandardOpsIncGen
MLIRMemRefOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRStandard
MLIRTensor
MLIRViewLikeInterface
)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRMemRef
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefUtils
MLIRStandard
MLIRTensor
MLIRViewLikeInterface

View File

@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIRMemRefTransforms
FoldSubViewOps.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
DEPENDS
MLIRMemRefPassIncGen
LINK_LIBS PUBLIC
MLIRMemRef
MLIRPass
MLIRStandard
MLIRTransforms
MLIRVector
)

View File

@ -1,4 +1,4 @@
//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,16 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
// This transformation pass legalizes operations before the conversion to SPIR-V
// dialect to handle ops that cannot be lowered directly.
// This transformation pass folds loading/storing from/to subview ops into
// loading/storing from/to the original memref.
//
//===----------------------------------------------------------------------===//
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
@ -23,6 +20,49 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Given the 'indices' of an load/store operation where the memref is a result
/// of a subview op, returns the indices w.r.t to the source memref of the
/// subview op. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
/// memref<4x4xf32, offset=?, strides=[?, ?]>
/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
///
/// could be folded into
///
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
/// memref<12x42xf32>
static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
memref::SubViewOp subViewOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
// TODO: Aborting when the offsets are static. There might be a way to fold
// the subview op with load even if the offsets have been canonicalized
// away.
SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
assert(opRanges.size() == indices.size() &&
"expected as many indices as rank of subview op result type");
// New indices for the load are the current indices * subview_stride +
// subview_offset.
sourceIndices.resize(indices.size());
for (auto index : llvm::enumerate(indices)) {
auto offset = *(opOffsets.begin() + index.index());
auto stride = *(opStrides.begin() + index.index());
auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
sourceIndices[index.index()] =
rewriter.create<AddIOp>(loc, offset, mul).getResult();
}
return success();
}
/// Helpers to access the memref operand for each op.
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
@ -34,6 +74,10 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.source();
}
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
@ -101,62 +145,15 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
}
} // namespace
//===----------------------------------------------------------------------===//
// Utility functions for op legalization.
//===----------------------------------------------------------------------===//
/// Given the 'indices' of an load/store operation where the memref is a result
/// of a subview op, returns the indices w.r.t to the source memref of the
/// subview op. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
/// memref<4x4xf32, offset=?, strides=[?, ?]>
/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
///
/// could be folded into
///
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
/// memref<12x42xf32>
static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
memref::SubViewOp subViewOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
// TODO: Aborting when the offsets are static. There might be a way to fold
// the subview op with load even if the offsets have been canonicalized
// away.
SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
assert(opRanges.size() == indices.size() &&
"expected as many indices as rank of subview op result type");
// New indices for the load are the current indices * subview_stride +
// subview_offset.
sourceIndices.resize(indices.size());
for (auto index : llvm::enumerate(indices)) {
auto offset = *(opOffsets.begin() + index.index());
auto stride = *(opStrides.begin() + index.index());
auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
sourceIndices[index.index()] =
rewriter.create<AddIOp>(loc, offset, mul).getResult();
}
return success();
}
//===----------------------------------------------------------------------===//
// Folding SubViewOp and LoadOp/TransferReadOp.
//===----------------------------------------------------------------------===//
template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp) {
if (!subViewOp)
return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
loadOp.indices(), sourceIndices)))
@ -166,19 +163,15 @@ LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
return success();
}
//===----------------------------------------------------------------------===//
// Folding SubViewOp and StoreOp/TransferWriteOp.
//===----------------------------------------------------------------------===//
template <typename OpTy>
LogicalResult
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp) {
if (!subViewOp)
return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
storeOp.indices(), sourceIndices)))
@ -188,12 +181,7 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
return success();
}
//===----------------------------------------------------------------------===//
// Hook for adding patterns.
//===----------------------------------------------------------------------===//
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
RewritePatternSet &patterns) {
void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
LoadOpOfSubViewFolder<vector::TransferReadOp>,
StoreOpOfSubViewFolder<memref::StoreOp>,
@ -202,23 +190,28 @@ void mlir::populateStdLegalizationPatternsForSPIRVLowering(
}
//===----------------------------------------------------------------------===//
// Pass for testing just the legalization patterns.
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
struct SPIRVLegalization final
: public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
struct FoldSubViewOpsPass final
: public FoldSubViewOpsBase<FoldSubViewOpsPass> {
void runOnOperation() override;
};
} // namespace
void SPIRVLegalization::runOnOperation() {
void FoldSubViewOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateStdLegalizationPatternsForSPIRVLowering(patterns);
memref::populateFoldSubViewOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
return std::make_unique<SPIRVLegalization>();
std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
return std::make_unique<FoldSubViewOpsPass>();
}

View File

@ -0,0 +1,11 @@
add_mlir_dialect_library(MLIRMemRefUtils
MemRefUtils.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRSideEffectInterfaces
)

View File

@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
using namespace mlir;

View File

@ -1,38 +0,0 @@
// RUN: mlir-opt -legalize-std-for-spirv %s -o - | FileCheck %s
module {
//===----------------------------------------------------------------------===//
// memref.subview
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @fold_static_stride_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: index
func @fold_static_stride_subview
(%arg0 : memref<12x32xf32>, %arg1 : index,
%arg2 : index, %arg3 : index, %arg4 : index) {
// CHECK-DAG: %[[C2:.*]] = constant 2
// CHECK-DAG: %[[C3:.*]] = constant 3
// CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C3]]
// CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]]
// CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[ARG2]]
// CHECK: %[[T3:.*]] = addi %[[T2]], %[[C2]]
// CHECK: %[[LOADVAL:.*]] = memref.load %[[ARG0]][%[[T1]], %[[T3]]]
// CHECK: %[[STOREVAL:.*]] = math.sqrt %[[LOADVAL]]
// CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C3]]
// CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]]
// CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[ARG2]]
// CHECK: %[[T9:.*]] = addi %[[T8]], %[[C2]]
// CHECK: memref.store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
%0 = memref.subview %arg0[%arg1, 2][4, 4][3, %arg2] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [96, ?]>
%1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
%2 = math.sqrt %1 : f32
memref.store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
return
}
} // end module

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -legalize-std-for-spirv -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-opt -fold-memref-subview-ops -verify-diagnostics %s -o - | FileCheck %s
// CHECK-LABEL: @fold_static_stride_subview_with_load
// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index

View File

@ -20,6 +20,7 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@ -40,7 +41,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
applyPassManagerCLOptions(passManager);
passManager.addPass(createGpuKernelOutliningPass());
passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
passManager.addPass(memref::createFoldSubViewOpsPass());
passManager.addPass(createConvertGPUToSPIRVPass());
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
modulePM.addPass(spirv::createLowerABIAttributesPass());