CombineContractBroadcast should not create dims unused in LHS+RHS

Differential Revision: https://reviews.llvm.org/D129087
This commit is contained in:
Benoit Jacob 2022-07-04 15:33:50 +00:00
parent 9eb6572786
commit c3839c0b46
6 changed files with 68 additions and 22 deletions

View File

@ -18,6 +18,7 @@
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/SmallBitVector.h"
namespace llvm { namespace llvm {
class SmallBitVector; class SmallBitVector;
@ -584,6 +585,11 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
map.print(os); map.print(os);
return os; return os;
} }
// Return a bitvector where each bit set indicates a dimension that is not used
// by any of the maps in the input array `maps`.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
} // namespace mlir } // namespace mlir
namespace llvm { namespace llvm {

View File

@ -687,6 +687,9 @@ static LogicalResult verifyOutputShape(
MLIRContext *ctx = op.getContext(); MLIRContext *ctx = op.getContext();
AffineMap lhsMap = op.getIndexingMaps()[0]; AffineMap lhsMap = op.getIndexingMaps()[0];
AffineMap rhsMap = op.getIndexingMaps()[1]; AffineMap rhsMap = op.getIndexingMaps()[1];
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
return op.emitOpError(
"expected all dimensions to be either a LHS or a RHS dimension");
SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
for (auto pair : for (auto pair :
{std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
@ -699,8 +702,8 @@ static LogicalResult verifyOutputShape(
} }
} }
if (!llvm::all_of(extents, [](AffineExpr e) { return e; })) if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
return op.emitOpError("expected all input dimensions to be used by " return op.emitOpError("expected all dimensions to get an extent as "
"either the LHS or the RHS"); "either a LHS or a RHS dimension");
AffineMap resMap = op.getIndexingMaps()[2]; AffineMap resMap = op.getIndexingMaps()[2];
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),

View File

@ -32,7 +32,6 @@
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h" #include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/CommandLine.h" #include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@ -1155,21 +1154,14 @@ struct CombineContractBroadcast
// Determine which dims are usused, now that the maps have been composed // Determine which dims are usused, now that the maps have been composed
// with the broadcast maps. // with the broadcast maps.
unsigned numDims = maps[0].getNumDims(); llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
llvm::SmallBitVector unusedDims(numDims, true);
for (const auto &m : maps) {
for (unsigned i = 0; i < numDims; ++i) {
if (m.isFunctionOfDim(i))
unusedDims.reset(i);
}
}
// Compress unused dims. // Compress unused dims.
for (auto &m : maps) for (auto &m : maps)
m = compressDims(m, unusedDims); m = compressDims(m, unusedDimsBitVector);
// Compute the combined iterators. // Compute the combined iterators.
SmallVector<Attribute, 4> iterators; SmallVector<Attribute, 4> iterators;
for (unsigned i = 0; i < numDims; ++i) { for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
if (!unusedDims.test(i)) if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
} }
// Check that compressing unused dims isn't removing all reduction // Check that compressing unused dims isn't removing all reduction
@ -1179,7 +1171,10 @@ struct CombineContractBroadcast
// a reduction iterator. // a reduction iterator.
if (!llvm::any_of(iterators, isReductionIterator)) if (!llvm::any_of(iterators, isReductionIterator))
return failure(); return failure();
// If the compressed maps have a dimension that is not used by either LHS or
// RHS then the ContractionOp verifier would fail.
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>( rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(), contractOp, lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));

View File

@ -560,12 +560,7 @@ AffineMap mlir::compressDims(AffineMap map,
} }
AffineMap mlir::compressUnusedDims(AffineMap map) { AffineMap mlir::compressUnusedDims(AffineMap map) {
llvm::SmallBitVector unusedDims(map.getNumDims(), true); return compressDims(map, getUnusedDimsBitVector({map}));
map.walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
unusedDims.reset(dimExpr.getPosition());
});
return compressDims(map, unusedDims);
} }
static SmallVector<AffineMap> static SmallVector<AffineMap>
@ -722,6 +717,18 @@ AffineMap mlir::getProjectedMap(AffineMap map,
return compressUnusedSymbols(compressDims(map, unusedDims)); return compressUnusedSymbols(compressDims(map, unusedDims));
} }
llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
unsigned numDims = maps[0].getNumDims();
llvm::SmallBitVector numDimsBitVector(numDims, true);
for (const auto &m : maps) {
for (unsigned i = 0; i < numDims; ++i) {
if (m.isFunctionOfDim(i))
numDimsBitVector.reset(i);
}
}
return numDimsBitVector;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MutableAffineMap. // MutableAffineMap.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -875,7 +875,7 @@ func.func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: ve
// ----- // -----
func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> { func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
// expected-error@+1 {{'vector.contract' op expected all input dimensions to be used by either the LHS or the RHS}} // expected-error@+1 {{'vector.contract' op expected all dimensions to be either a LHS or a RHS dimension}}
%result = vector.contract { %result = vector.contract {
indexing_maps = [ indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>,

View File

@ -159,6 +159,10 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction // CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>) // CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32> // CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
@ -178,6 +182,37 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
return %result : vector<8x8xi32> return %result : vector<8x8xi32>
} }
// -----
// Test that CombineContractBroadcast is not combining this case, as that would
// result in a dimension being unused in the LHS and RHS maps, which is illegal.
#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
// CHECK: vector.contract
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
%1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32>
%result = vector.contract {
indexing_maps = [#map0, #map1, #map2],
iterator_types = ["reduction", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
return %result : vector<1xi32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Reorder casting ops and vector ops. The casting ops have almost identical // Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested. // pattern, so only arith.extsi op is tested.