Adds affine.min operation which returns the minimum value from a multi-result affine map. This operation is useful for things like computing the dynamic value of affine loop bounds, and is trivial to constant fold.

PiperOrigin-RevId: 279959714
This commit is contained in:
Andy Davis 2019-11-12 07:08:23 -08:00 committed by A. Unique TensorFlower
parent f51a155337
commit 82d2c43eca
6 changed files with 193 additions and 1 deletions

View File

@ -560,6 +560,29 @@ Example:
```
#### 'affine.min' operation
Syntax:
``` {.ebnf}
operation ::= ssa-id `=` `affine.min` affine-map dim-and-symbol-use-list
```
The `affine.min` operation applies an
[affine mapping](#affine-expressions) to a list of SSA values, and returns the
minimum value of all result expressions. The number of dimension and symbol
arguments to affine.min must be equal to the respective number of dimensional
and symbolic inputs to the affine mapping; the `affine.min` operation always
returns one value. The input operands and result must all have 'index' type.
Example:
```mlir {.mlir}
%0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0) (%arg0)[%arg1]
```
#### `affine.terminator` operation
Syntax:

View File

@ -248,6 +248,24 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
let hasCanonicalizer = 1;
}
def AffineMinOp : Affine_Op<"min"> {
let summary = "min operation";
let description = [{
The "min" operation computes the minimum value result from a multi-result
affine map.
Example:
%0 = affine.min (d0) -> (1000, d0 + 512) (%i0) : index
}];
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
let results = (outs Index);
let extraClassDeclaration = [{
static StringRef getMapAttrName() { return "map"; }
}];
let hasFolder = 1;
}
def AffineTerminatorOp :
Affine_Op<"terminator", [Terminator]> {
let summary = "affine terminator operation";

View File

@ -1937,5 +1937,80 @@ void AffineStoreOp::getCanonicalizationPatterns(
results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
}
//===----------------------------------------------------------------------===//
// AffineMinOp
//===----------------------------------------------------------------------===//
//
// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
//
static ParseResult parseAffineMinOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexType = builder.getIndexType();
SmallVector<OpAsmParser::OperandType, 8> dim_infos;
SmallVector<OpAsmParser::OperandType, 8> sym_infos;
AffineMapAttr mapAttr;
return failure(
parser.parseAttribute(mapAttr, AffineMinOp::getMapAttrName(),
result.attributes) ||
parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
parser.parseOperandList(sym_infos,
OpAsmParser::Delimiter::OptionalSquare) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.resolveOperands(dim_infos, indexType, result.operands) ||
parser.resolveOperands(sym_infos, indexType, result.operands) ||
parser.addTypeToList(indexType, result.types));
}
static void print(OpAsmPrinter &p, AffineMinOp op) {
p << op.getOperationName() << ' '
<< op.getAttr(AffineMinOp::getMapAttrName());
auto begin = op.operand_begin();
auto end = op.operand_end();
unsigned numDims = op.map().getNumDims();
p << '(';
p.printOperands(begin, begin + numDims);
p << ')';
if (begin + numDims != end) {
p << '[';
p.printOperands(begin + numDims, end);
p << ']';
}
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
}
static LogicalResult verify(AffineMinOp op) {
// Verify that operand count matches affine map dimension and symbol count.
if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
return op.emitOpError(
"operand count and affine map dimension and symbol count must match");
return success();
}
OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
// Fold the affine map.
// TODO(andydavis, ntv) Fold more cases: partial static information,
// min(some_affine, some_affine + constant, ...).
SmallVector<Attribute, 2> results;
if (failed(map().constantFold(operands, results)))
return {};
// Compute and return min of folded map results.
int64_t min = std::numeric_limits<int64_t>::max();
int minIndex = -1;
for (unsigned i = 0, e = results.size(); i < e; ++i) {
auto intAttr = results[i].cast<IntegerAttr>();
if (intAttr.getInt() < min) {
min = intAttr.getInt();
minIndex = i;
}
}
if (minIndex < 0)
return {};
return results[minIndex];
}
#define GET_OP_CLASSES
#include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"

View File

@ -500,3 +500,29 @@ func @compose_into_affine_load_store(%A : memref<1024xf32>, %u : index) {
}
return
}
// -----
func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
%c511 = constant 511 : index
%c1 = constant 0 : index
%0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0 + 1) (%c1)[%c511]
"op0"(%0) : (index) -> ()
// CHECK: %[[CST:.*]] = constant 512 : index
// CHECK-NEXT: "op0"(%[[CST]]) : (index) -> ()
// CHECK-NEXT: return
return
}
// -----
func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
%c3 = constant 3 : index
%c20 = constant 20 : index
%0 = affine.min (d0)[s0] -> (1000, d0 floordiv 4, (s0 mod 5) + 1) (%c20)[%c3]
"op0"(%0) : (index) -> ()
// CHECK: %[[CST:.*]] = constant 4 : index
// CHECK-NEXT: "op0"(%[[CST]]) : (index) -> ()
// CHECK-NEXT: return
return
}

View File

@ -151,3 +151,33 @@ func @affine_store_missing_l_square(%C: memref<4096x4096xf32>) {
affine.store %9, %C : memref<4096x4096xf32>
return
}
// -----
// CHECK-LABEL: @affine_min
func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
// expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
%0 = affine.min (d0) -> (d0) (%arg0, %arg1)
return
}
// -----
// CHECK-LABEL: @affine_min
func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
// expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
%0 = affine.min ()[s0] -> (s0) (%arg0, %arg1)
return
}
// -----
// CHECK-LABEL: @affine_min
func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
// expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
%0 = affine.min (d0) -> (d0) ()
return
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s | FileCheck %s
// RUN: mlir-opt -split-input-file %s | FileCheck %s
// RUN: mlir-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
// Check that the attributes for the affine operations are round-tripped.
@ -58,3 +58,23 @@ func @affine_terminator() {
}
return
}
// -----
// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0)[s0] -> (1000, d0 + 512, s0)
// CHECK-DAG: #[[MAP1:map[0-9]+]] = (d0, d1)[s0] -> (d0 - d1, s0 + 512)
// CHECK-DAG: #[[MAP2:map[0-9]+]] = ()[s0, s1] -> (s0 - s1, 11)
// CHECK-DAG: #[[MAP3:map[0-9]+]] = () -> (77, 78, 79)
// CHECK-LABEL: @affine_min
func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: affine.min #[[MAP0]](%arg0)[%arg1]
%0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0) (%arg0)[%arg1]
// CHECK: affine.min #[[MAP1]](%arg0, %arg1)[%arg2]
%1 = affine.min (d0, d1)[s0] -> (d0 - d1, s0 + 512) (%arg0, %arg1)[%arg2]
// CHECK: affine.min #[[MAP2]]()[%arg1, %arg2]
%2 = affine.min ()[s0, s1] -> (s0 - s1, 11) ()[%arg1, %arg2]
// CHECK: affine.min #[[MAP3]]()
%3 = affine.min ()[] -> (77, 78, 79) ()[]
return
}