[mlir][tosa] Fix tosa.reshape failures due to implicit broadcasting

Make broadcastable needs the output shape to determine whether the operation
includes additional broadcasting. Include some canonicalizations for TOSA
to remove unneeded reshape.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D106846
This commit is contained in:
Rob Suderman 2021-07-29 14:38:30 -07:00
parent cf36ab1d6c
commit 2d0ba5e144
5 changed files with 58 additions and 2 deletions

View File

@ -1414,6 +1414,8 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
No data conversion happens during a reshape operation.
}];
let hasCanonicalizer = 1;
let arguments = (ins
Tosa_Tensor:$input1,
I64ArrayAttr:$new_shape

View File

@ -638,7 +638,8 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
if (newShape.size() != rank) {
operand = rewriter.create<tosa::ReshapeOp>(
loc, RankedTensorType::get(newShape, type.getElementType()), operand);
loc, RankedTensorType::get(newShape, type.getElementType()), operand,
rewriter.getI64ArrayAttr(newShape));
}
operands.push_back(operand);

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
@ -101,6 +102,48 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
//===----------------------------------------------------------------------===//
// Operator Canonicalizers.
//===----------------------------------------------------------------------===//
struct RemoveReshapeNoop : public OpRewritePattern<tosa::ReshapeOp> {
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
PatternRewriter &rewriter) const override {
if (op.input1().getType() != op.getType())
return failure();
rewriter.replaceOp(op, op.input1());
return success();
}
};
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.input1();
Operation *definingOp = input.getDefiningOp();
if (!definingOp)
return failure();
if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), reshapeOp.input1(), op.new_shape());
return success();
}
return failure();
}
};
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ReshapeReshapeOptimization, RemoveReshapeNoop>(context);
}
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//

View File

@ -143,7 +143,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
SmallVector<int64_t, 4> reshapeOutputShape;
computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
computeReshapeOutput(outputType.getShape(), lowerRankShape,
reshapeOutputShape);
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
auto reshapeOutputType = RankedTensorType::get(

View File

@ -136,6 +136,15 @@ func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tens
return %0 : tensor<14x15xf32>
}
// -----
// CHECK-LABEL: broadcast19
func @broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
// CHECK: reshape
// CHECK: sub
%0 = "tosa.sub"(%arg0, %arg1) : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
return %0 : tensor<64x64x17xf32>
}
// -----
// CHECK-LABEL: broadcast_mul
func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {