forked from OSchip/llvm-project
[mlir] Propagate arith.index_cast past tensor.extract
If we are extracting it is more useful to push the index_cast past the extraction. This increases the chance the tensor.extract can evaluated at compile time. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D118204
This commit is contained in:
parent
223f9dea3d
commit
7c984be21a
|
@ -425,11 +425,51 @@ struct ExtractElementFromTensorFromElements
|
|||
}
|
||||
};
|
||||
|
||||
// Pushes the index_casts that occur before extractions to after the extract.
|
||||
// This minimizes type conversion in some cases and enables the extract
|
||||
// canonicalizer. This changes:
|
||||
//
|
||||
// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
|
||||
// %extract = tensor.extract %cast[%index] : tensor<1xindex>
|
||||
//
|
||||
// to the following:
|
||||
//
|
||||
// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
|
||||
// %cast = arith.index_cast %extract : i32 to index
|
||||
//
|
||||
// to just %element.
|
||||
//
|
||||
// Consider expanding this to a template and handle all tensor cast operations.
|
||||
struct ExtractElementFromIndexCast
|
||||
: public OpRewritePattern<tensor::ExtractOp> {
|
||||
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = extract.getLoc();
|
||||
auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
|
||||
if (!indexCast)
|
||||
return failure();
|
||||
|
||||
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
|
||||
|
||||
auto newExtract = rewriter.create<tensor::ExtractOp>(
|
||||
loc, elementTy, indexCast.getIn(), extract.indices());
|
||||
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
|
||||
newExtract);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ExtractElementFromTensorFromElements>(context);
|
||||
results
|
||||
.add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1200,3 +1200,17 @@ func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
|
|||
%1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
|
||||
return %1 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @propogate_index_cast
|
||||
func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
|
||||
// CHECK: %[[IDX:.+]] = arith.constant 0
|
||||
// CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
|
||||
// CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
|
||||
// CHECK: return %[[CAST]] : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
|
||||
%1 = tensor.extract %0[%c0] : tensor<1xindex>
|
||||
return %1 : index
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue