forked from OSchip/llvm-project
[mlir][tosa] Resubmit add tosa.conv2d as tosa.fully_connected canonicalization
Fixed the tosa.conv2d to tosa.fully_connected canonicalization for incorrect output channels. Included uptes to tests to include checks for the result shapes during canonicalization. This allows conv2d to transform to the simpler fully_connected operation. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D115170
This commit is contained in:
parent
13278efd0c
commit
05e33d846f
|
@ -118,6 +118,8 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
|
||||||
let builders = [Tosa_ConvOpQuantInfoBuilder];
|
let builders = [Tosa_ConvOpQuantInfoBuilder];
|
||||||
|
|
||||||
let verifier = [{ return verifyConvOp(*this); }];
|
let verifier = [{ return verifyConvOp(*this); }];
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -423,6 +423,98 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
results.insert<MaterializePadValue>(context);
|
results.insert<MaterializePadValue>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Conv2DFullyConnectedOptimization
|
||||||
|
: public OpRewritePattern<tosa::Conv2DOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(tosa::Conv2DOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value input = op.input();
|
||||||
|
Value weight = op.weight();
|
||||||
|
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||||
|
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||||
|
ShapedType resultType = op.getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
if (!inputType.hasStaticShape() || !weightType.hasRank()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stride must be 1 for this optimization.
|
||||||
|
for (Attribute stride : op.stride().getValue()) {
|
||||||
|
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only works for a 1x1 kernel.
|
||||||
|
ArrayRef<int64_t> weightShape = weightType.getShape();
|
||||||
|
if (weightShape[1] != 1 || weightShape[2] != 1) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
|
||||||
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||||
|
llvm::SmallVector<int64_t, 2> revisedInputShape{
|
||||||
|
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
|
||||||
|
auto revisedInputShapeType = RankedTensorType::get(
|
||||||
|
revisedInputShape,
|
||||||
|
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||||
|
auto reshapedInput = rewriter
|
||||||
|
.create<tosa::ReshapeOp>(
|
||||||
|
op.getLoc(), revisedInputShapeType, input,
|
||||||
|
rewriter.getI64ArrayAttr(revisedInputShape))
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
|
||||||
|
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
|
||||||
|
weightShape[3]};
|
||||||
|
auto revisedWeightShapeType = RankedTensorType::get(
|
||||||
|
revisedWeightShape,
|
||||||
|
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||||
|
auto reshapedWeight = rewriter
|
||||||
|
.create<tosa::ReshapeOp>(
|
||||||
|
op.getLoc(), revisedWeightShapeType, weight,
|
||||||
|
rewriter.getI64ArrayAttr(revisedWeightShape))
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
// Perform a fully connected network over the reshaped input and weight.
|
||||||
|
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
|
||||||
|
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
|
||||||
|
auto fullyConnectedShapeType = RankedTensorType::get(
|
||||||
|
fullyConnectedShape,
|
||||||
|
weight.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||||
|
|
||||||
|
Value fullyConnectedValue;
|
||||||
|
if (op.quantization_info()) {
|
||||||
|
fullyConnectedValue =
|
||||||
|
rewriter
|
||||||
|
.create<tosa::FullyConnectedOp>(
|
||||||
|
op.getLoc(), fullyConnectedShapeType, reshapedInput,
|
||||||
|
reshapedWeight, op.bias(), op.quantization_info().getValue())
|
||||||
|
.getResult();
|
||||||
|
} else {
|
||||||
|
fullyConnectedValue = rewriter
|
||||||
|
.create<tosa::FullyConnectedOp>(
|
||||||
|
op.getLoc(), fullyConnectedShapeType,
|
||||||
|
reshapedInput, reshapedWeight, op.bias())
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape output to [N, IH, IW, OC].
|
||||||
|
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
|
||||||
|
inputShape[2], weightShape[0]};
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||||
|
op, resultType, fullyConnectedValue,
|
||||||
|
rewriter.getI64ArrayAttr(outputShape));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
results.insert<Conv2DFullyConnectedOptimization>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Operator Folders.
|
// Operator Folders.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -66,12 +66,52 @@ func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
|
||||||
return %0 : tensor<?x?xf32>
|
return %0 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @conv2d_as_fully_connected
|
||||||
|
func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
|
||||||
|
// CHECK-NOT: "tosa.conv2d"
|
||||||
|
// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]}
|
||||||
|
// CHECK-SAME: -> tensor<400x2xf32>
|
||||||
|
// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]}
|
||||||
|
// CHECK-SAME: -> tensor<3x2xf32>
|
||||||
|
// CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2)
|
||||||
|
// CHECK-SAME: -> tensor<400x3xf32>
|
||||||
|
// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]}
|
||||||
|
// CHECK-SAME: -> tensor<4x10x10x3xf32>
|
||||||
|
// CHECK: return %[[VAR3]]
|
||||||
|
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
|
||||||
|
return %0 : tensor<4x10x10x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @conv2d_stride_2
|
||||||
|
func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
|
||||||
|
// CHECK: "tosa.conv2d"
|
||||||
|
%weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32>
|
||||||
|
%bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32>
|
||||||
|
%0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
|
||||||
|
return %0 : tensor<4x10x10x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @conv2d_weight_2x2
|
||||||
|
func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
|
||||||
|
// CHECK: "tosa.conv2d"
|
||||||
|
%weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32>
|
||||||
|
%bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32>
|
||||||
|
%0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32>
|
||||||
|
return %0 : tensor<4x10x10x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// ----
|
// ----
|
||||||
|
|
||||||
// CHECK-LABEL: @pad_noop
|
// CHECK-LABEL: @pad_noop
|
||||||
func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
// CHECK: return %arg0
|
// CHECK: return %arg0
|
||||||
%0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
%0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||||
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
|
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
|
||||||
return %1 : tensor<?x?xf32>
|
return %1 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
@ -82,7 +122,7 @@ func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
|
func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
|
||||||
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
|
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
|
||||||
// CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
|
// CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
|
||||||
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||||
%1 = "tosa.pad"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
|
%1 = "tosa.pad"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
|
||||||
return %1 : tensor<?x?xi32>
|
return %1 : tensor<?x?xi32>
|
||||||
}
|
}
|
||||||
|
@ -93,7 +133,7 @@ func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) ->
|
||||||
func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
|
func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
|
||||||
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
|
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||||
// CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
|
// CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
|
||||||
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||||
%1 = "tosa.pad"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
|
%1 = "tosa.pad"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>
|
||||||
return %1 : tensor<?x?xf32>
|
return %1 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
@ -104,7 +144,7 @@ func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) ->
|
||||||
func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
|
func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
|
||||||
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor<i32>}
|
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor<i32>}
|
||||||
// CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
|
// CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]])
|
||||||
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
%0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||||
%1 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = 42:i32} } : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
|
%1 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = 42:i32} } : (tensor<?x?xi32>, tensor<2x2xi32>) -> tensor<?x?xi32>
|
||||||
return %1 : tensor<?x?xi32>
|
return %1 : tensor<?x?xi32>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue