[mlir][tosa] Allow optional TOSA decompositions to be populated separately

Moved all TOSA decomposition patterns so that they can be optionally populated
and used by external rewrites. This avoids decomposing TOSa operations when
backends may benefit from the non-decomposed version.

Reviewed By: rsuderman, mehdi_amini

Differential Revision: https://reviews.llvm.org/D116526
This commit is contained in:
Aaron DeBattista 2022-01-11 10:16:01 -08:00 committed by Rob Suderman
parent 0a8d15ad55
commit dfd070820c
12 changed files with 384 additions and 356 deletions

View File

@ -19,11 +19,18 @@
namespace mlir {
namespace tosa {
std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
// The rewrites can be selectively added to a conversion pass.
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaOptimizationPass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
std::unique_ptr<Pass> createTosaOptionalDecompositions();
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

View File

@ -15,21 +15,6 @@
include "mlir/Pass/PassBase.td"
def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> {
let summary = "Deompose transpose convolutiions into standard convolutions.";
let description = [{
Pass that uses shape manipulation and convolution operations to transform
a transpose convolution into a regular convolution.
}];
let constructor = "createTosaDecomposeTransposeConvPass()";
let dependentDialects = [
"StandardOpsDialect",
"tensor::TensorDialect",
"tosa::TosaDialect",
];
}
def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
let summary = "Propagate shapes across TOSA operations";
let description = [{
@ -58,13 +43,14 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
let constructor = "createTosaMakeBroadcastablePass()";
}
def TosaOptimization : FunctionPass<"tosa-optimization"> {
let summary = "TOSA operation optimizations";
def TosaOptionalDecompositions : FunctionPass<"tosa-optional-decompositions"> {
let summary = "Applies Tosa operations optional decompositions";
let description = [{
"Pass to perform optimizations on TOSA operations"
Pass to apply the Tosa operations decompositions
exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h
}];
let constructor = "createTosaOptimizationPass()";
let constructor = "tosa::createTosaOptionalDecompositions()";
}
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

View File

@ -68,6 +68,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
}
void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(mlir::tosa::createTosaOptionalDecompositions());
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
pm.addNestedPass<FuncOp>(createTosaToLinalgNamed());
pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());

View File

@ -1,8 +1,10 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaInferShapes.cpp
TosaMakeBroadcastable.cpp
TosaOptimization.cpp
TosaOptionalDecompositions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms

View File

@ -0,0 +1,115 @@
//===- TosaDecomposeConv2D.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::tosa;
namespace {
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(tosa::Conv2DOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
Value weight = op.weight();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.getType().cast<ShapedType>();
if (!inputType.hasStaticShape() || !weightType.hasRank()) {
return failure();
}
// Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
return failure();
}
}
// Only works for a 1x1 kernel.
ArrayRef<int64_t> weightShape = weightType.getShape();
if (weightShape[1] != 1 || weightShape[2] != 1) {
return failure();
}
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
auto revisedInputShapeType = RankedTensorType::get(
revisedInputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
rewriter.getI64ArrayAttr(revisedInputShape))
.getResult();
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
rewriter.getI64ArrayAttr(revisedWeightShape))
.getResult();
// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
auto fullyConnectedShapeType = RankedTensorType::get(
fullyConnectedShape,
resultType.dyn_cast<ShapedType>().getElementType());
Value fullyConnectedValue;
if (op.quantization_info()) {
fullyConnectedValue =
rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType, reshapedInput,
reshapedWeight, op.bias(), op.quantization_info().getValue())
.getResult();
} else {
fullyConnectedValue = rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType,
reshapedInput, reshapedWeight, op.bias())
.getResult();
}
// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultType, fullyConnectedValue,
rewriter.getI64ArrayAttr(outputShape));
return success();
}
};
} // namespace
void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
RewritePatternSet &patterns) {
patterns.insert<Conv2DIsFullyConnected>(ctx);
}

View File

@ -0,0 +1,121 @@
//===- TosaDecomposeDepthwise.cpp
//------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::tosa;
namespace {
struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
explicit DepthwiseConv2DIsMul(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
Value weight = op.weight();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.output().getType().cast<ShapedType>();
Type inputEType = inputType.getElementType();
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
resultType.hasStaticShape())) {
return failure();
}
// Quantization information needs to still be performed.
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
return failure();
}
// Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
return failure();
}
}
// Only works for a 1x1 kernel.
ArrayRef<int64_t> weightShape = weightType.getShape();
if (weightShape[0] != 1 || weightShape[1] != 1) {
return failure();
}
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
auto revisedInputShapeType = RankedTensorType::get(
revisedInputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
rewriter.getI64ArrayAttr(revisedInputShape))
.getResult();
// Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
rewriter.getI64ArrayAttr(revisedWeightShape))
.getResult();
// Perform an elementwise mul over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
inputShape[2], inputShape[3],
weightShape[3]};
auto mulShapeType = RankedTensorType::get(
mulShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
Value mulValue =
rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
reshapedWeight, /*shift=*/0)
.getResult();
// Reshape output to [N, H, W, C * M].
auto outputShape = op.output().getType().cast<ShapedType>().getShape();
auto outputShapeType = RankedTensorType::get(
outputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto outputValue =
rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
rewriter.getI64ArrayAttr(outputShape));
// Add in the bias.
rewriter
.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
op.bias())
.getResult();
return success();
}
};
} // namespace
void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns) {
patterns.insert<DepthwiseConv2DIsMul>(ctx);
}

View File

@ -7,17 +7,19 @@
//
//===----------------------------------------------------------------------===//
//
// Insert reshape to binary op's input if needed to match rank
// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
// including transposing/reversing/reshaping etc..
// of the weights and input/output tenors and reversing/reshaping etc .. of
// the weights
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR//TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::tosa;
@ -369,22 +371,10 @@ public:
}
};
/// Pass that enables broadcast by making all input arrays have the same
/// number of dimensions. Insert RESHAPE operations to lower rank operand
struct TosaDecomposeTransposeConv
: public TosaDecomposeTransposeConvBase<TosaDecomposeTransposeConv> {
public:
void runOnFunction() override {
auto func = getFunction();
RewritePatternSet patterns(func.getContext());
patterns
.insert<TransposeConvDilatedConverter, TransposeConvStridedConverter>(
func.getContext());
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};
} // namespace
std::unique_ptr<Pass> mlir::tosa::createTosaDecomposeTransposeConvPass() {
return std::make_unique<TosaDecomposeTransposeConv>();
void mlir::tosa::populateTosaDecomposeTransposeConv(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.insert<TransposeConvDilatedConverter>(ctx);
patterns.insert<TransposeConvStridedConverter>(ctx);
}

View File

@ -1,243 +0,0 @@
//===- TosaOptimization.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pass to perform optimizations on TOSA operations
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace mlir::tosa;
#define PASS_NAME "tosa-optimization"
#define DEBUG_TYPE PASS_NAME
namespace {
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(tosa::Conv2DOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
Value weight = op.weight();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.getType().cast<ShapedType>();
if (!inputType.hasStaticShape() || !weightType.hasRank()) {
return failure();
}
// Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
return failure();
}
}
// Only works for a 1x1 kernel.
ArrayRef<int64_t> weightShape = weightType.getShape();
if (weightShape[1] != 1 || weightShape[2] != 1) {
return failure();
}
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
auto revisedInputShapeType = RankedTensorType::get(
revisedInputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
rewriter.getI64ArrayAttr(revisedInputShape))
.getResult();
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
rewriter.getI64ArrayAttr(revisedWeightShape))
.getResult();
// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
auto fullyConnectedShapeType = RankedTensorType::get(
fullyConnectedShape,
resultType.dyn_cast<ShapedType>().getElementType());
Value fullyConnectedValue;
if (op.quantization_info()) {
fullyConnectedValue =
rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType, reshapedInput,
reshapedWeight, op.bias(), op.quantization_info().getValue())
.getResult();
} else {
fullyConnectedValue = rewriter
.create<tosa::FullyConnectedOp>(
op.getLoc(), fullyConnectedShapeType,
reshapedInput, reshapedWeight, op.bias())
.getResult();
}
// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultType, fullyConnectedValue,
rewriter.getI64ArrayAttr(outputShape));
return success();
}
};
struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
explicit DepthwiseConv2DIsMul(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
Value weight = op.weight();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.output().getType().cast<ShapedType>();
Type inputEType = inputType.getElementType();
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
resultType.hasStaticShape())) {
return failure();
}
// Quantization information needs to still be performed.
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
return failure();
}
// Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
return failure();
}
}
// Only works for a 1x1 kernel.
ArrayRef<int64_t> weightShape = weightType.getShape();
if (weightShape[0] != 1 || weightShape[1] != 1) {
return failure();
}
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
auto revisedInputShapeType = RankedTensorType::get(
revisedInputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
rewriter.getI64ArrayAttr(revisedInputShape))
.getResult();
// Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
rewriter.getI64ArrayAttr(revisedWeightShape))
.getResult();
// Perform an elementwise mul over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
inputShape[2], inputShape[3],
weightShape[3]};
auto mulShapeType = RankedTensorType::get(
mulShape,
weight.getType().dyn_cast<RankedTensorType>().getElementType());
Value mulValue =
rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
reshapedWeight, /*shift=*/0)
.getResult();
// Reshape output to [N, H, W, C * M].
auto outputShape = op.output().getType().cast<ShapedType>().getShape();
auto outputShapeType = RankedTensorType::get(
outputShape,
input.getType().dyn_cast<RankedTensorType>().getElementType());
auto outputValue =
rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
rewriter.getI64ArrayAttr(outputShape));
// Add in the bias.
rewriter
.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
op.bias())
.getResult();
return success();
}
};
class TosaOptimization : public PassWrapper<TosaOptimization, FunctionPass> {
public:
explicit TosaOptimization() = default;
void runOnFunction() override;
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "Applies TOSA Operation Optimizations";
}
};
void TosaOptimization::runOnFunction() {
OwningRewritePatternList patterns(&getContext());
patterns.insert<Conv2DIsFullyConnected>(&getContext());
patterns.insert<DepthwiseConv2DIsMul>(&getContext());
auto func = getFunction();
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) {
signalPassFailure();
}
}
} // namespace
std::unique_ptr<Pass> mlir::tosa::createTosaOptimizationPass() {
return std::make_unique<TosaOptimization>();
}

View File

@ -0,0 +1,46 @@
//===- TosaOptionalDecompositions.cpp
//------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pass to apply the Tosa operations decompositions
// exposed as populate functions in
// include/mlir/Dialect/Tosa/Transforms/Passes.h
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TosaOptionalDecompositions
: public TosaOptionalDecompositionsBase<TosaOptionalDecompositions> {
void runOnFunction() {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
auto func = getFunction();
mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::tosa::createTosaOptionalDecompositions() {
return std::make_unique<TosaOptionalDecompositions>();
}

View File

@ -1,69 +1,40 @@
// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s
// -----
// CHECK-LABEL: @conv2d_as_fully_connected
func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
// CHECK-NOT: "tosa.conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
// CHECK-SAME: -> tensor<400x2xf32>
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
// CHECK-SAME: -> tensor<3x2xf32>
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
// CHECK-SAME: -> tensor<400x3xf32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
// CHECK-SAME: -> tensor<4x10x10x3xf32>
// CHECK: return %[[VAR3]]
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
return %0 : tensor<4x10x10x3xf32>
}
// -----
// CHECK-LABEL: @conv2d_as_fully_connected_quant
func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
// CHECK-NOT: "tosa.conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
// CHECK-SAME: -> tensor<400x2xi8>
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
// CHECK-SAME: -> tensor<3x2xi8>
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
// CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK-SAME: -> tensor<400x3xi32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
// CHECK-SAME: -> tensor<4x10x10x3xi32>
// CHECK: return %[[VAR3]]
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
return %0 : tensor<4x10x10x3xi32>
}
// -----
// CHECK-LABEL: @depthwise_conv2d_as_mul
func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK-NOT: "tosa.depthwise_conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
// CHECK-SAME: -> tensor<1x1x1x2x3xf32>
// CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
// CHECK-SAME: -> tensor<4x10x10x6xf32>
// CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
// CHECK-SAME: -> tensor<4x10x10x6xf32>
// CHECK: return %[[VAR4]]
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
return %0 : tensor<4x10x10x6xf32>
}
// -----
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
// CHECK: "tosa.depthwise_conv2d"
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
// -----
// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
// -----
// CHECK-LABEL: @conv2d_as_fully_connected
func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
// CHECK-NOT: "tosa.conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
// CHECK-SAME: -> tensor<400x2xf32>
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
// CHECK-SAME: -> tensor<3x2xf32>
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
// CHECK-SAME: -> tensor<400x3xf32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
// CHECK-SAME: -> tensor<4x10x10x3xf32>
// CHECK: return %[[VAR3]]
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
return %0 : tensor<4x10x10x3xf32>
}
// -----
// CHECK-LABEL: @conv2d_as_fully_connected_quant
func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
// CHECK-NOT: "tosa.conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
// CHECK-SAME: -> tensor<400x2xi8>
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
// CHECK-SAME: -> tensor<3x2xi8>
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
// CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
// CHECK-SAME: -> tensor<400x3xi32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
// CHECK-SAME: -> tensor<4x10x10x3xi32>
// CHECK: return %[[VAR3]]
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
return %0 : tensor<4x10x10x3xi32>
}
// -----

View File

@ -0,0 +1,32 @@
// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
// -----
// CHECK-LABEL: @depthwise_conv2d_as_mul
func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK-NOT: "tosa.depthwise_conv2d"
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
// CHECK-SAME: -> tensor<1x1x1x2x3xf32>
// CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
// CHECK-SAME: -> tensor<4x10x10x6xf32>
// CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
// CHECK-SAME: -> tensor<4x10x10x6xf32>
// CHECK: return %[[VAR4]]
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
return %0 : tensor<4x10x10x6xf32>
}
// -----
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
// CHECK: "tosa.depthwise_conv2d"
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
// -----

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s
// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
// CHECK-LABEL: @transpose_conv2d
func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {