forked from OSchip/llvm-project
[mlir][tosa] Allow optional TOSA decompositions to be populated separately
Moved all TOSA decomposition patterns so that they can be optionally populated and used by external rewrites. This avoids decomposing TOSa operations when backends may benefit from the non-decomposed version. Reviewed By: rsuderman, mehdi_amini Differential Revision: https://reviews.llvm.org/D116526
This commit is contained in:
parent
0a8d15ad55
commit
dfd070820c
|
@ -19,11 +19,18 @@
|
|||
namespace mlir {
|
||||
namespace tosa {
|
||||
|
||||
std::unique_ptr<Pass> createTosaDecomposeTransposeConvPass();
|
||||
// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
|
||||
// The rewrites can be selectively added to a conversion pass.
|
||||
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
|
||||
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns);
|
||||
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
std::unique_ptr<Pass> createTosaInferShapesPass();
|
||||
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
|
||||
std::unique_ptr<Pass> createTosaOptimizationPass();
|
||||
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
|
||||
std::unique_ptr<Pass> createTosaOptionalDecompositions();
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
|
||||
|
|
|
@ -15,21 +15,6 @@
|
|||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> {
|
||||
let summary = "Deompose transpose convolutiions into standard convolutions.";
|
||||
let description = [{
|
||||
Pass that uses shape manipulation and convolution operations to transform
|
||||
a transpose convolution into a regular convolution.
|
||||
}];
|
||||
|
||||
let constructor = "createTosaDecomposeTransposeConvPass()";
|
||||
let dependentDialects = [
|
||||
"StandardOpsDialect",
|
||||
"tensor::TensorDialect",
|
||||
"tosa::TosaDialect",
|
||||
];
|
||||
}
|
||||
|
||||
def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
|
||||
let summary = "Propagate shapes across TOSA operations";
|
||||
let description = [{
|
||||
|
@ -58,13 +43,14 @@ def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
|
|||
let constructor = "createTosaMakeBroadcastablePass()";
|
||||
}
|
||||
|
||||
def TosaOptimization : FunctionPass<"tosa-optimization"> {
|
||||
let summary = "TOSA operation optimizations";
|
||||
def TosaOptionalDecompositions : FunctionPass<"tosa-optional-decompositions"> {
|
||||
let summary = "Applies Tosa operations optional decompositions";
|
||||
let description = [{
|
||||
"Pass to perform optimizations on TOSA operations"
|
||||
Pass to apply the Tosa operations decompositions
|
||||
exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h
|
||||
}];
|
||||
|
||||
let constructor = "createTosaOptimizationPass()";
|
||||
let constructor = "tosa::createTosaOptionalDecompositions()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
|
||||
|
|
|
@ -68,6 +68,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
|
|||
}
|
||||
|
||||
void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm) {
|
||||
pm.addNestedPass<FuncOp>(mlir::tosa::createTosaOptionalDecompositions());
|
||||
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
|
||||
pm.addNestedPass<FuncOp>(createTosaToLinalgNamed());
|
||||
pm.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
add_mlir_dialect_library(MLIRTosaTransforms
|
||||
TosaDecomposeTransposeConv.cpp
|
||||
TosaDecomposeConv2D.cpp
|
||||
TosaDecomposeDepthwise.cpp
|
||||
TosaInferShapes.cpp
|
||||
TosaMakeBroadcastable.cpp
|
||||
TosaOptimization.cpp
|
||||
TosaOptionalDecompositions.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
//===- TosaDecomposeConv2D.cpp ------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
|
||||
// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tosa;
|
||||
|
||||
namespace {
|
||||
|
||||
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
|
||||
explicit Conv2DIsFullyConnected(MLIRContext *context)
|
||||
: OpRewritePattern(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::Conv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.getType().cast<ShapedType>();
|
||||
|
||||
if (!inputType.hasStaticShape() || !weightType.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (Attribute stride : op.stride().getValue()) {
|
||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Only works for a 1x1 kernel.
|
||||
ArrayRef<int64_t> weightShape = weightType.getShape();
|
||||
if (weightShape[1] != 1 || weightShape[2] != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
llvm::SmallVector<int64_t, 2> revisedInputShape{
|
||||
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
|
||||
auto revisedInputShapeType = RankedTensorType::get(
|
||||
revisedInputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedInput = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedInputShapeType, input,
|
||||
rewriter.getI64ArrayAttr(revisedInputShape))
|
||||
.getResult();
|
||||
|
||||
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
|
||||
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
|
||||
weightShape[3]};
|
||||
auto revisedWeightShapeType = RankedTensorType::get(
|
||||
revisedWeightShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedWeight = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedWeightShapeType, weight,
|
||||
rewriter.getI64ArrayAttr(revisedWeightShape))
|
||||
.getResult();
|
||||
|
||||
// Perform a fully connected network over the reshaped input and weight.
|
||||
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
|
||||
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
|
||||
auto fullyConnectedShapeType = RankedTensorType::get(
|
||||
fullyConnectedShape,
|
||||
resultType.dyn_cast<ShapedType>().getElementType());
|
||||
|
||||
Value fullyConnectedValue;
|
||||
if (op.quantization_info()) {
|
||||
fullyConnectedValue =
|
||||
rewriter
|
||||
.create<tosa::FullyConnectedOp>(
|
||||
op.getLoc(), fullyConnectedShapeType, reshapedInput,
|
||||
reshapedWeight, op.bias(), op.quantization_info().getValue())
|
||||
.getResult();
|
||||
} else {
|
||||
fullyConnectedValue = rewriter
|
||||
.create<tosa::FullyConnectedOp>(
|
||||
op.getLoc(), fullyConnectedShapeType,
|
||||
reshapedInput, reshapedWeight, op.bias())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Reshape output to [N, IH, IW, OC].
|
||||
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
|
||||
inputShape[2], weightShape[0]};
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, resultType, fullyConnectedValue,
|
||||
rewriter.getI64ArrayAttr(outputShape));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.insert<Conv2DIsFullyConnected>(ctx);
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
//===- TosaDecomposeDepthwise.cpp
|
||||
//------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
|
||||
// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tosa;
|
||||
|
||||
namespace {
|
||||
|
||||
struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
|
||||
explicit DepthwiseConv2DIsMul(MLIRContext *context)
|
||||
: OpRewritePattern(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.output().getType().cast<ShapedType>();
|
||||
Type inputEType = inputType.getElementType();
|
||||
|
||||
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
|
||||
resultType.hasStaticShape())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Quantization information needs to still be performed.
|
||||
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (Attribute stride : op.stride().getValue()) {
|
||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Only works for a 1x1 kernel.
|
||||
ArrayRef<int64_t> weightShape = weightType.getShape();
|
||||
if (weightShape[0] != 1 || weightShape[1] != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
llvm::SmallVector<int64_t, 2> revisedInputShape{
|
||||
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
|
||||
auto revisedInputShapeType = RankedTensorType::get(
|
||||
revisedInputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedInput = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedInputShapeType, input,
|
||||
rewriter.getI64ArrayAttr(revisedInputShape))
|
||||
.getResult();
|
||||
|
||||
// Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
|
||||
llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
|
||||
weightShape[3]};
|
||||
auto revisedWeightShapeType = RankedTensorType::get(
|
||||
revisedWeightShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedWeight = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedWeightShapeType, weight,
|
||||
rewriter.getI64ArrayAttr(revisedWeightShape))
|
||||
.getResult();
|
||||
|
||||
// Perform an elementwise mul over the reshaped input and weight.
|
||||
llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
|
||||
inputShape[2], inputShape[3],
|
||||
weightShape[3]};
|
||||
auto mulShapeType = RankedTensorType::get(
|
||||
mulShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
Value mulValue =
|
||||
rewriter
|
||||
.create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
|
||||
reshapedWeight, /*shift=*/0)
|
||||
.getResult();
|
||||
|
||||
// Reshape output to [N, H, W, C * M].
|
||||
auto outputShape = op.output().getType().cast<ShapedType>().getShape();
|
||||
auto outputShapeType = RankedTensorType::get(
|
||||
outputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto outputValue =
|
||||
rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
|
||||
rewriter.getI64ArrayAttr(outputShape));
|
||||
|
||||
// Add in the bias.
|
||||
rewriter
|
||||
.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
|
||||
op.bias())
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.insert<DepthwiseConv2DIsMul>(ctx);
|
||||
}
|
|
@ -7,17 +7,19 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Insert reshape to binary op's input if needed to match rank
|
||||
// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
|
||||
// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
|
||||
// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
|
||||
// including transposing/reversing/reshaping etc..
|
||||
// of the weights and input/output tenors and reversing/reshaping etc .. of
|
||||
// the weights
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tosa/IR//TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tosa;
|
||||
|
@ -369,22 +371,10 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Pass that enables broadcast by making all input arrays have the same
|
||||
/// number of dimensions. Insert RESHAPE operations to lower rank operand
|
||||
struct TosaDecomposeTransposeConv
|
||||
: public TosaDecomposeTransposeConvBase<TosaDecomposeTransposeConv> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
auto func = getFunction();
|
||||
RewritePatternSet patterns(func.getContext());
|
||||
patterns
|
||||
.insert<TransposeConvDilatedConverter, TransposeConvStridedConverter>(
|
||||
func.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::tosa::createTosaDecomposeTransposeConvPass() {
|
||||
return std::make_unique<TosaDecomposeTransposeConv>();
|
||||
void mlir::tosa::populateTosaDecomposeTransposeConv(
|
||||
MLIRContext *ctx, RewritePatternSet &patterns) {
|
||||
patterns.insert<TransposeConvDilatedConverter>(ctx);
|
||||
patterns.insert<TransposeConvStridedConverter>(ctx);
|
||||
}
|
||||
|
|
|
@ -1,243 +0,0 @@
|
|||
//===- TosaOptimization.cpp ------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Pass to perform optimizations on TOSA operations
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tosa;
|
||||
|
||||
#define PASS_NAME "tosa-optimization"
|
||||
#define DEBUG_TYPE PASS_NAME
|
||||
|
||||
namespace {
|
||||
|
||||
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
|
||||
explicit Conv2DIsFullyConnected(MLIRContext *context)
|
||||
: OpRewritePattern(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::Conv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.getType().cast<ShapedType>();
|
||||
|
||||
if (!inputType.hasStaticShape() || !weightType.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (Attribute stride : op.stride().getValue()) {
|
||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Only works for a 1x1 kernel.
|
||||
ArrayRef<int64_t> weightShape = weightType.getShape();
|
||||
if (weightShape[1] != 1 || weightShape[2] != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
llvm::SmallVector<int64_t, 2> revisedInputShape{
|
||||
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
|
||||
auto revisedInputShapeType = RankedTensorType::get(
|
||||
revisedInputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedInput = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedInputShapeType, input,
|
||||
rewriter.getI64ArrayAttr(revisedInputShape))
|
||||
.getResult();
|
||||
|
||||
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
|
||||
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
|
||||
weightShape[3]};
|
||||
auto revisedWeightShapeType = RankedTensorType::get(
|
||||
revisedWeightShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedWeight = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedWeightShapeType, weight,
|
||||
rewriter.getI64ArrayAttr(revisedWeightShape))
|
||||
.getResult();
|
||||
|
||||
// Perform a fully connected network over the reshaped input and weight.
|
||||
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
|
||||
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
|
||||
auto fullyConnectedShapeType = RankedTensorType::get(
|
||||
fullyConnectedShape,
|
||||
resultType.dyn_cast<ShapedType>().getElementType());
|
||||
|
||||
Value fullyConnectedValue;
|
||||
if (op.quantization_info()) {
|
||||
fullyConnectedValue =
|
||||
rewriter
|
||||
.create<tosa::FullyConnectedOp>(
|
||||
op.getLoc(), fullyConnectedShapeType, reshapedInput,
|
||||
reshapedWeight, op.bias(), op.quantization_info().getValue())
|
||||
.getResult();
|
||||
} else {
|
||||
fullyConnectedValue = rewriter
|
||||
.create<tosa::FullyConnectedOp>(
|
||||
op.getLoc(), fullyConnectedShapeType,
|
||||
reshapedInput, reshapedWeight, op.bias())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Reshape output to [N, IH, IW, OC].
|
||||
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
|
||||
inputShape[2], weightShape[0]};
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, resultType, fullyConnectedValue,
|
||||
rewriter.getI64ArrayAttr(outputShape));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
|
||||
explicit DepthwiseConv2DIsMul(MLIRContext *context)
|
||||
: OpRewritePattern(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.output().getType().cast<ShapedType>();
|
||||
Type inputEType = inputType.getElementType();
|
||||
|
||||
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
|
||||
resultType.hasStaticShape())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Quantization information needs to still be performed.
|
||||
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (Attribute stride : op.stride().getValue()) {
|
||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Only works for a 1x1 kernel.
|
||||
ArrayRef<int64_t> weightShape = weightType.getShape();
|
||||
if (weightShape[0] != 1 || weightShape[1] != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
llvm::SmallVector<int64_t, 2> revisedInputShape{
|
||||
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
|
||||
auto revisedInputShapeType = RankedTensorType::get(
|
||||
revisedInputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedInput = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedInputShapeType, input,
|
||||
rewriter.getI64ArrayAttr(revisedInputShape))
|
||||
.getResult();
|
||||
|
||||
// Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
|
||||
llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
|
||||
weightShape[3]};
|
||||
auto revisedWeightShapeType = RankedTensorType::get(
|
||||
revisedWeightShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedWeight = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedWeightShapeType, weight,
|
||||
rewriter.getI64ArrayAttr(revisedWeightShape))
|
||||
.getResult();
|
||||
|
||||
// Perform an elementwise mul over the reshaped input and weight.
|
||||
llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
|
||||
inputShape[2], inputShape[3],
|
||||
weightShape[3]};
|
||||
auto mulShapeType = RankedTensorType::get(
|
||||
mulShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
Value mulValue =
|
||||
rewriter
|
||||
.create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
|
||||
reshapedWeight, /*shift=*/0)
|
||||
.getResult();
|
||||
|
||||
// Reshape output to [N, H, W, C * M].
|
||||
auto outputShape = op.output().getType().cast<ShapedType>().getShape();
|
||||
auto outputShapeType = RankedTensorType::get(
|
||||
outputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto outputValue =
|
||||
rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
|
||||
rewriter.getI64ArrayAttr(outputShape));
|
||||
|
||||
// Add in the bias.
|
||||
rewriter
|
||||
.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
|
||||
op.bias())
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TosaOptimization : public PassWrapper<TosaOptimization, FunctionPass> {
|
||||
public:
|
||||
explicit TosaOptimization() = default;
|
||||
void runOnFunction() override;
|
||||
|
||||
StringRef getArgument() const final { return PASS_NAME; }
|
||||
StringRef getDescription() const final {
|
||||
return "Applies TOSA Operation Optimizations";
|
||||
}
|
||||
};
|
||||
|
||||
void TosaOptimization::runOnFunction() {
|
||||
OwningRewritePatternList patterns(&getContext());
|
||||
|
||||
patterns.insert<Conv2DIsFullyConnected>(&getContext());
|
||||
patterns.insert<DepthwiseConv2DIsMul>(&getContext());
|
||||
|
||||
auto func = getFunction();
|
||||
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::tosa::createTosaOptimizationPass() {
|
||||
return std::make_unique<TosaOptimization>();
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
//===- TosaOptionalDecompositions.cpp
|
||||
//------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Pass to apply the Tosa operations decompositions
|
||||
// exposed as populate functions in
|
||||
// include/mlir/Dialect/Tosa/Transforms/Passes.h
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TosaOptionalDecompositions
|
||||
: public TosaOptionalDecompositionsBase<TosaOptionalDecompositions> {
|
||||
void runOnFunction() {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
auto func = getFunction();
|
||||
|
||||
mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
|
||||
mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
|
||||
mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::tosa::createTosaOptionalDecompositions() {
|
||||
return std::make_unique<TosaOptionalDecompositions>();
|
||||
}
|
|
@ -1,69 +1,40 @@
|
|||
// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @conv2d_as_fully_connected
|
||||
func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
|
||||
// CHECK-NOT: "tosa.conv2d"
|
||||
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
|
||||
// CHECK-SAME: -> tensor<400x2xf32>
|
||||
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
|
||||
// CHECK-SAME: -> tensor<3x2xf32>
|
||||
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
|
||||
// CHECK-SAME: -> tensor<400x3xf32>
|
||||
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x3xf32>
|
||||
// CHECK: return %[[VAR3]]
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
|
||||
return %0 : tensor<4x10x10x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @conv2d_as_fully_connected_quant
|
||||
func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
|
||||
// CHECK-NOT: "tosa.conv2d"
|
||||
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
|
||||
// CHECK-SAME: -> tensor<400x2xi8>
|
||||
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
|
||||
// CHECK-SAME: -> tensor<3x2xi8>
|
||||
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
|
||||
// CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
|
||||
// CHECK-SAME: -> tensor<400x3xi32>
|
||||
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x3xi32>
|
||||
// CHECK: return %[[VAR3]]
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
|
||||
return %0 : tensor<4x10x10x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv2d_as_mul
|
||||
func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
|
||||
// CHECK-NOT: "tosa.depthwise_conv2d"
|
||||
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
|
||||
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
|
||||
// CHECK-SAME: -> tensor<1x1x1x2x3xf32>
|
||||
// CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
|
||||
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
|
||||
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x6xf32>
|
||||
// CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
|
||||
// CHECK-SAME: -> tensor<4x10x10x6xf32>
|
||||
// CHECK: return %[[VAR4]]
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
|
||||
return %0 : tensor<4x10x10x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
|
||||
func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
|
||||
// CHECK: "tosa.depthwise_conv2d"
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
|
||||
return %0 : tensor<4x10x10x6xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @conv2d_as_fully_connected
|
||||
func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
|
||||
// CHECK-NOT: "tosa.conv2d"
|
||||
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
|
||||
// CHECK-SAME: -> tensor<400x2xf32>
|
||||
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
|
||||
// CHECK-SAME: -> tensor<3x2xf32>
|
||||
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
|
||||
// CHECK-SAME: -> tensor<400x3xf32>
|
||||
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x3xf32>
|
||||
// CHECK: return %[[VAR3]]
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
|
||||
return %0 : tensor<4x10x10x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @conv2d_as_fully_connected_quant
|
||||
func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
|
||||
// CHECK-NOT: "tosa.conv2d"
|
||||
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
|
||||
// CHECK-SAME: -> tensor<400x2xi8>
|
||||
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
|
||||
// CHECK-SAME: -> tensor<3x2xi8>
|
||||
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
|
||||
// CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}
|
||||
// CHECK-SAME: -> tensor<400x3xi32>
|
||||
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x3xi32>
|
||||
// CHECK: return %[[VAR3]]
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
|
||||
return %0 : tensor<4x10x10x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
|
@ -0,0 +1,32 @@
|
|||
// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv2d_as_mul
|
||||
func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
|
||||
// CHECK-NOT: "tosa.depthwise_conv2d"
|
||||
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
|
||||
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
|
||||
// CHECK-SAME: -> tensor<1x1x1x2x3xf32>
|
||||
// CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
|
||||
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
|
||||
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x6xf32>
|
||||
// CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
|
||||
// CHECK-SAME: -> tensor<4x10x10x6xf32>
|
||||
// CHECK: return %[[VAR4]]
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
|
||||
return %0 : tensor<4x10x10x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
|
||||
func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
|
||||
// CHECK: "tosa.depthwise_conv2d"
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
|
||||
return %0 : tensor<4x10x10x6xi32>
|
||||
}
|
||||
|
||||
// -----
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @transpose_conv2d
|
||||
func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
|
||||
|
|
Loading…
Reference in New Issue