forked from OSchip/llvm-project
[MLIR] Matrix: support matrix-vector multiplication
This just moves in the implementation from LinearTransform. Reviewed By: Groverkss, bondhugula Differential Revision: https://reviews.llvm.org/D118479
This commit is contained in:
parent
0c3d22a592
commit
255494144f
|
@ -39,12 +39,16 @@ public:
|
|||
|
||||
// The given vector is interpreted as a row vector v. Post-multiply v with
|
||||
// this transform, say T, and return vT.
|
||||
SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const;
|
||||
SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
|
||||
return matrix.preMultiplyWithRow(rowVec);
|
||||
}
|
||||
|
||||
// The given vector is interpreted as a column vector v. Pre-multiply v with
|
||||
// this transform, say T, and return Tv.
|
||||
SmallVector<int64_t, 8>
|
||||
postMultiplyWithColumn(ArrayRef<int64_t> colVec) const;
|
||||
postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
|
||||
return matrix.postMultiplyWithColumn(colVec);
|
||||
}
|
||||
|
||||
private:
|
||||
Matrix matrix;
|
||||
|
|
|
@ -117,6 +117,15 @@ public:
|
|||
/// Negate the specified column.
|
||||
void negateColumn(unsigned column);
|
||||
|
||||
/// The given vector is interpreted as a row vector v. Post-multiply v with
|
||||
/// this matrix, say M, and return vM.
|
||||
SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const;
|
||||
|
||||
/// The given vector is interpreted as a column vector v. Pre-multiply v with
|
||||
/// this matrix, say M, and return Mv.
|
||||
SmallVector<int64_t, 8>
|
||||
postMultiplyWithColumn(ArrayRef<int64_t> colVec) const;
|
||||
|
||||
/// Resize the matrix to the specified dimensions. If a dimension is smaller,
|
||||
/// the values are truncated; if it is bigger, the new values are initialized
|
||||
/// to zero.
|
||||
|
|
|
@ -111,30 +111,6 @@ LinearTransform::makeTransformToColumnEchelon(Matrix m) {
|
|||
return {echelonCol, LinearTransform(std::move(resultMatrix))};
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8>
|
||||
LinearTransform::preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
|
||||
assert(rowVec.size() == matrix.getNumRows() &&
|
||||
"row vector dimension should match transform output dimension");
|
||||
|
||||
SmallVector<int64_t, 8> result(matrix.getNumColumns(), 0);
|
||||
for (unsigned col = 0, e = matrix.getNumColumns(); col < e; ++col)
|
||||
for (unsigned i = 0, e = matrix.getNumRows(); i < e; ++i)
|
||||
result[col] += rowVec[i] * matrix(i, col);
|
||||
return result;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8>
|
||||
LinearTransform::postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
|
||||
assert(matrix.getNumColumns() == colVec.size() &&
|
||||
"column vector dimension should match transform input dimension");
|
||||
|
||||
SmallVector<int64_t, 8> result(matrix.getNumRows(), 0);
|
||||
for (unsigned row = 0, e = matrix.getNumRows(); row < e; row++)
|
||||
for (unsigned i = 0, e = matrix.getNumColumns(); i < e; i++)
|
||||
result[row] += matrix(row, i) * colVec[i];
|
||||
return result;
|
||||
}
|
||||
|
||||
IntegerPolyhedron
|
||||
LinearTransform::applyTo(const IntegerPolyhedron &poly) const {
|
||||
IntegerPolyhedron result(poly.getNumIds());
|
||||
|
|
|
@ -203,6 +203,29 @@ void Matrix::negateColumn(unsigned column) {
|
|||
at(row, column) = -at(row, column);
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8>
|
||||
Matrix::preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
|
||||
assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!");
|
||||
|
||||
SmallVector<int64_t, 8> result(getNumColumns(), 0);
|
||||
for (unsigned col = 0, e = getNumColumns(); col < e; ++col)
|
||||
for (unsigned i = 0, e = getNumRows(); i < e; ++i)
|
||||
result[col] += rowVec[i] * at(i, col);
|
||||
return result;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 8>
|
||||
Matrix::postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
|
||||
assert(getNumColumns() == colVec.size() &&
|
||||
"Invalid column vector dimension!");
|
||||
|
||||
SmallVector<int64_t, 8> result(getNumRows(), 0);
|
||||
for (unsigned row = 0, e = getNumRows(); row < e; row++)
|
||||
for (unsigned i = 0, e = getNumColumns(); i < e; i++)
|
||||
result[row] += at(row, i) * colVec[i];
|
||||
return result;
|
||||
}
|
||||
|
||||
void Matrix::print(raw_ostream &os) const {
|
||||
for (unsigned row = 0; row < nRows; ++row) {
|
||||
for (unsigned column = 0; column < nColumns; ++column)
|
||||
|
|
Loading…
Reference in New Issue