forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support tensor.from_elements
This is mostly a copy of the existing tensor.from_elements bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted. Differential Revision: https://reviews.llvm.org/D117775
This commit is contained in:
parent
71bbb78b8f
commit
d581c94d6b
|
@ -229,6 +229,82 @@ struct ExtractOpInterface
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Implements backtracking to traverse indices of the output buffer while
|
||||||
|
// iterating over op.elements().
|
||||||
|
static void createStores(RewriterBase &rewriter, Location loc, int dim,
|
||||||
|
Value buffer, ArrayRef<int64_t> shape,
|
||||||
|
ArrayRef<Value> constants,
|
||||||
|
OperandRange::iterator &elementIt,
|
||||||
|
SmallVectorImpl<Value> &indices) {
|
||||||
|
if (dim == static_cast<int>(shape.size()) - 1) {
|
||||||
|
for (int i = 0; i < shape.back(); ++i) {
|
||||||
|
indices.back() = constants[i];
|
||||||
|
rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
|
||||||
|
++elementIt;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < shape[dim]; ++i) {
|
||||||
|
indices[dim] = constants[i];
|
||||||
|
createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt,
|
||||||
|
indices);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bufferization of tensor.from_elements.
|
||||||
|
struct FromElementsOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
|
||||||
|
tensor::FromElementsOp> {
|
||||||
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
const BufferizationState &state) const {
|
||||||
|
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
|
||||||
|
|
||||||
|
// Allocate a buffer for the result.
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
|
||||||
|
auto shape = tensorType.getShape();
|
||||||
|
MemRefType resultType =
|
||||||
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
FailureOr<Value> maybeBuffer =
|
||||||
|
createAlloc(rewriter, loc, resultType, {},
|
||||||
|
/*deallocMemref=*/state.getOptions().createDeallocs,
|
||||||
|
state.getOptions());
|
||||||
|
if (failed(maybeBuffer))
|
||||||
|
return failure();
|
||||||
|
Value buffer = *maybeBuffer;
|
||||||
|
|
||||||
|
// Case: tensor<0xelem_type>.
|
||||||
|
if (fromElementsOp.elements().empty()) {
|
||||||
|
replaceOpWithBufferizedValues(rewriter, op, buffer);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case: tensor<elem_type>.
|
||||||
|
if (shape.empty()) {
|
||||||
|
rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(),
|
||||||
|
buffer);
|
||||||
|
replaceOpWithBufferizedValues(rewriter, op, buffer);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create constants for the range of possible indices [0, max{shape_i}).
|
||||||
|
auto maxDim = *std::max_element(shape.begin(), shape.end());
|
||||||
|
SmallVector<Value, 2> constants;
|
||||||
|
constants.reserve(maxDim);
|
||||||
|
for (int i = 0; i < maxDim; ++i)
|
||||||
|
constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
|
||||||
|
|
||||||
|
// Traverse all `elements` and create `memref.store` ops.
|
||||||
|
auto elementIt = fromElementsOp.elements().begin();
|
||||||
|
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
|
||||||
|
createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
|
||||||
|
indices);
|
||||||
|
|
||||||
|
replaceOpWithBufferizedValues(rewriter, op, buffer);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Bufferization of tensor.generate.
|
/// Bufferization of tensor.generate.
|
||||||
struct GenerateOpInterface
|
struct GenerateOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
|
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
|
||||||
|
@ -562,6 +638,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
||||||
registry.addOpInterface<DimOp, DimOpInterface>();
|
registry.addOpInterface<DimOp, DimOpInterface>();
|
||||||
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
||||||
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
||||||
|
registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
|
||||||
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
|
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
|
||||||
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
||||||
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
||||||
|
|
|
@ -1379,3 +1379,24 @@ func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
return %result : tensor<16x?xindex>
|
return %result : tensor<16x?xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_from_elements_2d(
|
||||||
|
// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index
|
||||||
|
func @tensor_from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> {
|
||||||
|
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||||
|
// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<3x2xindex>
|
||||||
|
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]]
|
||||||
|
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]]
|
||||||
|
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]]
|
||||||
|
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]]
|
||||||
|
// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]]
|
||||||
|
// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]]
|
||||||
|
%0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1
|
||||||
|
: tensor<3x2xindex>
|
||||||
|
// CHECK: return %[[MEMREF]]
|
||||||
|
return %0 : tensor<3x2xindex>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue