[mlir][tosa] Moves constant folding operations out of the Canonicalizer

Transpose operations on constant data were getting folded during the
canonicalization process. This has compile time cost proportional to
the constant size. Moving this to a separate pass to enable optionality
and flexibility of how such scenarios can be handled.

Reviewed By: rsuderman, jpienaar, stellaraccident

Differential Revision: https://reviews.llvm.org/D124685
This commit is contained in:
Georgios Pinitas 2022-06-06 22:10:08 +00:00 committed by Robert Suderman
parent a392a39f75
commit 3bcaf2eb93
15 changed files with 279 additions and 171 deletions

View File

@ -34,6 +34,17 @@ namespace tosa {
} // namespace tosa
} // namespace mlir
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
namespace mlir {
namespace tosa {
/// Appends the canonicalization patterns for all the TOSA ops to the `patterns`
void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
} // namespace tosa
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"

View File

@ -26,7 +26,10 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();

View File

@ -15,6 +15,15 @@
include "mlir/Pass/PassBase.td"
def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> {
let summary = "Fold layerwise operations on constant tensors";
let description = [{
Pass that enables folding of full-layer operations on constant tensors.
}];
let constructor = "createTosaLayerwiseConstantFoldPass()";
}
def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
let summary = "Propagate shapes across TOSA operations";
let description = [{

View File

@ -76,6 +76,8 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO: Remove pass that operates on const tensor and enable optionality
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
}

View File

@ -94,6 +94,20 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//
template <typename... Args>
void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
(void)std::initializer_list<int>{
0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...};
}
void mlir::tosa::populateTosaOpsCanonicalizationPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
addOpsCanonicalizations<
#define GET_OP_LIST
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
>(ctx, patterns);
}
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
@ -189,70 +203,6 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
return success();
}
struct ConstantTransposeOptimization
: public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
auto outputType = op.getType().cast<ShapedType>();
ArrayRef<int64_t> outputShape = outputType.getShape();
// TOSA supports quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();
DenseElementsAttr inputValues;
if (!matchPattern(op.input1(), m_Constant(&inputValues)))
return failure();
// Make sure the input is a constant that has a single user.
if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
return failure();
DenseIntElementsAttr permAttr;
if (!matchPattern(op.perms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::to_vector<6>(llvm::map_range(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getZExtValue(); }));
auto inputType = op.input1().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t numElements = inputType.getNumElements();
SmallVector<Attribute, 4> outputValues;
outputValues.resize(numElements);
// Transpose the input constant. Because we don't know its rank in advance,
// we need to loop over the range [0, element count) and delinearize the
// index.
auto attrValues = inputValues.getValues<Attribute>();
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
++srcLinearIndex) {
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
int totalCount = srcLinearIndex;
for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
srcIndices[dim] = totalCount % inputShape[dim];
totalCount /= inputShape[dim];
}
SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
dstIndices[dim] = srcIndices[permValues[dim]];
uint64_t dstLinearIndex = dstIndices.front();
for (int dim = 1; dim < outputType.getRank(); ++dim)
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
outputValues[dstLinearIndex] = attrValues[srcIndices];
}
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
op, outputType, DenseElementsAttr::get(outputType, outputValues));
return success();
}
};
struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
@ -282,7 +232,6 @@ struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConstantTransposeOptimization>(context);
results.add<NoOpOptimization>(context);
}

View File

@ -2,7 +2,9 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFoldConstantTranspose.cpp
TosaInferShapes.cpp
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
TosaOptionalDecompositions.cpp

View File

@ -1,4 +1,4 @@
//===- TosaDecomposeConv2D.cpp ------------------------------------------===//
//===- TosaDecomposeConv2D.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -1,5 +1,4 @@
//===- TosaDecomposeDepthwise.cpp
//------------------------------------------===//
//===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -1,5 +1,4 @@
//===- TosaDecomposeTransposeConv.cpp
//------------------------------------------===//
//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -0,0 +1,91 @@
//===- TosaFoldConstantTranspose.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA Transpose operation on constant data
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::tosa;
namespace {
struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
auto outputType = op.getType().cast<ShapedType>();
// TOSA supports quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();
DenseElementsAttr inputValues;
if (!matchPattern(op.input1(), m_Constant(&inputValues)))
return failure();
// Make sure the input is a constant that has a single user.
if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
return failure();
DenseIntElementsAttr permAttr;
if (!matchPattern(op.perms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::to_vector<6>(llvm::map_range(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getZExtValue(); }));
auto inputType = op.input1().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t numElements = inputType.getNumElements();
SmallVector<Attribute, 4> outputValues;
outputValues.resize(numElements);
// Transpose the input constant. Because we don't know its rank in advance,
// we need to loop over the range [0, element count) and delinearize the
// index.
auto attrValues = inputValues.getValues<Attribute>();
ArrayRef<int64_t> outputShape = outputType.getShape();
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
++srcLinearIndex) {
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
int totalCount = srcLinearIndex;
for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
srcIndices[dim] = totalCount % inputShape[dim];
totalCount /= inputShape[dim];
}
SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
dstIndices[dim] = srcIndices[permValues[dim]];
uint64_t dstLinearIndex = dstIndices.front();
for (int dim = 1; dim < outputType.getRank(); ++dim)
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
outputValues[dstLinearIndex] = attrValues[srcIndices];
}
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
op, outputType, DenseElementsAttr::get(outputType, outputValues));
return success();
}
};
} // namespace
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantTranspose>(ctx);
}

View File

@ -1,4 +1,4 @@
//===- TosaInferShapes.cpp ------------------------------------------===//
//===- TosaInferShapes.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -0,0 +1,43 @@
//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements constant folding transformations on TOSA operations
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::tosa;
namespace {
struct TosaLayerwiseConstantFoldPass
: public TosaLayerwiseConstantFoldPassBase<TosaLayerwiseConstantFoldPass> {
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
auto func = getOperation();
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns);
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
return std::make_unique<TosaLayerwiseConstantFoldPass>();
}

View File

@ -1,5 +1,4 @@
//===- TosaOptionalDecompositions.cpp
//------------------------------------------===//
//===- TosaOptionalDecompositions.cpp -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@ -391,104 +391,6 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
return %0 : tensor<3x8xf32>
}
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
%0 = arith.constant dense<[0, 1]> : tensor<2xi32>
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}
// CHECK-LABEL: @transpose_nofold
func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
// CHECK: "tosa.transpose"
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %1 : tensor<3x3xf32>
}
// CHECK-LABEL: @transpose_nofold_shape
func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
// CHECK: "tosa.transpose"
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @transpose_fold_splat
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_fold_2d_float
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_fold_4d_int
func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
%input = "tosa.const"() {value = dense<[[
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}
// CHECK-LABEL: @transpose_nofold_non_cst_input
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_nofold_non_cst_perms
func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_nofold_multi_users
func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
// CHECK-LABEL: @transpose_nofold_quantized_types
func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
%perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
%input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
// CHECK: tosa.transpose
%0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
}
// CHECK-LABEL: @transpose_no_op
func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
// CHECK: return %arg0

View File

@ -0,0 +1,99 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
%0 = arith.constant dense<[0, 1]> : tensor<2xi32>
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}
// CHECK-LABEL: @transpose_nofold
func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
// CHECK: "tosa.transpose"
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %1 : tensor<3x3xf32>
}
// CHECK-LABEL: @transpose_nofold_shape
func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
// CHECK: "tosa.transpose"
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @transpose_fold_splat
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_fold_2d_float
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_fold_4d_int
func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
%input = "tosa.const"() {value = dense<[[
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}
// CHECK-LABEL: @transpose_nofold_non_cst_input
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_nofold_non_cst_perms
func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// CHECK-LABEL: @transpose_nofold_multi_users
func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
// CHECK-LABEL: @transpose_nofold_quantized_types
func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
%perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
%input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
// CHECK: tosa.transpose
%0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
}