forked from OSchip/llvm-project
[mlir][tosa] Add tosa.depthwise_conv2d as tosa.mul canonicalization
For a 1x1 weight and stride of 1, the input/weight can be reshaped and multiplied elementwise then reshaped back Reviewed By: rsuderman, KoolJBlack Differential Revision: https://reviews.llvm.org/D115207
This commit is contained in:
parent
d9941f7454
commit
5911a29aa9
|
@ -187,6 +187,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
|
|||
let builders = [Tosa_ConvOpQuantInfoBuilder];
|
||||
|
||||
let verifier = [{ return verifyConvOp(*this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -515,6 +515,97 @@ void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
results.insert<Conv2DFullyConnectedOptimization>(context);
|
||||
}
|
||||
|
||||
struct DepthwiseConv2DMulOptimization
|
||||
: public OpRewritePattern<tosa::DepthwiseConv2DOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.output().getType().cast<ShapedType>();
|
||||
|
||||
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
|
||||
resultType.hasStaticShape())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (Attribute stride : op.stride().getValue()) {
|
||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Only works for a 1x1 kernel.
|
||||
ArrayRef<int64_t> weightShape = weightType.getShape();
|
||||
if (weightShape[0] != 1 || weightShape[1] != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
llvm::SmallVector<int64_t, 2> revisedInputShape{
|
||||
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
|
||||
auto revisedInputShapeType = RankedTensorType::get(
|
||||
revisedInputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedInput = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedInputShapeType, input,
|
||||
rewriter.getI64ArrayAttr(revisedInputShape))
|
||||
.getResult();
|
||||
|
||||
// Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M].
|
||||
llvm::SmallVector<int64_t, 2> revisedWeightShape{1, 1, 1, weightShape[2],
|
||||
weightShape[3]};
|
||||
auto revisedWeightShapeType = RankedTensorType::get(
|
||||
revisedWeightShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto reshapedWeight = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedWeightShapeType, weight,
|
||||
rewriter.getI64ArrayAttr(revisedWeightShape))
|
||||
.getResult();
|
||||
|
||||
// Perform an elementwise mul over the reshaped input and weight.
|
||||
llvm::SmallVector<int64_t, 2> mulShape{inputShape[0], inputShape[1],
|
||||
inputShape[2], inputShape[3],
|
||||
weightShape[3]};
|
||||
auto mulShapeType = RankedTensorType::get(
|
||||
mulShape,
|
||||
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
Value mulValue =
|
||||
rewriter
|
||||
.create<tosa::MulOp>(op.getLoc(), mulShapeType, reshapedInput,
|
||||
reshapedWeight, /*shift=*/0)
|
||||
.getResult();
|
||||
|
||||
// Reshape output to [N, H, W, C * M].
|
||||
auto outputShape = op.output().getType().cast<ShapedType>().getShape();
|
||||
auto outputShapeType = RankedTensorType::get(
|
||||
outputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
auto outputValue =
|
||||
rewriter.create<tosa::ReshapeOp>(op.getLoc(), outputShapeType, mulValue,
|
||||
rewriter.getI64ArrayAttr(outputShape));
|
||||
|
||||
// Add in the bias.
|
||||
rewriter
|
||||
.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
|
||||
op.bias())
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void DepthwiseConv2DOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<DepthwiseConv2DMulOptimization>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator Folders.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -106,6 +106,44 @@ func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
|
|||
return %0 : tensor<4x10x10x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv2d_as_mul
|
||||
func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
|
||||
// CHECK-NOT: "tosa.depthwise_conv2d"
|
||||
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x2x1xf32>
|
||||
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]}
|
||||
// CHECK-SAME: -> tensor<1x1x1x2x3xf32>
|
||||
// CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]])
|
||||
// CHECK-SAME: -> tensor<4x10x10x2x3xf32>
|
||||
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]}
|
||||
// CHECK-SAME: -> tensor<4x10x10x6xf32>
|
||||
// CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2)
|
||||
// CHECK-SAME: -> tensor<4x10x10x6xf32>
|
||||
// CHECK: return %[[VAR4]]
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
|
||||
return %0 : tensor<4x10x10x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv2d_stride_2
|
||||
func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
|
||||
// CHECK: "tosa.depthwise_conv2d"
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
|
||||
return %0 : tensor<4x10x10x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv2d_weight_2x2
|
||||
func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
|
||||
// CHECK: "tosa.depthwise_conv2d"
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32>
|
||||
return %0 : tensor<4x10x10x6xf32>
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: @pad_noop
|
||||
|
|
Loading…
Reference in New Issue