From eb04f321c344175e4510c3747d83a308bde96d68 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Thu, 6 Oct 2022 12:50:12 -0700 Subject: [PATCH] [tosa] Add legalization for conv3d Update the existing implementation to match TOSA spec. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D133062 --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 47 ++++++++++++++-------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 8df30279cb7e..841a27479bad 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamicSize); - Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); + Conv3DOp::Adaptor adaptor(operands.getValues(), attributes); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); - inputHeight = inputShape.getDimSize(1); - inputWidth = inputShape.getDimSize(2); - inputDepth = inputShape.getDimSize(3); + inputDepth = inputShape.getDimSize(1); + inputHeight = inputShape.getDimSize(2); + inputWidth = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight()); if (weightShape.hasRank()) { outputShape[4] = weightShape.getDimSize(0); - weightHeight = weightShape.getDimSize(1); - weightWidth = weightShape.getDimSize(2); - weightDepth = weightShape.getDimSize(3); + weightDepth = weightShape.getDimSize(1); + weightHeight = weightShape.getDimSize(2); + weightWidth = weightShape.getDimSize(3); } // Bias shape can describe the output channels. ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); - if (biasShape.hasRank()) { - outputShape[4] = - (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4]; + if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { + outputShape[4] = biasShape.getDimSize(0); } llvm::SmallVector dilation; - llvm::SmallVector padding; + llvm::SmallVector pad; llvm::SmallVector stride; getI64Values(adaptor.getDilation(), dilation); - getI64Values(adaptor.getPad(), padding); + getI64Values(adaptor.getPad(), pad); getI64Values(adaptor.getStride(), stride); - if (!ShapedType::isDynamic(inputHeight) && - !ShapedType::isDynamic(weightHeight)) { - int32_t inputSize = inputHeight + padding[0] + padding[1]; - int32_t filterSize = (weightHeight - 1) * dilation[0] + 1; + if (!ShapedType::isDynamic(inputDepth) && + !ShapedType::isDynamic(weightDepth)) { + int32_t inputSize = inputDepth + pad[0] + pad[1]; + int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } - if (!ShapedType::isDynamic(inputWidth) && - !ShapedType::isDynamic(weightWidth)) { - int32_t inputSize = inputWidth + padding[2] + padding[3]; - int32_t filterSize = (weightWidth - 1) * dilation[1] + 1; + if (!ShapedType::isDynamic(inputHeight) && + !ShapedType::isDynamic(weightHeight)) { + int32_t inputSize = inputHeight + pad[2] + pad[3]; + int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } - if (!ShapedType::isDynamic(inputDepth) && - !ShapedType::isDynamic(weightDepth)) { - int32_t inputSize = inputDepth + padding[4] + padding[5]; - int32_t filterSize = (weightDepth - 1) * dilation[2] + 1; + if (!ShapedType::isDynamic(inputWidth) && + !ShapedType::isDynamic(weightWidth)) { + int32_t inputSize = inputWidth + pad[4] + pad[5]; + int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[3] = (unstridedResult - 1) / stride[2] + 1; }