forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support tensor.from_elements
This is mostly a copy of the existing tensor.from_elements bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted. Differential Revision: https://reviews.llvm.org/D117775
This commit is contained in:
parent
71bbb78b8f
commit
d581c94d6b
|
@ -229,6 +229,82 @@ struct ExtractOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
// Implements backtracking to traverse indices of the output buffer while
|
||||
// iterating over op.elements().
|
||||
static void createStores(RewriterBase &rewriter, Location loc, int dim,
|
||||
Value buffer, ArrayRef<int64_t> shape,
|
||||
ArrayRef<Value> constants,
|
||||
OperandRange::iterator &elementIt,
|
||||
SmallVectorImpl<Value> &indices) {
|
||||
if (dim == static_cast<int>(shape.size()) - 1) {
|
||||
for (int i = 0; i < shape.back(); ++i) {
|
||||
indices.back() = constants[i];
|
||||
rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
|
||||
++elementIt;
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < shape[dim]; ++i) {
|
||||
indices[dim] = constants[i];
|
||||
createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
|
||||
indices);
|
||||
}
|
||||
}
|
||||
|
||||
/// Bufferization of tensor.from_elements.
|
||||
struct FromElementsOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
|
||||
tensor::FromElementsOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
|
||||
|
||||
// Allocate a buffer for the result.
|
||||
Location loc = op->getLoc();
|
||||
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorType.getShape();
|
||||
MemRefType resultType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
FailureOr<Value> maybeBuffer =
|
||||
createAlloc(rewriter, loc, resultType, {},
|
||||
/*deallocMemref=*/state.getOptions().createDeallocs,
|
||||
state.getOptions());
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
Value buffer = *maybeBuffer;
|
||||
|
||||
// Case: tensor<0xelem_type>.
|
||||
if (fromElementsOp.elements().empty()) {
|
||||
replaceOpWithBufferizedValues(rewriter, op, buffer);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Case: tensor<elem_type>.
|
||||
if (shape.empty()) {
|
||||
rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
|
||||
buffer);
|
||||
replaceOpWithBufferizedValues(rewriter, op, buffer);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Create constants for the range of possible indices [0, max{shape_i}).
|
||||
auto maxDim = *std::max_element(shape.begin(), shape.end());
|
||||
SmallVector<Value, 2> constants;
|
||||
constants.reserve(maxDim);
|
||||
for (int i = 0; i < maxDim; ++i)
|
||||
constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
|
||||
|
||||
// Traverse all `elements` and create `memref.store` ops.
|
||||
auto elementIt = fromElementsOp.elements().begin();
|
||||
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
|
||||
createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
|
||||
indices);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, buffer);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Bufferization of tensor.generate.
|
||||
struct GenerateOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
|
||||
|
@ -562,6 +638,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
|||
registry.addOpInterface<DimOp, DimOpInterface>();
|
||||
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
||||
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
||||
registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
|
||||
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
|
||||
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
||||
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
||||
|
|
|
@ -1379,3 +1379,24 @@ func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
|
|||
// CHECK: }
|
||||
return %result : tensor<16x?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor_from_elements_2d(
|
||||
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
|
||||
func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
|
||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
|
||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
|
||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
|
||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
|
||||
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
|
||||
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
|
||||
%0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
|
||||
: tensor<3x2xindex>
|
||||
// CHECK: return %[[MEMREF]]
|
||||
return %0 : tensor<3x2xindex>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue