forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support tensor.generate
This is mostly a copy of the existing tensor.generate bufferization. Once TensorInterfaceImpl.cpp is moved to the tensor dialect, the existing rewrite pattern can be deleted. Differential Revision: https://reviews.llvm.org/D117770
This commit is contained in:
parent
6a008de82a
commit
71bbb78b8f
|
@ -9,6 +9,7 @@
|
||||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
@ -228,6 +229,65 @@ struct ExtractOpInterface
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Bufferization of tensor.generate.
|
||||||
|
struct GenerateOpInterface
|
||||||
|
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
|
||||||
|
tensor::GenerateOp> {
|
||||||
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
const BufferizationState &state) const {
|
||||||
|
auto generateOp = cast<tensor::GenerateOp>(op);
|
||||||
|
|
||||||
|
// Allocate memory.
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
MemRefType memrefType =
|
||||||
|
getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>());
|
||||||
|
FailureOr<Value> maybeResult =
|
||||||
|
createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(),
|
||||||
|
/*deallocMemref=*/state.getOptions().createDeallocs,
|
||||||
|
state.getOptions());
|
||||||
|
if (failed(maybeResult))
|
||||||
|
return failure();
|
||||||
|
Value result = *maybeResult;
|
||||||
|
|
||||||
|
// Collect loop bounds.
|
||||||
|
int64_t rank = memrefType.getRank();
|
||||||
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
SmallVector<Value, 4> lowerBounds(rank, zero);
|
||||||
|
SmallVector<Value, 4> steps(rank, one);
|
||||||
|
SmallVector<Value, 4> upperBounds;
|
||||||
|
int nextDynamicIndex = 0;
|
||||||
|
for (int i = 0; i < rank; i++) {
|
||||||
|
Value upperBound = memrefType.isDynamicDim(i)
|
||||||
|
? generateOp.dynamicExtents()[nextDynamicIndex++]
|
||||||
|
: rewriter.create<arith::ConstantIndexOp>(
|
||||||
|
loc, memrefType.getDimSize(i));
|
||||||
|
upperBounds.push_back(upperBound);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate tensor elements with a parallel loop that stores into
|
||||||
|
// each element of the resulting memref. We use mergeBlockBefore to "move"
|
||||||
|
// this op's body into the scf.parallel's body.
|
||||||
|
auto parallel =
|
||||||
|
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
|
||||||
|
Block *parallelBody = parallel.getBody();
|
||||||
|
rewriter.mergeBlockBefore(generateOp.getBody(),
|
||||||
|
parallelBody->getTerminator(),
|
||||||
|
parallelBody->getArguments());
|
||||||
|
// Replace the inlined yield op with a store op. The scf.parallel's builder
|
||||||
|
// already populated an scf.yield at the end, so we don't need to worry
|
||||||
|
// about creating that.
|
||||||
|
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
|
||||||
|
rewriter.setInsertionPointAfter(elementYield);
|
||||||
|
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||||
|
elementYield, elementYield->getOperands()[0], result,
|
||||||
|
parallelBody->getArguments());
|
||||||
|
|
||||||
|
replaceOpWithBufferizedValues(rewriter, op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Bufferization of tensor.insert. Replace with memref.store.
|
/// Bufferization of tensor.insert. Replace with memref.store.
|
||||||
struct InsertOpInterface
|
struct InsertOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
|
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
|
||||||
|
@ -502,6 +562,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
||||||
registry.addOpInterface<DimOp, DimOpInterface>();
|
registry.addOpInterface<DimOp, DimOpInterface>();
|
||||||
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
||||||
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
||||||
|
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
|
||||||
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
||||||
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
||||||
registry.addOpInterface<RankOp, RankOpInterface>();
|
registry.addOpInterface<RankOp, RankOpInterface>();
|
||||||
|
|
|
@ -1359,3 +1359,23 @@ func @tensor_rank(%arg0: tensor<*xf32>) -> index {
|
||||||
// CHECK: return %[[r]] : index
|
// CHECK: return %[[r]] : index
|
||||||
return %0 : index
|
return %0 : index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensor_generate_static_and_dynamic(
|
||||||
|
// CHECK-SAME: %[[arg0:.*]]: index
|
||||||
|
func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
|
||||||
|
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex>
|
||||||
|
// CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} {
|
||||||
|
%result = tensor.generate %arg0 {
|
||||||
|
^bb0(%i: index, %j: index):
|
||||||
|
%sum = arith.addi %i, %j : index
|
||||||
|
// CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]]
|
||||||
|
// CHECK: scf.yield
|
||||||
|
tensor.yield %sum : index
|
||||||
|
} : tensor<16x?xindex>
|
||||||
|
// CHECK: }
|
||||||
|
return %result : tensor<16x?xindex>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue