forked from OSchip/llvm-project
[mlir][tosa] Add tosa.gather lowering to linalg.indexed_generic
Lowering gather operation to linalg dialect. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D101200
This commit is contained in:
parent
2205286095
commit
6f720d5eca
mlir
|
@ -1781,6 +1781,59 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tosa::GatherOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto input = args[0];
|
||||||
|
auto indices = args[1];
|
||||||
|
|
||||||
|
auto inputTy = input.getType().cast<ShapedType>();
|
||||||
|
auto indicesTy = indices.getType().cast<ShapedType>();
|
||||||
|
auto resultTy = op.getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "require input type to have static shape");
|
||||||
|
|
||||||
|
auto resultElementTy = resultTy.getElementType();
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
auto initTensor =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
|
||||||
|
resultTy.getShape(), resultElementTy)
|
||||||
|
.result();
|
||||||
|
|
||||||
|
SmallVector<AffineMap, 2> affineMaps = {
|
||||||
|
AffineMap::get(
|
||||||
|
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
|
||||||
|
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
|
||||||
|
rewriter.getContext()),
|
||||||
|
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
|
||||||
|
|
||||||
|
auto genericOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||||
|
loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
|
||||||
|
ValueRange{initTensor}, affineMaps,
|
||||||
|
getNParallelLoopsAttrs(resultTy.getRank()),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) {
|
||||||
|
auto indexValue = args[0];
|
||||||
|
auto index0 = indices[0];
|
||||||
|
Value index1 = rewriter.create<IndexCastOp>(
|
||||||
|
loc, rewriter.getIndexType(), indexValue);
|
||||||
|
auto index2 = indices[2];
|
||||||
|
Value extract = rewriter.create<tensor::ExtractOp>(
|
||||||
|
loc, input, ValueRange{index0, index1, index2});
|
||||||
|
rewriter.create<linalg::YieldOp>(loc, extract);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(op, genericOp.getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Lowerings the TableOp to a series of gathers and numerica operations. This
|
// Lowerings the TableOp to a series of gathers and numerica operations. This
|
||||||
// includes interpolation between the high/low values. For the I8 varient, this
|
// includes interpolation between the high/low values. For the I8 varient, this
|
||||||
// simplifies to a single gather operation.
|
// simplifies to a single gather operation.
|
||||||
|
@ -2085,6 +2138,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
||||||
ArgMaxConverter,
|
ArgMaxConverter,
|
||||||
ConcatConverter,
|
ConcatConverter,
|
||||||
Conv2DConverter,
|
Conv2DConverter,
|
||||||
|
GatherConverter,
|
||||||
PadConverter,
|
PadConverter,
|
||||||
ReshapeConverter,
|
ReshapeConverter,
|
||||||
RescaleConverter,
|
RescaleConverter,
|
||||||
|
|
|
@ -833,6 +833,32 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @gather_float
|
||||||
|
func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
|
||||||
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
|
||||||
|
// CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xf32>)
|
||||||
|
// CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32)
|
||||||
|
// CHECK: %[[CAST:.+]] = index_cast %[[ARG0]]
|
||||||
|
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xf32>
|
||||||
|
// CHECK: linalg.yield %[[EXTRACT]]
|
||||||
|
%0 = "tosa.gather"(%arg0, %arg1) : (tensor<2x3x2xf32>, tensor<2x3xi32>) -> (tensor<2x3x2xf32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @gather_int
|
||||||
|
func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () {
|
||||||
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
|
||||||
|
// CHECK: %[[GENERIC:.+]] = linalg.indexed_generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%[[INIT]] : tensor<2x3x2xi32>)
|
||||||
|
// CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index, %[[IDX2:.+]]: index, %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
|
||||||
|
// CHECK: %[[CAST:.+]] = index_cast %[[ARG0]]
|
||||||
|
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<2x3x2xi32>
|
||||||
|
// CHECK: linalg.yield %[[EXTRACT]]
|
||||||
|
%0 = "tosa.gather"(%arg0, %arg1) : (tensor<2x3x2xi32>, tensor<2x3xi32>) -> (tensor<2x3x2xi32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @table8
|
// CHECK-LABEL: @table8
|
||||||
func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () {
|
func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () {
|
||||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
|
// CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
|
||||||
|
|
Loading…
Reference in New Issue