[mlir][sparse] Introduce new reduce op

A new sparse_tensor operation allows for
custom reduction code to be injected during
linalg.generic lowering for sparse tensors.
An identity value is provided to indicate
the starting value of the reduction. A single
block region is required to contain the
custom reduce computation.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D128004
This commit is contained in:
Jim Kitchen 2022-07-15 15:26:41 -05:00
parent 6ab686eb86
commit 2b8a4d9ce1
4 changed files with 170 additions and 4 deletions

View File

@ -544,6 +544,56 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [NoSideEffect]>,
let hasVerifier = 1; let hasVerifier = 1;
} }
def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [NoSideEffect, SameOperandsAndResultType]>,
Arguments<(ins AnyType:$x, AnyType:$y, AnyType:$identity)>,
Results<(outs AnyType:$output)> {
let summary = "Custom reduction operation utilized within linalg.generic";
let description = [{
Defines a computation with a `linalg.generic` operation that takes two
operands and an identity value and reduces all values down to a single
result based on the computation in the region.
The region must contain exactly one block taking two arguments. The block
must end with a sparse_tensor.yield and the output must match the input
argument types.
Note that this operation is only required for custom reductions beyond the
standard operations (add, mul, and, or, etc). The `linalg.generic`
`iterator_types` defines which indices are being reduced. When the associated
operands are used in an operation, a reduction will occur. The use of this
explicit `reduce` operation is not required in most cases.
Example of Matrix->Vector reduction using max(product(x_i), 100):
```mlir
%cf1 = arith.constant 1.0 : f64
%cf100 = arith.constant 100.0 : f64
%C = bufferization.alloc_tensor...
%0 = linalg.generic #trait
ins(%A: tensor<?x?xf64, #SparseMatrix>)
outs(%C: tensor<?xf64, #SparseVec>) {
^bb0(%a: f64, %c: f64) :
%result = sparse_tensor.reduce %c, %a, %cf1 : f64 {
^bb0(%arg0: f64, %arg1: f64):
%0 = arith.mulf %arg0, %arg1 : f64
%cmp = arith.cmpf "ogt", %0, %cf100 : f64
%ret = arith.select %cmp, %cf100, %0 : f64
sparse_tensor.yield %ret : f64
}
linalg.yield %result : f64
} -> tensor<?xf64, #SparseVec>
```
}];
let regions = (region SizedRegion<1>:$region);
let assemblyFormat = [{
$x `,` $y `,` $identity attr-dict `:` type($output) $region
}];
let hasVerifier = 1;
}
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
Arguments<(ins AnyType:$result)> { Arguments<(ins AnyType:$result)> {
let summary = "Yield from sparse_tensor set-like operations"; let summary = "Yield from sparse_tensor set-like operations";

View File

@ -357,15 +357,31 @@ LogicalResult UnaryOp::verify() {
return success(); return success();
} }
LogicalResult ReduceOp::verify() {
Type inputType = x().getType();
LogicalResult regionResult = success();
// Check correct number of block arguments and return type.
Region &formula = region();
if (!formula.empty()) {
regionResult = verifyNumBlockArgs(
this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
if (failed(regionResult))
return regionResult;
}
return success();
}
LogicalResult YieldOp::verify() { LogicalResult YieldOp::verify() {
// Check for compatible parent. // Check for compatible parent.
auto *parentOp = (*this)->getParentOp(); auto *parentOp = (*this)->getParentOp();
if (auto binaryOp = dyn_cast<BinaryOp>(parentOp)) if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
return success(); isa<ReduceOp>(parentOp))
if (auto unaryOp = dyn_cast<UnaryOp>(parentOp))
return success(); return success();
return emitOpError("expected parent op to be sparse_tensor binary or unary"); return emitOpError(
"expected parent op to be sparse_tensor unary, binary, or reduce");
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -253,6 +253,20 @@ func.func @invalid_binary_wrong_identity_type(%arg0: i64, %arg1: f64) -> f64 {
// ----- // -----
func.func @invalid_binary_wrong_yield(%arg0: f64, %arg1: f64) -> f64 {
// expected-error@+1 {{left region must end with sparse_tensor.yield}}
%0 = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64
overlap={}
left={
^bb0(%x: f64):
tensor.yield %x : f64
}
right=identity
return %0 : f64
}
// -----
func.func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 { func.func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 {
// expected-error@+1 {{present region argument 1 type mismatch}} // expected-error@+1 {{present region argument 1 type mismatch}}
%r = sparse_tensor.unary %arg0 : f64 to f64 %r = sparse_tensor.unary %arg0 : f64 to f64
@ -290,3 +304,67 @@ func.func @invalid_unary_wrong_return_type(%arg0: f64) -> f64 {
absent={} absent={}
return %0 : f64 return %0 : f64
} }
// -----
func.func @invalid_unary_wrong_yield(%arg0: f64) -> f64 {
// expected-error@+1 {{present region must end with sparse_tensor.yield}}
%0 = sparse_tensor.unary %arg0 : f64 to f64
present={
^bb0(%x: f64):
tensor.yield %x : f64
}
absent={}
return %0 : f64
}
// -----
func.func @invalid_reduce_num_args_mismatch(%arg0: f64, %arg1: f64) -> f64 {
%cf1 = arith.constant 1.0 : f64
// expected-error@+1 {{reduce region must have exactly 2 arguments}}
%r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 {
^bb0(%x: f64):
sparse_tensor.yield %x : f64
}
return %r : f64
}
// -----
func.func @invalid_reduce_block_arg_type_mismatch(%arg0: i64, %arg1: i64) -> i64 {
%ci1 = arith.constant 1 : i64
// expected-error@+1 {{reduce region argument 1 type mismatch}}
%r = sparse_tensor.reduce %arg0, %arg1, %ci1 : i64 {
^bb0(%x: f64, %y: f64):
%cst = arith.constant 2 : i64
sparse_tensor.yield %cst : i64
}
return %r : i64
}
// -----
func.func @invalid_reduce_return_type_mismatch(%arg0: f64, %arg1: f64) -> f64 {
%cf1 = arith.constant 1.0 : f64
// expected-error@+1 {{reduce region yield type mismatch}}
%r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 {
^bb0(%x: f64, %y: f64):
%cst = arith.constant 2 : i64
sparse_tensor.yield %cst : i64
}
return %r : f64
}
// -----
func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 {
%cf1 = arith.constant 1.0 : f64
// expected-error@+1 {{reduce region must end with sparse_tensor.yield}}
%r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 {
^bb0(%x: f64, %y: f64):
%cst = arith.constant 2 : i64
tensor.yield %cst : i64
}
return %r : f64
}

View File

@ -268,3 +268,25 @@ func.func @sparse_unary(%arg0: f64) -> i64 {
absent={} absent={}
return %r : i64 return %r : i64
} }
// -----
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
// CHECK-LABEL: func @sparse_reduce_2d_to_1d(
// CHECK-SAME: %[[A:.*]]: f64, %[[B:.*]]: f64) -> f64 {
// CHECK: %[[Z:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[C1:.*]] = sparse_tensor.reduce %[[A]], %[[B]], %[[Z]] : f64 {
// CHECK: ^bb0(%[[A1:.*]]: f64, %[[B1:.*]]: f64):
// CHECK: sparse_tensor.yield %[[A1]] : f64
// CHECK: }
// CHECK: return %[[C1]] : f64
// CHECK: }
func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 {
%cf0 = arith.constant 0.0 : f64
%r = sparse_tensor.reduce %arg0, %arg1, %cf0 : f64 {
^bb0(%x: f64, %y: f64):
sparse_tensor.yield %x : f64
}
return %r : f64
}