[mlir][linalg][bufferize][NFC] Move vector interface impl to new build target

This makes ComprehensiveBufferize entirely independent of the vector dialect.

Differential Revision: https://reviews.llvm.org/D114218
This commit is contained in:
Matthias Springer 2021-11-24 19:32:33 +09:00
parent cf40ca026f
commit ca9d149e07
7 changed files with 182 additions and 103 deletions

View File

@ -0,0 +1,27 @@
//===- VectorInterfaceImpl.h - Vector Impl. of BufferizableOpInterface ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_VECTOR_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_VECTOR_INTERFACE_IMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace vector_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace vector_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_VECTOR_INTERFACE_IMPL_H

View File

@ -3,6 +3,7 @@ set(LLVM_OPTIONAL_SOURCES
ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
TensorInterfaceImpl.cpp
VectorInterfaceImpl.cpp
)
add_mlir_dialect_library(MLIRBufferizableOpInterface
@ -36,6 +37,15 @@ add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
MLIRTensor
)
add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
VectorInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRBufferizableOpInterface
MLIRIR
MLIRVector
)
add_mlir_dialect_library(MLIRComprehensiveBufferize
ComprehensiveBufferize.cpp
@ -48,5 +58,4 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
MLIRSCF
MLIRStandard
MLIRStandardOpsTransforms
MLIRVector
)

View File

@ -114,7 +114,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Operation.h"
@ -1926,102 +1925,6 @@ struct ReturnOpInterface
} // namespace std_ext
namespace vector_ext {
struct TransferReadOpInterface
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
vector::TransferReadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto transferReadOp = cast<vector::TransferReadOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
// TransferReadOp always reads from the bufferized op.source().
assert(transferReadOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
Value v = state.lookupBuffer(transferReadOp.source());
transferReadOp.sourceMutable().assign(v);
return success();
}
};
struct TransferWriteOpInterface
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(1)};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return op->getOpResult(0);
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
assert(writeOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
if (!resultBuffer)
return failure();
b.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
state.mapBuffer(op->getResult(0), resultBuffer);
return success();
}
};
} // namespace vector_ext
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
registry.addOpInterface<scf::ExecuteRegionOp,
@ -2031,10 +1934,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
registry.addOpInterface<vector::TransferReadOp,
vector_ext::TransferReadOpInterface>();
registry.addOpInterface<vector::TransferWriteOp,
vector_ext::TransferWriteOpInterface>();
// Ops that are not bufferizable but are allocation hoisting barriers.
registry.addOpInterface<FuncOp, AllocationHoistingBarrierOnly<FuncOp>>();

View File

@ -0,0 +1,123 @@
//===- VectorInterfaceImpl.cpp - Vector Impl. of BufferizableOpInterface --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace vector_ext {
struct TransferReadOpInterface
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
vector::TransferReadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto transferReadOp = cast<vector::TransferReadOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
// TransferReadOp always reads from the bufferized op.source().
assert(transferReadOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
Value v = state.lookupBuffer(transferReadOp.source());
transferReadOp.sourceMutable().assign(v);
return success();
}
};
struct TransferWriteOpInterface
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {&op->getOpOperand(1)};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return op->getOpResult(0);
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
assert(writeOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
Value resultBuffer = getResultBuffer(b, op->getResult(0), state);
if (!resultBuffer)
return failure();
b.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
state.mapBuffer(op->getResult(0), resultBuffer);
return success();
}
};
} // namespace vector_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::vector_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<vector::TransferReadOp,
vector_ext::TransferReadOpInterface>();
registry.addOpInterface<vector::TransferWriteOp,
vector_ext::TransferWriteOpInterface>();
}

View File

@ -53,6 +53,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRTransforms
MLIRTransformUtils
MLIRVector
MLIRVectorBufferizableOpInterfaceImpl
MLIRX86VectorTransforms
MLIRVectorToSCF
)

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@ -40,6 +41,7 @@ struct LinalgComprehensiveModuleBufferize
registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // end namespace

View File

@ -6345,6 +6345,24 @@ cc_library(
],
)
cc_library(
name = "VectorBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h",
],
includes = ["include"],
deps = [
":BufferizableOpInterface",
":IR",
":Support",
":VectorOps",
"//llvm:Support",
],
)
td_library(
name = "LinalgDocTdFiles",
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@ -6567,6 +6585,7 @@ cc_library(
":TensorBufferizableOpInterfaceImpl",
":TensorDialect",
":TransformUtils",
":VectorBufferizableOpInterfaceImpl",
":VectorOps",
":VectorToSCF",
":X86VectorTransforms",
@ -6596,7 +6615,6 @@ cc_library(
":StandardOps",
":Support",
":TransformUtils",
":VectorOps",
"//llvm:Support",
],
)