forked from OSchip/llvm-project
[mlir] Add canonicalization from `tensor_cast` to `dim` op.
Fold a `tensor_cast` -> `dim` to take the `dim` of the original tensor. Differential Revision: https://reviews.llvm.org/D93492
This commit is contained in:
parent
3d56644f18
commit
de031216bf
|
@ -1472,11 +1472,29 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold dim of a dim of a cast into the the dim of the source of the tensor
|
||||
/// cast.
|
||||
template <typename CastOpTy>
|
||||
struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
Value newSource = castOp.getOperand();
|
||||
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<DimOfMemRefReshape>(context);
|
||||
results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
@ -115,3 +115,19 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
|
|||
%1 = dim %0, %c3 : memref<*xf32>
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
|
||||
// CHECK-LABEL: func @fold_dim_of_tensor_cast
|
||||
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-NEXT: return %[[C4]], %[[T0]]
|
||||
func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%1 = dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2: index, index
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue