[tosa] Add legalization for conv3d

Update the existing implementation to match TOSA spec.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D133062
This commit is contained in:
TatWai Chong 2022-10-06 12:50:12 -07:00 committed by Rob Suderman
parent 3f965818b6
commit eb04f321c3
1 changed files with 23 additions and 24 deletions

View File

@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize); llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes); Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputWidth = ShapedType::kDynamicSize;
int32_t inputHeight = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize;
@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
if (inputShape.hasRank()) { if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0); outputShape[0] = inputShape.getDimSize(0);
inputHeight = inputShape.getDimSize(1); inputDepth = inputShape.getDimSize(1);
inputWidth = inputShape.getDimSize(2); inputHeight = inputShape.getDimSize(2);
inputDepth = inputShape.getDimSize(3); inputWidth = inputShape.getDimSize(3);
} }
// Weight shapes describes the filter width/height and the output channels. // Weight shapes describes the filter width/height and the output channels.
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight()); ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
if (weightShape.hasRank()) { if (weightShape.hasRank()) {
outputShape[4] = weightShape.getDimSize(0); outputShape[4] = weightShape.getDimSize(0);
weightHeight = weightShape.getDimSize(1); weightDepth = weightShape.getDimSize(1);
weightWidth = weightShape.getDimSize(2); weightHeight = weightShape.getDimSize(2);
weightDepth = weightShape.getDimSize(3); weightWidth = weightShape.getDimSize(3);
} }
// Bias shape can describe the output channels. // Bias shape can describe the output channels.
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
if (biasShape.hasRank()) { if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
outputShape[4] = outputShape[4] = biasShape.getDimSize(0);
(outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
} }
llvm::SmallVector<int64_t> dilation; llvm::SmallVector<int64_t> dilation;
llvm::SmallVector<int64_t> padding; llvm::SmallVector<int64_t> pad;
llvm::SmallVector<int64_t> stride; llvm::SmallVector<int64_t> stride;
getI64Values(adaptor.getDilation(), dilation); getI64Values(adaptor.getDilation(), dilation);
getI64Values(adaptor.getPad(), padding); getI64Values(adaptor.getPad(), pad);
getI64Values(adaptor.getStride(), stride); getI64Values(adaptor.getStride(), stride);
if (!ShapedType::isDynamic(inputHeight) && if (!ShapedType::isDynamic(inputDepth) &&
!ShapedType::isDynamic(weightHeight)) { !ShapedType::isDynamic(weightDepth)) {
int32_t inputSize = inputHeight + padding[0] + padding[1]; int32_t inputSize = inputDepth + pad[0] + pad[1];
int32_t filterSize = (weightHeight - 1) * dilation[0] + 1; int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
int32_t unstridedResult = inputSize - filterSize + 1; int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[1] = (unstridedResult - 1) / stride[0] + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
} }
if (!ShapedType::isDynamic(inputWidth) && if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(weightWidth)) { !ShapedType::isDynamic(weightHeight)) {
int32_t inputSize = inputWidth + padding[2] + padding[3]; int32_t inputSize = inputHeight + pad[2] + pad[3];
int32_t filterSize = (weightWidth - 1) * dilation[1] + 1; int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
int32_t unstridedResult = inputSize - filterSize + 1; int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[2] = (unstridedResult - 1) / stride[1] + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
} }
if (!ShapedType::isDynamic(inputDepth) && if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(weightDepth)) { !ShapedType::isDynamic(weightWidth)) {
int32_t inputSize = inputDepth + padding[4] + padding[5]; int32_t inputSize = inputWidth + pad[4] + pad[5];
int32_t filterSize = (weightDepth - 1) * dilation[2] + 1; int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
int32_t unstridedResult = inputSize - filterSize + 1; int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[3] = (unstridedResult - 1) / stride[2] + 1; outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
} }