forked from OSchip/llvm-project
[mlir][Linalg] NFC - Refactor vector.broadcast op verification logic and make it available as a precondition in Linalg vectorization.
Reviewed By: pifon2a Differential Revision: https://reviews.llvm.org/D111558
This commit is contained in:
parent
40d85f16c4
commit
8f1650cb65
|
@ -40,6 +40,18 @@ namespace detail {
|
|||
struct BitmaskEnumStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// Return whether `srcType` can be broadcast to `dstVectorType` under the
|
||||
/// semantics of the `vector.broadcast` op.
|
||||
enum class BroadcastableToResult {
|
||||
Success = 0,
|
||||
SourceRankHigher = 1,
|
||||
DimensionMismatch = 2,
|
||||
SourceTypeNotAVector = 3
|
||||
};
|
||||
BroadcastableToResult
|
||||
isBroadcastableTo(Type srcType, VectorType dstVectorType,
|
||||
std::pair<int, int> *mismatchingDims = nullptr);
|
||||
|
||||
/// Collect a set of vector-to-vector canonicalization patterns.
|
||||
void populateVectorToVectorCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
|
|
@ -147,24 +147,20 @@ matchLinalgReduction(OpOperand *outputOperand) {
|
|||
return getKindForOp(combinerOps[0]);
|
||||
}
|
||||
|
||||
/// If `value` of assumed VectorType has a shape different than `shape`, try to
|
||||
/// build and return a new vector.broadcast to `shape`.
|
||||
/// Otherwise, just return `value`.
|
||||
// TODO: this is best effort atm and there is currently no guarantee of
|
||||
// correctness for the broadcast semantics.
|
||||
/// Broadcast `value` to a vector of `shape` if possible. Return value
|
||||
/// otherwise.
|
||||
static Value broadcastIfNeeded(OpBuilder &b, Value value,
|
||||
ArrayRef<int64_t> shape) {
|
||||
unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
|
||||
[](int64_t val) { return val > 1; });
|
||||
auto vecType = value.getType().dyn_cast<VectorType>();
|
||||
if (shape.empty() ||
|
||||
(vecType != nullptr &&
|
||||
(vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
|
||||
// If no shape to broadcast to, just return `value`.
|
||||
if (shape.empty())
|
||||
return value;
|
||||
auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
|
||||
: value.getType());
|
||||
return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
|
||||
newVecType, value);
|
||||
VectorType targetVectorType =
|
||||
VectorType::get(shape, getElementTypeOrSelf(value));
|
||||
if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
|
||||
vector::BroadcastableToResult::Success)
|
||||
return value;
|
||||
Location loc = b.getInsertionPoint()->getLoc();
|
||||
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
|
||||
}
|
||||
|
||||
/// If value of assumed VectorType has a shape different than `shape`, build and
|
||||
|
@ -688,7 +684,8 @@ struct GenericPadTensorOpVectorizationPattern
|
|||
// by TransferReadOp, but TransferReadOp supports only constant padding.
|
||||
auto padValue = padOp.getConstantPaddingValue();
|
||||
if (!padValue) {
|
||||
if (!sourceType.hasStaticShape()) return failure();
|
||||
if (!sourceType.hasStaticShape())
|
||||
return failure();
|
||||
// Create dummy padding value.
|
||||
auto elemType = sourceType.getElementType();
|
||||
padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
|
||||
|
@ -733,14 +730,14 @@ struct GenericPadTensorOpVectorizationPattern
|
|||
|
||||
// If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
|
||||
// tensor, write directly to the FillOp's operand.
|
||||
if (llvm::equal(vecShape, resultType.getShape())
|
||||
&& llvm::all_of(writeInBounds, [](bool b) { return b; }))
|
||||
if (llvm::equal(vecShape, resultType.getShape()) &&
|
||||
llvm::all_of(writeInBounds, [](bool b) { return b; }))
|
||||
if (auto fill = dest.getDefiningOp<FillOp>())
|
||||
dest = fill.output();
|
||||
|
||||
// Generate TransferWriteOp.
|
||||
auto writeIndices = ofrToIndexValues(
|
||||
rewriter, padOp.getLoc(), padOp.getMixedLowPad());
|
||||
auto writeIndices =
|
||||
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
padOp, read, dest, writeIndices, writeInBounds);
|
||||
|
||||
|
@ -764,9 +761,9 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
|
|||
return success(changed);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual LogicalResult rewriteUser(
|
||||
PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
|
||||
protected:
|
||||
virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
|
||||
PadTensorOp padOp, OpTy op) const = 0;
|
||||
};
|
||||
|
||||
/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
|
||||
|
@ -790,18 +787,21 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
|
|||
/// - Single, scalar padding value.
|
||||
struct PadTensorOpVectorizationWithTransferReadPattern
|
||||
: public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
|
||||
using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
|
||||
::VectorizePadTensorOpUserPattern;
|
||||
using VectorizePadTensorOpUserPattern<
|
||||
vector::TransferReadOp>::VectorizePadTensorOpUserPattern;
|
||||
|
||||
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
|
||||
vector::TransferReadOp xferOp) const override {
|
||||
// Low padding must be static 0.
|
||||
if (!padOp.hasZeroLowPad()) return failure();
|
||||
if (!padOp.hasZeroLowPad())
|
||||
return failure();
|
||||
// Pad value must be a constant.
|
||||
auto padValue = padOp.getConstantPaddingValue();
|
||||
if (!padValue) return failure();
|
||||
if (!padValue)
|
||||
return failure();
|
||||
// Padding value of existing `xferOp` is unused.
|
||||
if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
|
||||
if (xferOp.hasOutOfBoundsDim() || xferOp.mask())
|
||||
return failure();
|
||||
|
||||
rewriter.updateRootInPlace(xferOp, [&]() {
|
||||
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
||||
|
@ -847,24 +847,30 @@ struct PadTensorOpVectorizationWithTransferReadPattern
|
|||
/// - Single, scalar padding value.
|
||||
struct PadTensorOpVectorizationWithTransferWritePattern
|
||||
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
|
||||
using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
|
||||
::VectorizePadTensorOpUserPattern;
|
||||
using VectorizePadTensorOpUserPattern<
|
||||
vector::TransferWriteOp>::VectorizePadTensorOpUserPattern;
|
||||
|
||||
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
|
||||
vector::TransferWriteOp xferOp) const override {
|
||||
// Low padding must be static 0.
|
||||
if (!padOp.hasZeroLowPad()) return failure();
|
||||
if (!padOp.hasZeroLowPad())
|
||||
return failure();
|
||||
// Pad value must be a constant.
|
||||
auto padValue = padOp.getConstantPaddingValue();
|
||||
if (!padValue) return failure();
|
||||
if (!padValue)
|
||||
return failure();
|
||||
// TransferWriteOp result must be directly consumed by an ExtractSliceOp.
|
||||
if (!xferOp->hasOneUse()) return failure();
|
||||
if (!xferOp->hasOneUse())
|
||||
return failure();
|
||||
auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
|
||||
if (!trimPadding) return failure();
|
||||
if (!trimPadding)
|
||||
return failure();
|
||||
// Only static zero offsets supported when trimming padding.
|
||||
if (!trimPadding.hasZeroOffset()) return failure();
|
||||
if (!trimPadding.hasZeroOffset())
|
||||
return failure();
|
||||
// trimPadding must remove the amount of padding that was added earlier.
|
||||
if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
|
||||
if (!hasSameTensorSize(padOp.source(), trimPadding))
|
||||
return failure();
|
||||
|
||||
// Insert the new TransferWriteOp at position of the old TransferWriteOp.
|
||||
rewriter.setInsertionPoint(xferOp);
|
||||
|
@ -894,14 +900,17 @@ struct PadTensorOpVectorizationWithTransferWritePattern
|
|||
// If the input to PadTensorOp is a CastOp, try with with both CastOp result
|
||||
// and CastOp operand.
|
||||
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
|
||||
if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
|
||||
if (hasSameTensorSize(castOp.source(), afterTrimming))
|
||||
return true;
|
||||
|
||||
auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
|
||||
auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
|
||||
// Only RankedTensorType supported.
|
||||
if (!t1 || !t2) return false;
|
||||
if (!t1 || !t2)
|
||||
return false;
|
||||
// Rank of both values must be the same.
|
||||
if (t1.getRank() != t2.getRank()) return false;
|
||||
if (t1.getRank() != t2.getRank())
|
||||
return false;
|
||||
|
||||
// All static dimensions must be the same. Mixed cases (e.g., dimension
|
||||
// static in `t1` but dynamic in `t2`) are not supported.
|
||||
|
@ -913,7 +922,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
|
|||
}
|
||||
|
||||
// Nothing more to check if all dimensions are static.
|
||||
if (t1.getNumDynamicDims() == 0) return true;
|
||||
if (t1.getNumDynamicDims() == 0)
|
||||
return true;
|
||||
|
||||
// All dynamic sizes must be the same. The only supported case at the moment
|
||||
// is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
|
||||
|
@ -925,29 +935,33 @@ struct PadTensorOpVectorizationWithTransferWritePattern
|
|||
|
||||
assert(static_cast<size_t>(t1.getRank()) ==
|
||||
beforeSlice.getMixedSizes().size());
|
||||
assert(static_cast<size_t>(t2.getRank())
|
||||
== afterTrimming.getMixedSizes().size());
|
||||
assert(static_cast<size_t>(t2.getRank()) ==
|
||||
afterTrimming.getMixedSizes().size());
|
||||
|
||||
for (unsigned i = 0; i < t1.getRank(); ++i) {
|
||||
// Skip static dimensions.
|
||||
if (!t1.isDynamicDim(i)) continue;
|
||||
if (!t1.isDynamicDim(i))
|
||||
continue;
|
||||
auto size1 = beforeSlice.getMixedSizes()[i];
|
||||
auto size2 = afterTrimming.getMixedSizes()[i];
|
||||
|
||||
// Case 1: Same value or same constant int.
|
||||
if (isEqualConstantIntOrValue(size1, size2)) continue;
|
||||
if (isEqualConstantIntOrValue(size1, size2))
|
||||
continue;
|
||||
|
||||
// Other cases: Take a deeper look at defining ops of values.
|
||||
auto v1 = size1.dyn_cast<Value>();
|
||||
auto v2 = size2.dyn_cast<Value>();
|
||||
if (!v1 || !v2) return false;
|
||||
if (!v1 || !v2)
|
||||
return false;
|
||||
|
||||
// Case 2: Both values are identical AffineMinOps. (Should not happen if
|
||||
// CSE is run.)
|
||||
auto minOp1 = v1.getDefiningOp<AffineMinOp>();
|
||||
auto minOp2 = v2.getDefiningOp<AffineMinOp>();
|
||||
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
|
||||
&& minOp1.operands() == minOp2.operands()) continue;
|
||||
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
|
||||
minOp1.operands() == minOp2.operands())
|
||||
continue;
|
||||
|
||||
// Add additional cases as needed.
|
||||
}
|
||||
|
@ -987,9 +1001,11 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
|
|||
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
|
||||
tensor::InsertSliceOp insertOp) const override {
|
||||
// Low padding must be static 0.
|
||||
if (!padOp.hasZeroLowPad()) return failure();
|
||||
if (!padOp.hasZeroLowPad())
|
||||
return failure();
|
||||
// Only unit stride supported.
|
||||
if (!insertOp.hasUnitStride()) return failure();
|
||||
if (!insertOp.hasUnitStride())
|
||||
return failure();
|
||||
// Pad value must be a constant.
|
||||
auto padValue = padOp.getConstantPaddingValue();
|
||||
if (!padValue)
|
||||
|
@ -1038,8 +1054,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
|
|||
|
||||
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
||||
patterns.add<GenericPadTensorOpVectorizationPattern>(
|
||||
patterns.getContext(), baseBenefit);
|
||||
patterns.add<GenericPadTensorOpVectorizationPattern>(patterns.getContext(),
|
||||
baseBenefit);
|
||||
// Try these specialized patterns first before resorting to the generic one.
|
||||
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
|
||||
PadTensorOpVectorizationWithTransferWritePattern,
|
||||
|
|
|
@ -1321,31 +1321,59 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
|
|||
// BroadcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(BroadcastOp op) {
|
||||
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
|
||||
VectorType dstVectorType = op.getVectorType();
|
||||
// Scalar to vector broadcast is always valid. A vector
|
||||
// to vector broadcast needs some additional checking.
|
||||
if (srcVectorType) {
|
||||
int64_t srcRank = srcVectorType.getRank();
|
||||
int64_t dstRank = dstVectorType.getRank();
|
||||
if (srcRank > dstRank)
|
||||
return op.emitOpError("source rank higher than destination rank");
|
||||
// Source has an exact match or singleton value for all trailing dimensions
|
||||
// (all leading dimensions are simply duplicated).
|
||||
int64_t lead = dstRank - srcRank;
|
||||
for (int64_t r = 0; r < srcRank; ++r) {
|
||||
int64_t srcDim = srcVectorType.getDimSize(r);
|
||||
int64_t dstDim = dstVectorType.getDimSize(lead + r);
|
||||
if (srcDim != 1 && srcDim != dstDim)
|
||||
return op.emitOpError("dimension mismatch (")
|
||||
<< srcDim << " vs. " << dstDim << ")";
|
||||
BroadcastableToResult
|
||||
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
|
||||
std::pair<int, int> *mismatchingDims) {
|
||||
// Broadcast scalar to vector of the same element type.
|
||||
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
|
||||
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
|
||||
return BroadcastableToResult::Success;
|
||||
// From now on, only vectors broadcast.
|
||||
VectorType srcVectorType = srcType.dyn_cast<VectorType>();
|
||||
if (!srcVectorType)
|
||||
return BroadcastableToResult::SourceTypeNotAVector;
|
||||
|
||||
int64_t srcRank = srcVectorType.getRank();
|
||||
int64_t dstRank = dstVectorType.getRank();
|
||||
if (srcRank > dstRank)
|
||||
return BroadcastableToResult::SourceRankHigher;
|
||||
// Source has an exact match or singleton value for all trailing dimensions
|
||||
// (all leading dimensions are simply duplicated).
|
||||
int64_t lead = dstRank - srcRank;
|
||||
for (int64_t r = 0; r < srcRank; ++r) {
|
||||
int64_t srcDim = srcVectorType.getDimSize(r);
|
||||
int64_t dstDim = dstVectorType.getDimSize(lead + r);
|
||||
if (srcDim != 1 && srcDim != dstDim) {
|
||||
if (mismatchingDims) {
|
||||
mismatchingDims->first = srcDim;
|
||||
mismatchingDims->second = dstDim;
|
||||
}
|
||||
return BroadcastableToResult::DimensionMismatch;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
|
||||
return BroadcastableToResult::Success;
|
||||
}
|
||||
|
||||
static LogicalResult verify(BroadcastOp op) {
|
||||
std::pair<int, int> mismatchingDims;
|
||||
BroadcastableToResult res = isBroadcastableTo(
|
||||
op.getSourceType(), op.getVectorType(), &mismatchingDims);
|
||||
if (res == BroadcastableToResult::Success)
|
||||
return success();
|
||||
if (res == BroadcastableToResult::SourceRankHigher)
|
||||
return op.emitOpError("source rank higher than destination rank");
|
||||
if (res == BroadcastableToResult::DimensionMismatch)
|
||||
return op.emitOpError("dimension mismatch (")
|
||||
<< mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
|
||||
if (res == BroadcastableToResult::SourceTypeNotAVector)
|
||||
return op.emitOpError("source type is not a vector");
|
||||
llvm_unreachable("unexpected vector.broadcast op error");
|
||||
}
|
||||
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getSourceType() == getVectorType())
|
||||
return source();
|
||||
if (!operands[0])
|
||||
return {};
|
||||
auto vectorType = getVectorType();
|
||||
|
|
|
@ -30,6 +30,13 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @broadcast_unknown(%arg0: memref<4x8xf32>) {
|
||||
// expected-error@+1 {{'vector.broadcast' op source type is not a vector}}
|
||||
%1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
|
||||
// expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
|
||||
%1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xi32>
|
||||
|
|
|
@ -493,7 +493,6 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
|
|||
func @cast_away_broadcast_leading_one_dims(
|
||||
%arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
|
||||
(vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
|
||||
// CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
|
||||
%0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
|
||||
// CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
|
||||
|
|
Loading…
Reference in New Issue