forked from OSchip/llvm-project
[MLIR][Shape] Concretize broadcast result type if possible
As a canonicalization, infer the resulting shape rank if possible. Differential Revision: https://reviews.llvm.org/D102068
This commit is contained in:
parent
541f107871
commit
a81e45b8bc
|
@ -29,7 +29,8 @@ class PatternRewriter;
|
|||
namespace shape {
|
||||
|
||||
/// Alias type for extent tensors.
|
||||
RankedTensorType getExtentTensorType(MLIRContext *ctx);
|
||||
RankedTensorType getExtentTensorType(MLIRContext *ctx,
|
||||
int64_t rank = ShapedType::kDynamicSize);
|
||||
|
||||
// Check if a type is an extent tensor, e.g., tensor<?xindex>.
|
||||
bool isExtentTensorType(Type);
|
||||
|
|
|
@ -27,8 +27,8 @@ namespace {
|
|||
#include "ShapeCanonicalization.inc"
|
||||
}
|
||||
|
||||
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
|
||||
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
|
||||
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
|
||||
return RankedTensorType::get({rank}, IndexType::get(ctx));
|
||||
}
|
||||
|
||||
bool shape::isExtentTensorType(Type type) {
|
||||
|
@ -660,11 +660,42 @@ struct CanonicalizeCastExtentTensorOperandsPattern
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct BroadcastConcretizeResultTypePattern
|
||||
: public OpRewritePattern<BroadcastOp> {
|
||||
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(BroadcastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only concretize dynamic extent tensor result types.
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultTy || !resultTy.isDynamicDim(0))
|
||||
return failure();
|
||||
|
||||
// Infer resulting shape rank if possible.
|
||||
int64_t maxRank = 0;
|
||||
for (Value shape : op.shapes()) {
|
||||
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
|
||||
// Cannot infer resulting shape rank if any operand is dynamically
|
||||
// ranked.
|
||||
if (extentTensorTy.isDynamicDim(0))
|
||||
return failure();
|
||||
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
|
||||
}
|
||||
}
|
||||
|
||||
auto newOp = rewriter.create<BroadcastOp>(
|
||||
op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<BroadcastFoldConstantOperandsPattern,
|
||||
patterns.add<BroadcastConcretizeResultTypePattern,
|
||||
BroadcastFoldConstantOperandsPattern,
|
||||
BroadcastForwardSingleOperandPattern,
|
||||
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
|
||||
RemoveDuplicateOperandsPattern<BroadcastOp>,
|
||||
|
|
|
@ -1344,7 +1344,8 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
|
|||
%arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
|
||||
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
|
||||
// CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
|
||||
// CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
|
||||
// CHECK: %[[UNCAST_RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
|
||||
// CHECK: %[[RES:.*]] = tensor.cast %[[UNCAST_RES]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: return %[[WIT]], %[[RES]]
|
||||
%0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
|
||||
%1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
|
||||
|
@ -1353,3 +1354,17 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
|
|||
-> tensor<?xindex>
|
||||
return %2, %3 : !shape.witness, tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @concretize_broadcast_result_type
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xindex>, %[[ARG1:.*]]: tensor<3xindex>)
|
||||
func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
|
||||
%arg1 : tensor<3xindex>) -> tensor<?xindex> {
|
||||
// CHECK: %[[CONCR:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<2xindex>, tensor<3xindex> -> tensor<3xindex>
|
||||
// CHECK: %[[RES:.*]] = tensor.cast %[[CONCR]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = shape.broadcast %arg0, %arg1 : tensor<2xindex>, tensor<3xindex>
|
||||
-> tensor<?xindex>
|
||||
return %0 : tensor<?xindex>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue