forked from OSchip/llvm-project
[mlir][spirv] Add lowering for composite std.constant.
Add lowering for constant operation with ranked tensor type to spv.constant with spv.array type. Differential Revision: https://reviews.llvm.org/D73022
This commit is contained in:
parent
178562fb35
commit
4460cb5bcd
|
@ -25,6 +25,17 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
/// Convert composite constant operation to SPIR-V dialect.
|
||||
// TODO(denis0x0D) : move to DRR.
|
||||
class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Convert constant operation with IndexType return to SPIR-V constant
|
||||
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
|
||||
/// special handling to make sure the result type and the type of the value
|
||||
|
@ -172,6 +183,39 @@ static spirv::AccessChainOp getElementPtr(OpBuilder &builder,
|
|||
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp with composite type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
|
||||
ConstantOp constCompositeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto compositeType =
|
||||
constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!compositeType)
|
||||
return matchFailure();
|
||||
|
||||
auto spirvCompositeType = typeConverter.convertType(compositeType);
|
||||
if (!spirvCompositeType)
|
||||
return matchFailure();
|
||||
|
||||
auto linearizedElements =
|
||||
constCompositeOp.value().dyn_cast<DenseElementsAttr>();
|
||||
if (!linearizedElements)
|
||||
return matchFailure();
|
||||
|
||||
// If composite type has rank greater than one, then perform linearization.
|
||||
if (compositeType.getRank() > 1) {
|
||||
auto linearizedType = RankedTensorType::get(compositeType.getNumElements(),
|
||||
compositeType.getElementType());
|
||||
linearizedElements = linearizedElements.reshape(linearizedType);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
|
||||
constCompositeOp, spirvCompositeType, linearizedElements);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp with index type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -354,7 +398,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|||
OwningRewritePatternList &patterns) {
|
||||
// Add patterns that lower operations into SPIR-V dialect.
|
||||
populateWithGenerated(context, &patterns);
|
||||
patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
|
||||
patterns.insert<ConstantCompositeOpConversion, ConstantIndexOpConversion,
|
||||
CmpFOpConversion, CmpIOpConversion,
|
||||
IntegerOpConversion<AddIOp, spirv::IAddOp>,
|
||||
IntegerOpConversion<MulIOp, spirv::IMulOp>,
|
||||
IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
|
||||
|
|
|
@ -80,6 +80,19 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
|
|||
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
|
||||
}
|
||||
return (offset + memrefSize) * elementSize.getValue();
|
||||
} else if (auto tensorType = t.dyn_cast<TensorType>()) {
|
||||
if (!tensorType.hasStaticShape()) {
|
||||
return llvm::None;
|
||||
}
|
||||
auto elementSize = getTypeNumBytes(tensorType.getElementType());
|
||||
if (!elementSize) {
|
||||
return llvm::None;
|
||||
}
|
||||
int64_t size = elementSize.getValue();
|
||||
for (auto shape : tensorType.getShape()) {
|
||||
size *= shape;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
// TODO: Add size computation for other types.
|
||||
return llvm::None;
|
||||
|
@ -131,6 +144,27 @@ static Type convertStdType(Type type) {
|
|||
}
|
||||
}
|
||||
|
||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
// TODO(ravishankarm) : Handle dynamic shapes.
|
||||
if (!tensorType.hasStaticShape()) {
|
||||
return Type();
|
||||
}
|
||||
auto elementType = convertStdType(tensorType.getElementType());
|
||||
if (!elementType) {
|
||||
return Type();
|
||||
}
|
||||
auto elementSize = getTypeNumBytes(elementType);
|
||||
if (!elementSize) {
|
||||
return Type();
|
||||
}
|
||||
auto tensorSize = getTypeNumBytes(tensorType);
|
||||
if (!tensorSize) {
|
||||
return Type();
|
||||
}
|
||||
return spirv::ArrayType::get(elementType,
|
||||
tensorSize.getValue() / elementSize.getValue(),
|
||||
elementSize.getValue());
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
|
|
|
@ -220,6 +220,18 @@ func @constant() {
|
|||
%3 = constant dense<[2, 3]> : vector<2xi32>
|
||||
// CHECK: spv.constant 1 : i32
|
||||
%4 = constant 1 : index
|
||||
// CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32 [4]>
|
||||
%5 = constant dense<1> : tensor<2x3xi32>
|
||||
// CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32 [4]>
|
||||
%6 = constant dense<1.0> : tensor<2x3xf32>
|
||||
// CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]>
|
||||
%7 = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
|
||||
// CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
|
||||
%8 = constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
|
||||
// CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
|
||||
%9 = constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
|
||||
// CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
|
||||
%10 = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue