forked from OSchip/llvm-project
[mlir] [VectorOps] Replace zero-scalar + splat into direct zero vector constant
Summary: The scalar zero + splat yields more intermediate code than the direct dense zero constant, and ultimately is lowered to exactly the same LLVM IR operations, so no point wasting the intermediate code. Reviewers: nicolasvasilache, andydavis1, reidtatge Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D79758
This commit is contained in:
parent
2b8b783b1a
commit
40f56c8cf1
|
@ -929,10 +929,10 @@ public:
|
||||||
/// One:
|
/// One:
|
||||||
/// %x = vector.insert_slices %0
|
/// %x = vector.insert_slices %0
|
||||||
/// is replaced by:
|
/// is replaced by:
|
||||||
/// %r0 = vector.splat 0
|
/// %r0 = zero-result
|
||||||
// %t1 = vector.tuple_get %0, 0
|
/// %t1 = vector.tuple_get %0, 0
|
||||||
/// %r1 = vector.insert_strided_slice %r0, %t1
|
/// %r1 = vector.insert_strided_slice %r0, %t1
|
||||||
// %t2 = vector.tuple_get %0, 1
|
/// %t2 = vector.tuple_get %0, 1
|
||||||
/// %r2 = vector.insert_strided_slice %r1, %t2
|
/// %r2 = vector.insert_strided_slice %r1, %t2
|
||||||
/// ..
|
/// ..
|
||||||
/// %x = ..
|
/// %x = ..
|
||||||
|
@ -953,10 +953,8 @@ public:
|
||||||
op.getStrides(strides); // all-ones at the moment
|
op.getStrides(strides); // all-ones at the moment
|
||||||
|
|
||||||
// Prepare result.
|
// Prepare result.
|
||||||
auto elemType = vectorType.getElementType();
|
Value result = rewriter.create<ConstantOp>(
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, elemType,
|
loc, vectorType, rewriter.getZeroAttr(vectorType));
|
||||||
rewriter.getZeroAttr(elemType));
|
|
||||||
Value result = rewriter.create<SplatOp>(loc, vectorType, zero);
|
|
||||||
|
|
||||||
// For each element in the tuple, extract the proper strided slice.
|
// For each element in the tuple, extract the proper strided slice.
|
||||||
TupleType tupleType = op.getSourceTupleType();
|
TupleType tupleType = op.getSourceTupleType();
|
||||||
|
@ -1015,9 +1013,8 @@ public:
|
||||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||||
Value bcst =
|
Value bcst =
|
||||||
rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
|
rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
Value result = rewriter.create<ConstantOp>(loc, dstType,
|
||||||
rewriter.getZeroAttr(eltType));
|
rewriter.getZeroAttr(dstType));
|
||||||
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
|
|
||||||
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
|
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
|
||||||
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
||||||
rewriter.replaceOp(op, result);
|
rewriter.replaceOp(op, result);
|
||||||
|
@ -1064,9 +1061,8 @@ public:
|
||||||
// %x = [%a,%b,%c,%d]
|
// %x = [%a,%b,%c,%d]
|
||||||
VectorType resType =
|
VectorType resType =
|
||||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
Value result = rewriter.create<ConstantOp>(loc, dstType,
|
||||||
rewriter.getZeroAttr(eltType));
|
rewriter.getZeroAttr(dstType));
|
||||||
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
|
|
||||||
if (m == 0) {
|
if (m == 0) {
|
||||||
// Stetch at start.
|
// Stetch at start.
|
||||||
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
|
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
|
||||||
|
@ -1104,7 +1100,6 @@ public:
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
VectorType resType = op.getResultType();
|
VectorType resType = op.getResultType();
|
||||||
Type eltType = resType.getElementType();
|
|
||||||
|
|
||||||
// Set up convenience transposition table.
|
// Set up convenience transposition table.
|
||||||
SmallVector<int64_t, 4> transp;
|
SmallVector<int64_t, 4> transp;
|
||||||
|
@ -1112,9 +1107,8 @@ public:
|
||||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||||
|
|
||||||
// Generate fully unrolled extract/insert ops.
|
// Generate fully unrolled extract/insert ops.
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
Value result = rewriter.create<ConstantOp>(loc, resType,
|
||||||
rewriter.getZeroAttr(eltType));
|
rewriter.getZeroAttr(resType));
|
||||||
Value result = rewriter.create<SplatOp>(loc, resType, zero);
|
|
||||||
SmallVector<int64_t, 4> lhs(transp.size(), 0);
|
SmallVector<int64_t, 4> lhs(transp.size(), 0);
|
||||||
SmallVector<int64_t, 4> rhs(transp.size(), 0);
|
SmallVector<int64_t, 4> rhs(transp.size(), 0);
|
||||||
rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
|
rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
|
||||||
|
@ -1173,9 +1167,8 @@ public:
|
||||||
Type eltType = resType.getElementType();
|
Type eltType = resType.getElementType();
|
||||||
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
|
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
|
||||||
|
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
Value result = rewriter.create<ConstantOp>(loc, resType,
|
||||||
rewriter.getZeroAttr(eltType));
|
rewriter.getZeroAttr(resType));
|
||||||
Value result = rewriter.create<SplatOp>(loc, resType, zero);
|
|
||||||
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
|
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
|
||||||
auto pos = rewriter.getI64ArrayAttr(d);
|
auto pos = rewriter.getI64ArrayAttr(d);
|
||||||
Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
|
Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
|
||||||
|
@ -1346,7 +1339,8 @@ private:
|
||||||
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
|
||||||
// Unroll into a series of lower dimensional vector.contract ops.
|
// Unroll into a series of lower dimensional vector.contract ops.
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value result = zeroVector(loc, resType, rewriter);
|
Value result = rewriter.create<ConstantOp>(loc, resType,
|
||||||
|
rewriter.getZeroAttr(resType));
|
||||||
for (int64_t d = 0; d < dimSize; ++d) {
|
for (int64_t d = 0; d < dimSize; ++d) {
|
||||||
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
|
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
|
||||||
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
|
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
|
||||||
|
@ -1381,7 +1375,8 @@ private:
|
||||||
// Base case.
|
// Base case.
|
||||||
if (lhsType.getRank() == 1) {
|
if (lhsType.getRank() == 1) {
|
||||||
assert(rhsType.getRank() == 1 && "corrupt contraction");
|
assert(rhsType.getRank() == 1 && "corrupt contraction");
|
||||||
Value zero = zeroVector(loc, lhsType, rewriter);
|
Value zero = rewriter.create<ConstantOp>(loc, lhsType,
|
||||||
|
rewriter.getZeroAttr(lhsType));
|
||||||
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
|
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
|
||||||
StringAttr kind = rewriter.getStringAttr("add");
|
StringAttr kind = rewriter.getStringAttr("add");
|
||||||
return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
|
return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
|
||||||
|
@ -1409,15 +1404,6 @@ private:
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper method to construct a zero vector.
|
|
||||||
static Value zeroVector(Location loc, VectorType vType,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
Type eltType = vType.getElementType();
|
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
|
||||||
rewriter.getZeroAttr(eltType));
|
|
||||||
return rewriter.create<SplatOp>(loc, vType, zero);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to find an index in an affine map.
|
// Helper to find an index in an affine map.
|
||||||
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
|
||||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||||
|
@ -1493,7 +1479,8 @@ private:
|
||||||
// Unroll leading dimensions.
|
// Unroll leading dimensions.
|
||||||
VectorType vType = lowType.cast<VectorType>();
|
VectorType vType = lowType.cast<VectorType>();
|
||||||
VectorType resType = adjustType(type, index).cast<VectorType>();
|
VectorType resType = adjustType(type, index).cast<VectorType>();
|
||||||
Value result = zeroVector(loc, resType, rewriter);
|
Value result = rewriter.create<ConstantOp>(loc, resType,
|
||||||
|
rewriter.getZeroAttr(resType));
|
||||||
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
|
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
|
||||||
auto posAttr = rewriter.getI64ArrayAttr(d);
|
auto posAttr = rewriter.getI64ArrayAttr(d);
|
||||||
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
|
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
|
||||||
|
@ -1555,10 +1542,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto elemType = sourceVectorType.getElementType();
|
Value desc = rewriter.create<ConstantOp>(
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, elemType,
|
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
||||||
rewriter.getZeroAttr(elemType));
|
|
||||||
Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
|
|
||||||
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
|
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
|
||||||
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
|
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
|
||||||
Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
|
Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
|
||||||
|
@ -1589,10 +1574,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto elemType = sourceVectorType.getElementType();
|
Value desc = rewriter.create<ConstantOp>(
|
||||||
Value zero = rewriter.create<ConstantOp>(loc, elemType,
|
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
||||||
rewriter.getZeroAttr(elemType));
|
|
||||||
Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
|
|
||||||
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
|
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
|
||||||
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
|
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
|
||||||
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
|
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||||
|
|
Loading…
Reference in New Issue