forked from OSchip/llvm-project
[mlir][bufferize] Add vector-bufferize pass and remove obsolete patterns from Linalg Bufferize
Differential Revision: https://reviews.llvm.org/D119444
This commit is contained in:
parent
8527859d89
commit
73e880fbf1
|
@ -130,8 +130,34 @@ struct BufferizationOptions {
|
|||
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getName().getStringRef() == opName;
|
||||
};
|
||||
opFilter.push_back(
|
||||
OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
|
||||
allowOperationInFilter(filterFn);
|
||||
}
|
||||
|
||||
/// Deny the given op and activate the filter (`hasFilter`).
|
||||
///
|
||||
/// This function adds a DENY filter.
|
||||
void denyOperationInFilter(StringRef opName) {
|
||||
hasFilter = true;
|
||||
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getName().getStringRef() == opName;
|
||||
};
|
||||
denyOperationInFilter(filterFn);
|
||||
}
|
||||
|
||||
/// Allow ops that are matched by `fn` and activate the filter (`hasFilter`).
|
||||
///
|
||||
/// This function adds an ALLOW filter.
|
||||
void allowOperationInFilter(OpFilterEntry::FilterFn fn) {
|
||||
hasFilter = true;
|
||||
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW});
|
||||
}
|
||||
|
||||
/// Deny ops that are matched by `fn` and activate the filter (`hasFilter`).
|
||||
///
|
||||
/// This function adds a DENY filter.
|
||||
void denyOperationInFilter(OpFilterEntry::FilterFn fn) {
|
||||
hasFilter = true;
|
||||
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY});
|
||||
}
|
||||
|
||||
/// Try to cast the given op to BufferizableOpInterface if the op is allow
|
||||
|
|
|
@ -1 +1,5 @@
|
|||
# This dialect does currently not have any passes.
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector)
|
||||
add_public_tablegen_target(MLIRVectorTransformsIncGen)
|
||||
|
||||
add_mlir_doc(Passes VectorPasses ./ -gen-pass-doc)
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
|
||||
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace vector {
|
||||
/// Creates an instance of the `vector` dialect bufferization pass.
|
||||
std::unique_ptr<Pass> createVectorBufferizePass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
|
||||
} // namespace vector
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
|
|
@ -0,0 +1,19 @@
|
|||
//===-- Passes.td - Vector pass definition file ------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
|
||||
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def VectorBufferize : Pass<"vector-bufferize", "FuncOp"> {
|
||||
let summary = "Bufferize Vector dialect ops";
|
||||
let constructor = "mlir::vector::createVectorBufferizePass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
|
|
@ -32,6 +32,7 @@
|
|||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
@ -71,6 +72,7 @@ inline void registerAllPasses() {
|
|||
registerStandardPasses();
|
||||
tensor::registerTensorPasses();
|
||||
tosa::registerTosaOptPasses();
|
||||
vector::registerVectorPasses();
|
||||
|
||||
// Dialect pipelines
|
||||
sparse_tensor::registerSparseTensorPipelines();
|
||||
|
|
|
@ -268,43 +268,6 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorTransferReadOpConverter
|
||||
: public OpConversionPattern<vector::TransferReadOp> {
|
||||
public:
|
||||
using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (readOp.getShapedType().isa<MemRefType>())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
|
||||
readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
|
||||
adaptor.permutation_mapAttr(), adaptor.padding(), adaptor.mask(),
|
||||
adaptor.in_boundsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorTransferWriteOpConverter
|
||||
: public OpConversionPattern<vector::TransferWriteOp> {
|
||||
public:
|
||||
using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (writeOp.getShapedType().isa<MemRefType>())
|
||||
return failure();
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
|
||||
adaptor.permutation_mapAttr(),
|
||||
adaptor.in_bounds() ? adaptor.in_boundsAttr() : ArrayAttr());
|
||||
rewriter.replaceOp(writeOp, adaptor.source());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -329,9 +292,6 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
|
|||
return typeConverter.isLegal(op);
|
||||
};
|
||||
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
|
||||
target
|
||||
.addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
|
||||
isLegalOperation);
|
||||
|
||||
RewritePatternSet patterns(&context);
|
||||
populateLinalgBufferizePatterns(typeConverter, patterns);
|
||||
|
@ -358,9 +318,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
|
|||
BufferizeTensorReshapeOp<tensor::ExpandShapeOp>,
|
||||
BufferizeTensorReshapeOp<tensor::CollapseShapeOp>,
|
||||
ExtractSliceOpConverter,
|
||||
InsertSliceOpConverter,
|
||||
VectorTransferReadOpConverter,
|
||||
VectorTransferWriteOpConverter
|
||||
InsertSliceOpConverter
|
||||
>(typeConverter, patterns.getContext());
|
||||
// clang-format on
|
||||
patterns.add<GeneralizePadOpPattern>(patterns.getContext());
|
||||
|
|
|
@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
|
|||
MLIRStandardOpsTransforms
|
||||
MLIRTensorTransforms
|
||||
MLIRVectorToLLVM
|
||||
MLIRVectorTransforms
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -31,6 +32,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
|
|||
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
|
||||
pm.addPass(createSparseTensorConversionPass());
|
||||
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(vector::createVectorBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertSCFToCFPass());
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
//===- Bufferize.cpp - Bufferization for `vector` dialect ops -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements bufferization of `vector` dialect ops
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace {
|
||||
struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.allowDialectInFilter<vector::VectorDialect>();
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
|
||||
tensor::TensorDialect, vector::VectorDialect>();
|
||||
vector::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::vector::createVectorBufferizePass() {
|
||||
return std::make_unique<VectorBufferizePass>();
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRVectorTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
VectorDropLeadUnitDim.cpp
|
||||
VectorInsertExtractStridedSliceRewritePatterns.cpp
|
||||
VectorMultiDimReductionTransforms.cpp
|
||||
|
@ -12,17 +13,22 @@ add_mlir_dialect_library(MLIRVectorTransforms
|
|||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
|
||||
|
||||
DEPENDS
|
||||
MLIRVectorTransformsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRAffineAnalysis
|
||||
MLIRAffineUtils
|
||||
MLIRArithmetic
|
||||
MLIRBufferization
|
||||
MLIRBufferizationTransforms
|
||||
MLIRDialectUtils
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRMemRef
|
||||
MLIRSCF
|
||||
MLIRTransforms
|
||||
MLIRVector
|
||||
MLIRVectorInterfaces
|
||||
MLIRVectorUtils
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
//===- PassDetail.h - Vector Pass class details -----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_
|
||||
#define DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace bufferization {
|
||||
class BufferizationDialect;
|
||||
} // namespace bufferization
|
||||
|
||||
namespace memref {
|
||||
class MemRefDialect;
|
||||
} // namespace memref
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_
|
|
@ -303,23 +303,6 @@ func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tens
|
|||
// CHECK: return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32>
|
||||
// CHECK: }
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_transfer
|
||||
func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%read = vector.transfer_read %in[%c0], %cst {in_bounds = [true]}
|
||||
: tensor<4xf32>, vector<4xf32>
|
||||
%tanh = math.tanh %read : vector<4xf32>
|
||||
%write = vector.transfer_write %tanh, %out[%c0] {in_bounds = [true]}
|
||||
: vector<4xf32>, tensor<4xf32>
|
||||
return
|
||||
// CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @bufferize_dot
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
// RUN: mlir-opt %s -vector-bufferize -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @transfer_read(
|
||||
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[pad:.*]]: f32)
|
||||
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32>
|
||||
// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[o1]], %[[o2]]], %[[pad]] {in_bounds = [true, false]} : memref<?x?xf32>, vector<5x6xf32>
|
||||
// CHECK: return %[[r]]
|
||||
func @transfer_read(%t: tensor<?x?xf32>, %o1: index,
|
||||
%o2: index, %pad: f32) -> vector<5x6xf32> {
|
||||
%0 = vector.transfer_read %t[%o1, %o2], %pad {in_bounds = [true, false]}
|
||||
: tensor<?x?xf32>, vector<5x6xf32>
|
||||
return %0 : vector<5x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transfer_write(
|
||||
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>)
|
||||
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32>
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}, %{{.*}}) {{.*}} : memref<?x?xf32>
|
||||
// CHECK: memref.copy %[[m]], %[[alloc]]
|
||||
// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]] {in_bounds = [true, false]} : vector<5x6xf32>, memref<?x?xf32>
|
||||
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref<?x?xf32>
|
||||
// CHECK: return %[[r]]
|
||||
func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
|
||||
%o2: index, %vec: vector<5x6xf32>) -> tensor<?x?xf32> {
|
||||
%0 = vector.transfer_write %vec, %t[%o1, %o2] {in_bounds = [true, false]}
|
||||
: vector<5x6xf32>, tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
|
@ -1994,6 +1994,7 @@ cc_library(
|
|||
":StandardOpsTransforms",
|
||||
":TensorTransforms",
|
||||
":VectorToLLVM",
|
||||
":VectorTransforms",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -2906,11 +2907,29 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "VectorPassIncGen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-pass-decls",
|
||||
"-name=Vector",
|
||||
],
|
||||
"include/mlir/Dialect/Vector/Transforms/Passes.h.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/Vector/Transforms/Passes.td",
|
||||
deps = [":PassBaseTdFiles"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "VectorTransforms",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/Dialect/Vector/Transforms/*.cpp",
|
||||
"lib/Dialect/Vector/Transforms/*.h",
|
||||
],
|
||||
),
|
||||
hdrs = glob([
|
||||
|
@ -2923,16 +2942,20 @@ cc_library(
|
|||
":Analysis",
|
||||
":ArithmeticDialect",
|
||||
":BufferizationDialect",
|
||||
":BufferizationTransforms",
|
||||
":DialectUtils",
|
||||
":IR",
|
||||
":LinalgOps",
|
||||
":MemRefDialect",
|
||||
":Pass",
|
||||
":SCFDialect",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":Transforms",
|
||||
":VectorInterfaces",
|
||||
":VectorOps",
|
||||
":VectorPassIncGen",
|
||||
":VectorUtils",
|
||||
"//llvm:Support",
|
||||
],
|
||||
|
@ -5911,6 +5934,7 @@ cc_library(
|
|||
":VectorToROCDL",
|
||||
":VectorToSCF",
|
||||
":VectorToSPIRV",
|
||||
":VectorTransforms",
|
||||
":X86Vector",
|
||||
":X86VectorTransforms",
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue