forked from OSchip/llvm-project
[tosa] Add legalization for conv3d
Update the existing implementation to match TOSA spec. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D133062
This commit is contained in:
parent
3f965818b6
commit
eb04f321c3
|
@ -1070,7 +1070,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||||||
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
|
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||||
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
|
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
|
||||||
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
|
Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
|
||||||
|
|
||||||
int32_t inputWidth = ShapedType::kDynamicSize;
|
int32_t inputWidth = ShapedType::kDynamicSize;
|
||||||
int32_t inputHeight = ShapedType::kDynamicSize;
|
int32_t inputHeight = ShapedType::kDynamicSize;
|
||||||
|
@ -1084,55 +1084,54 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||||
if (inputShape.hasRank()) {
|
if (inputShape.hasRank()) {
|
||||||
outputShape[0] = inputShape.getDimSize(0);
|
outputShape[0] = inputShape.getDimSize(0);
|
||||||
inputHeight = inputShape.getDimSize(1);
|
inputDepth = inputShape.getDimSize(1);
|
||||||
inputWidth = inputShape.getDimSize(2);
|
inputHeight = inputShape.getDimSize(2);
|
||||||
inputDepth = inputShape.getDimSize(3);
|
inputWidth = inputShape.getDimSize(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Weight shapes describes the filter width/height and the output channels.
|
// Weight shapes describes the filter width/height and the output channels.
|
||||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||||
if (weightShape.hasRank()) {
|
if (weightShape.hasRank()) {
|
||||||
outputShape[4] = weightShape.getDimSize(0);
|
outputShape[4] = weightShape.getDimSize(0);
|
||||||
weightHeight = weightShape.getDimSize(1);
|
weightDepth = weightShape.getDimSize(1);
|
||||||
weightWidth = weightShape.getDimSize(2);
|
weightHeight = weightShape.getDimSize(2);
|
||||||
weightDepth = weightShape.getDimSize(3);
|
weightWidth = weightShape.getDimSize(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bias shape can describe the output channels.
|
// Bias shape can describe the output channels.
|
||||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||||
if (biasShape.hasRank()) {
|
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
|
||||||
outputShape[4] =
|
outputShape[4] = biasShape.getDimSize(0);
|
||||||
(outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> dilation;
|
llvm::SmallVector<int64_t> dilation;
|
||||||
llvm::SmallVector<int64_t> padding;
|
llvm::SmallVector<int64_t> pad;
|
||||||
llvm::SmallVector<int64_t> stride;
|
llvm::SmallVector<int64_t> stride;
|
||||||
|
|
||||||
getI64Values(adaptor.getDilation(), dilation);
|
getI64Values(adaptor.getDilation(), dilation);
|
||||||
getI64Values(adaptor.getPad(), padding);
|
getI64Values(adaptor.getPad(), pad);
|
||||||
getI64Values(adaptor.getStride(), stride);
|
getI64Values(adaptor.getStride(), stride);
|
||||||
|
|
||||||
if (!ShapedType::isDynamic(inputHeight) &&
|
if (!ShapedType::isDynamic(inputDepth) &&
|
||||||
!ShapedType::isDynamic(weightHeight)) {
|
!ShapedType::isDynamic(weightDepth)) {
|
||||||
int32_t inputSize = inputHeight + padding[0] + padding[1];
|
int32_t inputSize = inputDepth + pad[0] + pad[1];
|
||||||
int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
|
int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
|
||||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||||
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ShapedType::isDynamic(inputWidth) &&
|
if (!ShapedType::isDynamic(inputHeight) &&
|
||||||
!ShapedType::isDynamic(weightWidth)) {
|
!ShapedType::isDynamic(weightHeight)) {
|
||||||
int32_t inputSize = inputWidth + padding[2] + padding[3];
|
int32_t inputSize = inputHeight + pad[2] + pad[3];
|
||||||
int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
|
int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
|
||||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||||
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
|
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ShapedType::isDynamic(inputDepth) &&
|
if (!ShapedType::isDynamic(inputWidth) &&
|
||||||
!ShapedType::isDynamic(weightDepth)) {
|
!ShapedType::isDynamic(weightWidth)) {
|
||||||
int32_t inputSize = inputDepth + padding[4] + padding[5];
|
int32_t inputSize = inputWidth + pad[4] + pad[5];
|
||||||
int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
|
int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
|
||||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||||
outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
|
outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue