forked from OSchip/llvm-project
[mlir][tosa] Enable decomposing Conv2D also where 1 input dim is dynamic
Restricted to just 1 dynamic input dim as that worked all the way through to codegen. Differential Revision: https://reviews.llvm.org/D129334
This commit is contained in:
parent
b12930e133
commit
e08a991f56
|
@ -13,7 +13,6 @@
|
|||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tosa;
|
||||
|
@ -32,30 +31,34 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
|
|||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.getType().cast<ShapedType>();
|
||||
|
||||
if (!inputType.hasStaticShape() || !weightType.hasRank()) {
|
||||
return failure();
|
||||
}
|
||||
auto numDynamic = llvm::count_if(inputType.getShape(), [](int64_t d) {
|
||||
return ShapedType::isDynamic(d);
|
||||
});
|
||||
if (numDynamic > 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "at most one dim in input may be dynamic");
|
||||
if (!weightType.hasRank())
|
||||
return rewriter.notifyMatchFailure(op, "unranked weight input");
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (Attribute stride : op.stride().getValue()) {
|
||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||
for (APInt stride : op.stride().getAsValueRange<IntegerAttr>()) {
|
||||
if (!stride.isOne())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Only works for a 1x1 kernel.
|
||||
ArrayRef<int64_t> weightShape = weightType.getShape();
|
||||
if (weightShape[1] != 1 || weightShape[2] != 1) {
|
||||
if (weightShape[1] != 1 || weightShape[2] != 1)
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
llvm::SmallVector<int64_t, 2> revisedInputShape{
|
||||
inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
|
||||
auto revisedInputShapeType = RankedTensorType::get(
|
||||
revisedInputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
int64_t combined = inputShape[0] * inputShape[1] * inputShape[2];
|
||||
if (combined < 0)
|
||||
combined = ShapedType::kDynamicSize;
|
||||
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
|
||||
auto revisedInputShapeType =
|
||||
RankedTensorType::get(revisedInputShape, inputType.getElementType());
|
||||
auto reshapedInput = rewriter
|
||||
.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), revisedInputShapeType, input,
|
||||
|
@ -75,11 +78,9 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
|
|||
.getResult();
|
||||
|
||||
// Perform a fully connected network over the reshaped input and weight.
|
||||
llvm::SmallVector<int64_t, 2> fullyConnectedShape{
|
||||
inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
|
||||
auto fullyConnectedShapeType = RankedTensorType::get(
|
||||
fullyConnectedShape,
|
||||
resultType.dyn_cast<ShapedType>().getElementType());
|
||||
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
|
||||
auto fullyConnectedShapeType =
|
||||
RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
|
||||
|
||||
Value fullyConnectedValue;
|
||||
if (op.quantization_info()) {
|
||||
|
|
|
@ -38,3 +38,19 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @conv_with_dynamic_dim(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x14x14x64xi8>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<384x1x1x64xi8>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
|
||||
func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_0]]) {new_shape = [-1, 64]} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [384, 64]} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.fully_connected"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]]) {quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [-1, 14, 14, 384]} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
|
||||
// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
|
||||
// CHECK: }
|
||||
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = [1, 1]} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32>
|
||||
return %0 : tensor<?x14x14x384xi32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue