[tosa][mlir] Refactor tosa.reshape lowering to linalg for dynamic cases.

Split tosa.reshape into three individual lowerings: collapse, expand and a
combination of both. Add simple dynamic shape support.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D113936
This commit is contained in:
natashaknk 2021-11-15 15:10:36 -08:00 committed by Rob Suderman
parent 833cdb0a07
commit 381677dfbf
2 changed files with 248 additions and 85 deletions

View File

@ -946,6 +946,112 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
return success();
}
static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
ArrayRef<int64_t> rhsShape,
SmallVector<int64_t> &intermediateShape,
bool isDynamic) {
if (isDynamic) {
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
intermediateShape = {-1};
return true;
}
if (lhsShape.empty() || rhsShape.empty()) {
intermediateShape = {};
return true;
}
unsigned currLhsDim = 0, currRhsDim = 0;
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
int64_t rhsSize = rhsShape[currRhsDim];
int64_t lhsSize = lhsShape[currLhsDim];
while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
currRhsDim < rhsShape.size()) {
if (lhsSize < rhsSize) {
currLhsDim++;
lhsSize *= lhsShape[currLhsDim];
} else {
currRhsDim++;
rhsSize *= rhsShape[currRhsDim];
}
}
if (lhsSize == rhsSize) {
intermediateShape.push_back(lhsSize);
}
currRhsDim++;
currLhsDim++;
}
// If the iterators didn't reach the end and their leftover dimensions are not
// equal to 1 an intermediate shape was not found.
while (currLhsDim < lhsShape.size()) {
if (lhsShape[currLhsDim++] != 1) {
return false;
}
}
while (currRhsDim < rhsShape.size()) {
if (rhsShape[currRhsDim++] != 1) {
return false;
}
}
return true;
}
static bool createReassociationMapsForCollapse(
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> dstShape,
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
// If the shape is dynamic, create a map for collapsing into one dimension.
if (isDynamic) {
SmallVector<AffineExpr, 2> exprs;
for (int i = 0, s = srcShape.size(); i < s; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
reassociationMap = {exprs};
return true;
}
if (dstShape.empty()) {
reassociationMap = {};
return true;
}
reassociationMap.resize(dstShape.size());
unsigned currSrcDim = 0, currDstDim = 0;
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
}
currDstDim++;
}
// If both iterators didn't reach the end, we have leftover dimentions which
// implies that we have a mismatch in shape.
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) {
return false;
}
return true;
}
namespace {
template <typename SrcOp>
@ -1534,7 +1640,7 @@ public:
}
};
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
@ -1543,103 +1649,116 @@ public:
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && resultTy.getRank() != 1) {
return rewriter.notifyMatchFailure(
reshape, "Cannot collapse dynamic dims to more than one dimension");
}
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
}
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> expandedShape =
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
: resultTy.getShape());
ArrayRef<int64_t> collapsedShape =
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
: operandTy.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
// First scan all dimensions in the source shapes to see whether we have a
// perfect case where consecutive dimensions in source are collapsed. For
// such case we can just generate one single linalg.reshape.
bool isCollapsingSource = true;
while (currSrcDim < expandedShape.size() &&
currDstDim < collapsedShape.size()) {
int64_t dstSize = collapsedShape[currDstDim];
int64_t srcSize = expandedShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= expandedShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == collapsedShape.size() - 1 ||
collapsedShape[currDstDim + 1] != 1) {
while (currSrcDim < expandedShape.size() &&
expandedShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
} else {
isCollapsingSource = false;
break;
}
currDstDim++;
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
resultTy.getShape(),
reassociationMap, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape,
"tosa.reshape Attempting to collapse into an incompatible shape");
}
// Check if any remaining dimensions exist. If either is rank-0 we only
// require the directly lowering.
if (currSrcDim != expandedShape.size() ||
currDstDim != collapsedShape.size())
isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot collapse into given shape");
}
// Otherwise, we need to first reduce all source dimensions into one and
// then expand to the destination dimensions.
if (!isCollapsingSource) {
auto getIdentityExprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs;
};
Location loc = reshape.getLoc();
int64_t totalElems =
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
std::multiplies<int64_t>());
auto elemTy = operandTy.getElementType();
SmallVector<ReassociationExprs, 4> collapsingMap = {
// Use operandTy here because we need to collapse all operands
// dimensions.
getIdentityExprs(operandTy.getShape().size())};
SmallVector<ReassociationExprs, 4> expandingMap = {
// Use resultTy here because we need to expand to all result
// dimensions.
getIdentityExprs(resultTy.getShape().size())};
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
return success();
}
};
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape, resultTy, collapsedOp, expandingMap);
class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
bool isDynamic = !operandTy.hasStaticShape();
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
}
if (resultTy.getRank() <
adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
else
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
if (isDynamic && operandTy.getRank() != 1) {
return rewriter.notifyMatchFailure(
reshape, "Cannot expand dynamic dims from more than one dimension");
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
operandTy.getShape(),
reassociationMap, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape,
"tosa.reshape Attempting to expand into an incompatible shape");
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic) ||
intermediateShape != operandTy.getShape()) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot expand into given shape");
}
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
return success();
}
};
class ReshapeConverterCollapseExpand
: public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
bool isDynamic = !operandTy.hasStaticShape();
if (operandTy == resultTy) {
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
return success();
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
intermediateShape, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot identify an intermediate shape between "
"the given two shapes");
}
Value collapse = rewriter.create<tosa::ReshapeOp>(
reshape.getLoc(),
RankedTensorType::get(intermediateShape,
reshape.getType().getElementType()),
adaptor.input1());
Value expand =
rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
rewriter.replaceOp(reshape, expand);
return success();
}
@ -3072,7 +3191,9 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
TransposeConvConverter,
GatherConverter,
PadConverter,
ReshapeConverter,
ReshapeConverterCollapse,
ReshapeConverterExpand,
ReshapeConverterCollapseExpand,
RescaleConverter,
ResizeConverter,
ReverseConverter,

View File

@ -541,6 +541,16 @@ func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
// -----
// CHECK-LABEL: @test_reshape_downrank_dyn
func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
%0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor<?xf32>
// CHECK: return [[RESHAPE]]
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @test_reshape_uprank
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
@ -551,6 +561,16 @@ func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
// -----
// CHECK-LABEL: @test_reshape_uprank_dyn
func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?xf32>) -> tensor<2x?xf32>
// CHECK: return [[RESHAPE]]
return %0 : tensor<2x?xf32>
}
// -----
// CHECK-LABEL: @test_reshape_samerank
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
@ -563,6 +583,18 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
// -----
// CHECK-LABEL: @test_reshape_samerank_dyn
func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
// CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
%0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?x2xf32>) -> tensor<2x?xf32>
// CHECK-NEXT: return %[[RESHAPE2]]
return %0 : tensor<2x?xf32>
}
// -----
// CHECK-LABEL: @test_reshape_downrank_6D
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
@ -572,6 +604,16 @@ func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77
// -----
// CHECK-LABEL: @test_reshape_downrank_6D_dyn
func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]]
// CHECK: linalg.tensor_expand_shape %0 {{\[}}[0, 1, 2]]
%0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
return %0 : tensor<?x5x77xf32>
}
// -----
// CHECK-LABEL: @test_identity
func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) {
%0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>