diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 1760390ebb41..1a9d28834e3a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -304,7 +304,7 @@ static void getI64Values(ArrayAttr arrayAttr, SmallVector &values) { LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapedType inputTy = operands[0].getType().cast(); IntegerAttr axis = attributes.get("axis").cast(); @@ -329,7 +329,7 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. int32_t axis = @@ -386,7 +386,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents( LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapedType inputTy = operands[0].getType().cast(); ShapedType weightTy = operands[1].getType().cast(); @@ -414,7 +414,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapedType lhsTy = operands[0].getType().cast(); ShapedType rhsTy = operands[1].getType().cast(); @@ -439,7 +439,7 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents( LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapedType inputTy = operands[0].getType().cast(); ShapedType paddingTy = operands[1].getType().cast(); @@ -495,7 +495,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { auto sizes = attributes.get("size").cast().getValue(); SmallVector outputShape; @@ -510,7 +510,7 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents( LogicalResult tosa::TableOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapedType inputTy = operands[0].getType().cast(); @@ -525,7 +525,7 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents( LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { auto multiples = attributes.get("multiples").cast().getValue(); ShapedType inputTy = operands[0].getType().cast(); @@ -558,7 +558,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapedType type = operands.front().getType().cast(); @@ -596,7 +596,7 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapedType inputTy = operands[0].getType().cast(); ShapedType permsTy = operands[1].getType().cast(); @@ -662,7 +662,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( LogicalResult tosa::GatherOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, -1); @@ -685,7 +685,7 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents( LogicalResult tosa::ScatterOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, -1); @@ -889,7 +889,7 @@ static LogicalResult poolingInferReturnTypes( LogicalResult Conv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); Conv2DOp::Adaptor adaptor(operands); @@ -950,7 +950,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamicSize); Conv2DOp::Adaptor adaptor(operands); @@ -1023,21 +1023,21 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); DepthwiseConv2DOp::Adaptor adaptor(operands); @@ -1112,7 +1112,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::llvm::Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { TransposeConv2DOp::Adaptor adaptor(operands); llvm::SmallVector outputShape; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 7c470dd30d05..02ad79ad085d 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -780,7 +780,7 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext *context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand.