[tosa][mlir] Add dynamic shape support for remaining ops

Added support for concat, tile, pad, argmax and table ops

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D118397
This commit is contained in:
natashaknk 2022-01-27 11:25:26 -08:00 committed by Rob Suderman
parent 9021f3682c
commit 024a1fab5c
2 changed files with 236 additions and 34 deletions

View File

@ -1681,11 +1681,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
LogicalResult
matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType || !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected static shaped tensor type");
}
Location loc = op.getLoc();
int axis = op.axis();
@ -1697,9 +1694,14 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
SmallVector<Value> dynDims;
for (int i = 0; i < rank; ++i) {
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
loc, adaptor.getOperands()[0], i));
if (inputType.isDynamicDim(i)) {
dynDims.push_back(
rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
}
}
Value resultDimSize = sizes[axis];
@ -1711,7 +1713,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
sizes[axis] = resultDimSize;
Value init = rewriter.create<linalg::InitTensorOp>(
loc, resultType.getShape(), resultType.getElementType());
loc, dynDims, resultType.getShape(), resultType.getElementType());
Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getZeroAttr(resultType.getElementType()));
@ -1815,9 +1817,6 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
SmallVector<int64_t> multiples;
getValuesFromIntArrayAttribute(op.multiples(), multiples);
@ -1828,8 +1827,15 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
genericShape.push_back(inputShape[i]);
}
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
auto initTensor = rewriter.create<linalg::InitTensorOp>(
op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
op.getLoc(), dynDims, genericShape, elementTy);
// We needs to map the input shape to the non-broadcasted dimensions.
SmallVector<AffineExpr, 4> dimExprs;
@ -1870,16 +1876,9 @@ public:
auto padding = padOp.padding();
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType paddingTy = padding.getType().cast<ShapedType>();
Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
return rewriter.notifyMatchFailure(
padOp,
"Pad converter requires static shaped input / padding values.");
}
// Setup the default constantAttr.
Value padConstant;
@ -1970,21 +1969,23 @@ public:
int axis = argmaxOp.axis();
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
if (!inputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
argmaxOp,
"tosa.arg_max to linalg.* requires statically shaped input");
if (!outElementTy.isa<IntegerType>())
return rewriter.notifyMatchFailure(
argmaxOp,
"tosa.arg_max to linalg.* requires integer-like result type");
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) && i != axis) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}
// First fill the output buffer for the index.
auto initTensorIdx =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
resultTy.getShape(), outElementTy)
.create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
outElementTy)
.result();
auto fillValueIdx = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(outElementTy, 0));
@ -1993,11 +1994,10 @@ public:
.result();
// Second fill the output buffer for the running max.
auto initTensorMax =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
resultTy.getShape(), inElementTy)
.result();
auto initTensorMax = rewriter
.create<linalg::InitTensorOp>(
loc, dynDims, resultTy.getShape(), inElementTy)
.result();
auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
@ -2138,18 +2138,22 @@ public:
auto tableTy = table.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
if (!inputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "require input type to have static shape");
auto inputElementTy = inputTy.getElementType();
auto tableElementTy = tableTy.getElementType();
auto resultElementTy = resultTy.getElementType();
SmallVector<Value> dynDims;
for (int i = 0; i < resultTy.getRank(); ++i) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(
rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
}
}
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
resultTy.getShape(), resultElementTy)
.create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
resultElementTy)
.result();
SmallVector<AffineMap, 2> affineMaps = {

View File

@ -910,6 +910,50 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// -----
// CHECK-LABEL: @concat_non_axis_dyn
func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
// CHECK: %[[AXIS:.+]] = arith.constant 0
// CHECK: %[[STRIDE:.+]] = arith.constant 1
// CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
// CHECK: %[[IDX0:.+]] = arith.constant 0 : index
// CHECK: %[[IDX1:.+]] = arith.constant 1 : index
// CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX1]]
// CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
// CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX1_2]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [11, %[[DYN]]]
// CHECK: %[[CST:.+]] = arith.constant 0.0
// CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [5, %[[SIZE]]] [1, 1]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>)
return
}
// -----
// CHECK-LABEL: @concat_axis_dyn
func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
// CHECK: %[[AXIS:.+]] = arith.constant 0
// CHECK: %[[STRIDE:.+]] = arith.constant 1
// CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
// CHECK: %[[IDX0:.+]] = arith.constant 0 : index
// CHECK: %[[SIZE:.+]] = tensor.dim %arg0, %[[IDX0]]
// CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
// CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[IDX0_2]]
// CHECK: %[[IDX1:.+]] = arith.constant 1 : index
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]], 3]
// CHECK: %[[CST:.+]] = arith.constant 0.0
// CHECK: %[[FILL:.+]] = linalg.fill(%[[CST]], %[[INIT]])
// CHECK: %[[DYN1:.+]] = tensor.dim %arg0, %[[AXIS]]
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %arg0 into %[[FILL]][0, 0] [%[[DYN1]], 3] [1, 1]
// CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]]
// CHECK: %[[DYN2:.+]] = tensor.dim %arg1, %[[AXIS]]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %arg1 into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>) -> (tensor<?x3xf32>)
return
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @rescale_i8
@ -1150,6 +1194,44 @@ func @tile(%arg0 : tensor<2x3xi8>) -> () {
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: @tile_dyn_input
func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
// CHECK: %[[CST0:.+]] = arith.constant 0
// CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]] : tensor<?x3xi8>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, %[[DYN]], 1, 3]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x3xi8>) outs(%[[INIT]] : tensor<2x?x1x3xi8>)
// CHECK: linalg.yield %arg1 : i8
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
// CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
%0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<?x3xi8>) -> (tensor<?x3xi8>)
return
}
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: @tile_dyn_multiples
func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: %[[CST1:.+]] = arith.constant 1
// CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]] : tensor<2x3xi8>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 2, %[[DYN]], 3]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3xi8>) outs(%[[INIT]] : tensor<2x2x?x3xi8>)
// CHECK: linalg.yield %arg1 : i8
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1, 2, 3]]
// CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1]]
%0 = "tosa.tile"(%arg0) {multiples = [2, -1]} : (tensor<2x3xi8>) -> (tensor<2x?xi8>)
return
}
// -----
func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// TODO: Output contains multiple "arith.constant 1 : index".
@ -1205,6 +1287,40 @@ func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
// -----
func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
// TODO: Output contains multiple "arith.constant 1 : index".
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: ^bb0(%arg1: index, %arg2: index):
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
%1 = "tosa.pad"(%arg0, %0) : (tensor<?x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
// TODO: Output contains multiple "arith.constant 1 : index".
// CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK: tensor.pad %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]] {
// CHECK: ^bb0(%arg1: index, %arg2: index):
// CHECK: tensor.yield [[CST]]
// CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
%1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<?x9xf32>)
return %1 : tensor<?x9xf32>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
@ -1256,6 +1372,54 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
func @argmax_dyn_non_axis(%arg0 : tensor<3x?xi32>) -> () {
// CHECK: %[[CST1:.+]] = arith.constant 1
// CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST1]]
// CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
// CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
// CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
// CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [%[[DYN]]]
// CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
// CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
// CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<?xi32>, tensor<?xi32>)
// CHECK: %[[IDX:.+]] = linalg.index 0
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]]
// CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3
// CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3
// CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2
// CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
%0 = "tosa.argmax"(%arg0) { axis = 0 : i64} : (tensor<3x?xi32>) -> (tensor<?xi32>)
return
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
func @argmax_dyn_axis(%arg0 : tensor<3x?xi32>) -> () {
// CHECK: %[[IDX_INIT:.+]] = linalg.init_tensor [3]
// CHECK: %[[IDX_MIN:.+]] = arith.constant 0 : i32
// CHECK: %[[IDX_FILL:.+]] = linalg.fill(%[[IDX_MIN]], %[[IDX_INIT]])
// CHECK: %[[VAL_INIT:.+]] = linalg.init_tensor [3]
// CHECK: %[[VAL_MIN:.+]] = arith.constant -2147483648
// CHECK: %[[VAL_FILL:.+]] = linalg.fill(%[[VAL_MIN]], %[[VAL_INIT]])
// CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<3x?xi32>) outs(%[[IDX_FILL]], %[[VAL_FILL]] : tensor<3xi32>, tensor<3xi32>)
// CHECK: %[[IDX:.+]] = linalg.index 1
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]]
// CHECK: %[[CMP:.+]] = arith.cmpi sgt, %arg1, %arg3
// CHECK: %[[SELECT_VAL:.+]] = select %[[CMP]], %arg1, %arg3
// CHECK: %[[SELECT_IDX:.+]] = select %[[CMP]], %[[CAST]], %arg2
// CHECK: linalg.yield %[[SELECT_IDX]], %[[SELECT_VAL]]
%0 = "tosa.argmax"(%arg0) { axis = 1 : i64} : (tensor<3x?xi32>) -> (tensor<3xi32>)
return
}
// -----
// CHECK-LABEL: @gather_float
func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
@ -1349,6 +1513,40 @@ func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
// -----
// CHECK-LABEL: @table8_dyn
func @table8_dyn(%arg0: tensor<?xi8>, %arg1: tensor<512xi8>) -> () {
// CHECK: %[[CST0:.+]] = arith.constant 0
// CHECK: %[[DYN:.+]] = tensor.dim %arg0, %[[CST0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DYN]]]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<?xi8>) outs(%[[INIT]] : tensor<?xi8>)
// CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
// CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
// CHECK: %[[OFFSET:.+]] = arith.constant 128
// CHECK: %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]]
// CHECK: linalg.yield %[[EXTRACT]]
%0 = "tosa.table"(%arg0, %arg1) : (tensor<?xi8>, tensor<512xi8>) -> (tensor<?xi8>)
return
}
// -----
// CHECK-LABEL: @table8_dyn_table
func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
// CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
// CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG_IN]]
// CHECK: %[[OFFSET:.+]] = arith.constant 128
// CHECK: %[[ADD:.+]] = arith.addi %[[CAST]], %[[OFFSET]]
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[ADD]]]
// CHECK: linalg.yield %[[EXTRACT]]
%0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi8>, tensor<?xi8>) -> (tensor<6xi8>)
return
}
// -----
// CHECK-LABEL: @resize_nearest
func @resize_nearest(%input: tensor<1x2x2x1xf32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 4, 4, 1]