[mlir][Linalg] NFC - Refactor vector.broadcast op verification logic and make it available as a precondition in Linalg vectorization.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D111558
This commit is contained in:
Nicolas Vasilache 2021-10-12 11:04:11 +00:00
parent 40d85f16c4
commit 8f1650cb65
5 changed files with 134 additions and 72 deletions

View File

@ -40,6 +40,18 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail
/// Return whether `srcType` can be broadcast to `dstVectorType` under the
/// semantics of the `vector.broadcast` op.
enum class BroadcastableToResult {
Success = 0,
SourceRankHigher = 1,
DimensionMismatch = 2,
SourceTypeNotAVector = 3
};
BroadcastableToResult
isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims = nullptr);
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
RewritePatternSet &patterns);

View File

@ -147,24 +147,20 @@ matchLinalgReduction(OpOperand *outputOperand) {
return getKindForOp(combinerOps[0]);
}
/// If `value` of assumed VectorType has a shape different than `shape`, try to
/// build and return a new vector.broadcast to `shape`.
/// Otherwise, just return `value`.
// TODO: this is best effort atm and there is currently no guarantee of
// correctness for the broadcast semantics.
/// Broadcast `value` to a vector of `shape` if possible. Return value
/// otherwise.
static Value broadcastIfNeeded(OpBuilder &b, Value value,
ArrayRef<int64_t> shape) {
unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
[](int64_t val) { return val > 1; });
auto vecType = value.getType().dyn_cast<VectorType>();
if (shape.empty() ||
(vecType != nullptr &&
(vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
// If no shape to broadcast to, just return `value`.
if (shape.empty())
return value;
auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
: value.getType());
return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
newVecType, value);
VectorType targetVectorType =
VectorType::get(shape, getElementTypeOrSelf(value));
if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
/// If value of assumed VectorType has a shape different than `shape`, build and
@ -688,7 +684,8 @@ struct GenericPadTensorOpVectorizationPattern
// by TransferReadOp, but TransferReadOp supports only constant padding.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) {
if (!sourceType.hasStaticShape()) return failure();
if (!sourceType.hasStaticShape())
return failure();
// Create dummy padding value.
auto elemType = sourceType.getElementType();
padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
@ -733,14 +730,14 @@ struct GenericPadTensorOpVectorizationPattern
// If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
// tensor, write directly to the FillOp's operand.
if (llvm::equal(vecShape, resultType.getShape())
&& llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (llvm::equal(vecShape, resultType.getShape()) &&
llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
dest = fill.output();
// Generate TransferWriteOp.
auto writeIndices = ofrToIndexValues(
rewriter, padOp.getLoc(), padOp.getMixedLowPad());
auto writeIndices =
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
padOp, read, dest, writeIndices, writeInBounds);
@ -764,9 +761,9 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
return success(changed);
}
protected:
virtual LogicalResult rewriteUser(
PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
protected:
virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
PadTensorOp padOp, OpTy op) const = 0;
};
/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
@ -790,18 +787,21 @@ struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferReadPattern
: public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
::VectorizePadTensorOpUserPattern;
using VectorizePadTensorOpUserPattern<
vector::TransferReadOp>::VectorizePadTensorOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
vector::TransferReadOp xferOp) const override {
// Low padding must be static 0.
if (!padOp.hasZeroLowPad()) return failure();
if (!padOp.hasZeroLowPad())
return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) return failure();
if (!padValue)
return failure();
// Padding value of existing `xferOp` is unused.
if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
if (xferOp.hasOutOfBoundsDim() || xferOp.mask())
return failure();
rewriter.updateRootInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
@ -847,24 +847,30 @@ struct PadTensorOpVectorizationWithTransferReadPattern
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferWritePattern
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
::VectorizePadTensorOpUserPattern;
using VectorizePadTensorOpUserPattern<
vector::TransferWriteOp>::VectorizePadTensorOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
vector::TransferWriteOp xferOp) const override {
// Low padding must be static 0.
if (!padOp.hasZeroLowPad()) return failure();
if (!padOp.hasZeroLowPad())
return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) return failure();
if (!padValue)
return failure();
// TransferWriteOp result must be directly consumed by an ExtractSliceOp.
if (!xferOp->hasOneUse()) return failure();
if (!xferOp->hasOneUse())
return failure();
auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
if (!trimPadding) return failure();
if (!trimPadding)
return failure();
// Only static zero offsets supported when trimming padding.
if (!trimPadding.hasZeroOffset()) return failure();
if (!trimPadding.hasZeroOffset())
return failure();
// trimPadding must remove the amount of padding that was added earlier.
if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
if (!hasSameTensorSize(padOp.source(), trimPadding))
return failure();
// Insert the new TransferWriteOp at position of the old TransferWriteOp.
rewriter.setInsertionPoint(xferOp);
@ -894,14 +900,17 @@ struct PadTensorOpVectorizationWithTransferWritePattern
// If the input to PadTensorOp is a CastOp, try with with both CastOp result
// and CastOp operand.
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
if (hasSameTensorSize(castOp.source(), afterTrimming))
return true;
auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
// Only RankedTensorType supported.
if (!t1 || !t2) return false;
if (!t1 || !t2)
return false;
// Rank of both values must be the same.
if (t1.getRank() != t2.getRank()) return false;
if (t1.getRank() != t2.getRank())
return false;
// All static dimensions must be the same. Mixed cases (e.g., dimension
// static in `t1` but dynamic in `t2`) are not supported.
@ -913,7 +922,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
}
// Nothing more to check if all dimensions are static.
if (t1.getNumDynamicDims() == 0) return true;
if (t1.getNumDynamicDims() == 0)
return true;
// All dynamic sizes must be the same. The only supported case at the moment
// is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
@ -925,29 +935,33 @@ struct PadTensorOpVectorizationWithTransferWritePattern
assert(static_cast<size_t>(t1.getRank()) ==
beforeSlice.getMixedSizes().size());
assert(static_cast<size_t>(t2.getRank())
== afterTrimming.getMixedSizes().size());
assert(static_cast<size_t>(t2.getRank()) ==
afterTrimming.getMixedSizes().size());
for (unsigned i = 0; i < t1.getRank(); ++i) {
// Skip static dimensions.
if (!t1.isDynamicDim(i)) continue;
if (!t1.isDynamicDim(i))
continue;
auto size1 = beforeSlice.getMixedSizes()[i];
auto size2 = afterTrimming.getMixedSizes()[i];
// Case 1: Same value or same constant int.
if (isEqualConstantIntOrValue(size1, size2)) continue;
if (isEqualConstantIntOrValue(size1, size2))
continue;
// Other cases: Take a deeper look at defining ops of values.
auto v1 = size1.dyn_cast<Value>();
auto v2 = size2.dyn_cast<Value>();
if (!v1 || !v2) return false;
if (!v1 || !v2)
return false;
// Case 2: Both values are identical AffineMinOps. (Should not happen if
// CSE is run.)
auto minOp1 = v1.getDefiningOp<AffineMinOp>();
auto minOp2 = v2.getDefiningOp<AffineMinOp>();
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
&& minOp1.operands() == minOp2.operands()) continue;
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
minOp1.operands() == minOp2.operands())
continue;
// Add additional cases as needed.
}
@ -987,9 +1001,11 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
tensor::InsertSliceOp insertOp) const override {
// Low padding must be static 0.
if (!padOp.hasZeroLowPad()) return failure();
if (!padOp.hasZeroLowPad())
return failure();
// Only unit stride supported.
if (!insertOp.hasUnitStride()) return failure();
if (!insertOp.hasUnitStride())
return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue)
@ -1038,8 +1054,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
patterns.add<GenericPadTensorOpVectorizationPattern>(
patterns.getContext(), baseBenefit);
patterns.add<GenericPadTensorOpVectorizationPattern>(patterns.getContext(),
baseBenefit);
// Try these specialized patterns first before resorting to the generic one.
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
PadTensorOpVectorizationWithTransferWritePattern,

View File

@ -1321,31 +1321,59 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
// BroadcastOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(BroadcastOp op) {
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
VectorType dstVectorType = op.getVectorType();
// Scalar to vector broadcast is always valid. A vector
// to vector broadcast needs some additional checking.
if (srcVectorType) {
int64_t srcRank = srcVectorType.getRank();
int64_t dstRank = dstVectorType.getRank();
if (srcRank > dstRank)
return op.emitOpError("source rank higher than destination rank");
// Source has an exact match or singleton value for all trailing dimensions
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
for (int64_t r = 0; r < srcRank; ++r) {
int64_t srcDim = srcVectorType.getDimSize(r);
int64_t dstDim = dstVectorType.getDimSize(lead + r);
if (srcDim != 1 && srcDim != dstDim)
return op.emitOpError("dimension mismatch (")
<< srcDim << " vs. " << dstDim << ")";
BroadcastableToResult
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims) {
// Broadcast scalar to vector of the same element type.
if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
return BroadcastableToResult::Success;
// From now on, only vectors broadcast.
VectorType srcVectorType = srcType.dyn_cast<VectorType>();
if (!srcVectorType)
return BroadcastableToResult::SourceTypeNotAVector;
int64_t srcRank = srcVectorType.getRank();
int64_t dstRank = dstVectorType.getRank();
if (srcRank > dstRank)
return BroadcastableToResult::SourceRankHigher;
// Source has an exact match or singleton value for all trailing dimensions
// (all leading dimensions are simply duplicated).
int64_t lead = dstRank - srcRank;
for (int64_t r = 0; r < srcRank; ++r) {
int64_t srcDim = srcVectorType.getDimSize(r);
int64_t dstDim = dstVectorType.getDimSize(lead + r);
if (srcDim != 1 && srcDim != dstDim) {
if (mismatchingDims) {
mismatchingDims->first = srcDim;
mismatchingDims->second = dstDim;
}
return BroadcastableToResult::DimensionMismatch;
}
}
return success();
return BroadcastableToResult::Success;
}
static LogicalResult verify(BroadcastOp op) {
std::pair<int, int> mismatchingDims;
BroadcastableToResult res = isBroadcastableTo(
op.getSourceType(), op.getVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
return op.emitOpError("source rank higher than destination rank");
if (res == BroadcastableToResult::DimensionMismatch)
return op.emitOpError("dimension mismatch (")
<< mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
if (res == BroadcastableToResult::SourceTypeNotAVector)
return op.emitOpError("source type is not a vector");
llvm_unreachable("unexpected vector.broadcast op error");
}
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (getSourceType() == getVectorType())
return source();
if (!operands[0])
return {};
auto vectorType = getVectorType();

View File

@ -30,6 +30,13 @@ func @broadcast_dim2_mismatch(%arg0: vector<4x8xf32>) {
// -----
func @broadcast_unknown(%arg0: memref<4x8xf32>) {
// expected-error@+1 {{'vector.broadcast' op source type is not a vector}}
%1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32>
}
// -----
func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) {
// expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}}
%1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xi32>

View File

@ -493,7 +493,6 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
func @cast_away_broadcast_leading_one_dims(
%arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
(vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
// CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
%0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
// CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>