[mlir] Propagate arith.index_cast past tensor.extract

If we are extracting it is more useful to push the index_cast past the
extraction. This increases the chance the tensor.extract can evaluated at
compile time.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D118204
This commit is contained in:
Rob Suderman 2022-01-25 22:15:55 -08:00
parent 223f9dea3d
commit 7c984be21a
2 changed files with 55 additions and 1 deletions

View File

@ -425,11 +425,51 @@ struct ExtractElementFromTensorFromElements
}
};
// Pushes the index_casts that occur before extractions to after the extract.
// This minimizes type conversion in some cases and enables the extract
// canonicalizer. This changes:
//
// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
// %extract = tensor.extract %cast[%index] : tensor<1xindex>
//
// to the following:
//
// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
// %cast = arith.index_cast %extract : i32 to index
//
// to just %element.
//
// Consider expanding this to a template and handle all tensor cast operations.
struct ExtractElementFromIndexCast
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
Location loc = extract.getLoc();
auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
if (!indexCast)
return failure();
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
auto newExtract = rewriter.create<tensor::ExtractOp>(
loc, elementTy, indexCast.getIn(), extract.indices());
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
newExtract);
return success();
}
};
} // namespace
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractElementFromTensorFromElements>(context);
results
.add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
context);
}
//===----------------------------------------------------------------------===//

View File

@ -1200,3 +1200,17 @@ func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
%1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
return %1 : tensor<1xi32>
}
// -----
// CHECK-LABEL: func @propogate_index_cast
func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
// CHECK: %[[IDX:.+]] = arith.constant 0
// CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
// CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
// CHECK: return %[[CAST]] : index
%c0 = arith.constant 0 : index
%0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
%1 = tensor.extract %0[%c0] : tensor<1xindex>
return %1 : index
}