forked from OSchip/llvm-project
[mlir][tosa] Add conv2d lowering to linalg.conv2d operator for FP
Handles lowering conv2d to linalg's convolution operator. This implementation only supports floating point values but handles all strides, dilations, and padding values. Differential Revision: https://reviews.llvm.org/D100061
This commit is contained in:
parent
204aaf8795
commit
7e1fb9a0d2
|
@ -740,6 +740,109 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class Conv2DConverter : public OpConversionPattern<tosa::Conv2DOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
Value bias = op.bias();
|
||||
|
||||
ShapedType inputTy = input.getType().cast<ShapedType>();
|
||||
ShapedType weightTy = weight.getType().cast<ShapedType>();
|
||||
ShapedType biasTy = bias.getType().cast<ShapedType>();
|
||||
ShapedType resultTy = op.getType().cast<ShapedType>();
|
||||
|
||||
Type inputETy = inputTy.getElementType();
|
||||
Type weightETy = weightTy.getElementType();
|
||||
Type biasETy = biasTy.getElementType();
|
||||
Type resultETy = resultTy.getElementType();
|
||||
|
||||
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
|
||||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"tosa.conv2d requires static shapes");
|
||||
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
// TODO(suderman): Support other types.
|
||||
if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() ||
|
||||
!resultETy.isF32())
|
||||
return failure();
|
||||
|
||||
// Broadcast the initial value to the output tensor before convolving.
|
||||
SmallVector<AffineMap, 4> indexingMaps;
|
||||
indexingMaps.push_back(AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(3)},
|
||||
rewriter.getContext()));
|
||||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
|
||||
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultTy.getShape(), resultTy.getElementType());
|
||||
Value biasBroadcast =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, resultTy, bias, initTensor, indexingMaps,
|
||||
getNParallelLoopsAttrs(resultTy.getRank()),
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc,
|
||||
ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// Transpose weights tensor to be in dim order: spatial dims,
|
||||
// input channels, and output channels.
|
||||
SmallVector<int64_t> permutation{1, 2, 3, 0};
|
||||
auto permutationAttr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({4}, rewriter.getI64Type()), permutation);
|
||||
Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);
|
||||
|
||||
SmallVector<int64_t> newKernelShape{weightShape[1], weightShape[2],
|
||||
weightShape[3], weightShape[0]};
|
||||
Type newKernelTy = RankedTensorType::get(newKernelShape, biasETy);
|
||||
|
||||
Value transposedKernel = rewriter.create<tosa::TransposeOp>(
|
||||
loc, newKernelTy, weight, permutationValue);
|
||||
|
||||
// Extract the attributes for convolution.
|
||||
llvm::SmallVector<int64_t> stride, dilation, pad;
|
||||
getValuesFromIntArrayAttribute(op.stride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.dilation(), dilation);
|
||||
getValuesFromIntArrayAttribute(op.pad(), pad);
|
||||
|
||||
// Input should be padded if necessary.
|
||||
if (llvm::any_of(pad, [](int64_t p) { return p; })) {
|
||||
llvm::SmallVector<int64_t, 8> newPad{0, 0, pad[0], pad[1],
|
||||
pad[2], pad[3], 0, 0};
|
||||
auto padAttr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({4, 2}, rewriter.getI64Type()), newPad);
|
||||
Value padValue = rewriter.create<ConstantOp>(loc, padAttr);
|
||||
|
||||
SmallVector<int64_t, 4> paddedShape{
|
||||
inputShape[0], inputShape[1] + pad[0] + pad[1],
|
||||
inputShape[2] + pad[2] + pad[3], inputShape[3]};
|
||||
Type paddedTy = RankedTensorType::get(paddedShape, inputETy);
|
||||
input = rewriter.create<tosa::PadOp>(loc, paddedTy, input, padValue);
|
||||
}
|
||||
|
||||
auto strideAttr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
|
||||
auto dilationAttr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
|
||||
|
||||
auto convOp = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
|
||||
loc, resultTy, ValueRange{input, transposedKernel},
|
||||
ValueRange{biasBroadcast}, dilationAttr, strideAttr);
|
||||
|
||||
rewriter.replaceOp(op, convOp.getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
|
@ -1693,6 +1796,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
|||
ReduceConverter<tosa::ReduceProdOp>,
|
||||
ArgMaxConverter,
|
||||
ConcatConverter,
|
||||
Conv2DConverter,
|
||||
PadConverter,
|
||||
ReshapeConverter,
|
||||
RescaleConverter,
|
||||
|
|
|
@ -923,3 +923,28 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
|
|||
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
|
||||
|
||||
func @conv2d_f32(%input: tensor<1x49x42x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28] : tensor<1x45x40x28xf32>
|
||||
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
|
||||
// CHECK: ^bb0(%arg3: f32, %arg4: f32):
|
||||
// CHECK: linalg.yield %arg3 : f32
|
||||
// CHECK: %[[INITKERNEL:.+]] = linalg.init_tensor [3, 3, 28, 28]
|
||||
// CHECK: %[[TRANSPOSEKERNEL:.+]] = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x28xf32>) outs(%[[INITKERNEL]] : tensor<3x3x28x28xf32>)
|
||||
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSEKERNEL]] : tensor<1x49x42x28xf32>, tensor<3x3x28x28xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
|
||||
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
|
||||
// CHECK: linalg.pad_tensor %arg0
|
||||
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
|
||||
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue