forked from OSchip/llvm-project
[mlir] Move memref.subview patterns to MemRef/Transforms/
These patterns have been used as a prerequisite step for lowering to SPIR-V. But they don't involve SPIR-V dialect ops; they are pure memref/vector op transformations. Given now we have a dedicated MemRef dialect, moving them to Memref/Transforms/, which is a more suitable place to host them, to allow used by others. This commit just moves code around and renames patterns/passes accordingly. CMakeLists.txt for existing MemRef libraries are also improved along the way. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D100326
This commit is contained in:
parent
05df5c54e8
commit
0deeaaca39
|
@ -410,17 +410,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
|
|||
// StandardToSPIRV
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
|
||||
let summary = "Legalize standard ops for SPIR-V lowering";
|
||||
let description = [{
|
||||
The pass contains certain intra standard op conversions that are meant for
|
||||
lowering to SPIR-V ops, e.g., folding subviews loads/stores to the original
|
||||
loads/stores from/to the original memref.
|
||||
}];
|
||||
let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
|
||||
let summary = "Convert Standard dialect to SPIR-V dialect";
|
||||
let constructor = "mlir::createConvertStandardToSPIRVPass()";
|
||||
|
|
|
@ -40,11 +40,6 @@ void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
int64_t byteCountThreshold,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends to a pattern list patterns to legalize ops that are not directly
|
||||
/// lowered to SPIR-V.
|
||||
void populateStdLegalizationPatternsForSPIRVLowering(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRV_H
|
||||
|
|
|
@ -20,9 +20,6 @@ namespace mlir {
|
|||
/// Creates a pass to convert standard ops to SPIR-V ops.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertStandardToSPIRVPass();
|
||||
|
||||
/// Creates a pass to legalize ops that are not directly lowered to SPIR-V.
|
||||
std::unique_ptr<Pass> createLegalizeStdOpsForSPIRVLoweringPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_STANDARDTOSPIRVPASS_H
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name MemRef)
|
||||
add_public_tablegen_target(MLIRMemRefPassIncGen)
|
||||
add_dependencies(mlir-headers MLIRMemRefPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc MemRefPasses ./)
|
|
@ -0,0 +1,47 @@
|
|||
//===- Passes.h - MemRef Patterns and Passes --------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header declares patterns and passes on MemRef operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
|
||||
#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace memref {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Appends patterns for folding memref.subview ops into consumer load/store ops
|
||||
/// into `patterns`.
|
||||
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Creates an operation pass to fold memref.subview ops into consumer
|
||||
/// load/store ops into `patterns`.
|
||||
std::unique_ptr<Pass> createFoldSubViewOpsPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
|
|
@ -0,0 +1,26 @@
|
|||
//===-- Passes.td - MemRef transformation definition file --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
|
||||
#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
|
||||
let summary = "Fold memref.subview ops into consumer load/store ops";
|
||||
let description = [{
|
||||
The pass folds loading/storing from/to subview ops to loading/storing
|
||||
from/to the original memref.
|
||||
}];
|
||||
let constructor = "mlir::memref::createFoldSubViewOpsPass()";
|
||||
let dependentDialects = ["memref::MemRefDialect", "vector::VectorDialect"];
|
||||
}
|
||||
|
||||
|
||||
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
|
||||
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Quant/Passes.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
|
||||
|
@ -55,6 +56,7 @@ inline void registerAllPasses() {
|
|||
registerGpuSerializeToHsacoPass();
|
||||
registerLinalgPasses();
|
||||
LLVM::registerLLVMPasses();
|
||||
memref::registerMemRefPasses();
|
||||
quant::registerQuantPasses();
|
||||
registerSCFPasses();
|
||||
registerShapePasses();
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
add_mlir_conversion_library(MLIRStandardToSPIRV
|
||||
LegalizeStandardForSPIRV.cpp
|
||||
StandardToSPIRV.cpp
|
||||
StandardToSPIRVPass.cpp
|
||||
|
||||
|
|
|
@ -1,23 +1,3 @@
|
|||
add_mlir_dialect_library(MLIRMemRef
|
||||
IR/MemRefDialect.cpp
|
||||
IR/MemRefOps.cpp
|
||||
Utils/MemRefUtils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
|
||||
|
||||
DEPENDS
|
||||
MLIRStandardOpsIncGen
|
||||
MLIRMemRefOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDialect
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
MLIRViewLikeInterface
|
||||
)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
|
|
@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRMemRef
|
|||
MLIRDialect
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRMemRefUtils
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
MLIRViewLikeInterface
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_dialect_library(MLIRMemRefTransforms
|
||||
FoldSubViewOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
|
||||
|
||||
DEPENDS
|
||||
MLIRMemRefPassIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
MLIRTransforms
|
||||
MLIRVector
|
||||
)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
|
||||
//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,16 +6,13 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This transformation pass legalizes operations before the conversion to SPIR-V
|
||||
// dialect to handle ops that cannot be lowered directly.
|
||||
// This transformation pass folds loading/storing from/to subview ops into
|
||||
// loading/storing from/to the original memref.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -23,6 +20,49 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Given the 'indices' of an load/store operation where the memref is a result
|
||||
/// of a subview op, returns the indices w.r.t to the source memref of the
|
||||
/// subview op. For example
|
||||
///
|
||||
/// %0 = ... : memref<12x42xf32>
|
||||
/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
|
||||
/// memref<4x4xf32, offset=?, strides=[?, ?]>
|
||||
/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
|
||||
///
|
||||
/// could be folded into
|
||||
///
|
||||
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
|
||||
/// memref<12x42xf32>
|
||||
static LogicalResult
|
||||
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
|
||||
memref::SubViewOp subViewOp, ValueRange indices,
|
||||
SmallVectorImpl<Value> &sourceIndices) {
|
||||
// TODO: Aborting when the offsets are static. There might be a way to fold
|
||||
// the subview op with load even if the offsets have been canonicalized
|
||||
// away.
|
||||
SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
|
||||
auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
|
||||
auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
|
||||
assert(opRanges.size() == indices.size() &&
|
||||
"expected as many indices as rank of subview op result type");
|
||||
|
||||
// New indices for the load are the current indices * subview_stride +
|
||||
// subview_offset.
|
||||
sourceIndices.resize(indices.size());
|
||||
for (auto index : llvm::enumerate(indices)) {
|
||||
auto offset = *(opOffsets.begin() + index.index());
|
||||
auto stride = *(opStrides.begin() + index.index());
|
||||
auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
|
||||
sourceIndices[index.index()] =
|
||||
rewriter.create<AddIOp>(loc, offset, mul).getResult();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Helpers to access the memref operand for each op.
|
||||
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
|
||||
|
||||
|
@ -34,6 +74,10 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
|
|||
return op.source();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Merges subview operation with load/transferRead operation.
|
||||
template <typename OpTy>
|
||||
|
@ -101,62 +145,15 @@ void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility functions for op legalization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Given the 'indices' of an load/store operation where the memref is a result
|
||||
/// of a subview op, returns the indices w.r.t to the source memref of the
|
||||
/// subview op. For example
|
||||
///
|
||||
/// %0 = ... : memref<12x42xf32>
|
||||
/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
|
||||
/// memref<4x4xf32, offset=?, strides=[?, ?]>
|
||||
/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
|
||||
///
|
||||
/// could be folded into
|
||||
///
|
||||
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
|
||||
/// memref<12x42xf32>
|
||||
static LogicalResult
|
||||
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
|
||||
memref::SubViewOp subViewOp, ValueRange indices,
|
||||
SmallVectorImpl<Value> &sourceIndices) {
|
||||
// TODO: Aborting when the offsets are static. There might be a way to fold
|
||||
// the subview op with load even if the offsets have been canonicalized
|
||||
// away.
|
||||
SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
|
||||
auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
|
||||
auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
|
||||
assert(opRanges.size() == indices.size() &&
|
||||
"expected as many indices as rank of subview op result type");
|
||||
|
||||
// New indices for the load are the current indices * subview_stride +
|
||||
// subview_offset.
|
||||
sourceIndices.resize(indices.size());
|
||||
for (auto index : llvm::enumerate(indices)) {
|
||||
auto offset = *(opOffsets.begin() + index.index());
|
||||
auto stride = *(opStrides.begin() + index.index());
|
||||
auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
|
||||
sourceIndices[index.index()] =
|
||||
rewriter.create<AddIOp>(loc, offset, mul).getResult();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Folding SubViewOp and LoadOp/TransferReadOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OpTy>
|
||||
LogicalResult
|
||||
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp =
|
||||
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
|
||||
if (!subViewOp) {
|
||||
if (!subViewOp)
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> sourceIndices;
|
||||
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
|
||||
loadOp.indices(), sourceIndices)))
|
||||
|
@ -166,19 +163,15 @@ LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Folding SubViewOp and StoreOp/TransferWriteOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OpTy>
|
||||
LogicalResult
|
||||
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp =
|
||||
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
|
||||
if (!subViewOp) {
|
||||
if (!subViewOp)
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> sourceIndices;
|
||||
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
|
||||
storeOp.indices(), sourceIndices)))
|
||||
|
@ -188,12 +181,7 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Hook for adding patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
|
||||
RewritePatternSet &patterns) {
|
||||
void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
|
||||
LoadOpOfSubViewFolder<vector::TransferReadOp>,
|
||||
StoreOpOfSubViewFolder<memref::StoreOp>,
|
||||
|
@ -202,23 +190,28 @@ void mlir::populateStdLegalizationPatternsForSPIRVLowering(
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass for testing just the legalization patterns.
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct SPIRVLegalization final
|
||||
: public LegalizeStandardForSPIRVBase<SPIRVLegalization> {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
|
||||
|
||||
struct FoldSubViewOpsPass final
|
||||
: public FoldSubViewOpsBase<FoldSubViewOpsPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void SPIRVLegalization::runOnOperation() {
|
||||
void FoldSubViewOpsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateStdLegalizationPatternsForSPIRVLowering(patterns);
|
||||
memref::populateFoldSubViewOpPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
|
||||
std::move(patterns));
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
|
||||
return std::make_unique<SPIRVLegalization>();
|
||||
std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
|
||||
return std::make_unique<FoldSubViewOpsPass>();
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
add_mlir_dialect_library(MLIRMemRefUtils
|
||||
MemRefUtils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
|
@ -11,7 +11,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
// RUN: mlir-opt -legalize-std-for-spirv %s -o - | FileCheck %s
|
||||
|
||||
module {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// memref.subview
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @fold_static_stride_subview
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<12x32xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: index
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: index
|
||||
func @fold_static_stride_subview
|
||||
(%arg0 : memref<12x32xf32>, %arg1 : index,
|
||||
%arg2 : index, %arg3 : index, %arg4 : index) {
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 3
|
||||
// CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C3]]
|
||||
// CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]]
|
||||
// CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[ARG2]]
|
||||
// CHECK: %[[T3:.*]] = addi %[[T2]], %[[C2]]
|
||||
// CHECK: %[[LOADVAL:.*]] = memref.load %[[ARG0]][%[[T1]], %[[T3]]]
|
||||
// CHECK: %[[STOREVAL:.*]] = math.sqrt %[[LOADVAL]]
|
||||
// CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C3]]
|
||||
// CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]]
|
||||
// CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[ARG2]]
|
||||
// CHECK: %[[T9:.*]] = addi %[[T8]], %[[C2]]
|
||||
// CHECK: memref.store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
|
||||
%0 = memref.subview %arg0[%arg1, 2][4, 4][3, %arg2] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [96, ?]>
|
||||
%1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
|
||||
%2 = math.sqrt %1 : f32
|
||||
memref.store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt -legalize-std-for-spirv -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-opt -fold-memref-subview-ops -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @fold_static_stride_subview_with_load
|
||||
// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
|
||||
|
@ -40,7 +41,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
|
|||
applyPassManagerCLOptions(passManager);
|
||||
|
||||
passManager.addPass(createGpuKernelOutliningPass());
|
||||
passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
|
||||
passManager.addPass(memref::createFoldSubViewOpsPass());
|
||||
passManager.addPass(createConvertGPUToSPIRVPass());
|
||||
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
|
||||
modulePM.addPass(spirv::createLowerABIAttributesPass());
|
||||
|
|
Loading…
Reference in New Issue