forked from OSchip/llvm-project
CombineContractBroadcast should not create dims unused in LHS+RHS
Differential Revision: https://reviews.llvm.org/D129087
This commit is contained in:
parent
9eb6572786
commit
c3839c0b46
|
@ -18,6 +18,7 @@
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/DenseMapInfo.h"
|
#include "llvm/ADT/DenseMapInfo.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
class SmallBitVector;
|
class SmallBitVector;
|
||||||
|
@ -584,6 +585,11 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
|
||||||
map.print(os);
|
map.print(os);
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return a bitvector where each bit set indicates a dimension that is not used
|
||||||
|
// by any of the maps in the input array `maps`.
|
||||||
|
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
|
|
|
@ -687,6 +687,9 @@ static LogicalResult verifyOutputShape(
|
||||||
MLIRContext *ctx = op.getContext();
|
MLIRContext *ctx = op.getContext();
|
||||||
AffineMap lhsMap = op.getIndexingMaps()[0];
|
AffineMap lhsMap = op.getIndexingMaps()[0];
|
||||||
AffineMap rhsMap = op.getIndexingMaps()[1];
|
AffineMap rhsMap = op.getIndexingMaps()[1];
|
||||||
|
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
|
||||||
|
return op.emitOpError(
|
||||||
|
"expected all dimensions to be either a LHS or a RHS dimension");
|
||||||
SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
|
SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
|
||||||
for (auto pair :
|
for (auto pair :
|
||||||
{std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
|
{std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
|
||||||
|
@ -699,8 +702,8 @@ static LogicalResult verifyOutputShape(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
|
if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
|
||||||
return op.emitOpError("expected all input dimensions to be used by "
|
return op.emitOpError("expected all dimensions to get an extent as "
|
||||||
"either the LHS or the RHS");
|
"either a LHS or a RHS dimension");
|
||||||
|
|
||||||
AffineMap resMap = op.getIndexingMaps()[2];
|
AffineMap resMap = op.getIndexingMaps()[2];
|
||||||
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
|
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
|
||||||
|
|
|
@ -32,7 +32,6 @@
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/MapVector.h"
|
#include "llvm/ADT/MapVector.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
@ -1155,21 +1154,14 @@ struct CombineContractBroadcast
|
||||||
|
|
||||||
// Determine which dims are usused, now that the maps have been composed
|
// Determine which dims are usused, now that the maps have been composed
|
||||||
// with the broadcast maps.
|
// with the broadcast maps.
|
||||||
unsigned numDims = maps[0].getNumDims();
|
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
|
||||||
llvm::SmallBitVector unusedDims(numDims, true);
|
|
||||||
for (const auto &m : maps) {
|
|
||||||
for (unsigned i = 0; i < numDims; ++i) {
|
|
||||||
if (m.isFunctionOfDim(i))
|
|
||||||
unusedDims.reset(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Compress unused dims.
|
// Compress unused dims.
|
||||||
for (auto &m : maps)
|
for (auto &m : maps)
|
||||||
m = compressDims(m, unusedDims);
|
m = compressDims(m, unusedDimsBitVector);
|
||||||
// Compute the combined iterators.
|
// Compute the combined iterators.
|
||||||
SmallVector<Attribute, 4> iterators;
|
SmallVector<Attribute, 4> iterators;
|
||||||
for (unsigned i = 0; i < numDims; ++i) {
|
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
|
||||||
if (!unusedDims.test(i))
|
if (!unusedDimsBitVector.test(i))
|
||||||
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
|
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
|
||||||
}
|
}
|
||||||
// Check that compressing unused dims isn't removing all reduction
|
// Check that compressing unused dims isn't removing all reduction
|
||||||
|
@ -1179,7 +1171,10 @@ struct CombineContractBroadcast
|
||||||
// a reduction iterator.
|
// a reduction iterator.
|
||||||
if (!llvm::any_of(iterators, isReductionIterator))
|
if (!llvm::any_of(iterators, isReductionIterator))
|
||||||
return failure();
|
return failure();
|
||||||
|
// If the compressed maps have a dimension that is not used by either LHS or
|
||||||
|
// RHS then the ContractionOp verifier would fail.
|
||||||
|
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
|
||||||
|
return failure();
|
||||||
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
|
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
|
||||||
contractOp, lhs, rhs, contractOp.getAcc(),
|
contractOp, lhs, rhs, contractOp.getAcc(),
|
||||||
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
|
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
|
||||||
|
|
|
@ -560,12 +560,7 @@ AffineMap mlir::compressDims(AffineMap map,
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineMap mlir::compressUnusedDims(AffineMap map) {
|
AffineMap mlir::compressUnusedDims(AffineMap map) {
|
||||||
llvm::SmallBitVector unusedDims(map.getNumDims(), true);
|
return compressDims(map, getUnusedDimsBitVector({map}));
|
||||||
map.walkExprs([&](AffineExpr expr) {
|
|
||||||
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
|
|
||||||
unusedDims.reset(dimExpr.getPosition());
|
|
||||||
});
|
|
||||||
return compressDims(map, unusedDims);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<AffineMap>
|
static SmallVector<AffineMap>
|
||||||
|
@ -722,6 +717,18 @@ AffineMap mlir::getProjectedMap(AffineMap map,
|
||||||
return compressUnusedSymbols(compressDims(map, unusedDims));
|
return compressUnusedSymbols(compressDims(map, unusedDims));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
|
||||||
|
unsigned numDims = maps[0].getNumDims();
|
||||||
|
llvm::SmallBitVector numDimsBitVector(numDims, true);
|
||||||
|
for (const auto &m : maps) {
|
||||||
|
for (unsigned i = 0; i < numDims; ++i) {
|
||||||
|
if (m.isFunctionOfDim(i))
|
||||||
|
numDimsBitVector.reset(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return numDimsBitVector;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MutableAffineMap.
|
// MutableAffineMap.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -875,7 +875,7 @@ func.func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: ve
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
|
func.func @contract_with_dim_unused_by_lhs_and_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
|
||||||
// expected-error@+1 {{'vector.contract' op expected all input dimensions to be used by either the LHS or the RHS}}
|
// expected-error@+1 {{'vector.contract' op expected all dimensions to be either a LHS or a RHS dimension}}
|
||||||
%result = vector.contract {
|
%result = vector.contract {
|
||||||
indexing_maps = [
|
indexing_maps = [
|
||||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||||
|
|
|
@ -159,6 +159,10 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
|
||||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||||
|
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||||
|
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||||
|
|
||||||
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
|
// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
|
||||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
|
// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
|
||||||
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
|
// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
|
||||||
|
@ -178,6 +182,37 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
|
||||||
return %result : vector<8x8xi32>
|
return %result : vector<8x8xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test that CombineContractBroadcast is not combining this case, as that would
|
||||||
|
// result in a dimension being unused in the LHS and RHS maps, which is illegal.
|
||||||
|
|
||||||
|
#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||||
|
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
#map2 = affine_map<(d0, d1, d2) -> (d1)>
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||||
|
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||||
|
|
||||||
|
// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
|
||||||
|
// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
|
||||||
|
// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
|
||||||
|
// CHECK: vector.contract
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
|
||||||
|
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
|
||||||
|
// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
|
||||||
|
|
||||||
|
func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> {
|
||||||
|
%1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32>
|
||||||
|
%result = vector.contract {
|
||||||
|
indexing_maps = [#map0, #map1, #map2],
|
||||||
|
iterator_types = ["reduction", "parallel", "reduction"],
|
||||||
|
kind = #vector.kind<add>
|
||||||
|
} %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
|
||||||
|
return %result : vector<1xi32>
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Reorder casting ops and vector ops. The casting ops have almost identical
|
// Reorder casting ops and vector ops. The casting ops have almost identical
|
||||||
// pattern, so only arith.extsi op is tested.
|
// pattern, so only arith.extsi op is tested.
|
||||||
|
|
Loading…
Reference in New Issue