forked from OSchip/llvm-project
CombineContractBroadcast should not create dims unused in LHS+RHS
Differential Revision: https://reviews.llvm.org/D129087
This commit is contained in:
parent
9eb6572786
commit
c3839c0b46
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
namespace llvm {
|
||||
class SmallBitVector;
|
||||
|
@ -584,6 +585,11 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
|
|||
map.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
// Return a bitvector where each bit set indicates a dimension that is not used
|
||||
// by any of the maps in the input array `maps`.
|
||||
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
|
|
@ -687,6 +687,9 @@ static LogicalResult verifyOutputShape(
|
|||
MLIRContext *ctx = op.getContext();
|
||||
AffineMap lhsMap = op.getIndexingMaps()[0];
|
||||
AffineMap rhsMap = op.getIndexingMaps()[1];
|
||||
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
|
||||
return op.emitOpError(
|
||||
"expected all dimensions to be either a LHS or a RHS dimension");
|
||||
SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
|
||||
for (auto pair :
|
||||
{std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
|
||||
|
@ -699,8 +702,8 @@ static LogicalResult verifyOutputShape(
|
|||
}
|
||||
}
|
||||
if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
|
||||
return op.emitOpError("expected all input dimensions to be used by "
|
||||
"either the LHS or the RHS");
|
||||
return op.emitOpError("expected all dimensions to get an extent as "
|
||||
"either a LHS or a RHS dimension");
|
||||
|
||||
AffineMap resMap = op.getIndexingMaps()[2];
|
||||
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
|
||||
|
|
|
@ -32,7 +32,6 @@
|
|||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -1155,21 +1154,14 @@ struct CombineContractBroadcast
|
|||
|
||||
// Determine which dims are usused, now that the maps have been composed
|
||||
// with the broadcast maps.
|
||||
unsigned numDims = maps[0].getNumDims();
|
||||
llvm::SmallBitVector unusedDims(numDims, true);
|
||||
for (const auto &m : maps) {
|
||||
for (unsigned i = 0; i < numDims; ++i) {
|
||||
if (m.isFunctionOfDim(i))
|
||||
unusedDims.reset(i);
|
||||
}
|
||||
}
|
||||
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
|
||||
// Compress unused dims.
|
||||
for (auto &m : maps)
|
||||
m = compressDims(m, unusedDims);
|
||||
m = compressDims(m, unusedDimsBitVector);
|
||||
// Compute the combined iterators.
|
||||
SmallVector<Attribute, 4> iterators;
|
||||
for (unsigned i = 0; i < numDims; ++i) {
|
||||
if (!unusedDims.test(i))
|
||||
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
|
||||
if (!unusedDimsBitVector.test(i))
|
||||
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
|
||||
}
|
||||
// Check that compressing unused dims isn't removing all reduction
|
||||
|
@ -1179,7 +1171,10 @@ struct CombineContractBroadcast
|
|||
// a reduction iterator.
|
||||
if (!llvm::any_of(iterators, isReductionIterator))
|
||||
return failure();
|
||||
|
||||
// If the compressed maps have a dimension that is not used by either LHS or
|
||||
// RHS then the ContractionOp verifier would fail.
|
||||
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
|
||||
contractOp, lhs, rhs, contractOp.getAcc(),
|
||||
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
|
||||
|
|
|
@ -560,12 +560,7 @@ AffineMap mlir::compressDims(AffineMap map,
|
|||
}
|
||||
|
||||
AffineMap mlir::compressUnusedDims(AffineMap map) {
|
||||
llvm::SmallBitVector unusedDims(map.getNumDims(), true);
|
||||
map.walkExprs([&](AffineExpr expr) {
|
||||
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
|
||||
unusedDims.reset(dimExpr.getPosition());
|
||||
});
|
||||
return compressDims(map, unusedDims);
|
||||
return compressDims(map, getUnusedDimsBitVector({map}));
|
||||
}
|
||||
|
||||
static SmallVector<AffineMap>
|
||||
|
@ -722,6 +717,18 @@ AffineMap mlir::getProjectedMap(AffineMap map,
|
|||
return compressUnusedSymbols(compressDims(map, unusedDims));
|
||||
}
|
||||
|
||||
llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
|
||||
unsigned numDims = maps[0].getNumDims();
|
||||
llvm::SmallBitVector numDimsBitVector(numDims, true);
|
||||
for (const auto &m : maps) {
|
||||
for (unsigned i = 0; i < numDims; ++i) {
|
||||
if (m.isFunctionOfDim(i))
|
||||
numDimsBitVector.reset(i);
|
||||
}
|
||||
}
|
||||
return numDimsBitVector;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MutableAffineMap.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -875,7 +875,7 @@ func.func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: ve
|
|||
// -----
|
||||
|
||||
func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
|
||||
// expected-error@+1 {{'vector.contract' op expected all input dimensions to be used by either the LHS or the RHS}}
|
||||
// expected-error@+1 {{'vector.contract' op expected all dimensions to be either a LHS or a RHS dimension}}
|
||||
%result = vector.contract {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
|
|
|
@ -159,6 +159,10 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
|
|||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
|
||||
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
|
||||
|
@ -178,6 +182,37 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
|
|||
return %result : vector<8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test that CombineContractBroadcast is not combining this case, as that would
|
||||
// result in a dimension being unused in the LHS and RHS maps, which is illegal.
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
|
||||
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
|
||||
// CHECK: vector.contract
|
||||
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
|
||||
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
|
||||
// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
|
||||
|
||||
func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
|
||||
%1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32>
|
||||
%result = vector.contract {
|
||||
indexing_maps = [#map0, #map1, #map2],
|
||||
iterator_types = ["reduction", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>
|
||||
} %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
|
||||
return %result : vector<1xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reorder casting ops and vector ops. The casting ops have almost identical
|
||||
// pattern, so only arith.extsi op is tested.
|
||||
|
|
Loading…
Reference in New Issue