[tosa] Add legalization for conv3d

Update the existing implementation to match TOSA spec.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D133062
This commit is contained in:
TatWai Chong 2022-10-06 12:50:12 -07:00 committed by Rob Suderman
parent 3f965818b6
commit eb04f321c3
1 changed files with 23 additions and 24 deletions

View File

@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
int32_t inputWidth = ShapedType::kDynamicSize;
int32_t inputHeight = ShapedType::kDynamicSize;
@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
if (inputShape.hasRank()) {
outputShape[0] = inputShape.getDimSize(0);
inputHeight = inputShape.getDimSize(1);
inputWidth = inputShape.getDimSize(2);
inputDepth = inputShape.getDimSize(3);
inputDepth = inputShape.getDimSize(1);
inputHeight = inputShape.getDimSize(2);
inputWidth = inputShape.getDimSize(3);
}
// Weight shapes describes the filter width/height and the output channels.
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
if (weightShape.hasRank()) {
outputShape[4] = weightShape.getDimSize(0);
weightHeight = weightShape.getDimSize(1);
weightWidth = weightShape.getDimSize(2);
weightDepth = weightShape.getDimSize(3);
weightDepth = weightShape.getDimSize(1);
weightHeight = weightShape.getDimSize(2);
weightWidth = weightShape.getDimSize(3);
}
// Bias shape can describe the output channels.
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
if (biasShape.hasRank()) {
outputShape[4] =
(outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
outputShape[4] = biasShape.getDimSize(0);
}
llvm::SmallVector<int64_t> dilation;
llvm::SmallVector<int64_t> padding;
llvm::SmallVector<int64_t> pad;
llvm::SmallVector<int64_t> stride;
getI64Values(adaptor.getDilation(), dilation);
getI64Values(adaptor.getPad(), padding);
getI64Values(adaptor.getPad(), pad);
getI64Values(adaptor.getStride(), stride);
if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(weightHeight)) {
int32_t inputSize = inputHeight + padding[0] + padding[1];
int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
if (!ShapedType::isDynamic(inputDepth) &&
!ShapedType::isDynamic(weightDepth)) {
int32_t inputSize = inputDepth + pad[0] + pad[1];
int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
}
if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(weightWidth)) {
int32_t inputSize = inputWidth + padding[2] + padding[3];
int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(weightHeight)) {
int32_t inputSize = inputHeight + pad[2] + pad[3];
int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
}
if (!ShapedType::isDynamic(inputDepth) &&
!ShapedType::isDynamic(weightDepth)) {
int32_t inputSize = inputDepth + padding[4] + padding[5];
int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(weightWidth)) {
int32_t inputSize = inputWidth + pad[4] + pad[5];
int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
}