forked from OSchip/llvm-project
[mlir][tensor] Replace tensor-bufferize with BufferizableOpInterface impl
This commit switches the `tensor-bufferize` pass over to BufferizableOpInterface-based bufferization. Differential Revision: https://reviews.llvm.org/D118246
This commit is contained in:
parent
d58757e522
commit
daf18108ec
|
@ -255,7 +255,7 @@ public:
|
||||||
const BufferizationOptions &getOptions() const { return options; }
|
const BufferizationOptions &getOptions() const { return options; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BufferizationState(const BufferizationOptions &options);
|
explicit BufferizationState(const BufferizationOptions &options);
|
||||||
|
|
||||||
// BufferizationState should be passed as a reference.
|
// BufferizationState should be passed as a reference.
|
||||||
BufferizationState(const BufferizationState &) = delete;
|
BufferizationState(const BufferizationState &) = delete;
|
||||||
|
@ -270,6 +270,24 @@ private:
|
||||||
const BufferizationOptions &options;
|
const BufferizationOptions &options;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// This a "no analysis, always copy" BufferizationState. In the absence of an
|
||||||
|
/// analysis, a buffer must be copied each time it is written to. Therefore, all
|
||||||
|
/// OpOperands that bufferize to a memory write must bufferize out-of-place.
|
||||||
|
class AlwaysCopyBufferizationState : public BufferizationState {
|
||||||
|
public:
|
||||||
|
explicit AlwaysCopyBufferizationState(const BufferizationOptions &options);
|
||||||
|
|
||||||
|
AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
|
||||||
|
|
||||||
|
virtual ~AlwaysCopyBufferizationState() = default;
|
||||||
|
|
||||||
|
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||||
|
bool isInPlace(OpOperand &opOperand) const override;
|
||||||
|
|
||||||
|
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||||
|
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
|
||||||
|
};
|
||||||
|
|
||||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||||
/// must be replaced with memref values.
|
/// must be replaced with memref values.
|
||||||
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
|
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
|
||||||
|
|
|
@ -69,6 +69,21 @@ void populateEliminateBufferizeMaterializationsPatterns(
|
||||||
// TODO: Extract `options` from `state` and pass as separate argument.
|
// TODO: Extract `options` from `state` and pass as separate argument.
|
||||||
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
|
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
|
||||||
|
|
||||||
|
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
|
||||||
|
/// Buffers are duplicated and copied before any tensor use that bufferizes to
|
||||||
|
/// a memory write.
|
||||||
|
///
|
||||||
|
/// Note: This function bufferizes ops without utilizing analysis results. It
|
||||||
|
/// can be used to implement partial bufferization passes.
|
||||||
|
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
|
||||||
|
|
||||||
|
/// Populate the pattern set with a pattern that bufferizes ops that implement
|
||||||
|
/// `BufferizableOpInterface`.
|
||||||
|
void populateBufferizationPattern(const BufferizationState &state,
|
||||||
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
std::unique_ptr<BufferizationOptions> getPartialBufferizationOptions();
|
||||||
|
|
||||||
} // namespace bufferization
|
} // namespace bufferization
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -12,16 +12,6 @@
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace bufferization {
|
|
||||||
class BufferizeTypeConverter;
|
|
||||||
} // namespace bufferization
|
|
||||||
|
|
||||||
class RewritePatternSet;
|
|
||||||
|
|
||||||
void populateTensorBufferizePatterns(
|
|
||||||
bufferization::BufferizeTypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns);
|
|
||||||
|
|
||||||
/// Creates an instance of `tensor` dialect bufferization pass.
|
/// Creates an instance of `tensor` dialect bufferization pass.
|
||||||
std::unique_ptr<Pass> createTensorBufferizePass();
|
std::unique_ptr<Pass> createTensorBufferizePass();
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,6 @@ include "mlir/Pass/PassBase.td"
|
||||||
def TensorBufferize : Pass<"tensor-bufferize", "FuncOp"> {
|
def TensorBufferize : Pass<"tensor-bufferize", "FuncOp"> {
|
||||||
let summary = "Bufferize the `tensor` dialect";
|
let summary = "Bufferize the `tensor` dialect";
|
||||||
let constructor = "mlir::createTensorBufferizePass()";
|
let constructor = "mlir::createTensorBufferizePass()";
|
||||||
let dependentDialects = [
|
|
||||||
"bufferization::BufferizationDialect",
|
|
||||||
"memref::MemRefDialect",
|
|
||||||
"scf::SCFDialect"
|
|
||||||
];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
|
||||||
|
|
|
@ -318,6 +318,25 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AlwaysCopyBufferizationState::AlwaysCopyBufferizationState(
|
||||||
|
const BufferizationOptions &options)
|
||||||
|
: BufferizationState(options) {}
|
||||||
|
|
||||||
|
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||||
|
bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const {
|
||||||
|
// OpOperands that bufferize to a memory write are out-of-place, i.e., an
|
||||||
|
// alloc and copy is inserted.
|
||||||
|
return !bufferizesToMemoryWrite(opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||||
|
bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues(
|
||||||
|
Value v1, Value v2) const {
|
||||||
|
// There is no analysis, so we do not know if the values are equivalent. The
|
||||||
|
// conservative answer is "false".
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Bufferization-specific scoped alloc/dealloc insertion support.
|
// Bufferization-specific scoped alloc/dealloc insertion support.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -207,9 +207,59 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
||||||
const BufferizationState &state) {
|
const BufferizationState &state) {
|
||||||
// Bufferize the op and its nested ops.
|
// Bufferize the op and its nested ops.
|
||||||
RewritePatternSet patterns(op->getContext());
|
RewritePatternSet patterns(op->getContext());
|
||||||
patterns.add<BufferizationPattern>(op->getContext(), state);
|
populateBufferizationPattern(state, patterns);
|
||||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return checkBufferizationResult(op, state.getOptions());
|
return checkBufferizationResult(op, state.getOptions());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// This a "no analysis, always copy" BufferizationState. In the absence of an
|
||||||
|
/// analysis, a buffer must be copied each time it is written to. Therefore, all
|
||||||
|
/// OpOperands that bufferize to a memory write must bufferize out-of-place.
|
||||||
|
class AlwaysCopyBufferizationState : public BufferizationState {
|
||||||
|
public:
|
||||||
|
AlwaysCopyBufferizationState(const BufferizationOptions &options)
|
||||||
|
: BufferizationState(options) {}
|
||||||
|
|
||||||
|
AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
|
||||||
|
|
||||||
|
virtual ~AlwaysCopyBufferizationState() = default;
|
||||||
|
|
||||||
|
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||||
|
bool isInPlace(OpOperand &opOperand) const override {
|
||||||
|
// OpOperands that bufferize to a memory write are out-of-place, i.e., an
|
||||||
|
// alloc and copy is inserted.
|
||||||
|
return !bufferizesToMemoryWrite(opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||||
|
bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
|
||||||
|
// There is no analysis, so we do not know if the values are equivalent. The
|
||||||
|
// conservative answer is "false".
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult bufferization::bufferizeOp(Operation *op,
|
||||||
|
const BufferizationOptions &options) {
|
||||||
|
AlwaysCopyBufferizationState state(options);
|
||||||
|
return bufferizeOp(op, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
void bufferization::populateBufferizationPattern(
|
||||||
|
const BufferizationState &state, RewritePatternSet &patterns) {
|
||||||
|
patterns.add<BufferizationPattern>(patterns.getContext(), state);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<BufferizationOptions>
|
||||||
|
bufferization::getPartialBufferizationOptions() {
|
||||||
|
auto options = std::make_unique<BufferizationOptions>();
|
||||||
|
options->allowReturnMemref = true;
|
||||||
|
options->allowUnknownOps = true;
|
||||||
|
options->createDeallocs = false;
|
||||||
|
options->fullyDynamicLayoutMaps = false;
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
|
@ -13,223 +13,40 @@
|
||||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
||||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace bufferization;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto resultType = getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
|
|
||||||
adaptor.getOperands()[0]);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
|
|
||||||
adaptor.index());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
|
|
||||||
adaptor.indices());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BufferizeFromElementsOp
|
|
||||||
: public OpConversionPattern<tensor::FromElementsOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
Location loc = op.getLoc();
|
|
||||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
|
||||||
auto shape = tensorType.getShape();
|
|
||||||
|
|
||||||
// Allocate a buffer for the result.
|
|
||||||
auto resultType =
|
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
||||||
Value buffer = rewriter.create<memref::AllocOp>(loc, resultType);
|
|
||||||
|
|
||||||
// Case: tensor<0xelem_type>.
|
|
||||||
if (op.elements().empty()) {
|
|
||||||
rewriter.replaceOp(op, {buffer});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case: tensor<elem_type>.
|
|
||||||
if (shape.empty()) {
|
|
||||||
rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer);
|
|
||||||
rewriter.replaceOp(op, {buffer});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create constants for the range of possible indices [0, max{shape_i}).
|
|
||||||
auto maxDim = *std::max_element(shape.begin(), shape.end());
|
|
||||||
SmallVector<Value, 2> constants;
|
|
||||||
constants.reserve(maxDim);
|
|
||||||
for (int i = 0; i < maxDim; ++i)
|
|
||||||
constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
|
|
||||||
|
|
||||||
// Traverse all `elements` and create `memref.store` ops.
|
|
||||||
ImplicitLocOpBuilder b(loc, rewriter);
|
|
||||||
auto elementIt = adaptor.elements().begin();
|
|
||||||
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
|
|
||||||
createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b);
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {buffer});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Implements backtracking to traverse indices of the output buffer while
|
|
||||||
// iterating over op.elements().
|
|
||||||
void createStores(int dim, Value buffer, ArrayRef<int64_t> shape,
|
|
||||||
ArrayRef<Value> constants, ValueRange::iterator &elementIt,
|
|
||||||
SmallVectorImpl<Value> &indices,
|
|
||||||
ImplicitLocOpBuilder b) const {
|
|
||||||
if (dim == static_cast<int>(shape.size()) - 1) {
|
|
||||||
for (int i = 0; i < shape.back(); ++i) {
|
|
||||||
indices.back() = constants[i];
|
|
||||||
b.create<memref::StoreOp>(*elementIt, buffer, indices);
|
|
||||||
++elementIt;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < shape[dim]; ++i) {
|
|
||||||
indices[dim] = constants[i];
|
|
||||||
createStores(dim + 1, buffer, shape, constants, elementIt, indices, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
|
||||||
// Allocate memory.
|
|
||||||
Location loc = op.getLoc();
|
|
||||||
RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
|
|
||||||
MemRefType memrefType =
|
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
||||||
Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
|
|
||||||
adaptor.dynamicExtents());
|
|
||||||
|
|
||||||
// Collect loop bounds.
|
|
||||||
int64_t rank = tensorType.getRank();
|
|
||||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
||||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
||||||
SmallVector<Value, 4> lowerBounds(rank, zero);
|
|
||||||
SmallVector<Value, 4> steps(rank, one);
|
|
||||||
SmallVector<Value, 4> upperBounds;
|
|
||||||
int nextDynamicIndex = 0;
|
|
||||||
for (int i = 0; i < rank; i++) {
|
|
||||||
Value upperBound = tensorType.isDynamicDim(i)
|
|
||||||
? adaptor.dynamicExtents()[nextDynamicIndex++]
|
|
||||||
: rewriter.create<arith::ConstantIndexOp>(
|
|
||||||
loc, memrefType.getDimSize(i));
|
|
||||||
upperBounds.push_back(upperBound);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate tensor elements with a parallel loop that stores into
|
|
||||||
// each element of the resulting memref.
|
|
||||||
//
|
|
||||||
// This is a bit tricky. We cannot simply clone the ops because when an op
|
|
||||||
// is cloned, it must be legalized. However, we want to allow arbitrary ops
|
|
||||||
// in the body that we don't necessarily have legalization patterns for as
|
|
||||||
// part of this dialect conversion invocation.
|
|
||||||
//
|
|
||||||
// To accomplish this, we use mergeBlockBefore to "move" this op's body
|
|
||||||
// into the scf.parallel's body.
|
|
||||||
auto parallel =
|
|
||||||
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
|
|
||||||
Block *parallelBody = parallel.getBody();
|
|
||||||
rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
|
|
||||||
parallelBody->getArguments());
|
|
||||||
// Replace the inlined yield op with a store op. The scf.parallel's builder
|
|
||||||
// already populated an scf.yield at the end, so we don't need to worry
|
|
||||||
// about creating that.
|
|
||||||
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
|
|
||||||
rewriter.setInsertionPointAfter(elementYield);
|
|
||||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
|
||||||
elementYield, elementYield->getOperands()[0], result,
|
|
||||||
parallelBody->getArguments());
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, {result});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
|
|
||||||
adaptor.tensor());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto *context = &getContext();
|
std::unique_ptr<BufferizationOptions> options =
|
||||||
bufferization::BufferizeTypeConverter typeConverter;
|
getPartialBufferizationOptions();
|
||||||
|
options->addToDialectFilter<tensor::TensorDialect>();
|
||||||
|
|
||||||
ConversionTarget target(*context);
|
if (failed(bufferizeOp(getOperation(), *options)))
|
||||||
target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
|
|
||||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
|
||||||
StandardOpsDialect>(
|
|
||||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
|
||||||
target.addLegalOp<CallOp, ReturnOp>();
|
|
||||||
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
|
|
||||||
tensor::FromElementsOp, tensor::GenerateOp>();
|
|
||||||
bufferization::populateBufferizeMaterializationLegality(target);
|
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
|
||||||
populateTensorBufferizePatterns(typeConverter, patterns);
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
|
||||||
std::move(patterns))))
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
|
||||||
|
tensor::TensorDialect, scf::SCFDialect,
|
||||||
|
arith::ArithmeticDialect>();
|
||||||
|
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::populateTensorBufferizePatterns(
|
|
||||||
bufferization::BufferizeTypeConverter &typeConverter,
|
|
||||||
RewritePatternSet &patterns) {
|
|
||||||
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
|
|
||||||
BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
|
|
||||||
typeConverter, patterns.getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
|
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
|
||||||
return std::make_unique<TensorBufferizePass>();
|
return std::make_unique<TensorBufferizePass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1355,55 +1355,3 @@ func @write_after_select_read_one(
|
||||||
// CHECK: return %[[f]], %[[select]]
|
// CHECK: return %[[f]], %[[select]]
|
||||||
return %f, %w : f32, tensor<?xf32>
|
return %f, %w : f32, tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_rank(
|
|
||||||
// CHECK-SAME: %[[arg0:.*]]: memref<*xf32>
|
|
||||||
func @tensor_rank(%arg0: tensor<*xf32>) -> index {
|
|
||||||
// CHECK: %[[r:.*]] = memref.rank %[[arg0]]
|
|
||||||
%0 = tensor.rank %arg0 : tensor<*xf32>
|
|
||||||
// CHECK: return %[[r]] : index
|
|
||||||
return %0 : index
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
|
|
||||||
// CHECK-SAME: %[[arg0:.*]]: index
|
|
||||||
func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
|
|
||||||
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
|
|
||||||
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
|
|
||||||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
|
|
||||||
// CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
|
|
||||||
%result = tensor.generate %arg0 {
|
|
||||||
^bb0(%i: index, %j: index):
|
|
||||||
%sum = arith.addi %i, %j : index
|
|
||||||
// CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
|
|
||||||
// CHECK: scf.yield
|
|
||||||
tensor.yield %sum : index
|
|
||||||
} : tensor<16x?xindex>
|
|
||||||
// CHECK: }
|
|
||||||
return %result : tensor<16x?xindex>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_from_elements_2d(
|
|
||||||
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
|
|
||||||
func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
|
|
||||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
|
||||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
|
||||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
|
||||||
// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
|
|
||||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
|
|
||||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
|
|
||||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
|
|
||||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
|
|
||||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
|
|
||||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
|
|
||||||
%0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
|
|
||||||
: tensor<3x2xindex>
|
|
||||||
// CHECK: return %[[MEMREF]]
|
|
||||||
return %0 : tensor<3x2xindex>
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
|
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||||
|
|
||||||
// CHECK-LABEL: func @dim(
|
// CHECK-LABEL: func @dim(
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
|
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
|
||||||
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
|
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
|
||||||
|
@ -66,8 +68,7 @@ func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> {
|
// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> {
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<0xindex>
|
// CHECK: %[[RET:.*]] = arith.constant dense<> : tensor<0xindex>
|
||||||
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
|
|
||||||
// CHECK: return %[[RET]] : tensor<0xindex>
|
// CHECK: return %[[RET]] : tensor<0xindex>
|
||||||
func @tensor.from_elements_no_elements() -> tensor<0xindex> {
|
func @tensor.from_elements_no_elements() -> tensor<0xindex> {
|
||||||
%0 = tensor.from_elements : tensor<0xindex>
|
%0 = tensor.from_elements : tensor<0xindex>
|
||||||
|
@ -76,7 +77,7 @@ func @tensor.from_elements_no_elements() -> tensor<0xindex> {
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor.from_elements_0d(
|
// CHECK-LABEL: func @tensor.from_elements_0d(
|
||||||
// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> {
|
// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> {
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<index>
|
// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
|
||||||
// CHECK: store %[[ELEM0]], %[[MEMREF]]
|
// CHECK: store %[[ELEM0]], %[[MEMREF]]
|
||||||
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
|
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
|
||||||
// CHECK: return %[[RET]] : tensor<index>
|
// CHECK: return %[[RET]] : tensor<index>
|
||||||
|
@ -88,9 +89,9 @@ func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
|
||||||
// CHECK-LABEL: func @tensor.from_elements_1d(
|
// CHECK-LABEL: func @tensor.from_elements_1d(
|
||||||
// CHECK-SAME: %[[ELEM0:.*]]: index,
|
// CHECK-SAME: %[[ELEM0:.*]]: index,
|
||||||
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
|
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<2xindex>
|
|
||||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
|
||||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
|
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
|
||||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
|
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
|
||||||
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
|
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
|
||||||
|
@ -103,10 +104,10 @@ func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
|
||||||
// CHECK-LABEL: func @tensor.from_elements_2d(
|
// CHECK-LABEL: func @tensor.from_elements_2d(
|
||||||
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
|
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
|
||||||
// CHECK-SAME: -> tensor<3x2xindex> {
|
// CHECK-SAME: -> tensor<3x2xindex> {
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex>
|
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
|
||||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
|
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
|
||||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
|
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
|
||||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
|
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
|
||||||
|
@ -121,9 +122,9 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
|
||||||
return %0 : tensor<3x2xindex>
|
return %0 : tensor<3x2xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor.from_elements_3d()
|
// CHECK-LABEL: func @tensor.from_elements_3d(
|
||||||
|
// CHECK-SAME: %[[F0:.*]]: f32
|
||||||
|
|
||||||
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
|
|
||||||
// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
|
// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
|
||||||
// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
|
// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
|
||||||
// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
|
// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
|
||||||
|
@ -136,11 +137,11 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
|
||||||
// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
|
// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
|
||||||
// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
|
// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
|
||||||
|
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32>
|
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||||
|
|
||||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
|
||||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
|
||||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
|
||||||
|
|
||||||
// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
|
// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
|
||||||
// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
|
// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
|
||||||
|
@ -157,8 +158,7 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
|
||||||
|
|
||||||
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
|
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
|
||||||
// CHECK: return %[[RET]] : tensor<3x2x2xf32>
|
// CHECK: return %[[RET]] : tensor<3x2x2xf32>
|
||||||
func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
|
func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
|
||||||
%f0 = arith.constant 0.0 : f32
|
|
||||||
%f1 = arith.constant 1.0 : f32
|
%f1 = arith.constant 1.0 : f32
|
||||||
%f2 = arith.constant 2.0 : f32
|
%f2 = arith.constant 2.0 : f32
|
||||||
%f3 = arith.constant 3.0 : f32
|
%f3 = arith.constant 3.0 : f32
|
||||||
|
@ -179,9 +179,9 @@ func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
|
||||||
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
|
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
|
||||||
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
|
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
|
||||||
// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
|
// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
|
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
|
||||||
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
|
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
|
||||||
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
|
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
|
||||||
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
|
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
|
||||||
|
@ -204,10 +204,10 @@ func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xi
|
||||||
//
|
//
|
||||||
// CHECK-LABEL: func @tensor.generate_static_and_dynamic(
|
// CHECK-LABEL: func @tensor.generate_static_and_dynamic(
|
||||||
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
|
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex>
|
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
|
||||||
// CHECK: %[[C16:.*]] = arith.constant 16 : index
|
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
|
||||||
// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
|
// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
|
||||||
// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
|
// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
|
||||||
// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
|
// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
|
||||||
|
@ -225,12 +225,6 @@ func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
|
||||||
return %result : tensor<16x?xindex>
|
return %result : tensor<16x?xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
// The tensor.generate op needs to put its body into the
|
|
||||||
// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
|
|
||||||
// the body because that would require the cloned ops to be legalized
|
|
||||||
// immediately, which is usually not possible since they might be from various
|
|
||||||
// other dialects.
|
|
||||||
//
|
|
||||||
// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
|
// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
|
||||||
func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
|
func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
|
||||||
// CHECK-NOT: tensor.generate
|
// CHECK-NOT: tensor.generate
|
||||||
|
@ -242,3 +236,68 @@ func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
|
||||||
} : tensor<?xindex>
|
} : tensor<?xindex>
|
||||||
return %tensor : tensor<?xindex>
|
return %tensor : tensor<?xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor.extract_slice(
|
||||||
|
// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index
|
||||||
|
func @tensor.extract_slice(
|
||||||
|
%t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
|
||||||
|
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
|
||||||
|
// CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP]]>
|
||||||
|
%0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
|
||||||
|
: tensor<?x?xf32> to tensor<?x10xf32>
|
||||||
|
// CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
|
||||||
|
// CHECK: return %[[r_tensor]]
|
||||||
|
return %0 : tensor<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor.extract_slice_rank_reducing(
|
||||||
|
// CHECK-SAME: %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index,
|
||||||
|
// CHECK-SAME: %[[idx2:.*]]: index
|
||||||
|
func @tensor.extract_slice_rank_reducing(
|
||||||
|
%t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
|
||||||
|
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
|
||||||
|
// CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP]]>
|
||||||
|
%0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
|
||||||
|
: tensor<?x10x?xf32> to tensor<?x15xf32>
|
||||||
|
// CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
|
||||||
|
// CHECK: return %[[r_tensor]]
|
||||||
|
return %0 : tensor<?x15xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor.insert_slice(
|
||||||
|
// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
|
||||||
|
// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index
|
||||||
|
func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
|
||||||
|
%idx1: index, %idx2: index) -> tensor<?x?xf32> {
|
||||||
|
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
|
||||||
|
// CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
|
||||||
|
// CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
|
||||||
|
// CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
|
||||||
|
// CHECK: memref.copy %[[m1]], %[[alloc]]
|
||||||
|
// CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
|
||||||
|
// CHECK: memref.copy %[[m2]], %[[subview]]
|
||||||
|
%0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1]
|
||||||
|
: tensor<?x10xf32> into tensor<?x?xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
|
||||||
|
// CHECK: return %[[r]]
|
||||||
|
return %0 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor.insert(
|
||||||
|
// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
|
||||||
|
// CHECK-SAME: %[[f:.*]]: f32
|
||||||
|
func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
|
||||||
|
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
|
||||||
|
// CHECK: memref.copy %[[m1]], %[[alloc]]
|
||||||
|
// CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
|
||||||
|
%0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
|
||||||
|
// CHECK: return %[[r]]
|
||||||
|
return %0 : tensor<5xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue