[mlir][spirv] Add lowering for composite std.constant.

Add lowering for constant operation with ranked tensor type to
spv.constant with spv.array type.

Differential Revision: https://reviews.llvm.org/D73022
This commit is contained in:
Denis Khalikov 2020-01-22 08:05:27 -05:00 committed by Lei Zhang
parent 178562fb35
commit 4460cb5bcd
3 changed files with 92 additions and 1 deletions

View File

@ -25,6 +25,17 @@ using namespace mlir;
namespace {
/// Convert composite constant operation to SPIR-V dialect.
// TODO(denis0x0D) : move to DRR.
class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Convert constant operation with IndexType return to SPIR-V constant
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
/// special handling to make sure the result type and the type of the value
@ -172,6 +183,39 @@ static spirv::AccessChainOp getElementPtr(OpBuilder &builder,
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}
//===----------------------------------------------------------------------===//
// ConstantOp with composite type.
//===----------------------------------------------------------------------===//
PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
ConstantOp constCompositeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto compositeType =
constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
if (!compositeType)
return matchFailure();
auto spirvCompositeType = typeConverter.convertType(compositeType);
if (!spirvCompositeType)
return matchFailure();
auto linearizedElements =
constCompositeOp.value().dyn_cast<DenseElementsAttr>();
if (!linearizedElements)
return matchFailure();
// If composite type has rank greater than one, then perform linearization.
if (compositeType.getRank() > 1) {
auto linearizedType = RankedTensorType::get(compositeType.getNumElements(),
compositeType.getElementType());
linearizedElements = linearizedElements.reshape(linearizedType);
}
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
constCompositeOp, spirvCompositeType, linearizedElements);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// ConstantOp with index type.
//===----------------------------------------------------------------------===//
@ -354,7 +398,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
// Add patterns that lower operations into SPIR-V dialect.
populateWithGenerated(context, &patterns);
patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
patterns.insert<ConstantCompositeOpConversion, ConstantIndexOpConversion,
CmpFOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,

View File

@ -80,6 +80,19 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
}
return (offset + memrefSize) * elementSize.getValue();
} else if (auto tensorType = t.dyn_cast<TensorType>()) {
if (!tensorType.hasStaticShape()) {
return llvm::None;
}
auto elementSize = getTypeNumBytes(tensorType.getElementType());
if (!elementSize) {
return llvm::None;
}
int64_t size = elementSize.getValue();
for (auto shape : tensorType.getShape()) {
size *= shape;
}
return size;
}
// TODO: Add size computation for other types.
return llvm::None;
@ -131,6 +144,27 @@ static Type convertStdType(Type type) {
}
}
if (auto tensorType = type.dyn_cast<TensorType>()) {
// TODO(ravishankarm) : Handle dynamic shapes.
if (!tensorType.hasStaticShape()) {
return Type();
}
auto elementType = convertStdType(tensorType.getElementType());
if (!elementType) {
return Type();
}
auto elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
return Type();
}
auto tensorSize = getTypeNumBytes(tensorType);
if (!tensorSize) {
return Type();
}
return spirv::ArrayType::get(elementType,
tensorSize.getValue() / elementSize.getValue(),
elementSize.getValue());
}
return Type();
}

View File

@ -220,6 +220,18 @@ func @constant() {
%3 = constant dense<[2, 3]> : vector<2xi32>
// CHECK: spv.constant 1 : i32
%4 = constant 1 : index
// CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32 [4]>
%5 = constant dense<1> : tensor<2x3xi32>
// CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32 [4]>
%6 = constant dense<1.0> : tensor<2x3xf32>
// CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]>
%7 = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
// CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
%8 = constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
// CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
%9 = constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
// CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
%10 = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
return
}