forked from OSchip/llvm-project
[tosa][mlir] Refactor tosa.reshape lowering to linalg for dynamic cases.
Split tosa.reshape into three individual lowerings: collapse, expand and a combination of both. Add simple dynamic shape support. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D113936
This commit is contained in:
parent
833cdb0a07
commit
381677dfbf
|
@ -946,6 +946,112 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
|
||||||
|
ArrayRef<int64_t> rhsShape,
|
||||||
|
SmallVector<int64_t> &intermediateShape,
|
||||||
|
bool isDynamic) {
|
||||||
|
if (isDynamic) {
|
||||||
|
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
|
||||||
|
intermediateShape = {-1};
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhsShape.empty() || rhsShape.empty()) {
|
||||||
|
intermediateShape = {};
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned currLhsDim = 0, currRhsDim = 0;
|
||||||
|
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
|
||||||
|
int64_t rhsSize = rhsShape[currRhsDim];
|
||||||
|
int64_t lhsSize = lhsShape[currLhsDim];
|
||||||
|
while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
|
||||||
|
currRhsDim < rhsShape.size()) {
|
||||||
|
if (lhsSize < rhsSize) {
|
||||||
|
currLhsDim++;
|
||||||
|
lhsSize *= lhsShape[currLhsDim];
|
||||||
|
} else {
|
||||||
|
currRhsDim++;
|
||||||
|
rhsSize *= rhsShape[currRhsDim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (lhsSize == rhsSize) {
|
||||||
|
intermediateShape.push_back(lhsSize);
|
||||||
|
}
|
||||||
|
currRhsDim++;
|
||||||
|
currLhsDim++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the iterators didn't reach the end and their leftover dimensions are not
|
||||||
|
// equal to 1 an intermediate shape was not found.
|
||||||
|
while (currLhsDim < lhsShape.size()) {
|
||||||
|
if (lhsShape[currLhsDim++] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (currRhsDim < rhsShape.size()) {
|
||||||
|
if (rhsShape[currRhsDim++] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool createReassociationMapsForCollapse(
|
||||||
|
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
|
||||||
|
ArrayRef<int64_t> dstShape,
|
||||||
|
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
|
||||||
|
|
||||||
|
// If the shape is dynamic, create a map for collapsing into one dimension.
|
||||||
|
if (isDynamic) {
|
||||||
|
SmallVector<AffineExpr, 2> exprs;
|
||||||
|
for (int i = 0, s = srcShape.size(); i < s; ++i)
|
||||||
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
reassociationMap = {exprs};
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dstShape.empty()) {
|
||||||
|
reassociationMap = {};
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
reassociationMap.resize(dstShape.size());
|
||||||
|
unsigned currSrcDim = 0, currDstDim = 0;
|
||||||
|
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||||
|
int64_t dstSize = dstShape[currDstDim];
|
||||||
|
int64_t srcSize = srcShape[currSrcDim];
|
||||||
|
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
|
||||||
|
reassociationMap[currDstDim].push_back(
|
||||||
|
rewriter.getAffineDimExpr(currSrcDim++));
|
||||||
|
srcSize *= srcShape[currSrcDim];
|
||||||
|
}
|
||||||
|
if (srcSize == dstSize) {
|
||||||
|
reassociationMap[currDstDim].push_back(
|
||||||
|
rewriter.getAffineDimExpr(currSrcDim++));
|
||||||
|
// If the next dim in collapsedShape is not 1, treat subsequent dims in
|
||||||
|
// expandedShape which are 1 to be collapsed.
|
||||||
|
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
|
||||||
|
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
|
||||||
|
reassociationMap[currDstDim].push_back(
|
||||||
|
rewriter.getAffineDimExpr(currSrcDim++));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
currDstDim++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If both iterators didn't reach the end, we have leftover dimentions which
|
||||||
|
// implies that we have a mismatch in shape.
|
||||||
|
if (currSrcDim != srcShape.size() || currDstDim != dstShape.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename SrcOp>
|
template <typename SrcOp>
|
||||||
|
@ -1534,7 +1640,7 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
@ -1543,103 +1649,116 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||||
|
bool isDynamic = !operandTy.hasStaticShape();
|
||||||
|
|
||||||
|
if (isDynamic && resultTy.getRank() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
reshape, "Cannot collapse dynamic dims to more than one dimension");
|
||||||
|
}
|
||||||
|
|
||||||
if (operandTy == resultTy) {
|
if (operandTy == resultTy) {
|
||||||
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
|
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
|
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||||
return failure();
|
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
|
||||||
|
resultTy.getShape(),
|
||||||
// Compute the reassociation maps for the linalg operation.
|
reassociationMap, isDynamic)) {
|
||||||
ArrayRef<int64_t> expandedShape =
|
return rewriter.notifyMatchFailure(
|
||||||
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
|
reshape,
|
||||||
: resultTy.getShape());
|
"tosa.reshape Attempting to collapse into an incompatible shape");
|
||||||
ArrayRef<int64_t> collapsedShape =
|
|
||||||
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
|
|
||||||
: operandTy.getShape());
|
|
||||||
unsigned currSrcDim = 0, currDstDim = 0;
|
|
||||||
SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
|
|
||||||
|
|
||||||
// First scan all dimensions in the source shapes to see whether we have a
|
|
||||||
// perfect case where consecutive dimensions in source are collapsed. For
|
|
||||||
// such case we can just generate one single linalg.reshape.
|
|
||||||
bool isCollapsingSource = true;
|
|
||||||
while (currSrcDim < expandedShape.size() &&
|
|
||||||
currDstDim < collapsedShape.size()) {
|
|
||||||
int64_t dstSize = collapsedShape[currDstDim];
|
|
||||||
int64_t srcSize = expandedShape[currSrcDim];
|
|
||||||
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
|
|
||||||
reassociationMap[currDstDim].push_back(
|
|
||||||
rewriter.getAffineDimExpr(currSrcDim++));
|
|
||||||
srcSize *= expandedShape[currSrcDim];
|
|
||||||
}
|
|
||||||
if (srcSize == dstSize) {
|
|
||||||
reassociationMap[currDstDim].push_back(
|
|
||||||
rewriter.getAffineDimExpr(currSrcDim++));
|
|
||||||
// If the next dim in collapsedShape is not 1, treat subsequent dims in
|
|
||||||
// expandedShape which are 1 to be collapsed.
|
|
||||||
if (currDstDim == collapsedShape.size() - 1 ||
|
|
||||||
collapsedShape[currDstDim + 1] != 1) {
|
|
||||||
while (currSrcDim < expandedShape.size() &&
|
|
||||||
expandedShape[currSrcDim] == 1) {
|
|
||||||
reassociationMap[currDstDim].push_back(
|
|
||||||
rewriter.getAffineDimExpr(currSrcDim++));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
isCollapsingSource = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
currDstDim++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if any remaining dimensions exist. If either is rank-0 we only
|
SmallVector<int64_t> intermediateShape;
|
||||||
// require the directly lowering.
|
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||||
if (currSrcDim != expandedShape.size() ||
|
intermediateShape, isDynamic)) {
|
||||||
currDstDim != collapsedShape.size())
|
return rewriter.notifyMatchFailure(
|
||||||
isCollapsingSource = collapsedShape.empty() || expandedShape.empty();
|
reshape, "tosa.reshape Cannot collapse into given shape");
|
||||||
|
}
|
||||||
|
|
||||||
// Otherwise, we need to first reduce all source dimensions into one and
|
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||||
// then expand to the destination dimensions.
|
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||||
if (!isCollapsingSource) {
|
return success();
|
||||||
auto getIdentityExprs = [&rewriter](int n) {
|
}
|
||||||
SmallVector<AffineExpr, 4> exprs;
|
};
|
||||||
for (int i = 0; i < n; ++i)
|
|
||||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
|
||||||
return exprs;
|
|
||||||
};
|
|
||||||
Location loc = reshape.getLoc();
|
|
||||||
int64_t totalElems =
|
|
||||||
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
|
|
||||||
std::multiplies<int64_t>());
|
|
||||||
auto elemTy = operandTy.getElementType();
|
|
||||||
SmallVector<ReassociationExprs, 4> collapsingMap = {
|
|
||||||
// Use operandTy here because we need to collapse all operands
|
|
||||||
// dimensions.
|
|
||||||
getIdentityExprs(operandTy.getShape().size())};
|
|
||||||
SmallVector<ReassociationExprs, 4> expandingMap = {
|
|
||||||
// Use resultTy here because we need to expand to all result
|
|
||||||
// dimensions.
|
|
||||||
getIdentityExprs(resultTy.getShape().size())};
|
|
||||||
|
|
||||||
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
|
class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
|
||||||
Value collapsedOp = rewriter.create<linalg::TensorCollapseShapeOp>(
|
public:
|
||||||
loc, collapsedTy, adaptor.getOperands()[0], collapsingMap);
|
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
|
||||||
reshape, resultTy, collapsedOp, expandingMap);
|
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||||
|
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||||
|
bool isDynamic = !operandTy.hasStaticShape();
|
||||||
|
|
||||||
|
if (operandTy == resultTy) {
|
||||||
|
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (resultTy.getRank() <
|
if (isDynamic && operandTy.getRank() != 1) {
|
||||||
adaptor.getOperands()[0].getType().cast<ShapedType>().getRank())
|
return rewriter.notifyMatchFailure(
|
||||||
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
reshape, "Cannot expand dynamic dims from more than one dimension");
|
||||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
}
|
||||||
else
|
|
||||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
SmallVector<ReassociationExprs, 4> reassociationMap;
|
||||||
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
|
||||||
|
operandTy.getShape(),
|
||||||
|
reassociationMap, isDynamic)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
reshape,
|
||||||
|
"tosa.reshape Attempting to expand into an incompatible shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> intermediateShape;
|
||||||
|
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
|
||||||
|
intermediateShape, isDynamic) ||
|
||||||
|
intermediateShape != operandTy.getShape()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
reshape, "tosa.reshape Cannot expand into given shape");
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
||||||
|
reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReshapeConverterCollapseExpand
|
||||||
|
: public OpConversionPattern<tosa::ReshapeOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||||
|
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||||
|
bool isDynamic = !operandTy.hasStaticShape();
|
||||||
|
|
||||||
|
if (operandTy == resultTy) {
|
||||||
|
rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> intermediateShape;
|
||||||
|
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
|
||||||
|
intermediateShape, isDynamic)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
reshape, "tosa.reshape Cannot identify an intermediate shape between "
|
||||||
|
"the given two shapes");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value collapse = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
reshape.getLoc(),
|
||||||
|
RankedTensorType::get(intermediateShape,
|
||||||
|
reshape.getType().getElementType()),
|
||||||
|
adaptor.input1());
|
||||||
|
Value expand =
|
||||||
|
rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
|
||||||
|
rewriter.replaceOp(reshape, expand);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3072,7 +3191,9 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
|
||||||
TransposeConvConverter,
|
TransposeConvConverter,
|
||||||
GatherConverter,
|
GatherConverter,
|
||||||
PadConverter,
|
PadConverter,
|
||||||
ReshapeConverter,
|
ReshapeConverterCollapse,
|
||||||
|
ReshapeConverterExpand,
|
||||||
|
ReshapeConverterCollapseExpand,
|
||||||
RescaleConverter,
|
RescaleConverter,
|
||||||
ResizeConverter,
|
ResizeConverter,
|
||||||
ReverseConverter,
|
ReverseConverter,
|
||||||
|
|
|
@ -541,6 +541,16 @@ func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reshape_downrank_dyn
|
||||||
|
func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1]]
|
||||||
|
%0 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<2x?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK: return [[RESHAPE]]
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reshape_uprank
|
// CHECK-LABEL: @test_reshape_uprank
|
||||||
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
|
func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
|
||||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
|
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
|
||||||
|
@ -551,6 +561,16 @@ func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reshape_uprank_dyn
|
||||||
|
func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
|
||||||
|
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_expand_shape %arg0 {{\[}}[0, 1]]
|
||||||
|
%0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?xf32>) -> tensor<2x?xf32>
|
||||||
|
// CHECK: return [[RESHAPE]]
|
||||||
|
return %0 : tensor<2x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reshape_samerank
|
// CHECK-LABEL: @test_reshape_samerank
|
||||||
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
|
func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
|
||||||
|
@ -563,6 +583,18 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reshape_samerank_dyn
|
||||||
|
func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
|
||||||
|
// CHECK-NEXT: %[[RESHAPE1:.*]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: %[[RESHAPE2:.*]] = linalg.tensor_expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
|
||||||
|
%0 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<?x2xf32>) -> tensor<2x?xf32>
|
||||||
|
// CHECK-NEXT: return %[[RESHAPE2]]
|
||||||
|
return %0 : tensor<2x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reshape_downrank_6D
|
// CHECK-LABEL: @test_reshape_downrank_6D
|
||||||
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
|
func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
|
||||||
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
|
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
|
||||||
|
@ -572,6 +604,16 @@ func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reshape_downrank_6D_dyn
|
||||||
|
func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
|
||||||
|
// CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]]
|
||||||
|
// CHECK: linalg.tensor_expand_shape %0 {{\[}}[0, 1, 2]]
|
||||||
|
%0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
|
||||||
|
return %0 : tensor<?x5x77xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_identity
|
// CHECK-LABEL: @test_identity
|
||||||
func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
||||||
%0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
%0 = "tosa.identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
|
|
Loading…
Reference in New Issue