[mlir][tosa] Add conv2d lowering to linalg.conv2d operator for FP

Handles lowering conv2d to linalg's convolution operator. This implementation
only supports floating point values but handles all strides, dilations, and
padding values.

Differential Revision: https://reviews.llvm.org/D100061
This commit is contained in:
Rob Suderman 2021-04-06 18:31:51 -07:00
parent 204aaf8795
commit 7e1fb9a0d2
2 changed files with 129 additions and 0 deletions

View File

@ -740,6 +740,109 @@ public:
}
};
class Conv2DConverter : public OpConversionPattern<tosa::Conv2DOp> {
public:
using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.input();
Value weight = op.weight();
Value bias = op.bias();
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op.getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
Type biasETy = biasTy.getElementType();
Type resultETy = resultTy.getElementType();
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv2d requires static shapes");
auto inputShape = inputTy.getShape();
auto weightShape = weightTy.getShape();
// TODO(suderman): Support other types.
if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() ||
!resultETy.isF32())
return failure();
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(3)},
rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());
Value biasBroadcast =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, bias, initTensor, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
// Transpose weights tensor to be in dim order: spatial dims,
// input channels, and output channels.
SmallVector<int64_t> permutation{1, 2, 3, 0};
auto permutationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4}, rewriter.getI64Type()), permutation);
Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
SmallVector<int64_t> newKernelShape{weightShape[1], weightShape[2],
weightShape[3], weightShape[0]};
Type newKernelTy = RankedTensorType::get(newKernelShape, biasETy);
Value transposedKernel = rewriter.create<tosa::TransposeOp>(
loc, newKernelTy, weight, permutationValue);
// Extract the attributes for convolution.
llvm::SmallVector<int64_t> stride, dilation, pad;
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.dilation(), dilation);
getValuesFromIntArrayAttribute(op.pad(), pad);
// Input should be padded if necessary.
if (llvm::any_of(pad, [](int64_t p) { return p; })) {
llvm::SmallVector<int64_t, 8> newPad{0, 0, pad[0], pad[1],
pad[2], pad[3], 0, 0};
auto padAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI64Type()), newPad);
Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
SmallVector<int64_t, 4> paddedShape{
inputShape[0], inputShape[1] + pad[0] + pad[1],
inputShape[2] + pad[2] + pad[3], inputShape[3]};
Type paddedTy = RankedTensorType::get(paddedShape, inputETy);
input = rewriter.create<tosa::PadOp>(loc, paddedTy, input, padValue);
}
auto strideAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
auto convOp = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
loc, resultTy, ValueRange{input, transposedKernel},
ValueRange{biasBroadcast}, dilationAttr, strideAttr);
rewriter.replaceOp(op, convOp.getResult(0));
return success();
}
};
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
@ -1693,6 +1796,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
ReduceConverter<tosa::ReduceProdOp>,
ArgMaxConverter,
ConcatConverter,
Conv2DConverter,
PadConverter,
ReshapeConverter,
RescaleConverter,

View File

@ -923,3 +923,28 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
return
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
func @conv2d_f32(%input: tensor<1x49x42x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28] : tensor<1x45x40x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
// CHECK: ^bb0(%arg3: f32, %arg4: f32):
// CHECK: linalg.yield %arg3 : f32
// CHECK: %[[INITKERNEL:.+]] = linalg.init_tensor [3, 3, 28, 28]
// CHECK: %[[TRANSPOSEKERNEL:.+]] = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x28xf32>) outs(%[[INITKERNEL]] : tensor<3x3x28x28xf32>)
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSEKERNEL]] : tensor<1x49x42x28xf32>, tensor<3x3x28x28xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)
return
}
func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: linalg.pad_tensor %arg0
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)
return
}