forked from OSchip/llvm-project
[mlir][tosa] Fix tosa.reshape failures due to implicit broadcasting
Make broadcastable needs the output shape to determine whether the operation includes additional broadcasting. Include some canonicalizations for TOSA to remove unneeded reshape. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D106846
This commit is contained in:
parent
cf36ab1d6c
commit
2d0ba5e144
|
@ -1414,6 +1414,8 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
|
|||
No data conversion happens during a reshape operation.
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let arguments = (ins
|
||||
Tosa_Tensor:$input1,
|
||||
I64ArrayAttr:$new_shape
|
||||
|
|
|
@ -638,7 +638,8 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
|
|||
|
||||
if (newShape.size() != rank) {
|
||||
operand = rewriter.create<tosa::ReshapeOp>(
|
||||
loc, RankedTensorType::get(newShape, type.getElementType()), operand);
|
||||
loc, RankedTensorType::get(newShape, type.getElementType()), operand,
|
||||
rewriter.getI64ArrayAttr(newShape));
|
||||
}
|
||||
|
||||
operands.push_back(operand);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
|
@ -101,6 +102,48 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator Canonicalizers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct RemoveReshapeNoop : public OpRewritePattern<tosa::ReshapeOp> {
|
||||
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.input1().getType() != op.getType())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, op.input1());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
|
||||
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input1();
|
||||
Operation *definingOp = input.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, op.getType(), reshapeOp.input1(), op.new_shape());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ReshapeReshapeOptimization, RemoveReshapeNoop>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator Folders.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -143,7 +143,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
|
|||
|
||||
SmallVector<int64_t, 4> reshapeOutputShape;
|
||||
|
||||
computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
|
||||
computeReshapeOutput(outputType.getShape(), lowerRankShape,
|
||||
reshapeOutputShape);
|
||||
|
||||
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
|
||||
auto reshapeOutputType = RankedTensorType::get(
|
||||
|
|
|
@ -136,6 +136,15 @@ func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tens
|
|||
return %0 : tensor<14x15xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: broadcast19
|
||||
func @broadcast19(%arg0: tensor<64x64x1xf32>, %arg1: tensor<1x17xf32>) -> (tensor<64x64x17xf32> ) {
|
||||
// CHECK: reshape
|
||||
// CHECK: sub
|
||||
%0 = "tosa.sub"(%arg0, %arg1) : (tensor<64x64x1xf32>, tensor<1x17xf32>) -> tensor<64x64x17xf32>
|
||||
return %0 : tensor<64x64x17xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: broadcast_mul
|
||||
func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
|
||||
|
|
Loading…
Reference in New Issue