forked from OSchip/llvm-project
[mlir][sparse] codegen for trivial tensor cast
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D133176
This commit is contained in:
parent
1b726f0a4c
commit
f27b806df5
|
@ -33,16 +33,6 @@ namespace {
|
|||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Reorders stored dimension to original dimension.
|
||||
static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) {
|
||||
auto order = enc.getDimOrdering();
|
||||
if (order) {
|
||||
assert(order.isPermutation());
|
||||
return order.getDimPosition(i);
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
/// Reorders original dimension to stored dimension.
|
||||
static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
|
||||
auto order = enc.getDimOrdering();
|
||||
|
@ -67,7 +57,6 @@ static Optional<Type> convertSparseTensorType(Type type) {
|
|||
Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
|
||||
Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
|
||||
Type eltType = rType.getElementType();
|
||||
ArrayRef<int64_t> shape = rType.getShape();
|
||||
//
|
||||
// Sparse tensor storage for rank-dimensional tensor is organized as a
|
||||
// single compound type with the following fields:
|
||||
|
@ -85,27 +74,18 @@ static Optional<Type> convertSparseTensorType(Type type) {
|
|||
// memref<? x eltType> values ; values
|
||||
// };
|
||||
//
|
||||
int64_t linear = 1;
|
||||
bool allDense = true;
|
||||
unsigned rank = rType.getShape().size();
|
||||
SmallVector<Type, 8> fields;
|
||||
// The dimSizes array.
|
||||
fields.push_back(MemRefType::get({rank}, indexType));
|
||||
// Per-dimension storage.
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
// Get the original dimension (ro) for the current stored dimension (r).
|
||||
unsigned ro = toOrig(enc, r);
|
||||
// Dimension level types apply in order to the reordered dimension.
|
||||
// As a result, the compound type can be constructed directly in the given
|
||||
// order. Clients of this type know what field is what from the sparse
|
||||
// tensor type.
|
||||
switch (enc.getDimLevelType()[r]) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::Dense:
|
||||
// Linearize the size of consecutive dense dimensions.
|
||||
if (ShapedType::isDynamic(shape[ro]) || ShapedType::isDynamic(linear))
|
||||
linear = ShapedType::kDynamicSize;
|
||||
else
|
||||
linear *= shape[ro];
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Compressed:
|
||||
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
|
||||
|
@ -113,23 +93,17 @@ static Optional<Type> convertSparseTensorType(Type type) {
|
|||
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
allDense = false;
|
||||
linear = 1;
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Singleton:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
|
||||
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
|
||||
allDense = false;
|
||||
linear = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
int64_t nnz =
|
||||
(rType.hasStaticShape() && allDense) ? linear : ShapedType::kDynamicSize;
|
||||
fields.push_back(MemRefType::get({nnz}, eltType));
|
||||
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
|
||||
// Sparse tensor storage (temporarily) lives in a tuple. This allows a
|
||||
// simple 1:1 type conversion during codegen. A subsequent pass uses
|
||||
// a 1:N type conversion to expand the tuple into its fields.
|
||||
|
@ -241,6 +215,23 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for trivial tensor casts.
|
||||
class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only rewrite identically annotated source/dest.
|
||||
auto encDst = getSparseTensorEncoding(op.getType());
|
||||
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
|
||||
if (!encDst || encDst != encSrc)
|
||||
return failure();
|
||||
rewriter.replaceOp(op, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for pointer accesses.
|
||||
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
|
||||
public:
|
||||
|
@ -314,7 +305,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
|||
/// the sparsification of linear algebra operations.
|
||||
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseDimOpConverter,
|
||||
patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
|
||||
SparseToPointersConverter, SparseToIndicesConverter,
|
||||
SparseToValuesConverter>(typeConverter, patterns.getContext());
|
||||
}
|
||||
|
|
|
@ -36,12 +36,28 @@
|
|||
}>
|
||||
|
||||
// CHECK-LABEL: func @sparse_nop(
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
|
||||
// CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
|
||||
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
|
||||
return %arg0 : tensor<?xf64, #SparseVector>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_nop_cast(
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>)
|
||||
// CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>
|
||||
func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
|
||||
%0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
|
||||
return %0 : tensor<?xf32, #SparseVector>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_nop_cast_3d(
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf32>>)
|
||||
// CHECK: return %[[A]] : tuple<memref<3xindex>, memref<?xf32>>
|
||||
func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?x?x?xf32, #Dense3D> {
|
||||
%0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor<?x?x?xf32, #Dense3D>
|
||||
return %0 : tensor<?x?x?xf32, #Dense3D>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_dense_2d(
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xf64>>)
|
||||
func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
|
||||
|
@ -71,7 +87,7 @@ func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
|
|||
// fold using the original static dimension sizes.
|
||||
//
|
||||
// CHECK-LABEL: func @sparse_dense_3d(
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<6000xf64>>) -> index {
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
|
||||
// CHECK: %[[C:.*]] = arith.constant 20 : index
|
||||
// CHECK: return %[[C]] : index
|
||||
func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
|
||||
|
@ -86,7 +102,7 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
|
|||
// since the latter honors the dimOrdering.
|
||||
//
|
||||
// CHECK-LABEL: func @sparse_dense_3d_dyn(
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>) -> index {
|
||||
// CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
|
||||
// CHECK: %[[C:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<3xindex>, memref<?xf64>> to memref<3xindex>
|
||||
// CHECK: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>
|
||||
|
|
Loading…
Reference in New Issue