forked from OSchip/llvm-project
[tosa][mlir] Refactor tosa.reshape lowering to linalg for dynamic cases.
Split tosa.reshape into three individual lowerings: collapse, expand and a combination of both. Add simple dynamic shape support. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D113936
This commit is contained in:
parent
833cdb0a07
commit
381677dfbf
|
@ -946,6 +946,112 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
|
|||
return success();
|
||||
}
|
||||
|
||||
static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
|
||||
ArrayRef<int64_t> rhsShape,
|
||||
SmallVector<int64_t> &intermediateShape,
|
||||
bool isDynamic) {
|
||||
if (isDynamic) {
|
||||
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
|
||||
intermediateShape = {-1};
|
||||
return true;
|
||||
}
|
||||
|
||||
if (lhsShape.empty() || rhsShape.empty()) {
|
||||
intermediateShape = {};
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned currLhsDim = 0, currRhsDim = 0;
|
||||
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
|
||||
int64_t rhsSize = rhsShape[currRhsDim];
|
||||
int64_t lhsSize = lhsShape[currLhsDim];
|
||||
while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
|
||||
currRhsDim < rhsShape.size()) {
|
||||
if (lhsSize < rhsSize) {
|
||||
currLhsDim++;
|
||||
lhsSize *= lhsShape[currLhsDim];
|
||||
} else {
|
||||
currRhsDim++;
|
||||
rhsSize *= rhsShape[currRhsDim];
|
||||
}
|
||||
}
|
||||
if (lhsSize == rhsSize) {
|
||||
intermediateShape.push_back(lhsSize);
|
||||
}
|
||||
currRhsDim++;
|
||||
currLhsDim++;
|
||||
}
|
||||
|
||||
// If the iterators didn't reach the end and their leftover dimensions are not
|
||||
// equal to 1 an intermediate shape was not found.
|
||||
while (currLhsDim < lhsShape.size()) {
|
||||
if (lhsShape[currLhsDim++] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
while (currRhsDim < rhsShape.size()) {
|
||||
if (rhsShape[currRhsDim++] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool createReassociationMapsForCollapse(
|
||||
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
|
||||
ArrayRef<int64_t> dstShape,
|
||||
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
|
||||
|
||||
// If the shape is dynamic, create a map for collapsing into one dimension.
|
||||
if (isDynamic) {
|
||||
SmallVector<AffineExpr, 2> exprs;
|
||||
for (int i = 0, s = srcShape.size(); i < s; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
reassociationMap = {exprs};
|
||||
return true;
|
||||
}
|
||||
|
||||
if (dstShape.empty()) {
|
||||
reassociationMap = {};
|
||||
return true;
|
||||
}
|
||||
|
||||
reassociationMap.resize(dstShape.size());
|
||||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||
int64_t dstSize = dstShape[currDstDim];
|
||||
int64_t srcSize = srcShape[currSrcDim];
|
||||
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
srcSize *= srcShape[currSrcDim];
|
||||
}
|
||||
if (srcSize == dstSize) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
// If the next dim in collapsedShape is not 1, treat subsequent dims in
|
||||
// expandedShape which are 1 to be collapsed.
|
||||
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
|
||||
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
}
|
||||
}
|
||||
}
|
||||
currDstDim++;
|
||||
}
|
||||
|
||||
// If both iterators didn't reach the end, we have leftover dimentions which
|
||||
// implies that we have a mismatch in shape.
|
||||
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename SrcOp>
|
||||
|
@ -1534,7 +1640,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
|
||||
|
@ -1543,103 +1649,116 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
if (isDynamic && resultTy.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "Cannot collapse dynamic dims to more than one dimension");
|
||||
}
|
||||
|
||||
if (operandTy == resultTy) {
|
||||
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Compute the reassociation maps for the linalg operation.
|
||||
ArrayRef<int64_t> expandedShape =
|
||||
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
|
||||
: resultTy.getShape());
|
||||
ArrayRef<int64_t> collapsedShape =
|
||||
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
|
||||
: operandTy.getShape());
|
||||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
|
||||
|
||||
// First scan all dimensions in the source shapes to see whether we have a
|
||||
// perfect case where consecutive dimensions in source are collapsed. For
|
||||
// such case we can just generate one single linalg.reshape.
|
||||
bool isCollapsingSource = true;
|
||||
while (currSrcDim < expandedShape.size() &&
|
||||
currDstDim < collapsedShape.size()) {
|
||||
int64_t dstSize = collapsedShape[currDstDim];
|
||||
int64_t srcSize = expandedShape[currSrcDim];
|
||||
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
srcSize *= expandedShape[currSrcDim];
|
||||
}
|
||||
if (srcSize == dstSize) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
// If the next dim in collapsedShape is not 1, treat subsequent dims in
|
||||
// expandedShape which are 1 to be collapsed.
|
||||
if (currDstDim == collapsedShape.size() - 1 ||
|
||||
collapsedShape[currDstDim + 1] != 1) {
|
||||
while (currSrcDim < expandedShape.size() &&
|
||||
expandedShape[currSrcDim] == 1) {
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
isCollapsingSource = false;
|
||||
break;
|
||||
}
|
||||
currDstDim++;
|
||||
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
|
||||
resultTy.getShape(),
|
||||
reassociationMap, isDynamic)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape,
|
||||
"tosa.reshape Attempting to collapse into an incompatible shape");
|
||||
}
|
||||
|
||||
// Check if any remaining dimensions exist. If either is rank-0 we only
|
||||
// require the directly lowering.
|
||||
if (currSrcDim != expandedShape.size() ||
|
||||
currDstDim != collapsedShape.size())
|
||||
isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||
intermediateShape, isDynamic)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "tosa.reshape Cannot collapse into given shape");
|
||||
}
|
||||
|
||||
// Otherwise, we need to first reduce all source dimensions into one and
|
||||
// then expand to the destination dimensions.
|
||||
if (!isCollapsingSource) {
|
||||
auto getIdentityExprs = [&rewriter](int n) {
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
for (int i = 0; i < n; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
return exprs;
|
||||
};
|
||||
Location loc = reshape.getLoc();
|
||||
int64_t totalElems =
|
||||
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
auto elemTy = operandTy.getElementType();
|
||||
SmallVector<ReassociationExprs, 4> collapsingMap = {
|
||||
// Use operandTy here because we need to collapse all operands
|
||||
// dimensions.
|
||||
getIdentityExprs(operandTy.getShape().size())};
|
||||
SmallVector<ReassociationExprs, 4> expandingMap = {
|
||||
// Use resultTy here because we need to expand to all result
|
||||
// dimensions.
|
||||
getIdentityExprs(resultTy.getShape().size())};
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
|
||||
Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
|
||||
loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
||||
reshape, resultTy, collapsedOp, expandingMap);
|
||||
class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
if (operandTy == resultTy) {
|
||||
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (resultTy.getRank() <
|
||||
adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||
if (isDynamic && operandTy.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "Cannot expand dynamic dims from more than one dimension");
|
||||
}
|
||||
|
||||
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
|
||||
operandTy.getShape(),
|
||||
reassociationMap, isDynamic)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape,
|
||||
"tosa.reshape Attempting to expand into an incompatible shape");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||
intermediateShape, isDynamic) ||
|
||||
intermediateShape != operandTy.getShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "tosa.reshape Cannot expand into given shape");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ReshapeConverterCollapseExpand
|
||||
: public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
if (operandTy == resultTy) {
|
||||
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
|
||||
intermediateShape, isDynamic)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
reshape, "tosa.reshape Cannot identify an intermediate shape between "
|
||||
"the given two shapes");
|
||||
}
|
||||
|
||||
Value collapse = rewriter.create<tosa::ReshapeOp>(
|
||||
reshape.getLoc(),
|
||||
RankedTensorType::get(intermediateShape,
|
||||
reshape.getType().getElementType()),
|
||||
adaptor.input1());
|
||||
Value expand =
|
||||
rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
|
||||
rewriter.replaceOp(reshape, expand);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -3072,7 +3191,9 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
|
|||
TransposeConvConverter,
|
||||
GatherConverter,
|
||||
PadConverter,
|
||||
ReshapeConverter,
|
||||
ReshapeConverterCollapse,
|
||||
ReshapeConverterExpand,
|
||||
ReshapeConverterCollapseExpand,
|
||||
RescaleConverter,
|
||||
ResizeConverter,
|
||||
ReverseConverter,
|
||||
|
|
|
@ -541,6 +541,16 @@ func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reshape_downrank_dyn
|
||||
func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor<?xf32>
|
||||
// CHECK: return [[RESHAPE]]
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reshape_uprank
|
||||
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
|
||||
|
@ -551,6 +561,16 @@ func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reshape_uprank_dyn
|
||||
func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?xf32>) -> tensor<2x?xf32>
|
||||
// CHECK: return [[RESHAPE]]
|
||||
return %0 : tensor<2x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reshape_samerank
|
||||
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
|
||||
|
@ -563,6 +583,18 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reshape_samerank_dyn
|
||||
func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
|
||||
// CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]]
|
||||
// CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?x2xf32>) -> tensor<2x?xf32>
|
||||
// CHECK-NEXT: return %[[RESHAPE2]]
|
||||
return %0 : tensor<2x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reshape_downrank_6D
|
||||
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
|
||||
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
|
||||
|
@ -572,6 +604,16 @@ func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_reshape_downrank_6D_dyn
|
||||
func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
|
||||
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]]
|
||||
// CHECK: linalg.tensor_expand_shape %0 {{\[}}[0, 1, 2]]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
|
||||
return %0 : tensor<?x5x77xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_identity
|
||||
func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
||||
%0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
|
|
Loading…
Reference in New Issue