forked from OSchip/llvm-project
[mlir][tosa] Moves constant folding operations out of the Canonicalizer
Transpose operations on constant data were getting folded during the canonicalization process. This has compile time cost proportional to the constant size. Moving this to a separate pass to enable optionality and flexibility of how such scenarios can be handled. Reviewed By: rsuderman, jpienaar, stellaraccident Differential Revision: https://reviews.llvm.org/D124685
This commit is contained in:
parent
a392a39f75
commit
3bcaf2eb93
|
@ -34,6 +34,17 @@ namespace tosa {
|
|||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
/// Appends the canonicalization patterns for all the TOSA ops to the `patterns`
|
||||
void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns);
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
|
||||
|
||||
|
|
|
@ -26,7 +26,10 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
|
|||
RewritePatternSet &patterns);
|
||||
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns);
|
||||
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
|
||||
std::unique_ptr<Pass> createTosaInferShapesPass();
|
||||
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
|
||||
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
|
||||
|
|
|
@ -15,6 +15,15 @@
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> {
|
||||
let summary = "Fold layerwise operations on constant tensors";
|
||||
let description = [{
|
||||
Pass that enables folding of full-layer operations on constant tensors.
|
||||
}];
|
||||
|
||||
let constructor = "createTosaLayerwiseConstantFoldPass()";
|
||||
}
|
||||
|
||||
def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
|
||||
let summary = "Propagate shapes across TOSA operations";
|
||||
let description = [{
|
||||
|
|
|
@ -76,6 +76,8 @@ void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm,
|
|||
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
|
||||
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// TODO: Remove pass that operates on const tensor and enable optionality
|
||||
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
|
||||
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
|
||||
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
|
||||
}
|
||||
|
|
|
@ -94,6 +94,20 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
|||
// Operator Canonicalizers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename... Args>
|
||||
void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...};
|
||||
}
|
||||
|
||||
void mlir::tosa::populateTosaOpsCanonicalizationPatterns(
|
||||
MLIRContext *ctx, RewritePatternSet &patterns) {
|
||||
addOpsCanonicalizations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
|
||||
>(ctx, patterns);
|
||||
}
|
||||
|
||||
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
|
||||
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
|
||||
|
||||
|
@ -189,70 +203,6 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
|
|||
return success();
|
||||
}
|
||||
|
||||
struct ConstantTransposeOptimization
|
||||
: public OpRewritePattern<tosa::TransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto outputType = op.getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
// TOSA supports quantized types.
|
||||
if (!outputType.getElementType().isIntOrIndexOrFloat())
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr inputValues;
|
||||
if (!matchPattern(op.input1(), m_Constant(&inputValues)))
|
||||
return failure();
|
||||
// Make sure the input is a constant that has a single user.
|
||||
if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
|
||||
return failure();
|
||||
|
||||
DenseIntElementsAttr permAttr;
|
||||
if (!matchPattern(op.perms(), m_Constant(&permAttr)))
|
||||
return failure();
|
||||
auto permValues = llvm::to_vector<6>(llvm::map_range(
|
||||
// TOSA allows both 32- and 64-bit integer tensors here.
|
||||
permAttr.getValues<APInt>(),
|
||||
[](const APInt &val) { return val.getZExtValue(); }));
|
||||
|
||||
auto inputType = op.input1().getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t numElements = inputType.getNumElements();
|
||||
|
||||
SmallVector<Attribute, 4> outputValues;
|
||||
outputValues.resize(numElements);
|
||||
|
||||
// Transpose the input constant. Because we don't know its rank in advance,
|
||||
// we need to loop over the range [0, element count) and delinearize the
|
||||
// index.
|
||||
auto attrValues = inputValues.getValues<Attribute>();
|
||||
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
|
||||
++srcLinearIndex) {
|
||||
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
|
||||
int totalCount = srcLinearIndex;
|
||||
for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
|
||||
srcIndices[dim] = totalCount % inputShape[dim];
|
||||
totalCount /= inputShape[dim];
|
||||
}
|
||||
|
||||
SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
|
||||
for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
|
||||
dstIndices[dim] = srcIndices[permValues[dim]];
|
||||
|
||||
uint64_t dstLinearIndex = dstIndices.front();
|
||||
for (int dim = 1; dim < outputType.getRank(); ++dim)
|
||||
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
||||
|
||||
outputValues[dstLinearIndex] = attrValues[srcIndices];
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
|
||||
op, outputType, DenseElementsAttr::get(outputType, outputValues));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
|
@ -282,7 +232,6 @@ struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
|
|||
|
||||
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ConstantTransposeOptimization>(context);
|
||||
results.add<NoOpOptimization>(context);
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,9 @@ add_mlir_dialect_library(MLIRTosaTransforms
|
|||
TosaDecomposeTransposeConv.cpp
|
||||
TosaDecomposeConv2D.cpp
|
||||
TosaDecomposeDepthwise.cpp
|
||||
TosaFoldConstantTranspose.cpp
|
||||
TosaInferShapes.cpp
|
||||
TosaLayerwiseConstantFoldPass.cpp
|
||||
TosaMakeBroadcastable.cpp
|
||||
TosaOptionalDecompositions.cpp
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- TosaDecomposeConv2D.cpp ------------------------------------------===//
|
||||
//===- TosaDecomposeConv2D.cpp --------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
//===- TosaDecomposeDepthwise.cpp
|
||||
//------------------------------------------===//
|
||||
//===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
//===- TosaDecomposeTransposeConv.cpp
|
||||
//------------------------------------------===//
|
||||
//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
//===- TosaFoldConstantTranspose.cpp --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Fold TOSA Transpose operation on constant data
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tosa;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto outputType = op.getType().cast<ShapedType>();
|
||||
// TOSA supports quantized types.
|
||||
if (!outputType.getElementType().isIntOrIndexOrFloat())
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr inputValues;
|
||||
if (!matchPattern(op.input1(), m_Constant(&inputValues)))
|
||||
return failure();
|
||||
// Make sure the input is a constant that has a single user.
|
||||
if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
|
||||
return failure();
|
||||
|
||||
DenseIntElementsAttr permAttr;
|
||||
if (!matchPattern(op.perms(), m_Constant(&permAttr)))
|
||||
return failure();
|
||||
auto permValues = llvm::to_vector<6>(llvm::map_range(
|
||||
// TOSA allows both 32- and 64-bit integer tensors here.
|
||||
permAttr.getValues<APInt>(),
|
||||
[](const APInt &val) { return val.getZExtValue(); }));
|
||||
|
||||
auto inputType = op.input1().getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t numElements = inputType.getNumElements();
|
||||
|
||||
SmallVector<Attribute, 4> outputValues;
|
||||
outputValues.resize(numElements);
|
||||
|
||||
// Transpose the input constant. Because we don't know its rank in advance,
|
||||
// we need to loop over the range [0, element count) and delinearize the
|
||||
// index.
|
||||
auto attrValues = inputValues.getValues<Attribute>();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
|
||||
++srcLinearIndex) {
|
||||
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
|
||||
int totalCount = srcLinearIndex;
|
||||
for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
|
||||
srcIndices[dim] = totalCount % inputShape[dim];
|
||||
totalCount /= inputShape[dim];
|
||||
}
|
||||
|
||||
SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
|
||||
for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
|
||||
dstIndices[dim] = srcIndices[permValues[dim]];
|
||||
|
||||
uint64_t dstLinearIndex = dstIndices.front();
|
||||
for (int dim = 1; dim < outputType.getRank(); ++dim)
|
||||
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
||||
|
||||
outputValues[dstLinearIndex] = attrValues[srcIndices];
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
|
||||
op, outputType, DenseElementsAttr::get(outputType, outputValues));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
|
||||
MLIRContext *ctx, RewritePatternSet &patterns) {
|
||||
patterns.add<TosaFoldConstantTranspose>(ctx);
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//===- TosaInferShapes.cpp ------------------------------------------===//
|
||||
//===- TosaInferShapes.cpp ------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements constant folding transformations on TOSA operations
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tosa;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TosaLayerwiseConstantFoldPass
|
||||
: public TosaLayerwiseConstantFoldPassBase<TosaLayerwiseConstantFoldPass> {
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
auto func = getOperation();
|
||||
|
||||
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
|
||||
mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
|
||||
return std::make_unique<TosaLayerwiseConstantFoldPass>();
|
||||
}
|
|
@ -1,5 +1,4 @@
|
|||
//===- TosaOptionalDecompositions.cpp
|
||||
//------------------------------------------===//
|
||||
//===- TosaOptionalDecompositions.cpp -------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
|
|
@ -391,104 +391,6 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
|
|||
return %0 : tensor<3x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_fold
|
||||
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK: return %arg0
|
||||
%0 = arith.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
|
||||
return %1 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold
|
||||
func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
|
||||
// CHECK: "tosa.transpose"
|
||||
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
|
||||
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %1 : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_shape
|
||||
func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK: "tosa.transpose"
|
||||
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
|
||||
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_splat
|
||||
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_2d_float
|
||||
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_4d_int
|
||||
func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
|
||||
%input = "tosa.const"() {value = dense<[[
|
||||
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
|
||||
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
|
||||
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<[
|
||||
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
|
||||
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
|
||||
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
|
||||
// CHECK-SAME{LITERAL}: ]>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x1x4x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_non_cst_input
|
||||
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_non_cst_perms
|
||||
func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_multi_users
|
||||
func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_quantized_types
|
||||
func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
|
||||
%perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
|
||||
// CHECK: tosa.transpose
|
||||
%0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
|
||||
return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_no_op
|
||||
func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
|
||||
// CHECK: return %arg0
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @transpose_fold
|
||||
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK: return %arg0
|
||||
%0 = arith.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
|
||||
return %1 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold
|
||||
func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
|
||||
// CHECK: "tosa.transpose"
|
||||
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
|
||||
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %1 : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_shape
|
||||
func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK: "tosa.transpose"
|
||||
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
|
||||
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_splat
|
||||
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_2d_float
|
||||
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_4d_int
|
||||
func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
|
||||
%input = "tosa.const"() {value = dense<[[
|
||||
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
|
||||
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
|
||||
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<[
|
||||
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
|
||||
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
|
||||
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
|
||||
// CHECK-SAME{LITERAL}: ]>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x1x4x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_non_cst_input
|
||||
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_non_cst_perms
|
||||
func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_multi_users
|
||||
func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_quantized_types
|
||||
func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
|
||||
%perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
|
||||
// CHECK: tosa.transpose
|
||||
%0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
|
||||
return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
|
||||
}
|
Loading…
Reference in New Issue