[mlir][tensor] Replace tensor-bufferize with BufferizableOpInterface impl

This commit switches the `tensor-bufferize` pass over to BufferizableOpInterface-based bufferization.

Differential Revision: https://reviews.llvm.org/D118246
This commit is contained in:
Matthias Springer 2022-01-27 19:18:59 +09:00
parent d58757e522
commit daf18108ec
9 changed files with 207 additions and 296 deletions

View File

@ -255,7 +255,7 @@ public:
const BufferizationOptions &getOptions() const { return options; } const BufferizationOptions &getOptions() const { return options; }
protected: protected:
BufferizationState(const BufferizationOptions &options); explicit BufferizationState(const BufferizationOptions &options);
// BufferizationState should be passed as a reference. // BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete; BufferizationState(const BufferizationState &) = delete;
@ -270,6 +270,24 @@ private:
const BufferizationOptions &options; const BufferizationOptions &options;
}; };
/// This a "no analysis, always copy" BufferizationState. In the absence of an
/// analysis, a buffer must be copied each time it is written to. Therefore, all
/// OpOperands that bufferize to a memory write must bufferize out-of-place.
class AlwaysCopyBufferizationState : public BufferizationState {
public:
explicit AlwaysCopyBufferizationState(const BufferizationOptions &options);
AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
virtual ~AlwaysCopyBufferizationState() = default;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const override;
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
};
/// Replace an op with replacement values. The op is deleted. Tensor OpResults /// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values. /// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,

View File

@ -69,6 +69,21 @@ void populateEliminateBufferizeMaterializationsPatterns(
// TODO: Extract `options` from `state` and pass as separate argument. // TODO: Extract `options` from `state` and pass as separate argument.
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state); LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
/// Buffers are duplicated and copied before any tensor use that bufferizes to
/// a memory write.
///
/// Note: This function bufferizes ops without utilizing analysis results. It
/// can be used to implement partial bufferization passes.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
/// Populate the pattern set with a pattern that bufferizes ops that implement
/// `BufferizableOpInterface`.
void populateBufferizationPattern(const BufferizationState &state,
RewritePatternSet &patterns);
std::unique_ptr<BufferizationOptions> getPartialBufferizationOptions();
} // namespace bufferization } // namespace bufferization
} // namespace mlir } // namespace mlir

View File

@ -12,16 +12,6 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // namespace bufferization
class RewritePatternSet;
void populateTensorBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Creates an instance of `tensor` dialect bufferization pass. /// Creates an instance of `tensor` dialect bufferization pass.
std::unique_ptr<Pass> createTensorBufferizePass(); std::unique_ptr<Pass> createTensorBufferizePass();

View File

@ -14,11 +14,6 @@ include "mlir/Pass/PassBase.td"
def TensorBufferize : Pass<"tensor-bufferize", "FuncOp"> { def TensorBufferize : Pass<"tensor-bufferize", "FuncOp"> {
let summary = "Bufferize the `tensor` dialect"; let summary = "Bufferize the `tensor` dialect";
let constructor = "mlir::createTensorBufferizePass()"; let constructor = "mlir::createTensorBufferizePass()";
let dependentDialects = [
"bufferization::BufferizationDialect",
"memref::MemRefDialect",
"scf::SCFDialect"
];
} }
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES

View File

@ -318,6 +318,25 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
rewriter.eraseOp(op); rewriter.eraseOp(op);
} }
AlwaysCopyBufferizationState::AlwaysCopyBufferizationState(
const BufferizationOptions &options)
: BufferizationState(options) {}
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const {
// OpOperands that bufferize to a memory write are out-of-place, i.e., an
// alloc and copy is inserted.
return !bufferizesToMemoryWrite(opOperand);
}
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues(
Value v1, Value v2) const {
// There is no analysis, so we do not know if the values are equivalent. The
// conservative answer is "false".
return false;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support. // Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -207,9 +207,59 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationState &state) { const BufferizationState &state) {
// Bufferize the op and its nested ops. // Bufferize the op and its nested ops.
RewritePatternSet patterns(op->getContext()); RewritePatternSet patterns(op->getContext());
patterns.add<BufferizationPattern>(op->getContext(), state); populateBufferizationPattern(state, patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return failure(); return failure();
return checkBufferizationResult(op, state.getOptions()); return checkBufferizationResult(op, state.getOptions());
} }
namespace {
/// This a "no analysis, always copy" BufferizationState. In the absence of an
/// analysis, a buffer must be copied each time it is written to. Therefore, all
/// OpOperands that bufferize to a memory write must bufferize out-of-place.
class AlwaysCopyBufferizationState : public BufferizationState {
public:
AlwaysCopyBufferizationState(const BufferizationOptions &options)
: BufferizationState(options) {}
AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
virtual ~AlwaysCopyBufferizationState() = default;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const override {
// OpOperands that bufferize to a memory write are out-of-place, i.e., an
// alloc and copy is inserted.
return !bufferizesToMemoryWrite(opOperand);
}
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override {
// There is no analysis, so we do not know if the values are equivalent. The
// conservative answer is "false".
return false;
}
};
} // namespace
LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options) {
AlwaysCopyBufferizationState state(options);
return bufferizeOp(op, state);
}
void bufferization::populateBufferizationPattern(
const BufferizationState &state, RewritePatternSet &patterns) {
patterns.add<BufferizationPattern>(patterns.getContext(), state);
}
std::unique_ptr<BufferizationOptions>
bufferization::getPartialBufferizationOptions() {
auto options = std::make_unique<BufferizationOptions>();
options->allowReturnMemref = true;
options->allowUnknownOps = true;
options->createDeallocs = false;
options->fullyDynamicLayoutMaps = false;
return options;
}

View File

@ -13,223 +13,40 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
using namespace mlir; using namespace mlir;
using namespace bufferization;
namespace { namespace {
struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
adaptor.getOperands()[0]);
return success();
}
};
struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
adaptor.index());
return success();
}
};
struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
adaptor.indices());
return success();
}
};
struct BufferizeFromElementsOp
: public OpConversionPattern<tensor::FromElementsOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto tensorType = op.getType().cast<RankedTensorType>();
auto shape = tensorType.getShape();
// Allocate a buffer for the result.
auto resultType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value buffer = rewriter.create<memref::AllocOp>(loc, resultType);
// Case: tensor<0xelem_type>.
if (op.elements().empty()) {
rewriter.replaceOp(op, {buffer});
return success();
}
// Case: tensor<elem_type>.
if (shape.empty()) {
rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer);
rewriter.replaceOp(op, {buffer});
return success();
}
// Create constants for the range of possible indices [0, max{shape_i}).
auto maxDim = *std::max_element(shape.begin(), shape.end());
SmallVector<Value, 2> constants;
constants.reserve(maxDim);
for (int i = 0; i < maxDim; ++i)
constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
// Traverse all `elements` and create `memref.store` ops.
ImplicitLocOpBuilder b(loc, rewriter);
auto elementIt = adaptor.elements().begin();
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b);
rewriter.replaceOp(op, {buffer});
return success();
}
private:
// Implements backtracking to traverse indices of the output buffer while
// iterating over op.elements().
void createStores(int dim, Value buffer, ArrayRef<int64_t> shape,
ArrayRef<Value> constants, ValueRange::iterator &elementIt,
SmallVectorImpl<Value> &indices,
ImplicitLocOpBuilder b) const {
if (dim == static_cast<int>(shape.size()) - 1) {
for (int i = 0; i < shape.back(); ++i) {
indices.back() = constants[i];
b.create<memref::StoreOp>(*elementIt, buffer, indices);
++elementIt;
}
return;
}
for (int i = 0; i < shape[dim]; ++i) {
indices[dim] = constants[i];
createStores(dim + 1, buffer, shape, constants, elementIt, indices, b);
}
}
};
struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Allocate memory.
Location loc = op.getLoc();
RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
adaptor.dynamicExtents());
// Collect loop bounds.
int64_t rank = tensorType.getRank();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value, 4> lowerBounds(rank, zero);
SmallVector<Value, 4> steps(rank, one);
SmallVector<Value, 4> upperBounds;
int nextDynamicIndex = 0;
for (int i = 0; i < rank; i++) {
Value upperBound = tensorType.isDynamicDim(i)
? adaptor.dynamicExtents()[nextDynamicIndex++]
: rewriter.create<arith::ConstantIndexOp>(
loc, memrefType.getDimSize(i));
upperBounds.push_back(upperBound);
}
// Generate tensor elements with a parallel loop that stores into
// each element of the resulting memref.
//
// This is a bit tricky. We cannot simply clone the ops because when an op
// is cloned, it must be legalized. However, we want to allow arbitrary ops
// in the body that we don't necessarily have legalization patterns for as
// part of this dialect conversion invocation.
//
// To accomplish this, we use mergeBlockBefore to "move" this op's body
// into the scf.parallel's body.
auto parallel =
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
Block *parallelBody = parallel.getBody();
rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
parallelBody->getArguments());
// Replace the inlined yield op with a store op. The scf.parallel's builder
// already populated an scf.yield at the end, so we don't need to worry
// about creating that.
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
rewriter.setInsertionPointAfter(elementYield);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
elementYield, elementYield->getOperands()[0], result,
parallelBody->getArguments());
rewriter.replaceOp(op, {result});
return success();
}
};
struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
adaptor.tensor());
return success();
}
};
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnOperation() override { void runOnOperation() override {
auto *context = &getContext(); std::unique_ptr<BufferizationOptions> options =
bufferization::BufferizeTypeConverter typeConverter; getPartialBufferizationOptions();
options->addToDialectFilter<tensor::TensorDialect>();
ConversionTarget target(*context); if (failed(bufferizeOp(getOperation(), *options)))
target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
StandardOpsDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
target.addLegalOp<CallOp, ReturnOp>();
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::GenerateOp>();
bufferization::populateBufferizeMaterializationLegality(target);
RewritePatternSet patterns(context);
populateTensorBufferizePatterns(typeConverter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure(); signalPassFailure();
} }
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
tensor::TensorDialect, scf::SCFDialect,
arith::ArithmeticDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry);
}
}; };
} // namespace } // namespace
void mlir::populateTensorBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
typeConverter, patterns.getContext());
}
std::unique_ptr<Pass> mlir::createTensorBufferizePass() { std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>(); return std::make_unique<TensorBufferizePass>();
} }

View File

@ -1355,55 +1355,3 @@ func @write_after_select_read_one(
// CHECK: return %[[f]], %[[select]] // CHECK: return %[[f]], %[[select]]
return %f, %w : f32, tensor<?xf32> return %f, %w : f32, tensor<?xf32>
} }
// -----
// CHECK-LABEL: func @tensor_rank(
// CHECK-SAME: %[[arg0:.*]]: memref<*xf32>
func @tensor_rank(%arg0: tensor<*xf32>) -> index {
// CHECK: %[[r:.*]] = memref.rank %[[arg0]]
%0 = tensor.rank %arg0 : tensor<*xf32>
// CHECK: return %[[r]] : index
return %0 : index
}
// -----
// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
// CHECK-SAME: %[[arg0:.*]]: index
func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
// CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
%result = tensor.generate %arg0 {
^bb0(%i: index, %j: index):
%sum = arith.addi %i, %j : index
// CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
// CHECK: scf.yield
tensor.yield %sum : index
} : tensor<16x?xindex>
// CHECK: }
return %result : tensor<16x?xindex>
}
// -----
// CHECK-LABEL: func @tensor_from_elements_2d(
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
%0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
: tensor<3x2xindex>
// CHECK: return %[[MEMREF]]
return %0 : tensor<3x2xindex>
}

View File

@ -1,5 +1,7 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-LABEL: func @dim( // CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, // CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
// CHECK-SAME: %[[INDEX:.*]]: index) -> index { // CHECK-SAME: %[[INDEX:.*]]: index) -> index {
@ -66,8 +68,7 @@ func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
} }
// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> { // CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> {
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<0xindex> // CHECK: %[[RET:.*]] = arith.constant dense<> : tensor<0xindex>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
// CHECK: return %[[RET]] : tensor<0xindex> // CHECK: return %[[RET]] : tensor<0xindex>
func @tensor.from_elements_no_elements() -> tensor<0xindex> { func @tensor.from_elements_no_elements() -> tensor<0xindex> {
%0 = tensor.from_elements : tensor<0xindex> %0 = tensor.from_elements : tensor<0xindex>
@ -76,7 +77,7 @@ func @tensor.from_elements_no_elements() -> tensor<0xindex> {
// CHECK-LABEL: func @tensor.from_elements_0d( // CHECK-LABEL: func @tensor.from_elements_0d(
// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> { // CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> {
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<index> // CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
// CHECK: store %[[ELEM0]], %[[MEMREF]] // CHECK: store %[[ELEM0]], %[[MEMREF]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
// CHECK: return %[[RET]] : tensor<index> // CHECK: return %[[RET]] : tensor<index>
@ -88,9 +89,9 @@ func @tensor.from_elements_0d(%arg0: index) -> tensor<index> {
// CHECK-LABEL: func @tensor.from_elements_1d( // CHECK-LABEL: func @tensor.from_elements_1d(
// CHECK-SAME: %[[ELEM0:.*]]: index, // CHECK-SAME: %[[ELEM0:.*]]: index,
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { // CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<2xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<2xindex>
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
@ -103,10 +104,10 @@ func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> {
// CHECK-LABEL: func @tensor.from_elements_2d( // CHECK-LABEL: func @tensor.from_elements_2d(
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index) // CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index)
// CHECK-SAME: -> tensor<3x2xindex> { // CHECK-SAME: -> tensor<3x2xindex> {
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] // CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
@ -121,9 +122,9 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
return %0 : tensor<3x2xindex> return %0 : tensor<3x2xindex>
} }
// CHECK-LABEL: func @tensor.from_elements_3d() // CHECK-LABEL: func @tensor.from_elements_3d(
// CHECK-SAME: %[[F0:.*]]: f32
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 // CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 // CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 // CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
@ -136,11 +137,11 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 // CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 // CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2x2xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] // CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]] // CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]]
@ -157,8 +158,7 @@ func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]]
// CHECK: return %[[RET]] : tensor<3x2x2xf32> // CHECK: return %[[RET]] : tensor<3x2x2xf32>
func @tensor.from_elements_3d() -> tensor<3x2x2xf32> { func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
%f0 = arith.constant 0.0 : f32
%f1 = arith.constant 1.0 : f32 %f1 = arith.constant 1.0 : f32
%f2 = arith.constant 2.0 : f32 %f2 = arith.constant 2.0 : f32
%f3 = arith.constant 3.0 : f32 %f3 = arith.constant 3.0 : f32
@ -179,9 +179,9 @@ func @tensor.from_elements_3d() -> tensor<3x2x2xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32> // CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> // CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
@ -204,10 +204,10 @@ func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xi
// //
// CHECK-LABEL: func @tensor.generate_static_and_dynamic( // CHECK-LABEL: func @tensor.generate_static_and_dynamic(
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { // CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index // CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> // CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
@ -225,12 +225,6 @@ func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
return %result : tensor<16x?xindex> return %result : tensor<16x?xindex>
} }
// The tensor.generate op needs to put its body into the
// resulting scf.parallel. To handle unknown ops in the body, it cannot clone
// the body because that would require the cloned ops to be legalized
// immediately, which is usually not possible since they might be from various
// other dialects.
//
// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body // CHECK-LABEL: func @tensor.generate_unknown_ops_in_body
func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> { func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
// CHECK-NOT: tensor.generate // CHECK-NOT: tensor.generate
@ -242,3 +236,68 @@ func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
} : tensor<?xindex> } : tensor<?xindex>
return %tensor : tensor<?xindex> return %tensor : tensor<?xindex>
} }
// CHECK-LABEL: func @tensor.extract_slice(
// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[idx1:.*]]: index, %[[idx2:.*]]: index
func @tensor.extract_slice(
%t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
// CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP]]>
%0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
: tensor<?x?xf32> to tensor<?x10xf32>
// CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
// CHECK: return %[[r_tensor]]
return %0 : tensor<?x10xf32>
}
// CHECK-LABEL: func @tensor.extract_slice_rank_reducing(
// CHECK-SAME: %[[t1:.*]]: tensor<?x10x?xf32>, %[[idx1:.*]]: index,
// CHECK-SAME: %[[idx2:.*]]: index
func @tensor.extract_slice_rank_reducing(
%t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
// CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP]]>
%0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
: tensor<?x10x?xf32> to tensor<?x15xf32>
// CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
// CHECK: return %[[r_tensor]]
return %0 : tensor<?x15xf32>
}
// CHECK-LABEL: func @tensor.insert_slice(
// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index
func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
%idx1: index, %idx2: index) -> tensor<?x?xf32> {
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
// CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
// CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
// CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
// CHECK: memref.copy %[[m1]], %[[alloc]]
// CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
// CHECK: memref.copy %[[m2]], %[[subview]]
%0 = tensor.insert_slice %t2 into %t1[%idx1, 5][%idx2, 10][1, 1]
: tensor<?x10xf32> into tensor<?x?xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
// CHECK: return %[[r]]
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @tensor.insert(
// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
// CHECK-SAME: %[[f:.*]]: f32
func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32> {
// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32>
// CHECK: memref.copy %[[m1]], %[[alloc]]
// CHECK: memref.store %[[f]], %[[alloc]][%[[idx1]]]
%0 = tensor.insert %f into %t1[%idx1] : tensor<5xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
// CHECK: return %[[r]]
return %0 : tensor<5xf32>
}