[mlir][linalg][bufferize] Support tensor.from_elements

This is mostly a copy of the existing tensor.from_elements bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted.

Differential Revision: https://reviews.llvm.org/D117775
This commit is contained in:
Matthias Springer 2022-01-25 22:05:35 +09:00
parent 71bbb78b8f
commit d581c94d6b
2 changed files with 98 additions and 0 deletions

View File

@ -229,6 +229,82 @@ struct ExtractOpInterface
}
};
// Implements backtracking to traverse indices of the output buffer while
// iterating over op.elements().
static void createStores(RewriterBase &rewriter, Location loc, int dim,
Value buffer, ArrayRef<int64_t> shape,
ArrayRef<Value> constants,
OperandRange::iterator &elementIt,
SmallVectorImpl<Value> &indices) {
if (dim == static_cast<int>(shape.size()) - 1) {
for (int i = 0; i < shape.back(); ++i) {
indices.back() = constants[i];
rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
++elementIt;
}
return;
}
for (int i = 0; i < shape[dim]; ++i) {
indices[dim] = constants[i];
createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
indices);
}
}
/// Bufferization of tensor.from_elements.
struct FromElementsOpInterface
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
tensor::FromElementsOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
// Allocate a buffer for the result.
Location loc = op->getLoc();
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
auto shape = tensorType.getShape();
MemRefType resultType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
FailureOr<Value> maybeBuffer =
createAlloc(rewriter, loc, resultType, {},
/*deallocMemref=*/state.getOptions().createDeallocs,
state.getOptions());
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
// Case: tensor<0xelem_type>.
if (fromElementsOp.elements().empty()) {
replaceOpWithBufferizedValues(rewriter, op, buffer);
return success();
}
// Case: tensor<elem_type>.
if (shape.empty()) {
rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
buffer);
replaceOpWithBufferizedValues(rewriter, op, buffer);
return success();
}
// Create constants for the range of possible indices [0, max{shape_i}).
auto maxDim = *std::max_element(shape.begin(), shape.end());
SmallVector<Value, 2> constants;
constants.reserve(maxDim);
for (int i = 0; i < maxDim; ++i)
constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
// Traverse all `elements` and create `memref.store` ops.
auto elementIt = fromElementsOp.elements().begin();
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
indices);
replaceOpWithBufferizedValues(rewriter, op, buffer);
return success();
}
};
/// Bufferization of tensor.generate.
struct GenerateOpInterface
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
@ -562,6 +638,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addOpInterface<DimOp, DimOpInterface>();
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
registry.addOpInterface<InsertOp, InsertOpInterface>();
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();

View File

@ -1379,3 +1379,24 @@ func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
// CHECK: }
return %result : tensor<16x?xindex>
}
// -----
// CHECK-LABEL: func @tensor_from_elements_2d(
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
%0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
: tensor<3x2xindex>
// CHECK: return %[[MEMREF]]
return %0 : tensor<3x2xindex>
}