forked from OSchip/llvm-project
[MLIR][LINALG] Add canonicalization pattern in `linalg.generic` op for static shape inference.
This commit adds canonicalization pattern in `linalg.generic` op for static shape inference. If any of the inputs or outputs have static shape or is casted from a tensor of static shape, then shapes of all the inputs and outputs can be inferred by using the affine map of the static shape input/output. Signed-Off-By: Prateek Gupta <prateek@nod-labs.com> Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D118929
This commit is contained in:
parent
c1e4e01945
commit
1a2bb03eda
|
@ -841,11 +841,169 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// For each of the operand in `operands` this function maps the static sizes of
|
||||||
|
/// dimensions to their affine dim expressions.
|
||||||
|
static void populateMap(GenericOp genericOp, ArrayRef<OpOperand *> operands,
|
||||||
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
|
||||||
|
for (OpOperand *opOperand : operands) {
|
||||||
|
if (genericOp.isScalar(opOperand))
|
||||||
|
continue;
|
||||||
|
Value src = opOperand->get();
|
||||||
|
auto sourceType = src.getType().cast<RankedTensorType>();
|
||||||
|
auto sourceMap = genericOp.getTiedIndexingMap(opOperand);
|
||||||
|
|
||||||
|
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
|
||||||
|
// `tensor.cast` operation and source of the cast operation has a static
|
||||||
|
// shape, then assign it to the `sourceShape`.
|
||||||
|
auto parentOp = src.getDefiningOp();
|
||||||
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||||
|
if (parentOp) {
|
||||||
|
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
|
||||||
|
Value castSource = castOp.source();
|
||||||
|
auto castSourceType = castSource.getType().cast<RankedTensorType>();
|
||||||
|
if (castSourceType.hasStaticShape())
|
||||||
|
sourceShape = castSourceType.getShape();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the source shape's dimension has a static shape, map the affine dim
|
||||||
|
// expression to the known static size.
|
||||||
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
||||||
|
if (sourceType.isDynamicDim(i))
|
||||||
|
continue;
|
||||||
|
if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
|
||||||
|
affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates new operand w.r.t 'opOperand' of `genericOp` with static sizes
|
||||||
|
/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
|
||||||
|
/// their result types is stored in `resultTypes`. If `opOperand` requires no
|
||||||
|
/// change then `changeNeeded` is false and same operand is added in the
|
||||||
|
/// `newOperands` list.
|
||||||
|
static void createNewOperandWithStaticSizes(
|
||||||
|
Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
|
||||||
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, GenericOp genericOp,
|
||||||
|
SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
|
||||||
|
bool &changeNeeded) {
|
||||||
|
Value src = opOperand->get();
|
||||||
|
newOperands.push_back(src);
|
||||||
|
if (genericOp.isScalar(opOperand))
|
||||||
|
return;
|
||||||
|
auto sourceType = src.getType().cast<RankedTensorType>();
|
||||||
|
Type resultType = sourceType;
|
||||||
|
if (sourceType.hasStaticShape() && genericOp.isOutputTensor(opOperand)) {
|
||||||
|
resultTypes.push_back(resultType);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||||
|
AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand);
|
||||||
|
SmallVector<int64_t> newShape;
|
||||||
|
// If operand is updated with new shape, `newOperandNeeded` will be
|
||||||
|
// true.
|
||||||
|
bool newOperandNeeded = false;
|
||||||
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
||||||
|
int64_t dimShape = sourceShape[i];
|
||||||
|
AffineExpr dimExpr = sourceMap.getResult(i);
|
||||||
|
if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
|
||||||
|
!sourceType.isDynamicDim(i)) {
|
||||||
|
newShape.push_back(dimShape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Dimension has a dynamic shape and corresponding affine dim
|
||||||
|
// expression is present in the map. So assign the size for the
|
||||||
|
// given affine dim expression to the dimension.
|
||||||
|
newShape.push_back(affineExprToSize[dimExpr]);
|
||||||
|
newOperandNeeded = true;
|
||||||
|
}
|
||||||
|
resultType = RankedTensorType::get(newShape, sourceType.getElementType());
|
||||||
|
if (newOperandNeeded) {
|
||||||
|
changeNeeded = true;
|
||||||
|
// Get the new operand value given its size and element type by
|
||||||
|
// casting it.
|
||||||
|
Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
|
||||||
|
unsigned index = opOperand->getOperandNumber();
|
||||||
|
newOperands[index] = newOperand;
|
||||||
|
}
|
||||||
|
if (genericOp.isOutputTensor(opOperand))
|
||||||
|
resultTypes.push_back(resultType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Static shapes for the operands can be inferred if any one of the operands
|
||||||
|
/// have a static shape. This can be done by referring to the affine dim
|
||||||
|
/// expressions for the operand.
|
||||||
|
struct InferStaticShapeOfOperands : public OpRewritePattern<GenericOp> {
|
||||||
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (!genericOp.hasTensorSemantics())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Maps must be projected permutations.
|
||||||
|
if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
|
||||||
|
return !map.isProjectedPermutation();
|
||||||
|
}))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Maps affine dim expressions to the static size of that dimension.
|
||||||
|
llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
|
||||||
|
Location loc = genericOp.getLoc();
|
||||||
|
|
||||||
|
// For each of the affine dim expression, check if the size is known. If
|
||||||
|
// known add that in the map.
|
||||||
|
populateMap(genericOp, genericOp.getInputAndOutputOperands(),
|
||||||
|
affineExprToSize);
|
||||||
|
|
||||||
|
SmallVector<Value> newOperands;
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
|
||||||
|
// `changeNeeded` is `false` if the operands of `genericOp` require no
|
||||||
|
// change in their types.
|
||||||
|
bool changeNeeded = false;
|
||||||
|
newOperands.reserve(genericOp.getNumInputsAndOutputs());
|
||||||
|
resultTypes.reserve(genericOp.getNumOutputs());
|
||||||
|
|
||||||
|
// Iterate over all the operands and update the static sizes.
|
||||||
|
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||||
|
createNewOperandWithStaticSizes(loc, rewriter, opOperand,
|
||||||
|
affineExprToSize, genericOp, newOperands,
|
||||||
|
resultTypes, changeNeeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the generic op has all the required static information, no
|
||||||
|
// canonicalization needed.
|
||||||
|
if (!changeNeeded)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Clone op.
|
||||||
|
Operation *newOp =
|
||||||
|
cast<linalg::LinalgOp>(genericOp.getOperation())
|
||||||
|
.clone(rewriter, genericOp->getLoc(), resultTypes, newOperands);
|
||||||
|
SmallVector<Value> replacements;
|
||||||
|
replacements.reserve(newOp->getNumResults());
|
||||||
|
for (auto it : llvm::zip(genericOp->getResults(), newOp->getResults())) {
|
||||||
|
Value newResult = std::get<1>(it);
|
||||||
|
Value oldResult = std::get<0>(it);
|
||||||
|
Type newType = newResult.getType();
|
||||||
|
Type oldType = oldResult.getType();
|
||||||
|
replacements.push_back(
|
||||||
|
(newType != oldType)
|
||||||
|
? rewriter.create<tensor::CastOp>(loc, newType, newResult)
|
||||||
|
: newResult);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(genericOp, replacements);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
|
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
|
||||||
|
InferStaticShapeOfOperands>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -650,3 +650,133 @@ func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
|
||||||
} : tensor<400x273xf32> to tensor<412x276xf32>
|
} : tensor<400x273xf32> to tensor<412x276xf32>
|
||||||
return %pad : tensor<412x276xf32>
|
return %pad : tensor<412x276xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Tests below verify whether static information is propagated through all the operands of generic op.
|
||||||
|
// 1. If one of the inputs of generic op has static info and it has no cast source.
|
||||||
|
// 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation.
|
||||||
|
// 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation.
|
||||||
|
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @static_input_without_cast
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
|
||||||
|
%1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
|
||||||
|
%2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
|
||||||
|
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
|
||||||
|
%4 = linalg.generic {
|
||||||
|
indexing_maps = [#map, #map, #map],
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
|
} ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<?x?x?xf32>)
|
||||||
|
outs(%3 : tensor<?x?x?xf32>) {
|
||||||
|
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
|
||||||
|
%9 = arith.addf %arg2, %arg3 : f32
|
||||||
|
linalg.yield %9 : f32
|
||||||
|
} -> (tensor<?x?x?xf32>)
|
||||||
|
%5 = tensor.cast %4 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
return %5 : tensor<2x3x4xf32>
|
||||||
|
// CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
// CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
|
||||||
|
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @static_input_with_cast
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
|
||||||
|
%1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
|
||||||
|
%2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
|
||||||
|
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
|
||||||
|
%4 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
|
||||||
|
%5 = linalg.generic {
|
||||||
|
indexing_maps = [#map, #map, #map],
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
|
} ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>)
|
||||||
|
outs(%3 : tensor<?x?x?xf32>) {
|
||||||
|
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
|
||||||
|
%9 = arith.addf %arg2, %arg3 : f32
|
||||||
|
linalg.yield %9 : f32
|
||||||
|
} -> (tensor<?x?x?xf32>)
|
||||||
|
%6 = tensor.cast %5 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
return %6: tensor<2x3x4xf32>
|
||||||
|
// CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
// CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
|
||||||
|
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @static_output_with_cast
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32>
|
||||||
|
%1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32>
|
||||||
|
%2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32>
|
||||||
|
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
|
||||||
|
%4 = tensor.cast %3 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
%5 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
|
||||||
|
%6 = linalg.generic {
|
||||||
|
indexing_maps = [#map, #map, #map],
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
|
} ins(%arg0, %5 : tensor<?x?x?xf32>, tensor<2x?x?xf32>)
|
||||||
|
outs(%4 : tensor<2x3x4xf32>) {
|
||||||
|
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
|
||||||
|
%9 = arith.addf %arg3, %arg4 : f32
|
||||||
|
linalg.yield %9 : f32
|
||||||
|
} -> (tensor<2x3x4xf32>)
|
||||||
|
return %6: tensor<2x3x4xf32>
|
||||||
|
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
// CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
// CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
|
||||||
|
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// This test checks the folding of tensor.cast operation when the source value of cast
|
||||||
|
// has more static information than the destination value.
|
||||||
|
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @cast_source
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
|
||||||
|
%1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
|
||||||
|
%2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
|
||||||
|
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
|
||||||
|
%4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
|
||||||
|
%5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
|
||||||
|
%6 = linalg.generic {
|
||||||
|
indexing_maps = [#map, #map, #map],
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
|
} ins(%4, %5 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
|
||||||
|
outs(%3 : tensor<?x?x?xf32>) {
|
||||||
|
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
|
||||||
|
%9 = arith.addf %arg2, %arg3 : f32
|
||||||
|
linalg.yield %9 : f32
|
||||||
|
} -> (tensor<?x?x?xf32>)
|
||||||
|
%7 = tensor.cast %6 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||||
|
return %7: tensor<2x3x4xf32>
|
||||||
|
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
|
||||||
|
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
|
||||||
|
}
|
||||||
|
|
|
@ -533,27 +533,28 @@ func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor<?xi64>) -> tensor<1xi64> {
|
func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> {
|
||||||
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
|
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64>
|
||||||
%1 = linalg.init_tensor [1] : tensor<1xi64>
|
%1 = linalg.init_tensor [2] : tensor<2xi64>
|
||||||
%2 = linalg.generic
|
%2 = linalg.generic
|
||||||
{indexing_maps = [affine_map<(d0) -> (d0)>,
|
{indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||||
affine_map<(d0) -> (d0)>,
|
affine_map<(d0) -> (d0)>,
|
||||||
affine_map<(d0) -> (d0)>],
|
affine_map<(d0) -> (d0)>],
|
||||||
iterator_types = ["parallel"]}
|
iterator_types = ["parallel"]}
|
||||||
ins(%0, %arg1 : tensor<1xi64>, tensor<?xi64>)
|
ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>)
|
||||||
outs(%1 : tensor<1xi64>) {
|
outs(%1 : tensor<2xi64>) {
|
||||||
^bb0(%arg4: i64, %arg5: i64, %arg6: i64):
|
^bb0(%arg4: i64, %arg5: i64, %arg6: i64):
|
||||||
%3 = arith.addi %arg4, %arg5 : i64
|
%3 = arith.addi %arg4, %arg5 : i64
|
||||||
linalg.yield %3 : i64
|
linalg.yield %3 : i64
|
||||||
} -> tensor<1xi64>
|
} -> tensor<2xi64>
|
||||||
return %2 : tensor<1xi64>
|
return %2 : tensor<2xi64>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: func @no_fuse_mismatched_dynamism
|
// CHECK: func @no_fuse_mismatched_dynamism
|
||||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xi64>
|
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64>
|
||||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
|
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
|
||||||
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
|
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
|
||||||
|
// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?xi64> to tensor<2xi64>
|
||||||
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
// CHECK: %[[GENERIC:.+]] = linalg.generic
|
||||||
// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor<?xi64>)
|
// CHECK-SAME: ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>)
|
||||||
// CHECK: return %[[GENERIC]]
|
// CHECK: return %[[GENERIC]]
|
||||||
|
|
Loading…
Reference in New Issue