forked from OSchip/llvm-project
[tosa] Add legalization for conv3d
Update the existing implementation to match TOSA spec. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D133062
This commit is contained in:
parent
3f965818b6
commit
eb04f321c3
|
@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
|||
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
|
||||
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
|
||||
Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
|
||||
|
||||
int32_t inputWidth = ShapedType::kDynamicSize;
|
||||
int32_t inputHeight = ShapedType::kDynamicSize;
|
||||
|
@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
|||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
inputHeight = inputShape.getDimSize(1);
|
||||
inputWidth = inputShape.getDimSize(2);
|
||||
inputDepth = inputShape.getDimSize(3);
|
||||
inputDepth = inputShape.getDimSize(1);
|
||||
inputHeight = inputShape.getDimSize(2);
|
||||
inputWidth = inputShape.getDimSize(3);
|
||||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||
if (weightShape.hasRank()) {
|
||||
outputShape[4] = weightShape.getDimSize(0);
|
||||
weightHeight = weightShape.getDimSize(1);
|
||||
weightWidth = weightShape.getDimSize(2);
|
||||
weightDepth = weightShape.getDimSize(3);
|
||||
weightDepth = weightShape.getDimSize(1);
|
||||
weightHeight = weightShape.getDimSize(2);
|
||||
weightWidth = weightShape.getDimSize(3);
|
||||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[4] =
|
||||
(outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
|
||||
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
|
||||
outputShape[4] = biasShape.getDimSize(0);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> dilation;
|
||||
llvm::SmallVector<int64_t> padding;
|
||||
llvm::SmallVector<int64_t> pad;
|
||||
llvm::SmallVector<int64_t> stride;
|
||||
|
||||
getI64Values(adaptor.getDilation(), dilation);
|
||||
getI64Values(adaptor.getPad(), padding);
|
||||
getI64Values(adaptor.getPad(), pad);
|
||||
getI64Values(adaptor.getStride(), stride);
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
int32_t inputSize = inputHeight + padding[0] + padding[1];
|
||||
int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
|
||||
if (!ShapedType::isDynamic(inputDepth) &&
|
||||
!ShapedType::isDynamic(weightDepth)) {
|
||||
int32_t inputSize = inputDepth + pad[0] + pad[1];
|
||||
int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
|
||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(inputWidth) &&
|
||||
!ShapedType::isDynamic(weightWidth)) {
|
||||
int32_t inputSize = inputWidth + padding[2] + padding[3];
|
||||
int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
int32_t inputSize = inputHeight + pad[2] + pad[3];
|
||||
int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
|
||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(inputDepth) &&
|
||||
!ShapedType::isDynamic(weightDepth)) {
|
||||
int32_t inputSize = inputDepth + padding[4] + padding[5];
|
||||
int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
|
||||
if (!ShapedType::isDynamic(inputWidth) &&
|
||||
!ShapedType::isDynamic(weightWidth)) {
|
||||
int32_t inputSize = inputWidth + pad[4] + pad[5];
|
||||
int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
|
||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||
outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue