[mlir] [VectorOps] Replace zero-scalar + splat into direct zero vector constant

Summary:
The scalar zero + splat yields more intermediate code than the direct
dense zero constant, and ultimately is lowered to exactly the same
LLVM IR operations, so no point wasting the intermediate code.

Reviewers: nicolasvasilache, andydavis1, reidtatge

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79758
This commit is contained in:
aartbik 2020-05-11 18:22:59 -07:00
parent 2b8b783b1a
commit 40f56c8cf1
1 changed files with 23 additions and 40 deletions

View File

@ -929,10 +929,10 @@ public:
/// One: /// One:
/// %x = vector.insert_slices %0 /// %x = vector.insert_slices %0
/// is replaced by: /// is replaced by:
/// %r0 = vector.splat 0 /// %r0 = zero-result
// %t1 = vector.tuple_get %0, 0 /// %t1 = vector.tuple_get %0, 0
/// %r1 = vector.insert_strided_slice %r0, %t1 /// %r1 = vector.insert_strided_slice %r0, %t1
// %t2 = vector.tuple_get %0, 1 /// %t2 = vector.tuple_get %0, 1
/// %r2 = vector.insert_strided_slice %r1, %t2 /// %r2 = vector.insert_strided_slice %r1, %t2
/// .. /// ..
/// %x = .. /// %x = ..
@ -953,10 +953,8 @@ public:
op.getStrides(strides); // all-ones at the moment op.getStrides(strides); // all-ones at the moment
// Prepare result. // Prepare result.
auto elemType = vectorType.getElementType(); Value result = rewriter.create<ConstantOp>(
Value zero = rewriter.create<ConstantOp>(loc, elemType, loc, vectorType, rewriter.getZeroAttr(vectorType));
rewriter.getZeroAttr(elemType));
Value result = rewriter.create<SplatOp>(loc, vectorType, zero);
// For each element in the tuple, extract the proper strided slice. // For each element in the tuple, extract the proper strided slice.
TupleType tupleType = op.getSourceTupleType(); TupleType tupleType = op.getSourceTupleType();
@ -1015,9 +1013,8 @@ public:
VectorType::get(dstType.getShape().drop_front(), eltType); VectorType::get(dstType.getShape().drop_front(), eltType);
Value bcst = Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
Value zero = rewriter.create<ConstantOp>(loc, eltType, Value result = rewriter.create<ConstantOp>(loc, dstType,
rewriter.getZeroAttr(eltType)); rewriter.getZeroAttr(dstType));
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
@ -1064,9 +1061,8 @@ public:
// %x = [%a,%b,%c,%d] // %x = [%a,%b,%c,%d]
VectorType resType = VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType); VectorType::get(dstType.getShape().drop_front(), eltType);
Value zero = rewriter.create<ConstantOp>(loc, eltType, Value result = rewriter.create<ConstantOp>(loc, dstType,
rewriter.getZeroAttr(eltType)); rewriter.getZeroAttr(dstType));
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
if (m == 0) { if (m == 0) {
// Stetch at start. // Stetch at start.
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
@ -1104,7 +1100,6 @@ public:
auto loc = op.getLoc(); auto loc = op.getLoc();
VectorType resType = op.getResultType(); VectorType resType = op.getResultType();
Type eltType = resType.getElementType();
// Set up convenience transposition table. // Set up convenience transposition table.
SmallVector<int64_t, 4> transp; SmallVector<int64_t, 4> transp;
@ -1112,9 +1107,8 @@ public:
transp.push_back(attr.cast<IntegerAttr>().getInt()); transp.push_back(attr.cast<IntegerAttr>().getInt());
// Generate fully unrolled extract/insert ops. // Generate fully unrolled extract/insert ops.
Value zero = rewriter.create<ConstantOp>(loc, eltType, Value result = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(eltType)); rewriter.getZeroAttr(resType));
Value result = rewriter.create<SplatOp>(loc, resType, zero);
SmallVector<int64_t, 4> lhs(transp.size(), 0); SmallVector<int64_t, 4> lhs(transp.size(), 0);
SmallVector<int64_t, 4> rhs(transp.size(), 0); SmallVector<int64_t, 4> rhs(transp.size(), 0);
rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs, rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
@ -1173,9 +1167,8 @@ public:
Type eltType = resType.getElementType(); Type eltType = resType.getElementType();
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
Value zero = rewriter.create<ConstantOp>(loc, eltType, Value result = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(eltType)); rewriter.getZeroAttr(resType));
Value result = rewriter.create<SplatOp>(loc, resType, zero);
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d); auto pos = rewriter.getI64ArrayAttr(d);
Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos);
@ -1346,7 +1339,8 @@ private:
rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));
// Unroll into a series of lower dimensional vector.contract ops. // Unroll into a series of lower dimensional vector.contract ops.
Location loc = op.getLoc(); Location loc = op.getLoc();
Value result = zeroVector(loc, resType, rewriter); Value result = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(resType));
for (int64_t d = 0; d < dimSize; ++d) { for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter);
@ -1381,7 +1375,8 @@ private:
// Base case. // Base case.
if (lhsType.getRank() == 1) { if (lhsType.getRank() == 1) {
assert(rhsType.getRank() == 1 && "corrupt contraction"); assert(rhsType.getRank() == 1 && "corrupt contraction");
Value zero = zeroVector(loc, lhsType, rewriter); Value zero = rewriter.create<ConstantOp>(loc, lhsType,
rewriter.getZeroAttr(lhsType));
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero); Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
StringAttr kind = rewriter.getStringAttr("add"); StringAttr kind = rewriter.getStringAttr("add");
return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma, return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
@ -1409,15 +1404,6 @@ private:
return result; return result;
} }
// Helper method to construct a zero vector.
static Value zeroVector(Location loc, VectorType vType,
PatternRewriter &rewriter) {
Type eltType = vType.getElementType();
Value zero = rewriter.create<ConstantOp>(loc, eltType,
rewriter.getZeroAttr(eltType));
return rewriter.create<SplatOp>(loc, vType, zero);
}
// Helper to find an index in an affine map. // Helper to find an index in an affine map.
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@ -1493,7 +1479,8 @@ private:
// Unroll leading dimensions. // Unroll leading dimensions.
VectorType vType = lowType.cast<VectorType>(); VectorType vType = lowType.cast<VectorType>();
VectorType resType = adjustType(type, index).cast<VectorType>(); VectorType resType = adjustType(type, index).cast<VectorType>();
Value result = zeroVector(loc, resType, rewriter); Value result = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d); auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
@ -1555,10 +1542,8 @@ public:
return failure(); return failure();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto elemType = sourceVectorType.getElementType(); Value desc = rewriter.create<ConstantOp>(
Value zero = rewriter.create<ConstantOp>(loc, elemType, loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
rewriter.getZeroAttr(elemType));
Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
@ -1589,10 +1574,8 @@ public:
return failure(); return failure();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto elemType = sourceVectorType.getElementType(); Value desc = rewriter.create<ConstantOp>(
Value zero = rewriter.create<ConstantOp>(loc, elemType, loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
rewriter.getZeroAttr(elemType));
Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractStridedSliceOp>( Value vec = rewriter.create<vector::ExtractStridedSliceOp>(