[mlir][linalg][bufferize] Support tensor.generate

This is mostly a copy of the existing tensor.generate bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted.

Differential Revision: https://reviews.llvm.org/D117770
This commit is contained in:
Matthias Springer 2022-01-25 22:05:10 +09:00
parent 6a008de82a
commit 71bbb78b8f
2 changed files with 81 additions and 0 deletions

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@ -228,6 +229,65 @@ struct ExtractOpInterface
}
};
/// Bufferization of tensor.generate.
struct GenerateOpInterface
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
tensor::GenerateOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto generateOp = cast<tensor::GenerateOp>(op);
// Allocate memory.
Location loc = op->getLoc();
MemRefType memrefType =
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
FailureOr<Value> maybeResult =
createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
/*deallocMemref=*/state.getOptions().createDeallocs,
state.getOptions());
if (failed(maybeResult))
return failure();
Value result = *maybeResult;
// Collect loop bounds.
int64_t rank = memrefType.getRank();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value, 4> lowerBounds(rank, zero);
SmallVector<Value, 4> steps(rank, one);
SmallVector<Value, 4> upperBounds;
int nextDynamicIndex = 0;
for (int i = 0; i < rank; i++) {
Value upperBound = memrefType.isDynamicDim(i)
? generateOp.dynamicExtents()[nextDynamicIndex++]
: rewriter.create<arith::ConstantIndexOp>(
loc, memrefType.getDimSize(i));
upperBounds.push_back(upperBound);
}
// Generate tensor elements with a parallel loop that stores into
// each element of the resulting memref. We use mergeBlockBefore to "move"
// this op's body into the scf.parallel's body.
auto parallel =
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
Block *parallelBody = parallel.getBody();
rewriter.mergeBlockBefore(generateOp.getBody(),
parallelBody->getTerminator(),
parallelBody->getArguments());
// Replace the inlined yield op with a store op. The scf.parallel's builder
// already populated an scf.yield at the end, so we don't need to worry
// about creating that.
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
rewriter.setInsertionPointAfter(elementYield);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
elementYield, elementYield->getOperands()[0], result,
parallelBody->getArguments());
replaceOpWithBufferizedValues(rewriter, op, result);
return success();
}
};
/// Bufferization of tensor.insert. Replace with memref.store.
struct InsertOpInterface
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
@ -502,6 +562,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
registry.addOpInterface<DimOp, DimOpInterface>();
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
registry.addOpInterface<InsertOp, InsertOpInterface>();
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
registry.addOpInterface<RankOp, RankOpInterface>();

View File

@ -1359,3 +1359,23 @@ func @tensor_rank(%arg0: tensor<*xf32>) -> index {
// CHECK: return %[[r]] : index
return %0 : index
}
// -----
// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
// CHECK-SAME: %[[arg0:.*]]: index
func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
// CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
%result = tensor.generate %arg0 {
^bb0(%i: index, %j: index):
%sum = arith.addi %i, %j : index
// CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
// CHECK: scf.yield
tensor.yield %sum : index
} : tensor<16x?xindex>
// CHECK: }
return %result : tensor<16x?xindex>
}