[mlir] Support partial folding of affine.min/max

Originally, these operations were folded only if all expressions in their
affine maps could be folded to a constant expression that can be then subject
to numeric min/max computation. This introduces a more advanced version that
partially folds the affine map by lifting individual constant expression in it
even if some of the expressions remain variable. The folding can update the
operation in place to use a simpler map. Note that this is not as powerful as
canonicalization, in particular this does not remove dimensions or symbols that
became useless. This allows for better composition of Linalg tiling and
promotion transformation, where the latter can handle some canonical forms of
affine.min that the folding can now produce.

Differential Revision: https://reviews.llvm.org/D79502
This commit is contained in:
Alex Zinenko 2020-05-07 12:29:12 +02:00
parent 717bef6623
commit a87db48e6f
6 changed files with 104 additions and 69 deletions

View File

@ -144,6 +144,16 @@ public:
LogicalResult constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results) const;
/// Propagates the constant operands into this affine map. Operands are
/// allowed to be null, at which point they are treated as non-constant. This
/// does not change the number of symbols and dimensions. Returns a new map,
/// which may be equal to the old map if no folding happened. If `results` is
/// provided and if all expressions in the map were folded to constants,
/// `results` will contain the values of these constants.
AffineMap
partialConstantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<int64_t> *results = nullptr) const;
/// Returns the AffineMap resulting from composing `this` with `map`.
/// The resulting AffineMap has as many AffineDimExpr as `map` and as many
/// AffineSymbolExpr as the concatenation of `this` and `map` (in which case

View File

@ -2089,6 +2089,38 @@ static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
parser.addTypeToList(indexType, result.types));
}
/// Fold an affine min or max operation with the given operands. The operand
/// list may contain nulls, which are interpreted as the operand not being a
/// constant.
template <typename T>
OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
"expected affine min or max op");
// Fold the affine map.
// TODO(andydavis, ntv) Fold more cases:
// min(some_affine, some_affine + constant, ...), etc.
SmallVector<int64_t, 2> results;
auto foldedMap = op.map().partialConstantFold(operands, &results);
// If some of the map results are not constant, try changing the map in-place.
if (results.empty()) {
// If the map is the same, report that folding did not happen.
if (foldedMap == op.map())
return {};
op.setAttr("map", AffineMapAttr::get(foldedMap));
return op.getResult();
}
// Otherwise, completely fold the op into a constant.
auto resultIt = std::is_same<T, AffineMinOp>::value
? std::min_element(results.begin(), results.end())
: std::max_element(results.begin(), results.end());
if (resultIt == results.end())
return {};
return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
}
//===----------------------------------------------------------------------===//
// AffineMinOp
//===----------------------------------------------------------------------===//
@ -2097,26 +2129,7 @@ static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
//
OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
// Fold the affine map.
// TODO(andydavis, ntv) Fold more cases: partial static information,
// min(some_affine, some_affine + constant, ...).
SmallVector<Attribute, 2> results;
if (failed(map().constantFold(operands, results)))
return {};
// Compute and return min of folded map results.
int64_t min = std::numeric_limits<int64_t>::max();
int minIndex = -1;
for (unsigned i = 0, e = results.size(); i < e; ++i) {
auto intAttr = results[i].cast<IntegerAttr>();
if (intAttr.getInt() < min) {
min = intAttr.getInt();
minIndex = i;
}
}
if (minIndex < 0)
return {};
return results[minIndex];
return foldMinMaxOp(*this, operands);
}
void AffineMinOp::getCanonicalizationPatterns(
@ -2132,26 +2145,7 @@ void AffineMinOp::getCanonicalizationPatterns(
//
OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
// Fold the affine map.
// TODO(andydavis, ntv, ouhang) Fold more cases: partial static information,
// max(some_affine, some_affine + constant, ...).
SmallVector<Attribute, 2> results;
if (failed(map().constantFold(operands, results)))
return {};
// Compute and return max of folded map results.
int64_t max = std::numeric_limits<int64_t>::min();
int maxIndex = -1;
for (unsigned i = 0, e = results.size(); i < e; ++i) {
auto intAttr = results[i].cast<IntegerAttr>();
if (intAttr.getInt() > max) {
max = intAttr.getInt();
maxIndex = i;
}
}
if (maxIndex < 0)
return {};
return results[maxIndex];
return foldMinMaxOp(*this, operands);
}
void AffineMaxOp::getCanonicalizationPatterns(

View File

@ -234,22 +234,51 @@ AffineExpr AffineMap::getResult(unsigned idx) const {
LogicalResult
AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results) const {
// Attempt partial folding.
SmallVector<int64_t, 2> integers;
partialConstantFold(operandConstants, &integers);
// If all expressions folded to a constant, populate results with attributes
// containing those constants.
if (integers.empty())
return failure();
auto range = llvm::map_range(integers, [this](int64_t i) {
return IntegerAttr::get(IndexType::get(getContext()), i);
});
results.append(range.begin(), range.end());
return success();
}
AffineMap
AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<int64_t> *results) const {
assert(getNumInputs() == operandConstants.size());
// Fold each of the result expressions.
AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
// Constant fold each AffineExpr in AffineMap and add to 'results'.
SmallVector<AffineExpr, 4> exprs;
exprs.reserve(getNumResults());
for (auto expr : getResults()) {
auto folded = exprFolder.constantFold(expr);
// If we didn't fold to a constant, then folding fails.
if (!folded)
return failure();
results.push_back(folded);
// If did not fold to a constant, keep the original expression, and clear
// the integer results vector.
if (folded) {
exprs.push_back(
getAffineConstantExpr(folded.getInt(), folded.getContext()));
if (results)
results->push_back(folded.getInt());
} else {
exprs.push_back(expr);
if (results) {
results->clear();
results = nullptr;
}
}
}
assert(results.size() == getNumResults() &&
"constant folding produced the wrong number of results");
return success();
return get(getNumDims(), getNumSymbols(), exprs, getContext());
}
/// Walk all of the AffineExpr's in this mapping. Each node in an expression

View File

@ -13,10 +13,12 @@
// TILE-002-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// TILE-234-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// TILE-234-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)>
// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)>
// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)>
// TILE-234-DAG: #[[bound_map_2:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)>
// TILE-234-DAG: #[[bound_map_3:.*]] = affine_map<(d0, d1, d2) -> (3, d1 - d2)>
// TILE-234-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)>
// TILE-2-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// TILE-02-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
@ -97,19 +99,19 @@ func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-234: loop.for %[[J:.*]] = %{{.*}}{{.*}} to %[[ubN]] step %{{.*}} {
// TILE-234: loop.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} {
// TILE-234: %[[localM:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[localK:.*]] = dim %{{.*}}, 1
// TILE-234: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[localK]], %[[K]])
// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]])
// TILE-234: %[[sAik:.*]] = subview %{{.*}}[%[[I]], %[[K]]] [%[[szM]], %[[szK]]] [%[[C1]], %[[C1]]] : memref<?x?xf32, #[[strided2D]]> to memref<?x?xf32, #[[strided2D_dynamic]]>
// TILE-234: %[[localK:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[localK]], %[[K]])
// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]])
// TILE-234: %[[localN:.*]] = dim %{{.*}}, 1
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%[[szK]], %[[szN]]] [%[[C1]], %[[C1]]] : memref<?x?xf32, #[[strided2D]]> to memref<?x?xf32, #[[strided2D_dynamic]]>
// TILE-234: %[[localM:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[localN:.*]] = dim %{{.*}}, 1
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref<?x?xf32, #[[strided2D]]> to memref<?x?xf32, #[[strided2D_dynamic]]>
//
// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
@ -230,15 +232,15 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
// TILE-234: loop.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} {
// TILE-234: loop.for %[[J:.*]] = %{{.*}}{{.*}} to %[[K]] step %{{.*}} {
// TILE-234: %[[localM:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[localN:.*]] = dim %{{.*}}, 1
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[sAij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref<?x?xf32, #[[strided2D]]> to memref<?x?xf32, #[[strided2D_dynamic]]>
// TILE-234: %[[localN:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]])
// TILE-234: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [%[[C1]]] : memref<?xf32, #[[strided1D]]> to memref<?xf32, #[[strided1D_dynamic]]>
// TILE-234: %[[localM:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref<?xf32, #[[strided1D]]> to memref<?xf32, #[[strided1D_dynamic]]>
//
// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?xf32, #[[strided1D_dynamic]]>, memref<?xf32, #[[strided1D_dynamic]]>
@ -274,10 +276,10 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
// TILE-234: %[[ubK:.*]] = dim %{{.*}}, 0 : memref<?xf32, #[[strided1D]]>
// TILE-234: loop.for %[[I:.*]] = %{{.*}} to %[[ubK]] step %{{.*}} {
// TILE-234: %[[localM:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[sAi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref<?xf32, #[[strided1D]]> to memref<?xf32, #[[strided1D_dynamic]]>
// TILE-234: %[[localM:.*]] = dim %{{.*}}, 0
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]])
// TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref<?xf32, #[[strided1D]]> to memref<?xf32, #[[strided1D_dynamic]]>
// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref<?xf32, #[[strided1D_dynamic]]>, memref<?xf32, #[[strided1D_dynamic]]>, memref<f32>

View File

@ -4,7 +4,7 @@
// TILE-23004-DAG: #[[S0x10p90:.*]] = affine_map<()[s0] -> (s0 * 10 + 90)>
// TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
// TILE-23004-DAG: #[[strided4D_dynamic:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
// TILE-23004-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// TILE-23004-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)>
func @conv(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg2: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>) {
linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], strides = [30, 40]} : memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>
@ -27,7 +27,7 @@ func @conv(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1:
// TILE-23004: %[[Z0:.*]] = dim %{{.*}}, 0 : memref<?x?x?x?xf32, #[[strided4D]]>
// TILE-23004: %[[Z1:.*]] = dim %{{.*}}, 1 : memref<?x?x?x?xf32, #[[strided4D]]>
// TILE-23004: %[[Z2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32, #[[strided4D]]>
// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map]](%[[C4]], %[[Z2]], %[[ivK]])
// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[Z2]], %[[ivK]])
// TILE-23004: %[[K:.*]] = dim %{{.*}}, 3 : memref<?x?x?x?xf32, #[[strided4D]]>
// TILE-23004: %[[FilterView:.*]] = subview %{{.*}}[%[[C0]], %[[C0]], %[[ivK]], %[[C0]]] [%[[Z0]], %[[Z1]], %[[szK]], %[[K]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<?x?x?x?xf32, #[[strided4D]]> to memref<?x?x?x?xf32, #[[strided4D_dynamic]]>
//
@ -35,7 +35,7 @@ func @conv(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1:
// T__ILE-23004: %[[I1pStep:.*]] = affine.apply #[[S0x10p90]]()[%[[I1]]]
// TILE-23004: %[[SZ2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32, #[[strided4D]]>
// TILE-23004: %[[dim3:.*]] = dim %{{.*}}, 3
// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map]](%[[C4]], %[[dim3]], %[[ivK]]
// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[dim3]], %[[ivK]]
// TILE-23004: %[[InputView:.*]] = subview %{{.*}}[%[[ivI]], %[[J1]], %[[C0]], %[[ivK]]] [%{{.*}}, %{{.*}}, %[[SZ2]], %[[sz3]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<?x?x?x?xf32, #[[strided4D]]> to memref<?x?x?x?xf32, #[[strided4D_dynamic]]>
//
// TILE-23004: %[[X0:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32, #[[strided4D]]>

View File

@ -3,7 +3,7 @@
// TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
// TILE-20000-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)>
// TILE-20000-DAG: #[[subviewstride:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
func @conv_padding(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg2: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>) {