forked from OSchip/llvm-project
[mlir][sparse] Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor
Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor. The sparse tensor storage are represented as a tuple, these operation will later be eliminated and the tuple will be flattened after sparse tensor codegen Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133049
This commit is contained in:
parent
f767f09252
commit
7ea643c06d
|
@ -623,4 +623,57 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sparse Tensor Storage Operation. These operations are used internally by
|
||||
// sparse tensor codegen to progressively lower sparse tensors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
|
||||
Arguments<(ins AnyTuple:$storage,
|
||||
IndexAttr:$idx)>,
|
||||
Results<(outs AnyType:$result)> {
|
||||
let summary = "Get the data stored in the sparse tensor storage at the given index";
|
||||
let description = [{
|
||||
Get the data stored in the sparse tensor storage (represented as a tuple)
|
||||
at the given index.
|
||||
|
||||
The result type should match the corresponding element type in the tuple.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = sparse_tensor.storage_get %arg0[0] : tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
|
||||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>,
|
||||
Arguments<(ins AnyTuple:$storage,
|
||||
AnyType:$value,
|
||||
IndexAttr:$idx)>,
|
||||
Results<(outs AnyTuple:$result)> {
|
||||
let summary = "Set the data stored in the sparse tensor storage at given index";
|
||||
let description = [{
|
||||
Set the data stored in the sparse tensor storage (represented as a tuple)
|
||||
at the given index. Return a new SSA value with the corresponding element
|
||||
updated (others remain unchanged).
|
||||
|
||||
The result type should match the original tuple type with only the updated
|
||||
element type changed accordingly.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
#endif // SPARSETENSOR_OPS
|
||||
|
|
|
@ -482,6 +482,48 @@ LogicalResult YieldOp::verify() {
|
|||
"expected parent op to be sparse_tensor unary, binary, or reduce");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sparse Tensor Storage Operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult StorageGetOp::verify() {
|
||||
uint64_t extractIdx = getIdx().getZExtValue();
|
||||
auto innerTypeArray = getStorage().getType().getTypes();
|
||||
if (extractIdx >= innerTypeArray.size())
|
||||
return emitError(llvm::formatv(
|
||||
"Out-of-bound access with index={0} on tuple with length={1}",
|
||||
extractIdx, innerTypeArray.size()));
|
||||
|
||||
auto expectedTy = getStorage().getType().getType(extractIdx);
|
||||
auto returnTy = getResult().getType();
|
||||
if (expectedTy != returnTy)
|
||||
return emitError(llvm::formatv(
|
||||
"Type mismatch between the returning type (type={0}) and the "
|
||||
"corresponding element type at index {1} (type={2})",
|
||||
expectedTy, extractIdx, returnTy));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult StorageSetOp::verify() {
|
||||
uint64_t setIdx = getIdx().getZExtValue();
|
||||
SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
|
||||
if (setIdx >= expectedElemTy.size())
|
||||
return emitError(llvm::formatv(
|
||||
"Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
|
||||
expectedElemTy.size()));
|
||||
|
||||
// Updates the element type after storage_set.
|
||||
expectedElemTy[setIdx] = getValue().getType();
|
||||
auto expectedTy = TupleType::get(getContext(), expectedElemTy);
|
||||
auto returnTy = getResult().getType();
|
||||
if (expectedTy != returnTy)
|
||||
return emitError(
|
||||
llvm::formatv("Type mismatch between the returning type "
|
||||
"(type={0}) and the expected type (type={1})",
|
||||
returnTy, expectedTy));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorDialect Methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -443,3 +443,42 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
|
|||
return %0 : tensor<9x4xf64, #DC>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
|
||||
// expected-error@+1{{Out-of-bound access}}
|
||||
%0 = sparse_tensor.storage_get %arg0[3]
|
||||
: tuple<memref<?xf64>, memref<?xf64>, f64> to
|
||||
memref<?xf64>
|
||||
return %0 : memref<?xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
|
||||
// expected-error@+1{{Type mismatch}}
|
||||
%0 = sparse_tensor.storage_get %arg0[2]
|
||||
: tuple<memref<?xf64>, memref<?xf64>, f64> to
|
||||
memref<?xf64>
|
||||
return %0 : memref<?xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
|
||||
// expected-error@+1{{Out-of-bound access}}
|
||||
%0 = sparse_tensor.storage_set %arg0[3], %arg1
|
||||
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
|
||||
tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
|
||||
// expected-error@+1{{Type mismatch}}
|
||||
%0 = sparse_tensor.storage_set %arg0[2], %arg1
|
||||
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
|
||||
tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
}
|
||||
|
|
|
@ -314,3 +314,34 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
|
|||
tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
|
||||
return %0 : tensor<9x4xf64, #SparseMatrix>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sparse_storage_get(
|
||||
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
|
||||
// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
// CHECK-SAME: to memref<?xf64>
|
||||
// CHECK: return %[[TMP0]] : memref<?xf64>
|
||||
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
|
||||
%0 = sparse_tensor.storage_get %arg0[0]
|
||||
: tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
|
||||
return %0 : memref<?xf64>
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func @sparse_storage_set(
|
||||
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,
|
||||
// CHECK-SAME: %[[A1:.*]]: memref<?xf64>
|
||||
// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] :
|
||||
// CHECK-SAME: tuple<memref<?xf64>, memref<?xf64>, f64>,
|
||||
// CHECK-SAME: memref<?xf64>
|
||||
// CHECK-SAME: to tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
// CHECK: return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>, %arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
|
||||
%0 = sparse_tensor.storage_set %arg0[0], %arg1
|
||||
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
|
||||
tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue