forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support tensor.generate
This is mostly a copy of the existing tensor.generate bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted. Differential Revision: https://reviews.llvm.org/D117770
This commit is contained in:
parent
6a008de82a
commit
71bbb78b8f
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -228,6 +229,65 @@ struct ExtractOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
/// Bufferization of tensor.generate.
|
||||
struct GenerateOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
|
||||
tensor::GenerateOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
auto generateOp = cast<tensor::GenerateOp>(op);
|
||||
|
||||
// Allocate memory.
|
||||
Location loc = op->getLoc();
|
||||
MemRefType memrefType =
|
||||
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
|
||||
FailureOr<Value> maybeResult =
|
||||
createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
|
||||
/*deallocMemref=*/state.getOptions().createDeallocs,
|
||||
state.getOptions());
|
||||
if (failed(maybeResult))
|
||||
return failure();
|
||||
Value result = *maybeResult;
|
||||
|
||||
// Collect loop bounds.
|
||||
int64_t rank = memrefType.getRank();
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
SmallVector<Value, 4> lowerBounds(rank, zero);
|
||||
SmallVector<Value, 4> steps(rank, one);
|
||||
SmallVector<Value, 4> upperBounds;
|
||||
int nextDynamicIndex = 0;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
Value upperBound = memrefType.isDynamicDim(i)
|
||||
? generateOp.dynamicExtents()[nextDynamicIndex++]
|
||||
: rewriter.create<arith::ConstantIndexOp>(
|
||||
loc, memrefType.getDimSize(i));
|
||||
upperBounds.push_back(upperBound);
|
||||
}
|
||||
|
||||
// Generate tensor elements with a parallel loop that stores into
|
||||
// each element of the resulting memref. We use mergeBlockBefore to "move"
|
||||
// this op's body into the scf.parallel's body.
|
||||
auto parallel =
|
||||
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
|
||||
Block *parallelBody = parallel.getBody();
|
||||
rewriter.mergeBlockBefore(generateOp.getBody(),
|
||||
parallelBody->getTerminator(),
|
||||
parallelBody->getArguments());
|
||||
// Replace the inlined yield op with a store op. The scf.parallel's builder
|
||||
// already populated an scf.yield at the end, so we don't need to worry
|
||||
// about creating that.
|
||||
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
|
||||
rewriter.setInsertionPointAfter(elementYield);
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
elementYield, elementYield->getOperands()[0], result,
|
||||
parallelBody->getArguments());
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Bufferization of tensor.insert. Replace with memref.store.
|
||||
struct InsertOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
|
||||
|
@ -502,6 +562,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
|||
registry.addOpInterface<DimOp, DimOpInterface>();
|
||||
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
||||
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
||||
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
|
||||
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
||||
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
||||
registry.addOpInterface<RankOp, RankOpInterface>();
|
||||
|
|
|
@ -1359,3 +1359,23 @@ func @tensor_rank(%arg0: tensor<*xf32>) -> index {
|
|||
// CHECK: return %[[r]] : index
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
|
||||
// CHECK-SAME: %[[arg0:.*]]: index
|
||||
func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
|
||||
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
|
||||
// CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
|
||||
%result = tensor.generate %arg0 {
|
||||
^bb0(%i: index, %j: index):
|
||||
%sum = arith.addi %i, %j : index
|
||||
// CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
|
||||
// CHECK: scf.yield
|
||||
tensor.yield %sum : index
|
||||
} : tensor<16x?xindex>
|
||||
// CHECK: }
|
||||
return %result : tensor<16x?xindex>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue