Moved getStaticLoopRanges and getStaticShape methods to LinalgInterfaces.td to add static shape verification

It is to use the methods in LinalgInterfaces.cpp for additional static shape verification to match the shaped operands and loop on linalgOps. If I used the existing methods, I would face circular dependency linking issue. Now we can use them as methods of LinalgOp.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D98163
This commit is contained in:
Inho Seo 2021-03-10 03:56:14 -08:00 committed by Hanhan Wang
parent a94ac467c2
commit 2ce4caf414
4 changed files with 39 additions and 29 deletions

View File

@ -1101,6 +1101,44 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
getNumInputs() + resultIdx, dim);
}]
>,
InterfaceMethod<
/*desc=*/[{
Like `getShape`, but only returns statically-known information, without
generating any new IR. For each shape dimension, returns >=0 if that
dimension is statically known, or ShapeType::kDynamicSize otherwise.
}],
/*retTy=*/"SmallVector<int64_t, 8>",
/*methodName=*/"getStaticShape",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<int64_t, 8> res;
for (Value v : getShapedOperands()) {
auto shape = v.getType().cast<ShapedType>().getShape();
res.append(shape.begin(), shape.end());
}
return res;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns the statically-known loop ranges. Composes
`getShapesToLoopsMap()` with the result of `getStaticShape`.
Returns None if `getShapesToLoopsMap()` fails. Returns
ShapeType::kDynamicSize for non-statically-known loop ranges.
}],
/*retTy=*/"Optional<SmallVector<int64_t, 4>>",
/*methodName=*/"getStaticLoopRanges",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<int64_t, 8> viewSizes = getStaticShape();
AffineMap invertedMap = getShapesToLoopsMap();
if (!invertedMap)
return {};
return invertedMap.compose(viewSizes);
}]
>,
//===------------------------------------------------------------------===//
// Other static interface methods.

View File

@ -118,17 +118,6 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
OpOperand &consumerOpOperand);
/// Like `getShape`, but only returns statically-known information, without
/// generating any new IR. For each shape dimension, returns >=0 if that
/// dimension is statically known, or -1 otherwise.
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp);
/// Returns the statically-known loop ranges of the `linalgOp`. Composes
/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`.
/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1
/// for non-statically-known loop ranges.
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector

View File

@ -499,7 +499,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
Optional<SmallVector<int64_t, 4>> originalLoopRange =
getStaticLoopRanges(linalgOp);
linalgOp.getStaticLoopRanges();
if (!originalLoopRange)
return linalgOp.emitError("unable to find loop range for operation");

View File

@ -98,23 +98,6 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
namespace mlir {
namespace linalg {
SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
SmallVector<int64_t, 8> res;
for (Value v : linalgOp.getShapedOperands()) {
auto shape = v.getType().cast<ShapedType>().getShape();
res.append(shape.begin(), shape.end());
}
return res;
}
Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
AffineMap invertedMap = linalgOp.getShapesToLoopsMap();
if (!invertedMap)
return {};
return invertedMap.compose(viewSizes);
}
/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
/// is a constant then return a new value set to the smallest such constant.
/// Otherwise returngetSmallestBoundingIndex nullptr.