[mlir][sparse] Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor

Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor. The sparse tensor storage are represented as a tuple, these operation will later be eliminated and the tuple will be flattened after sparse tensor codegen

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133049
This commit is contained in:
Peiming Liu 2022-08-31 20:44:41 +00:00
parent f767f09252
commit 7ea643c06d
4 changed files with 165 additions and 0 deletions

View File

@ -623,4 +623,57 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// Sparse Tensor Storage Operation. These operations are used internally by
// sparse tensor codegen to progressively lower sparse tensors.
//===----------------------------------------------------------------------===//
def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
Arguments<(ins AnyTuple:$storage,
IndexAttr:$idx)>,
Results<(outs AnyType:$result)> {
let summary = "Get the data stored in the sparse tensor storage at the given index";
let description = [{
Get the data stored in the sparse tensor storage (represented as a tuple)
at the given index.
The result type should match the corresponding element type in the tuple.
Example:
```mlir
%0 = sparse_tensor.storage_get %arg0[0] : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
```
}];
let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>,
Arguments<(ins AnyTuple:$storage,
AnyType:$value,
IndexAttr:$idx)>,
Results<(outs AnyTuple:$result)> {
let summary = "Set the data stored in the sparse tensor storage at given index";
let description = [{
Set the data stored in the sparse tensor storage (represented as a tuple)
at the given index. Return a new SSA value with the corresponding element
updated (others remain unchanged).
The result type should match the original tuple type with only the updated
element type changed accordingly.
Example:
```mlir
%0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, f64>
```
}];
let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)";
let hasVerifier = 1;
}
#endif // SPARSETENSOR_OPS

View File

@ -482,6 +482,48 @@ LogicalResult YieldOp::verify() {
"expected parent op to be sparse_tensor unary, binary, or reduce");
}
//===----------------------------------------------------------------------===//
// Sparse Tensor Storage Operation.
//===----------------------------------------------------------------------===//
LogicalResult StorageGetOp::verify() {
uint64_t extractIdx = getIdx().getZExtValue();
auto innerTypeArray = getStorage().getType().getTypes();
if (extractIdx >= innerTypeArray.size())
return emitError(llvm::formatv(
"Out-of-bound access with index={0} on tuple with length={1}",
extractIdx, innerTypeArray.size()));
auto expectedTy = getStorage().getType().getType(extractIdx);
auto returnTy = getResult().getType();
if (expectedTy != returnTy)
return emitError(llvm::formatv(
"Type mismatch between the returning type (type={0}) and the "
"corresponding element type at index {1} (type={2})",
expectedTy, extractIdx, returnTy));
return success();
}
LogicalResult StorageSetOp::verify() {
uint64_t setIdx = getIdx().getZExtValue();
SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
if (setIdx >= expectedElemTy.size())
return emitError(llvm::formatv(
"Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
expectedElemTy.size()));
// Updates the element type after storage_set.
expectedElemTy[setIdx] = getValue().getType();
auto expectedTy = TupleType::get(getContext(), expectedElemTy);
auto returnTy = getResult().getType();
if (expectedTy != returnTy)
return emitError(
llvm::formatv("Type mismatch between the returning type "
"(type={0}) and the expected type (type={1})",
returnTy, expectedTy));
return success();
}
//===----------------------------------------------------------------------===//
// TensorDialect Methods.
//===----------------------------------------------------------------------===//

View File

@ -443,3 +443,42 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
return %0 : tensor<9x4xf64, #DC>
}
// -----
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
// expected-error@+1{{Out-of-bound access}}
%0 = sparse_tensor.storage_get %arg0[3]
: tuple<memref<?xf64>, memref<?xf64>, f64> to
memref<?xf64>
return %0 : memref<?xf64>
}
// -----
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
// expected-error@+1{{Type mismatch}}
%0 = sparse_tensor.storage_get %arg0[2]
: tuple<memref<?xf64>, memref<?xf64>, f64> to
memref<?xf64>
return %0 : memref<?xf64>
}
// -----
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
// expected-error@+1{{Out-of-bound access}}
%0 = sparse_tensor.storage_set %arg0[3], %arg1
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
tuple<memref<?xf64>, memref<?xf64>, f64>
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
}
// -----
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
// expected-error@+1{{Type mismatch}}
%0 = sparse_tensor.storage_set %arg0[2], %arg1
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
tuple<memref<?xf64>, memref<?xf64>, f64>
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
}

View File

@ -314,3 +314,34 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
return %0 : tensor<9x4xf64, #SparseMatrix>
}
// -----
// CHECK-LABEL: func @sparse_storage_get(
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>
// CHECK-SAME: to memref<?xf64>
// CHECK: return %[[TMP0]] : memref<?xf64>
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
%0 = sparse_tensor.storage_get %arg0[0]
: tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
return %0 : memref<?xf64>
}
// ----
// CHECK-LABEL: func @sparse_storage_set(
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,
// CHECK-SAME: %[[A1:.*]]: memref<?xf64>
// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] :
// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>,
// CHECK-SAME: memref<?xf64>
// CHECK-SAME: to tuple<memref<?xf64>, memref<?xf64>, f64>
// CHECK: return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
%0 = sparse_tensor.storage_set %arg0[0], %arg1
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
tuple<memref<?xf64>, memref<?xf64>, f64>
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
}