forked from OSchip/llvm-project
[mlir][shape] Migrate bufferization to BufferizableOpInterface
Differential Revision: https://reviews.llvm.org/D121043
This commit is contained in:
parent
df6c26fd34
commit
93e663273b
|
@ -0,0 +1,20 @@
|
|||
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace shape {
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace shape
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
|
|
@ -40,21 +40,6 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
|
|||
void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
|
||||
std::unique_ptr<OperationPass<FuncOp>> createRemoveShapeConstraintsPass();
|
||||
|
||||
/// Populates patterns for shape dialect structural type conversions and sets up
|
||||
/// the provided ConversionTarget with the appropriate legality configuration
|
||||
/// for the ops to get converted properly.
|
||||
///
|
||||
/// A "structural" type conversion is one where the underlying ops are
|
||||
/// completely agnostic to the actual types involved and simply need to update
|
||||
/// their types consistently. An example of this is shape.assuming -- the
|
||||
/// shape.assuming op and the corresponding shape.assuming_yield op need to have
|
||||
/// consistent types, but the exact types don't matter. So all that we need to
|
||||
/// do for a structural type conversion is to update both of their types
|
||||
/// consistently to the new types prescribed by the TypeConverter.
|
||||
void populateShapeStructuralTypeConversionsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
|
||||
// Bufferizes shape dialect ops.
|
||||
//
|
||||
// Note that most shape dialect ops must be converted to std before
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
using namespace mlir::shape;
|
||||
|
||||
namespace mlir {
|
||||
namespace shape {
|
||||
namespace {
|
||||
|
||||
/// Bufferization of shape.assuming.
|
||||
struct AssumingOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
|
||||
shape::AssumingOp> {
|
||||
SmallVector<OpOperand *>
|
||||
getAliasingOpOperand(Operation *op, OpResult opResult,
|
||||
const BufferizationState &state) const {
|
||||
// AssumingOps do not have tensor OpOperands. The yielded value can be any
|
||||
// SSA value that is in scope. To allow for use-def chain traversal through
|
||||
// AssumingOps in the analysis, the corresponding yield value is considered
|
||||
// to be aliasing with the result.
|
||||
auto assumingOp = cast<shape::AssumingOp>(op);
|
||||
size_t resultNum = std::distance(op->getOpResults().begin(),
|
||||
llvm::find(op->getOpResults(), opResult));
|
||||
// TODO: Support multiple blocks.
|
||||
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
|
||||
"expected exactly 1 block");
|
||||
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
|
||||
assumingOp.getDoRegion().front().getTerminator());
|
||||
assert(yieldOp && "expected shape.assuming_yield terminator");
|
||||
return {&yieldOp->getOpOperand(resultNum)};
|
||||
}
|
||||
|
||||
// TODO: For better bufferization results, this could return `true` only if
|
||||
// there is a memory write in the region.
|
||||
bool isMemoryWrite(Operation *op, OpResult opResult,
|
||||
const BufferizationState &state) const {
|
||||
// Similar to scf.if, results of this op are always considered memory writes
|
||||
// in the analysis. This is a useful pattern for all ops that have tensor
|
||||
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
|
||||
// implemented in terms of `bufferizesToMemoryWrite`, which does not work on
|
||||
// ops without OpOperands.
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
auto assumingOp = cast<shape::AssumingOp>(op);
|
||||
|
||||
// Compute new result types.
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (Type type : assumingOp->getResultTypes()) {
|
||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
|
||||
} else {
|
||||
newResultTypes.push_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
// Create new op and move over region.
|
||||
auto newOp = rewriter.create<shape::AssumingOp>(
|
||||
op->getLoc(), newResultTypes, assumingOp.getWitness());
|
||||
newOp.getDoRegion().takeBody(assumingOp.getRegion());
|
||||
|
||||
// Update terminator.
|
||||
assert(newOp.getDoRegion().getBlocks().size() == 1 &&
|
||||
"only 1 block supported");
|
||||
Block *newBlock = &newOp.getDoRegion().front();
|
||||
auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator());
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
SmallVector<Value> newYieldValues;
|
||||
for (const auto &it : llvm::enumerate(yieldOp.operands())) {
|
||||
Value val = it.value();
|
||||
if (val.getType().isa<TensorType>()) {
|
||||
newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
|
||||
yieldOp.getLoc(), newResultTypes[it.index()], val));
|
||||
} else {
|
||||
newYieldValues.push_back(val);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp,
|
||||
newYieldValues);
|
||||
|
||||
// Update all uses of the old op.
|
||||
rewriter.setInsertionPointAfter(newOp);
|
||||
SmallVector<Value> newResults;
|
||||
for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
|
||||
if (it.value().isa<TensorType>()) {
|
||||
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
|
||||
assumingOp.getLoc(), newOp->getResult(it.index())));
|
||||
} else {
|
||||
newResults.push_back(newOp->getResult(it.index()));
|
||||
}
|
||||
}
|
||||
|
||||
// Replace old op.
|
||||
rewriter.replaceOp(assumingOp, newResults);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
};
|
||||
|
||||
/// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
|
||||
/// ops, so this is for analysis only.
|
||||
struct AssumingYieldOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
|
||||
shape::AssumingOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult>
|
||||
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
assert(isa<shape::AssumingOp>(op->getParentOp()) &&
|
||||
"expected that parent is an AssumingOp");
|
||||
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
|
||||
}
|
||||
|
||||
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
|
||||
const BufferizationState &state) const {
|
||||
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
|
||||
// may be generated inside the block. We should not return/yield allocations
|
||||
// when possible.
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
// Op is bufferized as part of AssumingOp.
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace shape
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::shape::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<shape::AssumingOp, AssumingOpInterface>();
|
||||
registry.addOpInterface<shape::AssumingYieldOp, AssumingYieldOpInterface>();
|
||||
}
|
|
@ -8,30 +8,32 @@
|
|||
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace {
|
||||
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext &ctx = getContext();
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.allowDialectInFilter<shape::ShapeDialect>();
|
||||
|
||||
RewritePatternSet patterns(&ctx);
|
||||
bufferization::BufferizeTypeConverter typeConverter;
|
||||
ConversionTarget target(ctx);
|
||||
|
||||
bufferization::populateBufferizeMaterializationLegality(target);
|
||||
populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
|
||||
shape::ShapeDialect>();
|
||||
shape::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
add_mlir_dialect_library(MLIRShapeOpsTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
RemoveShapeConstraints.cpp
|
||||
ShapeToShapeLowering.cpp
|
||||
StructuralTypeConversions.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
|
||||
|
@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
|
|||
target_link_libraries(MLIRShapeOpsTransforms
|
||||
PUBLIC
|
||||
MLIRArithmetic
|
||||
MLIRBufferization
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
//===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::shape;
|
||||
|
||||
namespace {
|
||||
class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AssumingOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
newResultTypes.reserve(op.getNumResults());
|
||||
for (auto result : op.getResults()) {
|
||||
auto originalType = result.getType();
|
||||
Type convertedType = getTypeConverter()->convertType(originalType);
|
||||
newResultTypes.push_back(convertedType);
|
||||
}
|
||||
|
||||
auto newAssumingOp = rewriter.create<AssumingOp>(
|
||||
op.getLoc(), newResultTypes, op.getWitness());
|
||||
rewriter.inlineRegionBefore(op.getDoRegion(), newAssumingOp.getDoRegion(),
|
||||
newAssumingOp.getDoRegion().end());
|
||||
rewriter.replaceOp(op, newAssumingOp.getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAssumingYieldOpTypes
|
||||
: public OpConversionPattern<AssumingYieldOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AssumingYieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateShapeStructuralTypeConversionsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
patterns.add<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
|
||||
typeConverter, patterns.getContext());
|
||||
target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
|
||||
return typeConverter.isLegal(op.getResultTypes());
|
||||
});
|
||||
target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
|
||||
return typeConverter.isLegal(op.getOperandTypes());
|
||||
});
|
||||
}
|
|
@ -2702,7 +2702,10 @@ cc_library(
|
|||
"lib/Dialect/Shape/Transforms/*.cpp",
|
||||
"lib/Dialect/Shape/Transforms/*.h",
|
||||
]),
|
||||
hdrs = ["include/mlir/Dialect/Shape/Transforms/Passes.h"],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h",
|
||||
"include/mlir/Dialect/Shape/Transforms/Passes.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ArithmeticDialect",
|
||||
|
|
Loading…
Reference in New Issue