[mlir][bufferize] Add vector-bufferize pass and remove obsolete patterns from Linalg Bufferize

Differential Revision: https://reviews.llvm.org/D119444
This commit is contained in:
Matthias Springer 2022-02-15 21:16:50 +09:00
parent 8527859d89
commit 73e880fbf1
14 changed files with 223 additions and 63 deletions

View File

@ -130,8 +130,34 @@ struct BufferizationOptions {
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
opFilter.push_back(
OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
allowOperationInFilter(filterFn);
}
/// Deny the given op and activate the filter (`hasFilter`).
///
/// This function adds a DENY filter.
void denyOperationInFilter(StringRef opName) {
hasFilter = true;
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
denyOperationInFilter(filterFn);
}
/// Allow ops that are matched by `fn` and activate the filter (`hasFilter`).
///
/// This function adds an ALLOW filter.
void allowOperationInFilter(OpFilterEntry::FilterFn fn) {
hasFilter = true;
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW});
}
/// Deny ops that are matched by `fn` and activate the filter (`hasFilter`).
///
/// This function adds a DENY filter.
void denyOperationInFilter(OpFilterEntry::FilterFn fn) {
hasFilter = true;
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY});
}
/// Try to cast the given op to BufferizableOpInterface if the op is allow

View File

@ -1 +1,5 @@
# This dialect does currently not have any passes.
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector)
add_public_tablegen_target(MLIRVectorTransformsIncGen)
add_mlir_doc(Passes VectorPasses ./ -gen-pass-doc)

View File

@ -0,0 +1,30 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace vector {
/// Creates an instance of the `vector` dialect bufferization pass.
std::unique_ptr<Pass> createVectorBufferizePass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_

View File

@ -0,0 +1,19 @@
//===-- Passes.td - Vector pass definition file ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
def VectorBufferize : Pass<"vector-bufferize", "FuncOp"> {
let summary = "Bufferize Vector dialect ops";
let constructor = "mlir::vector::createVectorBufferizePass()";
}
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES

View File

@ -32,6 +32,7 @@
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
#include <cstdlib>
@ -71,6 +72,7 @@ inline void registerAllPasses() {
registerStandardPasses();
tensor::registerTensorPasses();
tosa::registerTosaOptPasses();
vector::registerVectorPasses();
// Dialect pipelines
sparse_tensor::registerSparseTensorPipelines();

View File

@ -268,43 +268,6 @@ public:
return success();
}
};
class VectorTransferReadOpConverter
: public OpConversionPattern<vector::TransferReadOp> {
public:
using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (readOp.getShapedType().isa<MemRefType>())
return failure();
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
adaptor.permutation_mapAttr(), adaptor.padding(), adaptor.mask(),
adaptor.in_boundsAttr());
return success();
}
};
class VectorTransferWriteOpConverter
: public OpConversionPattern<vector::TransferWriteOp> {
public:
using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (writeOp.getShapedType().isa<MemRefType>())
return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
adaptor.permutation_mapAttr(),
adaptor.in_bounds() ? adaptor.in_boundsAttr() : ArrayAttr());
rewriter.replaceOp(writeOp, adaptor.source());
return success();
}
};
} // namespace
namespace {
@ -329,9 +292,6 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
return typeConverter.isLegal(op);
};
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
target
.addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
isLegalOperation);
RewritePatternSet patterns(&context);
populateLinalgBufferizePatterns(typeConverter, patterns);
@ -358,9 +318,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeTensorReshapeOp<tensor::ExpandShapeOp>,
BufferizeTensorReshapeOp<tensor::CollapseShapeOp>,
ExtractSliceOpConverter,
InsertSliceOpConverter,
VectorTransferReadOpConverter,
VectorTransferWriteOpConverter
InsertSliceOpConverter
>(typeConverter, patterns.getContext());
// clang-format on
patterns.add<GeneralizePadOpPattern>(patterns.getContext());

View File

@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
MLIRStandardOpsTransforms
MLIRTensorTransforms
MLIRVectorToLLVM
MLIRVectorTransforms
)

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
using namespace mlir;
@ -31,6 +32,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
pm.addPass(createSparseTensorConversionPass());
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(vector::createVectorBufferizePass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
pm.addNestedPass<FuncOp>(createConvertSCFToCFPass());

View File

@ -0,0 +1,46 @@
//===- Bufferize.cpp - Bufferization for `vector` dialect ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of `vector` dialect ops
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
using namespace mlir;
using namespace bufferization;
namespace {
struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<vector::VectorDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
tensor::TensorDialect, vector::VectorDialect>();
vector::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // namespace
std::unique_ptr<Pass> mlir::vector::createVectorBufferizePass() {
return std::make_unique<VectorBufferizePass>();
}

View File

@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
VectorDropLeadUnitDim.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorMultiDimReductionTransforms.cpp
@ -12,17 +13,22 @@ add_mlir_dialect_library(MLIRVectorTransforms
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
DEPENDS
MLIRVectorTransformsIncGen
LINK_LIBS PUBLIC
MLIRAffine
MLIRAffineAnalysis
MLIRAffineUtils
MLIRArithmetic
MLIRBufferization
MLIRBufferizationTransforms
MLIRDialectUtils
MLIRIR
MLIRLinalg
MLIRMemRef
MLIRSCF
MLIRTransforms
MLIRVector
MLIRVectorInterfaces
MLIRVectorUtils

View File

@ -0,0 +1,29 @@
//===- PassDetail.h - Vector Pass class details -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace bufferization {
class BufferizationDialect;
} // namespace bufferization
namespace memref {
class MemRefDialect;
} // namespace memref
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
} // namespace mlir
#endif // DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_

View File

@ -303,23 +303,6 @@ func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tens
// CHECK: return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32>
// CHECK: }
// -----
// CHECK-LABEL: func @vector_transfer
func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%read = vector.transfer_read %in[%c0], %cst {in_bounds = [true]}
: tensor<4xf32>, vector<4xf32>
%tanh = math.tanh %read : vector<4xf32>
%write = vector.transfer_write %tanh, %out[%c0] {in_bounds = [true]}
: vector<4xf32>, tensor<4xf32>
return
// CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32>
}
// -----
// CHECK-LABEL: func @bufferize_dot

View File

@ -0,0 +1,30 @@
// RUN: mlir-opt %s -vector-bufferize -split-input-file | FileCheck %s
// CHECK-LABEL: func @transfer_read(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[pad:.*]]: f32)
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32>
// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[o1]], %[[o2]]], %[[pad]] {in_bounds = [true, false]} : memref<?x?xf32>, vector<5x6xf32>
// CHECK: return %[[r]]
func @transfer_read(%t: tensor<?x?xf32>, %o1: index,
%o2: index, %pad: f32) -> vector<5x6xf32> {
%0 = vector.transfer_read %t[%o1, %o2], %pad {in_bounds = [true, false]}
: tensor<?x?xf32>, vector<5x6xf32>
return %0 : vector<5x6xf32>
}
// -----
// CHECK-LABEL: func @transfer_write(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>)
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32>
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}, %{{.*}}) {{.*}} : memref<?x?xf32>
// CHECK: memref.copy %[[m]], %[[alloc]]
// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]] {in_bounds = [true, false]} : vector<5x6xf32>, memref<?x?xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref<?x?xf32>
// CHECK: return %[[r]]
func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
%o2: index, %vec: vector<5x6xf32>) -> tensor<?x?xf32> {
%0 = vector.transfer_write %vec, %t[%o1, %o2] {in_bounds = [true, false]}
: vector<5x6xf32>, tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

View File

@ -1994,6 +1994,7 @@ cc_library(
":StandardOpsTransforms",
":TensorTransforms",
":VectorToLLVM",
":VectorTransforms",
],
)
@ -2906,11 +2907,29 @@ cc_library(
],
)
gentbl_cc_library(
name = "VectorPassIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=Vector",
],
"include/mlir/Dialect/Vector/Transforms/Passes.h.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Vector/Transforms/Passes.td",
deps = [":PassBaseTdFiles"],
)
cc_library(
name = "VectorTransforms",
srcs = glob(
[
"lib/Dialect/Vector/Transforms/*.cpp",
"lib/Dialect/Vector/Transforms/*.h",
],
),
hdrs = glob([
@ -2923,16 +2942,20 @@ cc_library(
":Analysis",
":ArithmeticDialect",
":BufferizationDialect",
":BufferizationTransforms",
":DialectUtils",
":IR",
":LinalgOps",
":MemRefDialect",
":Pass",
":SCFDialect",
":StandardOps",
":Support",
":TensorDialect",
":Transforms",
":VectorInterfaces",
":VectorOps",
":VectorPassIncGen",
":VectorUtils",
"//llvm:Support",
],
@ -5911,6 +5934,7 @@ cc_library(
":VectorToROCDL",
":VectorToSCF",
":VectorToSPIRV",
":VectorTransforms",
":X86Vector",
":X86VectorTransforms",
],